1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for common methods in strategy classes.""" 16 17from absl.testing import parameterized 18 19from tensorflow.python.data.ops import dataset_ops 20from tensorflow.python.distribute import central_storage_strategy 21from tensorflow.python.distribute import combinations 22from tensorflow.python.distribute import distribution_strategy_context as ds_context 23from tensorflow.python.distribute import mirrored_strategy 24from tensorflow.python.distribute import strategy_combinations 25from tensorflow.python.distribute import test_util 26from tensorflow.python.distribute import tpu_strategy 27from tensorflow.python.distribute.collective_all_reduce_strategy import CollectiveAllReduceStrategy 28from tensorflow.python.eager import def_function 29from tensorflow.python.framework import constant_op 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import errors 32from tensorflow.python.framework import indexed_slices 33from tensorflow.python.framework import test_util as tf_test_util 34from tensorflow.python.ops import array_ops 35from tensorflow.python.ops import gradients_impl 36from tensorflow.python.platform import test 37from tensorflow.python.util import nest 38 39 40@tf_test_util.with_eager_op_as_function 41@combinations.generate( 42 combinations.combine( 43 strategy=[ 44 strategy_combinations.default_strategy, 45 strategy_combinations.one_device_strategy, 46 strategy_combinations.one_device_strategy_gpu, 47 strategy_combinations.central_storage_strategy_with_two_gpus, 48 strategy_combinations.central_storage_strategy_with_gpu_and_cpu, 49 strategy_combinations.mirrored_strategy_with_one_cpu, 50 strategy_combinations.mirrored_strategy_with_one_gpu, 51 strategy_combinations.mirrored_strategy_with_two_gpus, 52 strategy_combinations.mirrored_strategy_with_cpu_1_and_2, 53 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 54 strategy_combinations.multi_worker_mirrored_2x2_gpu, 55 strategy_combinations.multi_worker_mirrored_2x1_cpu, 56 strategy_combinations.multi_worker_mirrored_2x1_gpu, 57 ], 58 mode=['eager'], 59 pure_eager=[True, False]) + combinations.combine( 60 strategy=[ 61 strategy_combinations.tpu_strategy, 62 strategy_combinations.tpu_strategy_packed_var, 63 strategy_combinations.tpu_strategy_one_step, 64 strategy_combinations.cloud_tpu_strategy, 65 ], 66 mode=['eager'], 67 pure_eager=[False])) 68class GatherTest(test.TestCase, parameterized.TestCase): 69 70 def _gather_same_shape_and_verify(self, value_on_replica, axis, pure_eager, 71 strategy): 72 distributed_values = strategy.experimental_distribute_values_from_function( 73 lambda _: array_ops.identity(value_on_replica)) 74 75 def run(): 76 return strategy.gather(distributed_values, axis=axis) 77 78 if not pure_eager: 79 run = def_function.function(run) 80 81 all_results = [ 82 value_on_replica for _ in range(strategy.num_replicas_in_sync) 83 ] 84 expected_result = array_ops.concat(all_results, axis=axis) 85 self.assertAllEqual(expected_result, run().numpy()) 86 87 def testGatherPerReplicaDense1D0Axis(self, strategy, pure_eager): 88 """A DistributedValues object with two tensors of shape [3] on each replica gathers to a tensor of [6].""" 89 single_value = constant_op.constant([1, 2, 3]) 90 axis = 0 91 self._gather_same_shape_and_verify(single_value, axis, pure_eager, strategy) 92 93 def testGatherPerReplicaDense2D0Axis(self, strategy, pure_eager): 94 """A DistributedValues object with two tensors of [1, 3] on each replica gathers along 0th dim to a tensor of [2, 3].""" 95 single_value = constant_op.constant([[1, 2, 3]]) 96 axis = 0 97 self._gather_same_shape_and_verify(single_value, axis, pure_eager, strategy) 98 99 def testGatherPerReplicaDense2D1Axis(self, strategy, pure_eager): 100 """A DistributedValues object with two tensors of [1, 3] on each replica gathers along 1st dim to a tensor of [1, 6].""" 101 single_value = constant_op.constant([[1, 2, 3]]) 102 axis = 1 103 self._gather_same_shape_and_verify(single_value, axis, pure_eager, strategy) 104 105 def testGatherPerReplicaDense3D0Axis(self, strategy, pure_eager): 106 """A DistributedValues object with two tensors of [1, 2, 2] on each replica gathers along 0th dim to a tensor of [2, 2, 2].""" 107 single_value = constant_op.constant([[[1, 2], [1, 2]]]) 108 axis = 0 109 self._gather_same_shape_and_verify(single_value, axis, pure_eager, strategy) 110 111 def testGatherPerReplicaDense3D1Axis(self, strategy, pure_eager): 112 """A DistributedValues object with two tensors of [1, 2, 2] on each replica gathers along 1nd dimension to a tensor of [1, 4, 2].""" 113 single_value = constant_op.constant([[[1, 2], [1, 2]]]) 114 axis = 1 115 self._gather_same_shape_and_verify(single_value, axis, pure_eager, strategy) 116 117 def testGatherPerReplicaDense3D2Axis(self, strategy, pure_eager): 118 """A DistributedValues object with two tensors of [1, 2, 2] on each replica gathers along 2nd dimension to a tensor of [1, 2, 4].""" 119 single_value = constant_op.constant([[[1, 2], [1, 2]]]) 120 axis = 2 121 self._gather_same_shape_and_verify(single_value, axis, pure_eager, strategy) 122 123 def testGatherDiffShapeAtAxis0(self, strategy, pure_eager): 124 """Different `Axis`-th (0) dimension: shape [1, 1], [2, 1] -> [3, 1].""" 125 126 def value_fn(ctx): 127 return constant_op.constant( 128 1, shape=(ctx.replica_id_in_sync_group + 1, 1)) 129 130 distributed_values = strategy.experimental_distribute_values_from_function( 131 value_fn) 132 axis = 0 133 134 def run(): 135 return strategy.gather(distributed_values, axis=axis) 136 137 if not pure_eager: 138 run = def_function.function(run) 139 140 expected_result = constant_op.constant( 141 1, shape=(sum(range(strategy.num_replicas_in_sync + 1)), 1)) 142 143 self.assertAllEqual(expected_result, run().numpy()) 144 145 def testGatherDiffShapeAtAxis1(self, strategy, pure_eager): 146 """Different `Axis`-th (non-0) dimension: shape [1, 1], [1, 2] -> [1, 3].""" 147 148 def value_fn(ctx): 149 return constant_op.constant( 150 1, shape=(1, ctx.replica_id_in_sync_group + 1)) 151 152 distributed_values = strategy.experimental_distribute_values_from_function( 153 value_fn) 154 axis = 1 155 156 def run(): 157 return strategy.gather(distributed_values, axis=axis) 158 159 if not pure_eager: 160 run = def_function.function(run) 161 162 expected_result = constant_op.constant( 163 1, shape=(1, sum(range(strategy.num_replicas_in_sync + 1)))) 164 165 self.assertAllEqual(expected_result, run().numpy()) 166 167 def testGatherRaiseDiffShapeAtNonAxis(self, strategy, pure_eager): 168 """Different at non-`axis`-th dimension : [1, 1], [1, 2], 0th -> raise error.""" 169 if isinstance(strategy, CollectiveAllReduceStrategy 170 ) and _get_num_replicas_per_client(strategy) > 1: 171 self.skipTest('b/167331966') 172 173 if strategy.num_replicas_in_sync <= 1: 174 self.skipTest('Test for more than 1 replica only.') 175 176 def value_fn(ctx): 177 return constant_op.constant( 178 1, shape=(1, ctx.replica_id_in_sync_group + 1)) 179 180 distributed_values = strategy.experimental_distribute_values_from_function( 181 value_fn) 182 axis = 0 183 184 def run(): 185 return strategy.gather(distributed_values, axis=axis) 186 187 if not pure_eager: 188 run = def_function.function(run) 189 190 if isinstance(strategy, CollectiveAllReduceStrategy): 191 with self.assertRaisesRegex(errors.InvalidArgumentError, 192 r'Shape mismatch'): 193 run() 194 elif isinstance(strategy, 195 (mirrored_strategy.MirroredStrategy, 196 central_storage_strategy.CentralStorageStrategy)): 197 with self.assertRaisesRegex((errors.InvalidArgumentError, ValueError), 198 r'Dimension \d in both shapes must be equal'): 199 run() 200 201 def testGatherRaiseSparse(self, strategy, pure_eager): 202 dense_shape = [5, 2] 203 t0 = _make_indexed_slices( 204 values=[[1., 2.]], indices=[2], dense_shape=dense_shape) 205 206 def run(value): 207 return strategy.gather(value, axis=0) 208 209 with self.assertRaisesRegex( 210 NotImplementedError, 211 r'gather does not support IndexedSlices'): 212 if pure_eager: 213 run(t0) 214 else: 215 def_function.function(run)(t0) 216 217 def testGatherRaiseDifferentRank(self, strategy, pure_eager): 218 """Different rank: [1,], [1, 2] -> raise error.""" 219 if strategy.num_replicas_in_sync <= 1: 220 self.skipTest('Test for more than 1 replicas.') 221 if isinstance(strategy, CollectiveAllReduceStrategy 222 ) and _get_num_replicas_per_client(strategy) > 1: 223 self.skipTest('b/167331966') 224 def value_fn(ctx): 225 return array_ops.ones(shape=(range(1, ctx.replica_id_in_sync_group + 2))) 226 227 distributed_values = strategy.experimental_distribute_values_from_function( 228 value_fn) 229 axis = 0 230 231 def run(): 232 return strategy.gather(distributed_values, axis=axis) 233 234 if not pure_eager: 235 run = def_function.function(run) 236 237 if isinstance(strategy, CollectiveAllReduceStrategy): 238 with self.assertRaisesRegex(errors.InvalidArgumentError, 239 r'Shape mismatch'): 240 run() 241 elif isinstance( 242 strategy, 243 (mirrored_strategy.MirroredStrategy, 244 central_storage_strategy.CentralStorageStrategy)): 245 if pure_eager: 246 with self.assertRaises(errors.InvalidArgumentError) as e: 247 run() 248 # Different error message depending on whether collective ops is used. 249 self.assertRegexMatch( 250 str(e.exception), 251 ['Ranks of all input tensors should match', 'Shape mismatch']) 252 else: 253 with self.assertRaises((errors.InvalidArgumentError, ValueError)) as e: 254 run() 255 self.assertRegexMatch( 256 str(e.exception), 257 [r'Shape must be rank \d but is rank \d', 'Shape mismatch']) 258 elif _is_tpu_strategy(strategy) and pure_eager: 259 with self.assertRaisesRegex(ValueError, 260 r'Dimension \d in both shapes must be equal'): 261 run() 262 else: 263 with self.assertRaisesRegex(ValueError, 264 r'Shape must be rank \d but is rank \d'): 265 run() 266 267 # Ideally, here we should split them into another test class, AllGatherTest. 268 # But doing that makes two initialize_tpu_system() calls and one of them times 269 # out, on Kokoro. Integrating two into one avoids it. 270 def _all_gather_same_shape_and_verify(self, value_on_replica, axis, 271 pure_eager, strategy): 272 per_replica_value = strategy.experimental_distribute_values_from_function( 273 lambda _: array_ops.identity(value_on_replica)) 274 275 def replica_fn(per_replica_value): 276 ctx = ds_context.get_replica_context() 277 local_value = array_ops.identity(per_replica_value) 278 return ctx.all_gather(local_value, axis=axis) 279 280 if not pure_eager: 281 replica_fn = def_function.function(replica_fn) 282 283 result = strategy.experimental_local_results( 284 strategy.run(replica_fn, args=(per_replica_value,))) 285 286 all_value = [value_on_replica for _ in range(strategy.num_replicas_in_sync)] 287 expect = array_ops.concat(all_value, axis=axis) 288 expected_result = [expect] * _get_num_replicas_per_client(strategy) 289 290 self.assertAllClose(expected_result, result) 291 292 def testAllGatherPerReplicaDense1D0Axis(self, strategy, pure_eager): 293 """all_gather(..., axis=0,...) a DistributedValues with a Tensor of shape (3,) on two replica returns a PerReplica of tensor(s) with shape (6,).""" 294 single_value = constant_op.constant([1, 2, 3], dtype=dtypes.float32) 295 axis = 0 296 self._all_gather_same_shape_and_verify(single_value, axis, pure_eager, 297 strategy) 298 299 def testAllGatherPerReplicaDense2D0Axis(self, strategy, pure_eager): 300 """all_gather(..., axis=0,...) a DistributedValues with a Tensor of shape (1,3) on two replica returns PerReplica of tensor(s) with shape (2,3).""" 301 single_value = constant_op.constant([[1, 2, 3]]) 302 axis = 0 303 self._all_gather_same_shape_and_verify(single_value, axis, pure_eager, 304 strategy) 305 306 def testAllGatherPerReplicaDense2D1Axis(self, strategy, pure_eager): 307 """all_gather(..., axis=1,...) a DistributedValues with a Tensor of shape (1,3) on two replica returns PerReplica of tensor(s) with shape (1,6).""" 308 single_value = constant_op.constant([[1, 2, 3]]) 309 axis = 1 310 self._all_gather_same_shape_and_verify(single_value, axis, pure_eager, 311 strategy) 312 313 def testAllGatherPerReplicaDense3D0Axis(self, strategy, pure_eager): 314 """all_gather(..., axis=0,...) a DistributedValues with a Tensor of shape (1,2,2) on two replica returns PerReplica of tensor(s) with shape (2,2,2).""" 315 single_value = constant_op.constant([[[1, 2], [1, 2]]]) 316 axis = 0 317 self._all_gather_same_shape_and_verify(single_value, axis, pure_eager, 318 strategy) 319 320 def testAllGatherPerReplicaDense3D1Axis(self, strategy, pure_eager): 321 """all_gather(..., axis=1,...) a DistributedValues with a Tensor of shape (1,2,2) on two replica returns PerReplica of tensor(s) with shape (1,4,2).""" 322 single_value = constant_op.constant([[[1, 2], [1, 2]]]) 323 axis = 1 324 self._all_gather_same_shape_and_verify(single_value, axis, pure_eager, 325 strategy) 326 327 def testAllGatherPerReplicaDense3D2Axis(self, strategy, pure_eager): 328 """all_gather(..., axis=2,...) a DistributedValues with a Tensor of shape (1,2,2) on two replica returns PerReplica of tensor(s) with shape (1,2,4).""" 329 single_value = constant_op.constant([[[1, 2], [1, 2]]]) 330 axis = 2 331 self._all_gather_same_shape_and_verify(single_value, axis, pure_eager, 332 strategy) 333 334 def testAllGatherDiffValueTPU(self, strategy, pure_eager): 335 # Test for TPU only since it can't be tested via testAllGatherDiffShape* 336 if not _is_tpu_strategy(strategy): 337 self.skipTest('Test for TPU only. For other strategies case already' 338 ' covered in other tests') 339 340 data = [[1], [2], [3], [4], [5], [6], [7], [8]] 341 342 axis = 0 343 dataset = dataset_ops.DatasetV2.from_tensor_slices(data).batch(8) 344 input_iterator = iter(strategy.experimental_distribute_dataset(dataset)) 345 346 @def_function.function 347 def replica_fn(per_replica_value): 348 ctx = ds_context.get_replica_context() 349 return ctx.all_gather(array_ops.identity(per_replica_value), axis=axis) 350 351 result = strategy.experimental_local_results( 352 strategy.run(replica_fn, args=(next(input_iterator),))) 353 354 expected_result = [data] * _get_num_replicas_per_client(strategy) 355 self.assertAllClose(expected_result, result) 356 357 def testAllGatherDiffShapeAtAxis0(self, strategy, pure_eager): 358 """Different `Axis==0`-th dimension: shape [1, 1], [2, 1] -> [3, 1].""" 359 360 if _is_tpu_strategy(strategy): 361 self.skipTest('TPU does not support all_gather different shapes') 362 363 def value_fn(ctx): 364 return constant_op.constant( 365 1, shape=(ctx.replica_id_in_sync_group + 1, 1)) 366 367 per_replica_value = strategy.experimental_distribute_values_from_function( 368 value_fn) 369 370 expect = constant_op.constant( 371 1, shape=(sum(range(strategy.num_replicas_in_sync + 1)), 1)) 372 373 def run(value): 374 value_identity = array_ops.identity(value) 375 ctx = ds_context.get_replica_context() 376 return ctx.all_gather(value_identity, axis=0) 377 378 if not pure_eager: 379 run = def_function.function(run) 380 381 expected_result = [expect] * _get_num_replicas_per_client(strategy) 382 result = strategy.experimental_local_results( 383 strategy.run(run, args=(per_replica_value,))) 384 self.assertAllEqual(expected_result, result) 385 386 def testAllGatherDiffShapeAtAxis1(self, strategy, pure_eager): 387 """Different `Axis`-th (not 0th) dimension: shape [1, 1], [1, 2] -> [1, 3].""" 388 if _is_tpu_strategy(strategy): 389 self.skipTest('TPU does not support all_gather different shapes') 390 391 def value_fn(ctx): 392 return constant_op.constant( 393 1, shape=(1, ctx.replica_id_in_sync_group + 1)) 394 395 per_replica_value = strategy.experimental_distribute_values_from_function( 396 value_fn) 397 398 expect = constant_op.constant( 399 1, shape=(1, sum(range(strategy.num_replicas_in_sync + 1)))) 400 401 def run(value): 402 value_identity = array_ops.identity(value) 403 ctx = ds_context.get_replica_context() 404 return ctx.all_gather(value_identity, axis=1) 405 406 if not pure_eager: 407 run = def_function.function(run) 408 409 expected_result = [expect] * _get_num_replicas_per_client(strategy) 410 result = strategy.experimental_local_results( 411 strategy.run(run, args=(per_replica_value,))) 412 self.assertAllEqual(expected_result, result) 413 414 def testAllGatherNest(self, strategy, pure_eager): 415 if _is_tpu_strategy(strategy): 416 self.skipTest('TPU does not support all_gather different shapes') 417 418 axis = 1 419 420 def value_fn(ctx): 421 value = constant_op.constant( 422 1, shape=(1, ctx.replica_id_in_sync_group + 1)) 423 return value 424 per_replica_value = strategy.experimental_distribute_values_from_function( 425 value_fn) 426 427 expect_1 = constant_op.constant( 428 1, shape=(1, sum(range(strategy.num_replicas_in_sync + 1)))) 429 430 expected_per_replica_1 = [expect_1] * _get_num_replicas_per_client(strategy) 431 432 value_2 = constant_op.constant([[[1, 2], [1, 2]]]) 433 434 expect_2 = array_ops.concat( 435 [value_2 for _ in range(strategy.num_replicas_in_sync)], axis=axis) 436 437 expected_per_replica_2 = [expect_2] * _get_num_replicas_per_client(strategy) 438 439 def run(value): 440 value_1 = array_ops.identity(value) 441 value_3 = array_ops.identity(value_2) 442 ctx = ds_context.get_replica_context() 443 return ctx.all_gather([value_1, value_3], axis=axis) 444 445 if not pure_eager: 446 run = def_function.function(run) 447 448 result = strategy.run(run, args=(per_replica_value,)) 449 self.assertAllEqual(expected_per_replica_1, 450 strategy.experimental_local_results(result[0])) 451 self.assertAllEqual(expected_per_replica_2, 452 strategy.experimental_local_results(result[1])) 453 454 def testAllGatherNest1D0Axis(self, strategy, pure_eager): 455 """all_gather(..., axis=0,...) a nest of DistributedValues.""" 456 single_value = constant_op.constant([1, 2, 3]) 457 axis = 0 458 459 def run(): 460 value_identity = array_ops.identity(single_value) 461 ctx = ds_context.get_replica_context() 462 return ctx.all_gather([value_identity, value_identity], axis=axis) 463 464 if not pure_eager: 465 run = def_function.function(run) 466 467 all_value = [single_value for _ in range(strategy.num_replicas_in_sync)] 468 expect = array_ops.concat(all_value, axis=axis) 469 expected_per_replica = [expect] * _get_num_replicas_per_client(strategy) 470 471 result = strategy.run(run) 472 for gathered_result in result: 473 self.assertAllEqual(expected_per_replica, 474 strategy.experimental_local_results(gathered_result)) 475 476 def testAllGatherRaiseDiffShapeAtNonAxis(self, strategy, pure_eager): 477 """Different at non-`axis`-th dimension : [2, 1], [1, 1], all_gather(...axis=1...) -> raise error.""" 478 if _is_tpu_strategy(strategy): 479 self.skipTest('TODO(b/169108777): raise a clear error message in xla.') 480 481 if isinstance(strategy, CollectiveAllReduceStrategy 482 ) and _get_num_replicas_per_client(strategy) > 1: 483 self.skipTest('b/167331966') 484 485 if strategy.num_replicas_in_sync <= 1: 486 self.skipTest('Test for more than 1 replica only.') 487 488 def value_fn(ctx): 489 return constant_op.constant( 490 1, shape=(1, ctx.replica_id_in_sync_group + 1)) 491 492 per_replica_value = strategy.experimental_distribute_values_from_function( 493 value_fn) 494 495 def run(value): 496 value_identity = array_ops.identity(value) 497 ctx = ds_context.get_replica_context() 498 return ctx.all_gather(value_identity, axis=0) 499 500 if not pure_eager: 501 run = def_function.function(run) 502 503 if isinstance(strategy, CollectiveAllReduceStrategy): 504 with self.assertRaisesRegex(errors.InvalidArgumentError, 505 r'Shape mismatch'): 506 strategy.run(run, args=(per_replica_value,)) 507 elif isinstance(strategy, 508 (mirrored_strategy.MirroredStrategy, 509 central_storage_strategy.CentralStorageStrategy)): 510 with self.assertRaisesRegex((errors.InvalidArgumentError, ValueError), 511 r'Dimension \d in both shapes must be equal'): 512 strategy.run(run, args=(per_replica_value,)) 513 514 def testAllGatherRaiseSparse(self, strategy, pure_eager): 515 dense_shape = [5, 2] 516 t0 = _make_indexed_slices( 517 values=[[1., 2.]], indices=[2], dense_shape=dense_shape) 518 519 def replica_fn(value): 520 ctx = ds_context.get_replica_context() 521 return ctx.all_gather(value, axis=0) 522 523 with self.assertRaisesRegex( 524 NotImplementedError, 525 r'all_gather does not support IndexedSlices'): 526 if not pure_eager: 527 strategy.run(def_function.function(replica_fn), args=(t0,)) 528 else: 529 strategy.run(replica_fn, args=(t0,)) 530 531 def testAllGatherRaiseDifferentRank(self, strategy, pure_eager): 532 """Different rank: [1,], [1, 2] -> raise error.""" 533 if _is_tpu_strategy(strategy): 534 self.skipTest('TODO(b/169108777): raise a clear error message in xla.') 535 536 if strategy.num_replicas_in_sync <= 1: 537 self.skipTest('Test for more than 1 replicas.') 538 if isinstance(strategy, CollectiveAllReduceStrategy 539 ) and _get_num_replicas_per_client(strategy) > 1: 540 self.skipTest('b/167331966') 541 def value_fn(ctx): 542 return array_ops.ones(shape=(range(1, ctx.replica_id_in_sync_group + 2))) 543 544 per_replica_value = strategy.experimental_distribute_values_from_function( 545 value_fn) 546 547 def run(value): 548 value_identity = array_ops.identity(value) 549 ctx = ds_context.get_replica_context() 550 return ctx.all_gather(value_identity, axis=0) 551 552 if not pure_eager: 553 run = def_function.function(run) 554 555 if isinstance(strategy, CollectiveAllReduceStrategy): 556 with self.assertRaisesRegex(errors.InvalidArgumentError, 557 r'Shape mismatch'): 558 strategy.run(run, args=(per_replica_value,)) 559 elif isinstance(strategy, 560 (mirrored_strategy.MirroredStrategy, 561 central_storage_strategy.CentralStorageStrategy)): 562 if pure_eager: 563 with self.assertRaises(errors.InvalidArgumentError) as e: 564 strategy.run(run, args=(per_replica_value,)) 565 # Different error message depending on whether collective ops is used. 566 self.assertRegexMatch( 567 str(e.exception), 568 ['Ranks of all input tensors should match', 'Shape mismatch']) 569 else: 570 with self.assertRaises((errors.InvalidArgumentError, ValueError)) as e: 571 strategy.run(run, args=(per_replica_value,)) 572 self.assertRegexMatch( 573 str(e.exception), 574 [r'Shape must be rank \d but is rank \d', 'Shape mismatch']) 575 else: 576 with self.assertRaisesRegex(ValueError, 577 r'Dimension \d in both shapes must be equal'): 578 strategy.run(run, args=(per_replica_value,)) 579 580 def testAllGatherGradient(self, strategy, pure_eager): 581 if pure_eager: 582 self.skipTest('`tf.gradients` is not supported with eager execution ' 583 'without using tf.functions.') 584 585 def all_gather_fn(value): 586 axis = 1 587 ctx = ds_context.get_replica_context() 588 return ctx.all_gather(array_ops.identity(value), axis) 589 590 gradient_comp = sum(range(1, strategy.num_replicas_in_sync + 1)) 591 gradient = [[gradient_comp], [gradient_comp]] 592 grads_for_all_replicas = [gradient] * _get_num_replicas_per_client(strategy) 593 594 @def_function.function 595 def step(c): 596 x = constant_op.constant([[3.], [5.]]) 597 mid = all_gather_fn(x) 598 y = mid * c 599 return gradients_impl.gradients_v2(y, [x])[0] 600 601 def value_fn(ctx): 602 x = [1., 2., 3., 4., 5., 6., 7., 8.] 603 return array_ops.constant([x[ctx.replica_id_in_sync_group]]) 604 605 per_replica_value = strategy.experimental_distribute_values_from_function( 606 value_fn) 607 result = strategy.experimental_local_results( 608 strategy.run(step, args=(per_replica_value,))) 609 610 self.assertAllEqual(grads_for_all_replicas, result) 611 612 def testAllGatherGradientNest(self, strategy, pure_eager): 613 if pure_eager: 614 self.skipTest('`tf.gradients` is not supported with eager execution ' 615 'without using tf.functions.') 616 617 def all_gather_fn(value): 618 axis = 1 619 ctx = ds_context.get_replica_context() 620 return ctx.all_gather(array_ops.identity(value), axis) 621 622 gradient_comp = sum(range(1, strategy.num_replicas_in_sync + 1)) 623 gradient = [[gradient_comp], [gradient_comp]] 624 grads_for_all_replicas = [gradient] * _get_num_replicas_per_client(strategy) 625 626 @def_function.function 627 def step(c): 628 x = constant_op.constant([[3.], [5.]]) 629 y = constant_op.constant([[2.], [4.]]) 630 mid = all_gather_fn([x, y]) 631 y = mid * c 632 return gradients_impl.gradients_v2(y, [x])[0] 633 634 def value_fn(ctx): 635 x = [1., 2., 3., 4., 5., 6., 7., 8.] 636 return array_ops.constant([x[ctx.replica_id_in_sync_group]]) 637 638 per_replica_value = strategy.experimental_distribute_values_from_function( 639 value_fn) 640 result = strategy.experimental_local_results( 641 strategy.run(step, args=(per_replica_value,))) 642 643 self.assertAllEqual(grads_for_all_replicas, result) 644 645 646def _make_indexed_slices(values, indices, dense_shape): 647 tensor = indexed_slices.IndexedSlices( 648 values=constant_op.constant(values), 649 indices=constant_op.constant(indices), 650 dense_shape=constant_op.constant(dense_shape)) 651 return tensor 652 653 654def _get_num_replicas_per_client(strategy): 655 if isinstance(strategy, CollectiveAllReduceStrategy): 656 resolver = strategy.cluster_resolver 657 return max(nest.flatten(resolver.num_accelerators())[0], 1) 658 else: 659 return strategy.num_replicas_in_sync 660 661 662def _is_tpu_strategy(strategy): 663 return isinstance(strategy, 664 (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1, 665 tpu_strategy.TPUStrategyV2)) 666 667 668if __name__ == '__main__': 669 test_util.main() 670