1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests tf.random.Generator with distribution strategies.""" 16 17import functools 18import os 19 20from absl.testing import parameterized 21from tensorflow.python.checkpoint import checkpoint as tracking_util 22from tensorflow.python.compat import v2_compat 23from tensorflow.python.distribute import combinations as ds_combinations 24from tensorflow.python.distribute import multi_process_runner 25from tensorflow.python.distribute import sharded_variable 26from tensorflow.python.distribute import strategy_combinations 27from tensorflow.python.distribute import values as dist_values 28from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib 29from tensorflow.python.distribute.mirrored_strategy import MirroredStrategy 30from tensorflow.python.eager import def_function 31from tensorflow.python.framework import constant_op 32from tensorflow.python.framework import dtypes 33from tensorflow.python.framework import test_combinations as combinations 34from tensorflow.python.framework import test_util 35from tensorflow.python.module import module 36from tensorflow.python.ops import array_ops 37from tensorflow.python.ops import stateful_random_ops as rng 38from tensorflow.python.platform import test 39from tensorflow.python.saved_model import load 40from tensorflow.python.saved_model import save 41from tensorflow.python.util import deprecation 42 43 44def get_num_local_replicas(strat, values=None): 45 strat_name = type(strat).__name__ 46 if "MultiWorker" in strat_name or "CollectiveAllReduceStrategy" in strat_name: 47 if values is None: 48 values = strat.run(lambda: constant_op.constant(0)) 49 values = strat.experimental_local_results(values) 50 return len(values) 51 else: 52 return strat.num_replicas_in_sync 53 54 55ps_strategies = [ 56 strategy_combinations.parameter_server_strategy_3worker_2ps_cpu, 57 strategy_combinations.parameter_server_strategy_1worker_2ps_cpu, 58 strategy_combinations.parameter_server_strategy_3worker_2ps_1gpu, 59 strategy_combinations.parameter_server_strategy_1worker_2ps_1gpu, 60] 61all_strategies = (strategy_combinations.all_strategies + 62 strategy_combinations.multiworker_strategies + 63 ps_strategies) 64 65 66def run_on_strategy(replica_fn, strat, coord): 67 def distributed_fn(): 68 return strat.run(replica_fn) 69 if coord is not None: 70 results = coord.schedule( 71 def_function.function(distributed_fn)).fetch() 72 else: 73 results = distributed_fn() 74 return results 75 76 77class GeneratorTest(test.TestCase, parameterized.TestCase): 78 79 def setUp(self): 80 super(GeneratorTest, self).setUp() 81 v2_compat.enable_v2_behavior() 82 83 def assertAllDifferent(self, tensors): 84 """Checks that there are no duplicate elements anywhere among the tensors. 85 86 Args: 87 tensors: a list of tensors. They can have different shapes. 88 """ 89 values = [array_ops.reshape(t, shape=[-1]) for t in tensors] 90 values = array_ops.concat(values, axis=0) 91 values = self.evaluate(values) 92 values = values.tolist() 93 self.assertAllEqual(len(values), len(set(values))) 94 95 @test_util.run_v2_only 96 def testCreateOutsideMirroredStrat(self): 97 """Tests RNG/MirrorStrategy interaction #1. 98 99 If an RNG is created outside a DS scope, all replicas will access the 100 same RNG object, and accesses are serialized. 101 """ 102 shape = [3, 4] 103 dtype = dtypes.int32 104 gen = rng.Generator.from_seed(1234) 105 strat = MirroredStrategy(devices=["cpu:0", "cpu:1"]) 106 with strat.scope(): 107 108 def f(): 109 t1 = gen.uniform_full_int(shape=shape, dtype=dtype) 110 t2 = gen.uniform_full_int(shape=shape, dtype=dtype) 111 t = array_ops.stack([t1, t2]) 112 return t 113 114 results = strat.extended.call_for_each_replica(fn=f) 115 values = results.values 116 self.assertAllEqual(2, len(values)) 117 self.assertAllDifferent(values) 118 119 @test_util.run_v2_only 120 def testMirroredStratParaAsync(self): 121 """Tests RNG/MirrorStrategy interaction #2. 122 123 The user can create n independent RNGs outside strategy.scope(), where n 124 is the number of replicas, and give one to each replica. The replicas can 125 thus get different random-number streams. 126 """ 127 shape = [3, 4] 128 dtype = dtypes.int32 129 gens = rng.get_global_generator().split(count=2) 130 devices = ["cpu:0", "cpu:1"] 131 strat = MirroredStrategy(devices=devices) 132 # Use `PerReplica` to specify which `gen` is sent to which replica 133 gens = dist_values.PerReplica([[g] for g in gens]) 134 with strat.scope(): 135 136 def f(gen): 137 t1 = gen.uniform_full_int(shape=shape, dtype=dtype) 138 t2 = gen.uniform_full_int(shape=shape, dtype=dtype) 139 t = array_ops.stack([t1, t2]) 140 return t 141 142 results = strat.extended.call_for_each_replica(fn=f, args=gens) 143 local_results = strat.experimental_local_results(results) 144 self.assertAllEqual(2, len(local_results)) 145 self.assertAllDifferent(local_results) 146 147 @ds_combinations.generate( 148 combinations.combine( 149 strat=all_strategies, 150 mode=["eager"])) 151 def testCrossReplica(self, strat): 152 """Tests that RNG can be properly advanced in cross-replica context.""" 153 def read_values(dv): 154 return [v.read_value() for v in strat.experimental_local_results(dv)] 155 with strat.scope(): 156 g = rng.Generator.from_seed(1) 157 s1 = read_values(g.state) 158 g.normal([3]) 159 g.skip(4) 160 s2 = read_values(g.state) 161 self.assertNotAllEqual(s1[0], s2[0]) 162 self.assertEqual(len(s1), len(s2)) 163 for i in range(1, len(s1)): 164 self.assertAllEqual(s1[0], s1[i]) 165 self.assertAllEqual(s2[0], s2[i]) 166 167 @ds_combinations.generate( 168 combinations.combine( 169 strat=all_strategies, 170 mode=["eager"], 171 jit_replica_fn=[False, True], 172 seeded=[True, False],)) 173 def testDistStrat(self, strat, jit_replica_fn, seeded): 174 """Tests RNG with distribution strategies.""" 175 strat_name = type(strat).__name__ 176 if "TPU" in strat_name and not jit_replica_fn: 177 self.skipTest( 178 "TPUStrategy requires the replica function (the function passed to " 179 "strategy.run) to be decorated with tf.function") 180 coord = None 181 if "ParameterServer" in strat_name: 182 coord = coordinator_lib.ClusterCoordinator(strat) 183 creators = { 184 True: functools.partial(rng.Generator.from_seed, 1234), 185 False: rng.Generator.from_non_deterministic_state, 186 } 187 shape = [3, 4] 188 dtype = dtypes.int32 189 creator = creators[seeded] 190 with strat.scope(): 191 gen = creator() 192 def f(): 193 t1 = gen.uniform_full_int(shape=shape, dtype=dtype) 194 t2 = gen.uniform_full_int(shape=shape, dtype=dtype) 195 t = array_ops.stack([t1, t2]) 196 return t 197 replica_fn = def_function.function(f) if jit_replica_fn else f 198 results = run_on_strategy(replica_fn, strat, coord) 199 values = strat.experimental_local_results(results) 200 n = get_num_local_replicas(strat, values) 201 self.assertAllEqual(n, len(values)) 202 self.assertAllDifferent(values) 203 204 @ds_combinations.generate( 205 combinations.combine( 206 strat=[ 207 strategy_combinations.parameter_server_strategy_fn( 208 "ParameterServer1Worker2PSCPUFixedShards", 209 num_workers=1, num_ps=2, 210 variable_partitioner=( 211 sharded_variable.FixedShardsPartitioner(2))) 212 ], 213 mode=["eager"])) 214 def testShardedError(self, strat): 215 """Tests error about sharding is raised.""" 216 with strat.scope(): 217 with self.assertRaisesRegex( 218 ValueError, "state is sharded, which is not allowed"): 219 rng.Generator.from_seed(1234) 220 221 @ds_combinations.generate( 222 combinations.combine( 223 strat=all_strategies, 224 mode=["eager"], 225 jit_replica_fn=[False, True])) 226 def testDistVarAsTFFunArg(self, strat, jit_replica_fn): 227 """Tests that RNG with dist variables can be used as tf.function's arg.""" 228 strat_name = type(strat).__name__ 229 if "CentralStorage" in strat_name: 230 self.skipTest( 231 "CentralStorageStrategy wraps variable updates in merge_call which " 232 "can't be called inside a tf.function that doesn't cover the entire " 233 "replica function (the function passed to strategy.run).") 234 if "TPU" in strat_name and not jit_replica_fn: 235 self.skipTest( 236 "TPUStrategy requires the replica function (the function passed to " 237 "strategy.run) to be decorated with tf.function") 238 coord = None 239 if "ParameterServer" in strat_name: 240 coord = coordinator_lib.ClusterCoordinator(strat) 241 shape = [3, 4] 242 dtype = dtypes.int32 243 with strat.scope(): 244 gen = rng.Generator.from_seed(1234) 245 @def_function.function 246 def f(gen): # the main focus 247 t1 = gen.uniform_full_int(shape=shape, dtype=dtype) 248 t2 = gen.uniform_full_int(shape=shape, dtype=dtype) 249 t = array_ops.stack([t1, t2]) 250 return t 251 def g(): 252 return f(gen) 253 replica_fn = def_function.function(g) if jit_replica_fn else g 254 for _ in range(2): 255 results = run_on_strategy(replica_fn, strat, coord) 256 values = strat.experimental_local_results(results) 257 n = get_num_local_replicas(strat, values) 258 self.assertAllEqual(n, len(values)) 259 self.assertAllDifferent(values) 260 261 @ds_combinations.generate( 262 combinations.combine( 263 strat1=strategy_combinations.all_strategies, 264 strat2=strategy_combinations.all_strategies, 265 jit_replica_fn=[False, True], 266 mode=["eager"]) + 267 combinations.combine( 268 strat1=strategy_combinations.multiworker_strategies + ps_strategies, 269 strat2=[None], 270 jit_replica_fn=[False, True], 271 mode=["eager"])) 272 def testDistStratRestore(self, strat1, strat2, jit_replica_fn): 273 """Tests checkpointing and restoring (to possibly different #replicas).""" 274 if strat2 is None: 275 strat2 = strat1 276 strat1_name = type(strat1).__name__ 277 strat2_name = type(strat2).__name__ 278 if "Default" in strat1_name or "Default" in strat2_name: 279 self.skipTest( 280 "We don't guarantee consistency between strategy and no-strategy.") 281 if ("TPU" in strat1_name or "TPU" in strat2_name) and not jit_replica_fn: 282 self.skipTest( 283 "TPUStrategy requires the replica function (the function passed to " 284 "strategy.run) to be decorated with tf.function") 285 coord1 = None 286 if "ParameterServer" in strat1_name: 287 coord1 = coordinator_lib.ClusterCoordinator(strat1) 288 coord2 = None 289 if "ParameterServer" in strat2_name: 290 coord2 = coordinator_lib.ClusterCoordinator(strat2) 291 fname = os.path.join(self.get_temp_dir(), "checkpoint") 292 def uniform(strat, coord, g): 293 def f(): 294 return g.uniform_full_int([3], dtype=dtypes.int32) 295 replica_fn = def_function.function(f) if jit_replica_fn else f 296 result = run_on_strategy(replica_fn, strat, coord) 297 return strat.experimental_local_results(result) 298 with strat1.scope(): 299 g1 = rng.Generator.from_seed(1) 300 with strat2.scope(): 301 g2 = rng.Generator.from_seed(10) 302 cp1 = tracking_util.Checkpoint(g=g1) 303 cp2 = tracking_util.Checkpoint(g=g2) 304 def write_restore_compare(): 305 cp1.write(fname) 306 r1 = uniform(strat1, coord1, g1) 307 cp2.restore(fname) 308 r2 = uniform(strat2, coord2, g2) 309 # Tests that overlapping replicas are properly restored. 310 n1 = get_num_local_replicas(strat1) 311 n2 = get_num_local_replicas(strat2) 312 n = min(n1, n2) 313 self.assertAllEqual(r1[:n], r2[:n]) 314 # Run multiple times so that cp1.write is called in various RNG states 315 for _ in range(2): 316 write_restore_compare() 317 318 @ds_combinations.generate( 319 combinations.combine( 320 strat=all_strategies, 321 mode=["eager"], 322 is_save_in_scope=[True, False])) 323 def testSavedModel(self, strat, is_save_in_scope): 324 325 class CustomModule(module.Module): 326 327 def __init__(self): 328 super(CustomModule, self).__init__() 329 self.g = rng.Generator.from_seed(0) 330 331 @def_function.function 332 def __call__(self): 333 return self.g.state 334 335 @def_function.function 336 def mutate(self): 337 self.g.normal([]) 338 339 with strat.scope(): 340 m = CustomModule() 341 m.mutate() 342 state_before = m() 343 path = os.path.join(self.get_temp_dir(), "saved_model") 344 if is_save_in_scope: 345 with strat.scope(): 346 save.save(m, path) 347 else: 348 save.save(m, path) 349 with strat.scope(): 350 m.mutate() 351 state_before_2 = m() 352 353 imported = load.load(path) 354 state_after = imported() 355 self.assertAllEqual(state_before, state_after) 356 imported.mutate() 357 state_after_2 = imported() 358 self.assertAllEqual(state_before_2, state_after_2) 359 360 361if __name__ == "__main__": 362 with deprecation.silence(): 363 multi_process_runner.test_main() 364