xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/values_v2_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for the distributed values library."""
16
17from absl.testing import parameterized
18
19from tensorflow.python.distribute import combinations
20from tensorflow.python.distribute import strategy_combinations
21from tensorflow.python.distribute import test_util
22from tensorflow.python.distribute import values_v2
23from tensorflow.python.eager import def_function
24from tensorflow.python.eager import test
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import indexed_slices
27from tensorflow.python.framework import ops
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import resource_variable_ops
30from tensorflow.python.ops import variables as variables_lib
31
32
33class _VariableInterfaceTestBase(test.TestCase, parameterized.TestCase):
34  # This test verifies that DistributedVariable/AutoSyncVariable conforms to
35  # Variable and ResourceVariable interface, i.e. the methods and properties are
36  # all defined. It verifies methods and properties that have the same code path
37  # under different replicas/devices as well. It is not intended to verify
38  # methods and properties that behave differently under different
39  # replicas/devices; those should be covered separate tests.
40
41  def create_variable(self, initial_value=1., **kwargs):
42    raise NotImplementedError
43
44  @property
45  def devices(self):
46    return ["CPU:0", "CPU:1"]
47
48  # ==== Begin Variable interface ===
49  # Please follow the same order as methods and properties defined in
50  # tf.Variable.
51
52  def testStringify(self):
53    v = self.create_variable()
54    self.assertIsInstance(v.__str__(), str)
55    self.assertIsInstance(v.__repr__(), str)
56
57  def testDenseRead(self):
58    v = self.create_variable(1.)
59    self.assertEqual(v.value(), 1.)
60    self.assertEqual(v.read_value(), 1.)
61
62  def testShape(self):
63    v = self.create_variable([1.])
64    self.assertEqual(v.shape, (1,))
65    self.assertEqual(v.get_shape(), (1,))
66    v.set_shape((1,))
67    with self.assertRaisesRegex(ValueError, "not compatible"):
68      v.set_shape((1, 1))
69
70  @combinations.generate(combinations.combine(trainable=[True, False]))
71  def testTrainable(self, trainable):
72    v = self.create_variable(trainable=trainable)
73    self.assertEqual(v.trainable, trainable)
74
75  @combinations.generate(
76      combinations.combine(synchronization=[
77          variables_lib.VariableSynchronization.ON_READ,
78          variables_lib.VariableSynchronization.ON_WRITE,
79          variables_lib.VariableSynchronization.AUTO,
80          variables_lib.VariableSynchronization.NONE,
81      ]))
82  def testSynchronization(self, synchronization):
83    v = self.create_variable(synchronization=synchronization)
84    self.assertEqual(v.synchronization, synchronization)
85
86  @combinations.generate(
87      combinations.combine(aggregation=[
88          variables_lib.VariableAggregation.MEAN,
89          variables_lib.VariableAggregation.SUM,
90          variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
91          variables_lib.VariableAggregation.NONE,
92      ]))
93  def testAggregation(self, aggregation):
94    v = self.create_variable(aggregation=aggregation)
95    self.assertEqual(v.aggregation, aggregation)
96
97  @combinations.generate(combinations.combine(mode="graph"))
98  def testEval(self):
99    v = self.create_variable(1.)
100    with self.cached_session():
101      self.evaluate(variables_lib.global_variables_initializer())
102      self.assertEqual(v.eval(), 1.)
103
104  def testInitialValueEager(self):
105    v = self.create_variable(1.)
106    with self.assertRaises(RuntimeError):
107      v.initial_value  # pylint: disable=pointless-statement
108
109  @combinations.generate(combinations.combine(mode="graph"))
110  def testInitialValueGraph(self):
111    v = self.create_variable(1.)
112    self.assertEqual(self.evaluate(v.initial_value), 1.)
113
114  def testConstraint(self):
115    v = self.create_variable(constraint=lambda x: x + 1.)
116    self.assertEqual(v.constraint(1.), 2.)
117
118  def testDenseUpdate(self):
119    v = self.create_variable(1.)
120    self.assertEqual(
121        v.assign(2., use_locking=True, name="assign", read_value=True), 2.)
122    self.assertIsNone(v.assign(3., read_value=False))
123    self.assertEqual(v, 3.)
124    self.assertEqual(
125        v.assign_add(1., use_locking=True, name="assign_add", read_value=True),
126        4.)
127    self.assertIsNone(v.assign_add(1., read_value=False))
128    self.assertEqual(v, 5.)
129    self.assertEqual(
130        v.assign_sub(1., use_locking=True, name="assign_sub", read_value=True),
131        4.)
132    self.assertIsNone(v.assign_sub(1., read_value=False))
133    self.assertEqual(v, 3.)
134
135    @def_function.function
136    def f():
137      self.assertIsInstance(v.assign(1., read_value=False), ops.Operation)
138      self.assertIsInstance(v.assign_add(1., read_value=False), ops.Operation)
139      self.assertIsInstance(v.assign_sub(1., read_value=False), ops.Operation)
140
141    f()
142
143  def testSparseUpdate(self):
144    v = self.create_variable([0., 0., 0.])
145    self.assertAllEqual(
146        v.scatter_add(
147            _make_index_slices(values=[1., 2.], indices=[0, 2]),
148            use_locking=True,
149            name="add"), [1., 0., 2.])
150    self.assertAllEqual(
151        v.scatter_div(
152            _make_index_slices(values=[4., 2.], indices=[0, 2]),
153            use_locking=True,
154            name="div"), [0.25, 0., 1.])
155    self.assertAllEqual(
156        v.scatter_max(
157            _make_index_slices(values=[1., 0.5], indices=[1, 2]),
158            use_locking=True,
159            name="max"), [0.25, 1., 1.])
160    self.assertAllEqual(
161        v.scatter_min(
162            _make_index_slices(values=[1., 0.5], indices=[0, 1]),
163            use_locking=True,
164            name="min"), [0.25, 0.5, 1.])
165    self.assertAllEqual(
166        v.scatter_mul(
167            _make_index_slices(values=[2., 0.5], indices=[0, 1]),
168            use_locking=True,
169            name="mul"), [0.5, 0.25, 1.])
170    self.assertAllEqual(
171        v.scatter_sub(
172            _make_index_slices(values=[2., 0.5], indices=[0, 1]),
173            use_locking=True,
174            name="sub"), [-1.5, -0.25, 1.])
175    self.assertAllEqual(
176        v.scatter_update(
177            _make_index_slices(values=[2., 0.5], indices=[0, 1]),
178            use_locking=True,
179            name="update"), [2., 0.5, 1.])
180    self.assertAllEqual(
181        v.batch_scatter_update(
182            _make_index_slices(values=[1., 1.5], indices=[0, 1]),
183            use_locking=True,
184            name="update"), [1., 1.5, 1.])
185
186  def testSparseNdUpdate(self):
187    v = self.create_variable([0., 0., 0., 0.])
188    self.assertAllEqual(
189        v.scatter_nd_sub([[3], [1]], [1., 2.], name="sub"), [0., -2., 0., -1.])
190    self.assertAllEqual(
191        v.scatter_nd_add([[2], [0]], [1., 2.], name="add"), [2., -2., 1., -1.])
192    self.assertAllEqual(
193        v.scatter_nd_update([[1], [3]], [3., 3.], name="update"),
194        [2., 3., 1., 3.])
195
196  def testSparseRead(self):
197    v = self.create_variable([[1., 2.], [3., 4.]])
198    self.assertAllEqual(
199        v.sparse_read([1, 0], name="read"), [[3., 4.], [1., 2.]])
200    self.assertAllEqual(
201        v.gather_nd([[1, 0], [0, 1]], name="gather_nd"), [3., 2.])
202
203  def testTensorConversion(self):
204    v = self.create_variable([1.])
205    self.assertEqual(ops.convert_to_tensor(v), [1.])
206
207  def testHash(self):
208    v = self.create_variable()
209    w = self.create_variable()
210    d = {}
211    with self.assertRaises(TypeError):
212      d[v] = 1
213    d[v.ref()] = 1
214    self.assertEqual(d[v.ref()], 1)
215    self.assertNotIn(w.ref(), d)
216
217  @combinations.generate(combinations.combine(mode="graph"))
218  def testHashGraph(self):
219    v = self.create_variable()
220    w = self.create_variable()
221    d = {v: 1}
222    self.assertEqual(d[v], 1)
223    self.assertNotIn(w, d)
224
225  def testEquality(self):
226    v = self.create_variable(1.)
227    w = self.create_variable(2.)
228    x = self.create_variable(1.)
229    self.assertEqual(v, x)
230    self.assertNotEqual(v, w)
231
232  @combinations.generate(combinations.combine(mode="graph"))
233  def testEqualityGraph(self):
234    # In legacy graph mode, tensor equality is object equality
235    v = self.create_variable(1.)
236    w = self.create_variable(1.)
237    self.assertNotEqual(v, w)
238    self.assertEqual(v, v)
239
240  def testIteration(self):
241    v = self.create_variable([1.])
242    self.assertEqual([1.], list(iter(v)))
243
244  def testProperties(self):
245    v = self.create_variable()
246    self.assertIsInstance(v.name, str)
247    # _shared_name is also part of the interface. E.g. it's used in optimizer to
248    # determine slot variable key.
249    self.assertIsInstance(v._shared_name, str)
250    self.assertIsNone(v.initializer)
251    self.assertIsInstance(v.device, str)
252    self.assertEqual(v.dtype, dtypes.float32)
253    with self.assertRaises(AttributeError):
254      v.op  # pylint: disable=pointless-statement
255    with self.assertRaises(AttributeError):
256      v.graph  # pylint: disable=pointless-statement
257
258  @combinations.generate(combinations.combine(mode="graph"))
259  def testPropertiesGraph(self):
260    v = self.create_variable()
261    self.assertIsInstance(v.initializer, ops.Operation)
262    self.assertIsInstance(v.op, ops.Operation)
263    self.assertIsInstance(v.graph, ops.Graph)
264
265  def testProtoConversion(self):
266    # to_proto and from_proto are not supported.
267    v = self.create_variable([1, 2])
268    with self.assertRaises(TypeError):
269      v.to_proto()
270    with self.assertRaises(TypeError):
271      v.from_proto(variable_def=None)
272
273  def testSaveSliceInfo(self):
274    v = self.create_variable()
275    slice_info = variables_lib.Variable.SaveSliceInfo()
276    v._set_save_slice_info(slice_info)
277    self.assertIs(v._get_save_slice_info(), slice_info)
278    # Some code accesses _save_slice_info directly without using the getter.
279    self.assertIs(v._save_slice_info, slice_info)
280
281  def testOperatorOverride(self):
282    v = self.create_variable(7)
283    self.assertEqual(v + 1, 8)
284    self.assertEqual(3 + v, 10)
285    self.assertEqual(v + v, 14)
286    self.assertEqual(v - 2, 5)
287    self.assertEqual(13 - v, 6)
288    self.assertEqual(v - v, 0)
289    self.assertEqual(v * 2, 14)
290    self.assertEqual(3 * v, 21)
291    self.assertEqual(v * v, 49)
292    self.assertEqual(v / 2, 3.5)
293    self.assertEqual(14 / v, 2.)
294    self.assertEqual(v // 2, 3)
295    self.assertEqual(15 // v, 2)
296    self.assertEqual(v % 2, 1)
297    self.assertEqual(16 % v, 2)
298    # pylint: disable=g-generic-assert
299    self.assertTrue(v < 12)
300    self.assertTrue(v <= 12)
301    self.assertFalse(v > 12)
302    self.assertFalse(v >= 12)
303    self.assertFalse(12 < v)
304    self.assertFalse(12 <= v)
305    self.assertTrue(12 > v)
306    self.assertTrue(12 >= v)
307    # pylint: enable=g-generic-assert
308    self.assertEqual(v & 3, 3)
309    self.assertEqual(11 & v, 3)
310    self.assertEqual(v | 8, 15)
311    self.assertEqual(16 | v, 23)
312    self.assertEqual(v ^ 3, 4)
313    self.assertEqual(11 ^ v, 12)
314    self.assertEqual(pow(v, 3), 343)
315    # TODO(b/178748613): pow(v, 3, 10) fails.
316    self.assertEqual(pow(2, v), 128)
317    self.assertEqual(-v, -7)
318    self.assertEqual(~v, ~7)
319    self.assertEqual(abs(v), 7)
320
321  def testSlice(self):
322    v = self.create_variable([1., 2., 3.])
323    self.assertEqual(v[1], 2.)
324    v[2].assign(4.)
325    self.assertAllEqual(v, [1., 2., 4.])
326
327  # ==== End Variable interface ===
328
329  # ==== Begin ResourceVariable interface ===
330  def testHandle(self):
331    v = self.create_variable()
332    self.assertIsInstance(v.handle, ops.Tensor)
333    self.assertEqual(v.handle.dtype, dtypes.resource)
334
335  def testInGraphMode(self):
336    # This is protected but used in a lot of places internally.
337    v = self.create_variable()
338    self.assertFalse(v._in_graph_mode)
339
340  def testUniqueId(self):
341    # This is used in optimizer as part of slot variable key.
342    v = self.create_variable()
343    w = self.create_variable()
344    self.assertNotEqual(v._unique_id, w._unique_id)
345
346  def testIsResourceVariable(self):
347    v = self.create_variable()
348    self.assertTrue(resource_variable_ops.is_resource_variable(v))
349  # ==== End ResourceVariable interface ===
350
351  @combinations.generate(combinations.combine(mode="graph"))
352  def testAsGraphElement(self):
353    g = ops.Graph()
354    with g.as_default():
355      v = self.create_variable(1.)
356      g.finalize()
357      self.evaluate(v.initializer)
358      # _as_graph_element shouldn't create new operations.
359      self.assertEqual(self.evaluate(v._as_graph_element()), 1.)
360
361
362class DistributedVariableInterfaceTest(_VariableInterfaceTestBase):
363
364  def create_variable(self, initial_value=1., **kwargs):
365    variables = []
366    for device in self.devices:
367      with ops.device(device):
368        variables.append(
369            variables_lib.Variable(initial_value, **kwargs))
370    return values_v2.DistributedVariable(variables)
371
372
373# Prevent the base class from running.
374del _VariableInterfaceTestBase
375
376
377@combinations.generate(
378    combinations.combine(
379        strategy=[
380            strategy_combinations.tpu_strategy,
381            strategy_combinations.mirrored_strategy_with_two_cpus,
382            strategy_combinations.mirrored_strategy_with_two_gpus,
383        ],
384        enable_packed_handle=[True, False],
385        tf_function=[combinations.tf_function, combinations.no_tf_function]))
386class DistributedVariableTest(test.TestCase, parameterized.TestCase):
387
388  def create_variable(self, strategy, initial_value, enable_packed_handle,
389                      **kwargs):
390    variables = []
391    for device in strategy.extended.parameter_devices:
392      with ops.device(device):
393        variables.append(variables_lib.Variable(initial_value, **kwargs))
394    return values_v2.DistributedVariable(
395        variables, enable_packed_handle=enable_packed_handle)
396
397  def assertReplica(self, distributed_var, values):
398    for var, value in zip(distributed_var._variables, values):
399      self.assertAllEqual(var, value)
400
401  def testRead(self, strategy, enable_packed_handle, tf_function):
402    v = self.create_variable(strategy, 0., enable_packed_handle)
403
404    with ops.device(strategy.extended.parameter_devices[0]):
405      v.assign(1.)
406    with ops.device(strategy.extended.parameter_devices[1]):
407      v.assign(2.)
408
409    @tf_function
410    def read_device0():
411      with ops.device(strategy.extended.parameter_devices[0]):
412        return v.read_value(), v.value()
413
414    @tf_function
415    def read_device1():
416      with ops.device(strategy.extended.parameter_devices[1]):
417        return v.read_value(), v.value()
418
419    @tf_function
420    def read_other_device():
421      with ops.device("CPU:0"):
422        return v.read_value(), v.value()
423
424    self.assertAllEqual(read_device0(), [1., 1.])
425    self.assertAllEqual(read_device1(), [2., 2.])
426    self.assertAllEqual(read_other_device(), [1., 1.])
427
428  def testAssign(self, strategy, enable_packed_handle, tf_function):
429    v = self.create_variable(strategy, 0., enable_packed_handle)
430
431    @tf_function
432    def update_device0():
433      with ops.device(strategy.extended.parameter_devices[0]):
434        v.assign(1.)
435
436    @tf_function
437    def update_device1():
438      with ops.device(strategy.extended.parameter_devices[1]):
439        v.assign(2.)
440
441    update_device0()
442    update_device1()
443    self.assertReplica(v, [1., 2.])
444
445    with ops.device("CPU:0"):
446      # Update the primary replica.
447      v.assign(3.)
448      self.assertReplica(v, [3., 2.])
449
450  def testStrategyRun(self, strategy, enable_packed_handle, tf_function):
451    if (test_util.is_tpu_strategy(strategy) and
452        tf_function is combinations.no_tf_function):
453      self.skipTest("tpu doesn't support eager")
454    v = self.create_variable(strategy, 0., enable_packed_handle)
455
456    @tf_function
457    def update(per_replica):
458      v.assign(per_replica)
459
460    @tf_function
461    def read():
462      return v.read_value()
463
464    strategy.run(
465        update, args=(test_util.create_per_replica(strategy, [1., 2.]),))
466    self.assertReplica(v, [1., 2.])
467    self.assertAllEqual(
468        test_util.gather(strategy, strategy.run(read)), [1., 2.])
469
470
471def _make_index_slices(values, indices, dense_shape=None):
472  if dense_shape:
473    dense_shape = array_ops.identity(dense_shape)
474  return indexed_slices.IndexedSlices(
475      array_ops.identity(values), array_ops.identity(indices), dense_shape)
476
477
478if __name__ == "__main__":
479  test_util.main()
480