1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16from tensorflow.python.distribute import device_util 17from tensorflow.python.distribute import packed_distributed_variable 18from tensorflow.python.eager import context 19from tensorflow.python.eager import def_function 20from tensorflow.python.framework import config 21from tensorflow.python.framework import ops 22from tensorflow.python.framework import test_util 23from tensorflow.python.ops import math_ops 24from tensorflow.python.ops import resource_variable_ops 25from tensorflow.python.platform import test 26 27 28class PackedDistributedVariableTest(test.TestCase): 29 30 def setUp(self): 31 super(PackedDistributedVariableTest, self).setUp() 32 cpus = config.list_physical_devices('CPU') 33 # Set 2 virtual CPUs 34 config.set_logical_device_configuration(cpus[0], [ 35 context.LogicalDeviceConfiguration(), 36 context.LogicalDeviceConfiguration(), 37 ]) 38 39 def testPackedVariable(self): 40 with ops.device('/cpu:0'): 41 v0 = resource_variable_ops.ResourceVariable(1.0, name='var0') 42 with ops.device('/cpu:1'): 43 v1 = resource_variable_ops.ResourceVariable(2.0, name='var1') 44 45 packed_var = packed_distributed_variable.PackedDistributedVariable([v0, v1]) 46 self.assertFalse(packed_var.handle.is_packed) 47 self.assertTrue(packed_var.is_initialized) 48 49 with ops.device('/cpu:0'): 50 self.assertAllEqual(packed_var.get_var_on_current_device(), v0) 51 val0 = packed_var.assign(2.0).assign_add(1.0) 52 self.assertAllEqual(val0, 3.0) 53 54 with ops.device('/cpu:1'): 55 self.assertAllEqual(packed_var.get_var_on_current_device(), v1) 56 val0 = packed_var.assign(2.0).assign_add(1.0) 57 self.assertAllEqual(val0, 3.0) 58 59 @def_function.function 60 def update_var(): 61 self.assertTrue(packed_var.handle.is_packed) 62 with ops.device('/cpu:0'): 63 packed_var.assign_add(3.0).assign_sub(1.0) 64 read0 = packed_var.value() 65 with ops.device('/cpu:1'): 66 packed_var.assign_sub(4.0).assign_sub(2.0) 67 read1 = packed_var.value() 68 69 return read0, read1 70 71 self.assertAllEqual(update_var(), (5.0, -3.0)) 72 73 def testPackedVarAndDevice(self): 74 device0 = device_util.canonicalize('/cpu:0') 75 device1 = device_util.canonicalize('/cpu:1') 76 77 with ops.device(device0): 78 v0 = resource_variable_ops.ResourceVariable(1.0) 79 with ops.device(device1): 80 v1 = resource_variable_ops.ResourceVariable(2.0) 81 82 packed_var = packed_distributed_variable.PackedDistributedVariable([v0, v1]) 83 84 packed_var0 = packed_distributed_variable.PackedVarAndDevice( 85 packed_var, device0) 86 self.assertFalse(packed_var0.handle.is_packed) 87 self.assertAllEqual(math_ops.mul(packed_var0, 2.0), 2.0) 88 89 packed_var1 = packed_distributed_variable.PackedVarAndDevice( 90 packed_var, device1) 91 self.assertAllEqual(packed_var1.assign(3.0), 3.0) 92 93 @def_function.function 94 def func(): 95 self.assertTrue(packed_var.handle.is_packed) 96 var0 = packed_distributed_variable.PackedVarAndDevice(packed_var, device0) 97 var0.assign_add(3.0) 98 var1 = packed_distributed_variable.PackedVarAndDevice(packed_var, device1) 99 return var0.value(), math_ops.add(var1, 2.0) 100 101 self.assertAllEqual(func(), (4.0, 5.0)) 102 103 @test_util.assert_no_garbage_created 104 def testNoGarbage(self): 105 device0 = device_util.canonicalize('/cpu:0') 106 device1 = device_util.canonicalize('/cpu:1') 107 108 with ops.device(device0): 109 v0 = resource_variable_ops.ResourceVariable(1.0) 110 with ops.device(device1): 111 v1 = resource_variable_ops.ResourceVariable(2.0) 112 113 packed_var = packed_distributed_variable.PackedDistributedVariable([v0, v1]) 114 # This needs a workaround to avoid creating reference cycles if the 115 # attribute doesn't exist. 116 hasattr(packed_var.on_device('/cpu:0'), 'nonexist') 117 118 119if __name__ == '__main__': 120 ops.enable_eager_execution() 121 test.main() 122