xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/tpu_replicated_variable_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for TPUReplicatedVariable."""
16from __future__ import division
17from __future__ import print_function
18
19from absl.testing import parameterized
20import numpy as np
21
22from tensorflow.python.distribute import tpu_replicated_variable
23from tensorflow.python.eager import test
24from tensorflow.python.framework import combinations
25from tensorflow.python.framework import dtypes
26from tensorflow.python.ops import variables as variables_lib
27
28
29class TPUReplicatedVariableTest(test.TestCase, parameterized.TestCase):
30
31  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
32  def test_tpu_replicated_variable_simple(self):
33    v0 = variables_lib.Variable([0], name='v0')
34    v1 = variables_lib.Variable([0], name='v1')
35    r = tpu_replicated_variable.TPUReplicatedVariable([v0, v1])
36    self.evaluate(variables_lib.global_variables_initializer())
37    self.assertEqual(r.variables[0], v0)
38    self.assertEqual(r.variables[1], v1)
39    self.assertEqual(r.shape.as_list(), [1])
40    self.assertEqual(r.dtype, v0.dtype)
41    self.check_replicated_variables_all_the_same(r)
42
43  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
44  def test_tpu_replicated_variable_update(self):
45    batch_size = 32
46    num_feature_in = 16
47
48    x = np.random.rand(batch_size, num_feature_in).astype(np.float32)
49    w_init = np.random.rand(batch_size, num_feature_in).astype(np.float32)
50
51    w0 = variables_lib.Variable(w_init, dtype=dtypes.float32, name='w0')
52    w1 = variables_lib.Variable(w_init, dtype=dtypes.float32, name='w1')
53    self.evaluate(variables_lib.global_variables_initializer())
54    w = tpu_replicated_variable.TPUReplicatedVariable([w0, w1])
55
56    # Make a copy of x so that `w` and `x` do not share the same buffer.
57    # See b/195972684.
58    self.evaluate(w.assign(x.copy()))
59    result = self.evaluate(w.read_value())
60    self.assertAllClose(result, x)
61    self.check_replicated_variables_all_the_same(w)
62
63    x1 = np.random.rand(batch_size, num_feature_in).astype(np.float32)
64    self.evaluate(w.assign_sub(x1))
65    result = self.evaluate(w.read_value())
66    self.assertAllClose(result, np.subtract(x, x1))
67    self.check_replicated_variables_all_the_same(w)
68
69    x2 = np.random.rand(batch_size, num_feature_in).astype(np.float32)
70    self.evaluate(w.assign(x.copy()))
71    self.evaluate(w.assign_add(x2))
72    result = self.evaluate(w.read_value())
73    self.assertAllClose(result, np.add(x, x2))
74    self.check_replicated_variables_all_the_same(w)
75
76  def check_replicated_variables_all_the_same(self, rv):
77    for v in rv.variables:
78      self.assertAllEqual(
79          self.evaluate(rv.variables[0].read_value()),
80          self.evaluate(v))
81
82
83if __name__ == '__main__':
84  test.main()
85