1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for the distributed values library.""" 16 17from absl.testing import parameterized 18 19from tensorflow.python.distribute import combinations 20from tensorflow.python.distribute import test_util as ds_test_util 21from tensorflow.python.distribute import values as values_lib 22from tensorflow.python.eager import context 23from tensorflow.python.eager import def_function 24from tensorflow.python.eager import test 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import tensor_spec 28from tensorflow.python.framework import test_util 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import control_flow_ops 31from tensorflow.python.util import nest 32 33 34class PerReplicaTest(test.TestCase, parameterized.TestCase): 35 36 @combinations.generate(combinations.combine(mode=["eager"])) 37 def testTypeSpec(self): 38 vals = (constant_op.constant(1.),) 39 per_replica = values_lib.PerReplica(vals) 40 41 spec = per_replica._type_spec 42 self.assertEqual(spec._value_specs, 43 (tensor_spec.TensorSpec([], dtypes.float32),)) 44 45 @combinations.generate(combinations.combine(mode=["eager"])) 46 def testTypeSpecRoundTrip(self): 47 vals = (constant_op.constant(1.),) 48 per_replica = values_lib.PerReplica(vals) 49 50 spec = per_replica._type_spec 51 tensor_list = spec._to_components(per_replica) 52 reconstructed = spec._from_components(tensor_list) 53 54 self.assertAllEqual(per_replica.values, reconstructed.values) 55 56 @combinations.generate(combinations.combine(mode=["eager"])) 57 def testTypeSpecNest(self): 58 vals = (constant_op.constant(1.), constant_op.constant([5., 6.0]),) 59 per_replica = values_lib.PerReplica(vals) 60 61 # Note: nest.map_structure exercises nest.flatten and 62 # nest.pack_sequence_as. 63 result = nest.map_structure( 64 lambda t: t + 10, per_replica, expand_composites=True) 65 66 self.assertLen(result.values, 2) 67 self.assertAllEqual(result.values[0], 11.) 68 self.assertAllEqual(result.values[1], [15., 16.0]) 69 70 @test_util.run_in_graph_and_eager_modes 71 def testIsGraphTensor(self): 72 per_replica = values_lib.PerReplica((constant_op.constant(1.),)) 73 for t in nest.flatten(per_replica, expand_composites=True): 74 self.assertEqual(hasattr(t, "graph"), not context.executing_eagerly()) 75 76 @combinations.generate(combinations.combine(mode=["eager"])) 77 def testDoesNotTriggerFunctionTracing(self): 78 traces = [] 79 80 @def_function.function 81 def f(x): 82 traces.append(None) # Only happens on trace. 83 return x 84 85 per_replica = values_lib.PerReplica((constant_op.constant(1.),)) 86 87 # Trace once. 88 f(per_replica) 89 self.assertNotEmpty(traces) 90 del traces[:] 91 92 per_replica_spec = per_replica._type_spec 93 for _ in range(5): 94 vals = per_replica_spec._to_components(per_replica) 95 vals = [v * 2 for v in vals] 96 per_replica = per_replica_spec._from_components(vals) 97 98 output = f(per_replica) 99 self.assertIsInstance(output, values_lib.PerReplica) 100 self.assertAllEqual(output._values, per_replica._values) 101 self.assertEmpty(traces) # Make sure we're not re-tracing `f`. 102 103 @combinations.generate(combinations.combine(mode=["eager"])) 104 def testFunctionCanReturnPerReplica(self): 105 f = def_function.function(lambda x: x) 106 x = values_lib.PerReplica((constant_op.constant(1.),)) 107 y = f(x) 108 self.assertIsNot(x, y) 109 nest.map_structure(self.assertAllEqual, x, y, expand_composites=True) 110 self.assertEqual(x._type_spec, y._type_spec) 111 112 @test_util.run_in_graph_and_eager_modes 113 def testCondWithTensorValues(self): 114 per_replica_1 = values_lib.PerReplica((constant_op.constant("a"),)) 115 per_replica_2 = values_lib.PerReplica((constant_op.constant(["b", "c"]),)) 116 condition = array_ops.placeholder_with_default(True, []) 117 118 result = control_flow_ops.cond( 119 condition, lambda: per_replica_1, lambda: per_replica_2) 120 121 self.assertLen(result.values, 1) 122 self.assertAllEqual(result.values[0], "a") 123 124 @test_util.run_in_graph_and_eager_modes 125 def testCondWithValuesConvertibleToTensor(self): 126 per_replica_1 = values_lib.PerReplica(("a",)) 127 per_replica_2 = values_lib.PerReplica(("b",)) 128 condition = array_ops.placeholder_with_default(True, []) 129 130 result = control_flow_ops.cond( 131 condition, lambda: per_replica_1, lambda: per_replica_2) 132 133 self.assertLen(result.values, 1) 134 self.assertAllEqual(result.values[0], "a") 135 136 @test_util.build_as_function_and_v1_graph 137 def testCondWithValuesNotConvertibleToTensor(self): 138 per_replica_1 = values_lib.PerReplica(({"a"},)) 139 per_replica_2 = values_lib.PerReplica(({"b", "c"},)) 140 condition = array_ops.placeholder(dtypes.bool, []) 141 142 with self.assertRaisesRegex(TypeError, "Could not build a TypeSpec for"): 143 control_flow_ops.cond( 144 condition, lambda: per_replica_1, lambda: per_replica_2) 145 146if __name__ == "__main__": 147 ds_test_util.main() 148