1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for TPUReplicatedVariable.""" 16from __future__ import division 17from __future__ import print_function 18 19from absl.testing import parameterized 20import numpy as np 21 22from tensorflow.python.distribute import tpu_replicated_variable 23from tensorflow.python.eager import test 24from tensorflow.python.framework import combinations 25from tensorflow.python.framework import dtypes 26from tensorflow.python.ops import variables as variables_lib 27 28 29class TPUReplicatedVariableTest(test.TestCase, parameterized.TestCase): 30 31 @combinations.generate(combinations.combine(mode=['graph', 'eager'])) 32 def test_tpu_replicated_variable_simple(self): 33 v0 = variables_lib.Variable([0], name='v0') 34 v1 = variables_lib.Variable([0], name='v1') 35 r = tpu_replicated_variable.TPUReplicatedVariable([v0, v1]) 36 self.evaluate(variables_lib.global_variables_initializer()) 37 self.assertEqual(r.variables[0], v0) 38 self.assertEqual(r.variables[1], v1) 39 self.assertEqual(r.shape.as_list(), [1]) 40 self.assertEqual(r.dtype, v0.dtype) 41 self.check_replicated_variables_all_the_same(r) 42 43 @combinations.generate(combinations.combine(mode=['graph', 'eager'])) 44 def test_tpu_replicated_variable_update(self): 45 batch_size = 32 46 num_feature_in = 16 47 48 x = np.random.rand(batch_size, num_feature_in).astype(np.float32) 49 w_init = np.random.rand(batch_size, num_feature_in).astype(np.float32) 50 51 w0 = variables_lib.Variable(w_init, dtype=dtypes.float32, name='w0') 52 w1 = variables_lib.Variable(w_init, dtype=dtypes.float32, name='w1') 53 self.evaluate(variables_lib.global_variables_initializer()) 54 w = tpu_replicated_variable.TPUReplicatedVariable([w0, w1]) 55 56 # Make a copy of x so that `w` and `x` do not share the same buffer. 57 # See b/195972684. 58 self.evaluate(w.assign(x.copy())) 59 result = self.evaluate(w.read_value()) 60 self.assertAllClose(result, x) 61 self.check_replicated_variables_all_the_same(w) 62 63 x1 = np.random.rand(batch_size, num_feature_in).astype(np.float32) 64 self.evaluate(w.assign_sub(x1)) 65 result = self.evaluate(w.read_value()) 66 self.assertAllClose(result, np.subtract(x, x1)) 67 self.check_replicated_variables_all_the_same(w) 68 69 x2 = np.random.rand(batch_size, num_feature_in).astype(np.float32) 70 self.evaluate(w.assign(x.copy())) 71 self.evaluate(w.assign_add(x2)) 72 result = self.evaluate(w.read_value()) 73 self.assertAllClose(result, np.add(x, x2)) 74 self.check_replicated_variables_all_the_same(w) 75 76 def check_replicated_variables_all_the_same(self, rv): 77 for v in rv.variables: 78 self.assertAllEqual( 79 self.evaluate(rv.variables[0].read_value()), 80 self.evaluate(v)) 81 82 83if __name__ == '__main__': 84 test.main() 85