xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/mirrored_values_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for the mirrored values library."""
16
17import os
18
19from absl.testing import parameterized
20
21from tensorflow.core.protobuf import config_pb2
22from tensorflow.python.distribute import combinations
23from tensorflow.python.distribute import strategy_combinations
24from tensorflow.python.distribute import strategy_test_lib
25from tensorflow.python.distribute import test_util as ds_test_util
26from tensorflow.python.distribute import tpu_values
27from tensorflow.python.distribute import values as values_lib
28from tensorflow.python.eager import context
29from tensorflow.python.eager import test
30from tensorflow.python.framework import constant_op
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import test_util
33from tensorflow.python.ops import math_ops
34from tensorflow.python.ops import variable_scope
35from tensorflow.python.training import saver as saver_lib
36
37
38def _make_mirrored(distribution=None):
39  v = []
40  if distribution:
41    devices = distribution.extended.worker_devices
42  else:
43    devices = ["/device:GPU:0", "/device:CPU:0"]
44  for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]):
45    with ops.device(d):
46      v.append(
47          variable_scope.get_variable(
48              name=n, initializer=init, use_resource=True))
49
50  if (distribution
51      is not None) and strategy_test_lib.is_tpu_strategy(distribution):
52    var_cls = tpu_values.TPUMirroredVariable
53  else:
54    var_cls = values_lib.MirroredVariable
55  mirrored = var_cls(distribution, v, variable_scope.VariableAggregation.SUM)
56  return mirrored
57
58
59def _make_mirrored_val(init_val=5.0):
60  v = []
61  devices = ["/device:GPU:0", "/device:CPU:0"]
62  for d, _ in zip(devices, ["v", "v/replica"]):
63    with ops.device(d):
64      v.append(constant_op.constant(init_val))
65  return values_lib.Mirrored(v)
66
67
68def mirrored_and_tpu_strategy_combinations():
69  return combinations.combine(
70      distribution=[
71          strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
72          strategy_combinations.mirrored_strategy_with_two_gpus_no_merge_call,
73          strategy_combinations.tpu_strategy,
74          strategy_combinations.tpu_strategy_packed_var,
75      ],
76      mode=["graph", "eager"])
77
78
79class MirroredVariableTest(test.TestCase, parameterized.TestCase):
80
81  config = config_pb2.ConfigProto()
82  config.allow_soft_placement = True
83
84  @test_util.run_in_graph_and_eager_modes(config=config)
85  def testProperties(self):
86    if context.num_gpus() < 1 and context.executing_eagerly():
87      self.skipTest("A GPU is not available for this test in eager mode.")
88
89    mirrored = _make_mirrored()
90    v = mirrored.values[0]
91    self.assertEqual(v.name, mirrored.name)
92    self.assertEqual(v.dtype, mirrored.dtype)
93    self.assertEqual(v.shape, mirrored.shape)
94
95  @test_util.run_in_graph_and_eager_modes(config=config)
96  def testVariableOnAnotherDevice(self):
97    v = variable_scope.get_variable(
98        name="v", initializer=[1.], use_resource=True)
99    mirrored = values_lib.MirroredVariable(
100        None, (v,), variable_scope.VariableAggregation.MEAN)
101
102    self.assertEqual(v.name, mirrored.name)
103    self.assertEqual(v.dtype, mirrored.dtype)
104    self.assertEqual(v.shape, mirrored.shape)
105
106
107class MirroredVariableSaveRestoreTest(test.TestCase, parameterized.TestCase):
108
109  def _assign_mirrored(self, v, new):
110    for var, n in zip(v.values, new):
111      self.evaluate(var.assign(n))
112
113  def _save_return_saver(self, sess, var):
114    saver = saver_lib.Saver(var_list=[var])
115    test_dir = self.get_temp_dir()
116    prefix = os.path.join(test_dir, "ckpt")
117    return saver.save(sess, prefix), saver
118
119  def _save(self, sess, var):
120    save_path, _ = self._save_return_saver(sess, var)
121    return save_path
122
123  def _save_mirrored(self, distribution):
124    """Save variables with mirroring, returns save_path."""
125    with self.session(graph=ops.Graph()) as sess:
126      mirrored = _make_mirrored(distribution)
127
128      # Overwrite the initial values.
129      self._assign_mirrored(mirrored, [3., 4.])
130
131      # Saves the current value of v[0], 3.
132      save_path = self._save(sess, mirrored)
133
134      # Change the values between save and restore.
135      self._assign_mirrored(mirrored, [5., 6.])
136    return save_path
137
138  def _save_normal(self):
139    """Save variables without mirroring, returns save_path."""
140    with self.session(graph=ops.Graph()) as sess:
141      var = variable_scope.get_variable(
142          name="v", initializer=1., use_resource=True)
143
144      # Overwrite the initial value.
145      self.evaluate(var.assign(3.))
146
147      # Saves the current value of var, 3.
148      save_path = self._save(sess, var)
149
150      # Change the values between save and restore.
151      self.evaluate(var.assign(5.))
152    return save_path
153
154  def _restore_normal(self, save_path):
155    """Restore to variables without mirroring in a fresh graph."""
156    with self.session(graph=ops.Graph()) as sess:
157      var = variable_scope.get_variable(
158          name="v", initializer=7., use_resource=True)
159
160      # Overwrite the initial value.
161      self.evaluate(var.assign(8.))
162
163      # Restores the saved value of 3. to `var`.
164      saver = saver_lib.Saver(var_list=[var])
165      saver.restore(sess, save_path)
166      self.assertEqual(3., self.evaluate(var))
167
168  def _restore_mirrored(self, save_path, distribution):
169    """Restore to variables with mirroring in a fresh graph."""
170    with self.session(graph=ops.Graph()) as sess:
171      mirrored = _make_mirrored(distribution)
172      v = mirrored.values
173
174      # Overwrite the initial values.
175      self._assign_mirrored(mirrored, [7., 8.])
176
177      # Restores the saved value of 3. to both variables.
178      saver = saver_lib.Saver(var_list=[mirrored])
179      saver.restore(sess, save_path)
180      self.assertEqual([3., 3.], self.evaluate([v[0], v[1]]))
181
182  @combinations.generate(mirrored_and_tpu_strategy_combinations())
183  def testSaveAndRestoreMirroredOneGraph(self, distribution):
184    with self.cached_session() as sess:
185      mirrored = _make_mirrored(distribution)
186      v = mirrored  .values
187
188      # Overwrite the initial values.
189      self._assign_mirrored(mirrored, [3., 4.])
190
191      # Saves the current value of v[0], 3.
192      save_path, saver = self._save_return_saver(sess, mirrored)
193
194      # Change the values between save and restore.
195      self._assign_mirrored(mirrored, [5., 6.])
196
197      # Restores the saved value of 3. to both variables.
198      saver.restore(sess, save_path)
199      self.assertEqual([3., 3.], self.evaluate([v[0], v[1]]))
200
201  @combinations.generate(mirrored_and_tpu_strategy_combinations())
202  def testSaveMirroredRestoreMirrored(self, distribution):
203    if context.num_gpus() < 1 and context.executing_eagerly():
204      # Graph mode can work without GPU because the Placer "moves" the
205      # variable to a CPU. In other words, if there is no GPU available, but
206      # user requested to create a variable on GPU, Placer will ignore the
207      # user request and assign the VarHandleOp to CPU. This requires
208      # soft_placement, which is on by default.
209      self.skipTest("A GPU is not available for this test in eager mode.")
210
211    save_path = self._save_mirrored(distribution)
212    self._restore_mirrored(save_path, distribution)
213
214  @combinations.generate(mirrored_and_tpu_strategy_combinations())
215  def testSaveMirroredRestoreNormal(self, distribution):
216    if context.num_gpus() < 1 and context.executing_eagerly():
217      # Graph mode can work without GPU because the Placer "moves" the
218      # variable to a CPU. In other words, if there is no GPU available, but
219      # user requested to create a variable on GPU, Placer will ignore the
220      # user request and assign the VarHandleOp to CPU. This requires
221      # soft_placement, which is on by default.
222      self.skipTest("A GPU is not available for this test in eager mode.")
223
224    save_path = self._save_mirrored(distribution)
225    self._restore_normal(save_path)
226
227  @combinations.generate(mirrored_and_tpu_strategy_combinations())
228  def testSaveNormalRestoreMirrored(self, distribution):
229    if context.num_gpus() < 1 and context.executing_eagerly():
230      # Graph mode can work without GPU because the Placer "moves" the
231      # variable to a CPU. In other words, if there is no GPU available, but
232      # user requested to create a variable on GPU, Placer will ignore the
233      # user request and assign the VarHandleOp to CPU. This requires
234      # soft_placement, which is on by default.
235      self.skipTest("A GPU is not available for this test in eager mode.")
236
237    save_path = self._save_normal()
238    self._restore_mirrored(save_path, distribution)
239
240
241class MirroredTest(test.TestCase):
242
243  def testAddOp(self):
244    if context.num_gpus() < 1:
245      self.skipTest("A GPU is not available for this test.")
246    mirrored_val = _make_mirrored_val(init_val=3.)
247
248    self.assertEqual(self.evaluate(constant_op.constant(6.)),
249                     self.evaluate(mirrored_val + mirrored_val))
250    self.assertEqual(self.evaluate(constant_op.constant(4.)),
251                     self.evaluate(mirrored_val + 1))
252    self.assertEqual(self.evaluate(mirrored_val + 1),
253                     self.evaluate(math_ops.add(mirrored_val, 1)))
254    self.assertEqual(type(mirrored_val + 1),
255                     type(math_ops.add(mirrored_val, 1)))
256
257
258if __name__ == "__main__":
259  ds_test_util.main()
260