1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for utility functions in distribute_utils.""" 16 17import collections 18import collections.abc 19 20from absl.testing import parameterized 21import wrapt 22 23from tensorflow.python.distribute import combinations 24from tensorflow.python.distribute import distribute_utils 25from tensorflow.python.distribute import strategy_combinations 26from tensorflow.python.distribute import values 27from tensorflow.python.eager import context 28from tensorflow.python.eager import test 29from tensorflow.python.framework import constant_op 30from tensorflow.python.framework import ops 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import variable_scope 33from tensorflow.python.saved_model.model_utils import mode_keys 34 35 36def _nested_value(d): 37 return ("a" + d, ["b" + d, {"c": "d" + d, "e": "f" + d}, "g" + d], "h" + d) 38 39 40class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase): 41 42 def _is_per_replica(self, result, expected, klass=values.PerReplica): 43 self.assertIsInstance(result, klass) 44 for i, exp in enumerate(expected): 45 self.assertEqual(exp, result.values[i]) 46 47 def testNested(self): 48 result = distribute_utils.regroup((_nested_value("1"), _nested_value("2"))) 49 self.assertIsInstance(result, tuple) 50 self.assertLen(result, 3) 51 self._is_per_replica(result[0], ["a1", "a2"]) 52 self._is_per_replica(result[2], ["h1", "h2"]) 53 54 self.assertIsInstance(result[1], list) 55 self.assertLen(result[1], 3) 56 self._is_per_replica(result[1][0], ["b1", "b2"]) 57 self._is_per_replica(result[1][2], ["g1", "g2"]) 58 59 self.assertIsInstance(result[1][1], dict) 60 self.assertEqual(set(["c", "e"]), set(result[1][1].keys())) 61 self._is_per_replica(result[1][1]["c"], ["d1", "d2"]) 62 self._is_per_replica(result[1][1]["e"], ["f1", "f2"]) 63 64 # Also test that we can undo the merge using select_replica() 65 self.assertEqual(_nested_value("1"), 66 distribute_utils.select_replica(0, result)) 67 self.assertEqual(_nested_value("2"), 68 distribute_utils.select_replica(1, result)) 69 # select_device_mirrored() should fail due to non-mirrored values 70 with self.assertRaises(TypeError): 71 distribute_utils.select_replica_mirrored(0, result) 72 with self.assertRaises(TypeError): 73 distribute_utils.select_replica_mirrored(1, result) 74 75 def testRegroupKeepsDictBasedClass(self): 76 class DictBasedClass(dict): 77 """Dummy class inherited from a dict.""" 78 79 result = distribute_utils.regroup( 80 (DictBasedClass(a="a1", b="b1"), DictBasedClass(a="a2", b="b2"))) 81 self.assertIsInstance(result, DictBasedClass) 82 self._is_per_replica(result["a"], ["a1", "a2"]) 83 self._is_per_replica(result["b"], ["b1", "b2"]) 84 85 def testRegroupCollectionsMapping(self): 86 87 class CollectionsMappingBasedClass(collections.abc.Mapping): 88 """Class inherited from collections.abc.Mapping.""" 89 90 def __init__(self, *args, **kwargs): 91 self._d = dict(*args, **kwargs) 92 93 def __getitem__(self, key): 94 return self._d.__getitem__(key) 95 96 def __iter__(self): 97 return iter(self._d) 98 99 def __len__(self): 100 return len(self._d) 101 102 result = distribute_utils.regroup( 103 (CollectionsMappingBasedClass(a="a1", b="b1"), 104 CollectionsMappingBasedClass(a="a2", b="b2"))) 105 self.assertIsInstance(result, CollectionsMappingBasedClass) 106 self._is_per_replica(result["a"], ["a1", "a2"]) 107 self._is_per_replica(result["b"], ["b1", "b2"]) 108 109 def testWrapClass(self): 110 # Normally a mirrored value would be the same across devices, but 111 # for a test it is convenient to be able to tell the values apart. 112 result = distribute_utils.regroup((_nested_value("1"), _nested_value("2")), 113 values.Mirrored) 114 self.assertIsInstance(result, tuple) 115 self.assertLen(result, 3) 116 self._is_per_replica(result[0], ["a1", "a2"], values.Mirrored) 117 self._is_per_replica(result[2], ["h1", "h2"], values.Mirrored) 118 119 self.assertIsInstance(result[1], list) 120 self.assertLen(result[1], 3) 121 self._is_per_replica(result[1][0], ["b1", "b2"], values.Mirrored) 122 self._is_per_replica(result[1][2], ["g1", "g2"], values.Mirrored) 123 124 self.assertIsInstance(result[1][1], dict) 125 self.assertEqual(set(["c", "e"]), set(result[1][1].keys())) 126 self._is_per_replica(result[1][1]["c"], ["d1", "d2"], values.Mirrored) 127 self._is_per_replica(result[1][1]["e"], ["f1", "f2"], values.Mirrored) 128 129 # Also test that we can undo the merge using select_replica() 130 self.assertEqual(_nested_value("1"), 131 distribute_utils.select_replica(0, result)) 132 self.assertEqual(_nested_value("2"), 133 distribute_utils.select_replica(1, result)) 134 # Values are marked as mirrored, so select_device_mirrored() is allowed. 135 self.assertEqual(_nested_value("1"), 136 distribute_utils.select_replica_mirrored(0, result)) 137 self.assertEqual(_nested_value("2"), 138 distribute_utils.select_replica_mirrored(1, result)) 139 140 def testWrapAListOfTwoTuples(self): 141 result = distribute_utils.regroup([("1", "2"), ("3", "4")]) 142 self.assertIsInstance(result, tuple) 143 self.assertLen(result, 2) 144 self._is_per_replica(result[0], ("1", "3"), values.PerReplica) 145 self._is_per_replica(result[1], ("2", "4"), values.PerReplica) 146 147 @combinations.generate( 148 combinations.combine( 149 distribution=[ 150 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 151 strategy_combinations.mirrored_strategy_with_one_cpu, 152 ], 153 mode=["graph", "eager"], 154 )) 155 def testMirroredContainer(self, distribution): 156 with distribution.scope(): 157 v = variable_scope.variable( 158 1., aggregation=variable_scope.VariableAggregation.SUM) 159 self.assertTrue(distribute_utils.is_distributed_variable(v)) 160 self.assertTrue(distribute_utils.is_distributed_variable( 161 distribute_utils.regroup(v.values))) 162 163 def testSameId(self): 164 foo = object() 165 result = distribute_utils.regroup((("a", foo), ("b", foo))) 166 self.assertIsInstance(result, tuple) 167 self.assertLen(result, 2) 168 self._is_per_replica(result[0], ["a", "b"]) 169 self.assertIs(foo, result[1]) 170 171 # Test select_replica(), should undo the merge done by regroup(). 172 result_0 = distribute_utils.select_replica(0, result) 173 self.assertIsInstance(result_0, tuple) 174 self.assertLen(result_0, 2) 175 self.assertEqual("a", result_0[0]) 176 self.assertIs(foo, result_0[1]) 177 result_1 = distribute_utils.select_replica(1, result) 178 self.assertIsInstance(result_1, tuple) 179 self.assertLen(result_1, 2) 180 self.assertEqual("b", result_1[0]) 181 self.assertIs(foo, result_1[1]) 182 183 def testOneDevice(self): 184 result = distribute_utils.regroup((_nested_value("1"),)) 185 # On one device regroup() and select_replica() are basically identity. 186 self.assertEqual(_nested_value("1"), result) 187 self.assertEqual(_nested_value("1"), 188 distribute_utils.select_replica(0, result)) 189 190 def testNamedTuple(self): 191 192 # We include toy implementations of Scaffold and EstimatorSpec to 193 # avoid a dependency on Estimator here. 194 195 class Scaffold(object): 196 pass 197 198 class EstimatorSpec(collections.namedtuple( 199 "EstimatorSpec", ["mode", "loss", "train_op", "scaffold"])): 200 201 def __new__(cls, mode, loss, train_op, scaffold=None): 202 return super(EstimatorSpec, cls).__new__( 203 cls, mode=mode, loss=loss, train_op=train_op, 204 scaffold=scaffold or Scaffold()) 205 206 with context.graph_mode(), ops.Graph().as_default(): 207 created_estimator_specs = [] 208 209 for device_id in range(3): 210 spec = EstimatorSpec( 211 mode=mode_keys.EstimatorModeKeys.TRAIN, 212 loss=constant_op.constant(device_id / 2), 213 train_op=array_ops.identity(constant_op.constant(device_id))) 214 created_estimator_specs.append(spec) 215 216 merged_estimator_spec = distribute_utils.regroup(created_estimator_specs) 217 218 self.assertIsInstance(merged_estimator_spec, EstimatorSpec) 219 self.assertEqual(mode_keys.EstimatorModeKeys.TRAIN, 220 merged_estimator_spec.mode) 221 for device_id in range(3): 222 self.assertEqual(created_estimator_specs[device_id].loss, 223 merged_estimator_spec.loss.values[device_id]) 224 self.assertEqual(created_estimator_specs[device_id].train_op, 225 merged_estimator_spec.train_op.values[device_id]) 226 # Scaffold is populated by `EstimatorSpec.__new__`. 227 self.assertEqual(created_estimator_specs[device_id].scaffold, 228 merged_estimator_spec.scaffold.values[device_id]) 229 self.assertIsInstance(created_estimator_specs[device_id].scaffold, 230 Scaffold) 231 # Also test that we can undo the merge using select_replica() 232 self.assertEqual(created_estimator_specs[device_id], 233 distribute_utils.select_replica( 234 device_id, merged_estimator_spec)) 235 236 def testWrappedNamedTuple(self): 237 Point = collections.namedtuple("Point", ["x", "y"]) 238 point1 = Point(x=0, y=2) 239 point2 = Point(x=1, y=3) 240 wrapped1 = wrapt.ObjectProxy(point1) 241 wrapped2 = wrapt.ObjectProxy(point2) 242 result = distribute_utils.regroup([wrapped1, wrapped2]) 243 self.assertEqual(result.x.values, (0, 1)) 244 self.assertEqual(result.y.values, (2, 3)) 245 246if __name__ == "__main__": 247 test.main() 248