xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/per_replica_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for the distributed values library."""
16
17from absl.testing import parameterized
18
19from tensorflow.python.distribute import combinations
20from tensorflow.python.distribute import test_util as ds_test_util
21from tensorflow.python.distribute import values as values_lib
22from tensorflow.python.eager import context
23from tensorflow.python.eager import def_function
24from tensorflow.python.eager import test
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import tensor_spec
28from tensorflow.python.framework import test_util
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import control_flow_ops
31from tensorflow.python.util import nest
32
33
34class PerReplicaTest(test.TestCase, parameterized.TestCase):
35
36  @combinations.generate(combinations.combine(mode=["eager"]))
37  def testTypeSpec(self):
38    vals = (constant_op.constant(1.),)
39    per_replica = values_lib.PerReplica(vals)
40
41    spec = per_replica._type_spec
42    self.assertEqual(spec._value_specs,
43                     (tensor_spec.TensorSpec([], dtypes.float32),))
44
45  @combinations.generate(combinations.combine(mode=["eager"]))
46  def testTypeSpecRoundTrip(self):
47    vals = (constant_op.constant(1.),)
48    per_replica = values_lib.PerReplica(vals)
49
50    spec = per_replica._type_spec
51    tensor_list = spec._to_components(per_replica)
52    reconstructed = spec._from_components(tensor_list)
53
54    self.assertAllEqual(per_replica.values, reconstructed.values)
55
56  @combinations.generate(combinations.combine(mode=["eager"]))
57  def testTypeSpecNest(self):
58    vals = (constant_op.constant(1.), constant_op.constant([5., 6.0]),)
59    per_replica = values_lib.PerReplica(vals)
60
61    # Note: nest.map_structure exercises nest.flatten and
62    # nest.pack_sequence_as.
63    result = nest.map_structure(
64        lambda t: t + 10, per_replica, expand_composites=True)
65
66    self.assertLen(result.values, 2)
67    self.assertAllEqual(result.values[0], 11.)
68    self.assertAllEqual(result.values[1], [15., 16.0])
69
70  @test_util.run_in_graph_and_eager_modes
71  def testIsGraphTensor(self):
72    per_replica = values_lib.PerReplica((constant_op.constant(1.),))
73    for t in nest.flatten(per_replica, expand_composites=True):
74      self.assertEqual(hasattr(t, "graph"), not context.executing_eagerly())
75
76  @combinations.generate(combinations.combine(mode=["eager"]))
77  def testDoesNotTriggerFunctionTracing(self):
78    traces = []
79
80    @def_function.function
81    def f(x):
82      traces.append(None)  # Only happens on trace.
83      return x
84
85    per_replica = values_lib.PerReplica((constant_op.constant(1.),))
86
87    # Trace once.
88    f(per_replica)
89    self.assertNotEmpty(traces)
90    del traces[:]
91
92    per_replica_spec = per_replica._type_spec
93    for _ in range(5):
94      vals = per_replica_spec._to_components(per_replica)
95      vals = [v * 2 for v in vals]
96      per_replica = per_replica_spec._from_components(vals)
97
98      output = f(per_replica)
99      self.assertIsInstance(output, values_lib.PerReplica)
100      self.assertAllEqual(output._values, per_replica._values)
101      self.assertEmpty(traces)  # Make sure we're not re-tracing `f`.
102
103  @combinations.generate(combinations.combine(mode=["eager"]))
104  def testFunctionCanReturnPerReplica(self):
105    f = def_function.function(lambda x: x)
106    x = values_lib.PerReplica((constant_op.constant(1.),))
107    y = f(x)
108    self.assertIsNot(x, y)
109    nest.map_structure(self.assertAllEqual, x, y, expand_composites=True)
110    self.assertEqual(x._type_spec, y._type_spec)
111
112  @test_util.run_in_graph_and_eager_modes
113  def testCondWithTensorValues(self):
114    per_replica_1 = values_lib.PerReplica((constant_op.constant("a"),))
115    per_replica_2 = values_lib.PerReplica((constant_op.constant(["b", "c"]),))
116    condition = array_ops.placeholder_with_default(True, [])
117
118    result = control_flow_ops.cond(
119        condition, lambda: per_replica_1, lambda: per_replica_2)
120
121    self.assertLen(result.values, 1)
122    self.assertAllEqual(result.values[0], "a")
123
124  @test_util.run_in_graph_and_eager_modes
125  def testCondWithValuesConvertibleToTensor(self):
126    per_replica_1 = values_lib.PerReplica(("a",))
127    per_replica_2 = values_lib.PerReplica(("b",))
128    condition = array_ops.placeholder_with_default(True, [])
129
130    result = control_flow_ops.cond(
131        condition, lambda: per_replica_1, lambda: per_replica_2)
132
133    self.assertLen(result.values, 1)
134    self.assertAllEqual(result.values[0], "a")
135
136  @test_util.build_as_function_and_v1_graph
137  def testCondWithValuesNotConvertibleToTensor(self):
138    per_replica_1 = values_lib.PerReplica(({"a"},))
139    per_replica_2 = values_lib.PerReplica(({"b", "c"},))
140    condition = array_ops.placeholder(dtypes.bool, [])
141
142    with self.assertRaisesRegex(TypeError, "Could not build a TypeSpec for"):
143      control_flow_ops.cond(
144          condition, lambda: per_replica_1, lambda: per_replica_2)
145
146if __name__ == "__main__":
147  ds_test_util.main()
148