xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/packed_distributed_variable_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16from tensorflow.python.distribute import device_util
17from tensorflow.python.distribute import packed_distributed_variable
18from tensorflow.python.eager import context
19from tensorflow.python.eager import def_function
20from tensorflow.python.framework import config
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import test_util
23from tensorflow.python.ops import math_ops
24from tensorflow.python.ops import resource_variable_ops
25from tensorflow.python.platform import test
26
27
28class PackedDistributedVariableTest(test.TestCase):
29
30  def setUp(self):
31    super(PackedDistributedVariableTest, self).setUp()
32    cpus = config.list_physical_devices('CPU')
33    # Set 2 virtual CPUs
34    config.set_logical_device_configuration(cpus[0], [
35        context.LogicalDeviceConfiguration(),
36        context.LogicalDeviceConfiguration(),
37    ])
38
39  def testPackedVariable(self):
40    with ops.device('/cpu:0'):
41      v0 = resource_variable_ops.ResourceVariable(1.0, name='var0')
42    with ops.device('/cpu:1'):
43      v1 = resource_variable_ops.ResourceVariable(2.0, name='var1')
44
45    packed_var = packed_distributed_variable.PackedDistributedVariable([v0, v1])
46    self.assertFalse(packed_var.handle.is_packed)
47    self.assertTrue(packed_var.is_initialized)
48
49    with ops.device('/cpu:0'):
50      self.assertAllEqual(packed_var.get_var_on_current_device(), v0)
51      val0 = packed_var.assign(2.0).assign_add(1.0)
52      self.assertAllEqual(val0, 3.0)
53
54    with ops.device('/cpu:1'):
55      self.assertAllEqual(packed_var.get_var_on_current_device(), v1)
56      val0 = packed_var.assign(2.0).assign_add(1.0)
57      self.assertAllEqual(val0, 3.0)
58
59    @def_function.function
60    def update_var():
61      self.assertTrue(packed_var.handle.is_packed)
62      with ops.device('/cpu:0'):
63        packed_var.assign_add(3.0).assign_sub(1.0)
64        read0 = packed_var.value()
65      with ops.device('/cpu:1'):
66        packed_var.assign_sub(4.0).assign_sub(2.0)
67        read1 = packed_var.value()
68
69      return read0, read1
70
71    self.assertAllEqual(update_var(), (5.0, -3.0))
72
73  def testPackedVarAndDevice(self):
74    device0 = device_util.canonicalize('/cpu:0')
75    device1 = device_util.canonicalize('/cpu:1')
76
77    with ops.device(device0):
78      v0 = resource_variable_ops.ResourceVariable(1.0)
79    with ops.device(device1):
80      v1 = resource_variable_ops.ResourceVariable(2.0)
81
82    packed_var = packed_distributed_variable.PackedDistributedVariable([v0, v1])
83
84    packed_var0 = packed_distributed_variable.PackedVarAndDevice(
85        packed_var, device0)
86    self.assertFalse(packed_var0.handle.is_packed)
87    self.assertAllEqual(math_ops.mul(packed_var0, 2.0), 2.0)
88
89    packed_var1 = packed_distributed_variable.PackedVarAndDevice(
90        packed_var, device1)
91    self.assertAllEqual(packed_var1.assign(3.0), 3.0)
92
93    @def_function.function
94    def func():
95      self.assertTrue(packed_var.handle.is_packed)
96      var0 = packed_distributed_variable.PackedVarAndDevice(packed_var, device0)
97      var0.assign_add(3.0)
98      var1 = packed_distributed_variable.PackedVarAndDevice(packed_var, device1)
99      return var0.value(), math_ops.add(var1, 2.0)
100
101    self.assertAllEqual(func(), (4.0, 5.0))
102
103  @test_util.assert_no_garbage_created
104  def testNoGarbage(self):
105    device0 = device_util.canonicalize('/cpu:0')
106    device1 = device_util.canonicalize('/cpu:1')
107
108    with ops.device(device0):
109      v0 = resource_variable_ops.ResourceVariable(1.0)
110    with ops.device(device1):
111      v1 = resource_variable_ops.ResourceVariable(2.0)
112
113    packed_var = packed_distributed_variable.PackedDistributedVariable([v0, v1])
114    # This needs a workaround to avoid creating reference cycles if the
115    # attribute doesn't exist.
116    hasattr(packed_var.on_device('/cpu:0'), 'nonexist')
117
118
119if __name__ == '__main__':
120  ops.enable_eager_execution()
121  test.main()
122