xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/strategy_gather_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for common methods in strategy classes."""
16
17from absl.testing import parameterized
18
19from tensorflow.python.data.ops import dataset_ops
20from tensorflow.python.distribute import central_storage_strategy
21from tensorflow.python.distribute import combinations
22from tensorflow.python.distribute import distribution_strategy_context as ds_context
23from tensorflow.python.distribute import mirrored_strategy
24from tensorflow.python.distribute import strategy_combinations
25from tensorflow.python.distribute import test_util
26from tensorflow.python.distribute import tpu_strategy
27from tensorflow.python.distribute.collective_all_reduce_strategy import CollectiveAllReduceStrategy
28from tensorflow.python.eager import def_function
29from tensorflow.python.framework import constant_op
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import errors
32from tensorflow.python.framework import indexed_slices
33from tensorflow.python.framework import test_util as tf_test_util
34from tensorflow.python.ops import array_ops
35from tensorflow.python.ops import gradients_impl
36from tensorflow.python.platform import test
37from tensorflow.python.util import nest
38
39
40@tf_test_util.with_eager_op_as_function
41@combinations.generate(
42    combinations.combine(
43        strategy=[
44            strategy_combinations.default_strategy,
45            strategy_combinations.one_device_strategy,
46            strategy_combinations.one_device_strategy_gpu,
47            strategy_combinations.central_storage_strategy_with_two_gpus,
48            strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
49            strategy_combinations.mirrored_strategy_with_one_cpu,
50            strategy_combinations.mirrored_strategy_with_one_gpu,
51            strategy_combinations.mirrored_strategy_with_two_gpus,
52            strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
53            strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
54            strategy_combinations.multi_worker_mirrored_2x2_gpu,
55            strategy_combinations.multi_worker_mirrored_2x1_cpu,
56            strategy_combinations.multi_worker_mirrored_2x1_gpu,
57        ],
58        mode=['eager'],
59        pure_eager=[True, False]) + combinations.combine(
60            strategy=[
61                strategy_combinations.tpu_strategy,
62                strategy_combinations.tpu_strategy_packed_var,
63                strategy_combinations.tpu_strategy_one_step,
64                strategy_combinations.cloud_tpu_strategy,
65            ],
66            mode=['eager'],
67            pure_eager=[False]))
68class GatherTest(test.TestCase, parameterized.TestCase):
69
70  def _gather_same_shape_and_verify(self, value_on_replica, axis, pure_eager,
71                                    strategy):
72    distributed_values = strategy.experimental_distribute_values_from_function(
73        lambda _: array_ops.identity(value_on_replica))
74
75    def run():
76      return strategy.gather(distributed_values, axis=axis)
77
78    if not pure_eager:
79      run = def_function.function(run)
80
81    all_results = [
82        value_on_replica for _ in range(strategy.num_replicas_in_sync)
83    ]
84    expected_result = array_ops.concat(all_results, axis=axis)
85    self.assertAllEqual(expected_result, run().numpy())
86
87  def testGatherPerReplicaDense1D0Axis(self, strategy, pure_eager):
88    """A DistributedValues object with two tensors of shape [3] on each replica gathers to a tensor of [6]."""
89    single_value = constant_op.constant([1, 2, 3])
90    axis = 0
91    self._gather_same_shape_and_verify(single_value, axis, pure_eager, strategy)
92
93  def testGatherPerReplicaDense2D0Axis(self, strategy, pure_eager):
94    """A DistributedValues object with two tensors of [1, 3] on each replica gathers along 0th dim to a tensor of [2, 3]."""
95    single_value = constant_op.constant([[1, 2, 3]])
96    axis = 0
97    self._gather_same_shape_and_verify(single_value, axis, pure_eager, strategy)
98
99  def testGatherPerReplicaDense2D1Axis(self, strategy, pure_eager):
100    """A DistributedValues object with two tensors of [1, 3] on each replica gathers along 1st dim to a tensor of [1, 6]."""
101    single_value = constant_op.constant([[1, 2, 3]])
102    axis = 1
103    self._gather_same_shape_and_verify(single_value, axis, pure_eager, strategy)
104
105  def testGatherPerReplicaDense3D0Axis(self, strategy, pure_eager):
106    """A DistributedValues object with two tensors of [1, 2, 2] on each replica gathers along 0th dim to a tensor of [2, 2, 2]."""
107    single_value = constant_op.constant([[[1, 2], [1, 2]]])
108    axis = 0
109    self._gather_same_shape_and_verify(single_value, axis, pure_eager, strategy)
110
111  def testGatherPerReplicaDense3D1Axis(self, strategy, pure_eager):
112    """A DistributedValues object with two tensors of [1, 2, 2] on each replica gathers along 1nd dimension to a tensor of [1, 4, 2]."""
113    single_value = constant_op.constant([[[1, 2], [1, 2]]])
114    axis = 1
115    self._gather_same_shape_and_verify(single_value, axis, pure_eager, strategy)
116
117  def testGatherPerReplicaDense3D2Axis(self, strategy, pure_eager):
118    """A DistributedValues object with two tensors of [1, 2, 2] on each replica gathers along 2nd dimension to a tensor of [1, 2, 4]."""
119    single_value = constant_op.constant([[[1, 2], [1, 2]]])
120    axis = 2
121    self._gather_same_shape_and_verify(single_value, axis, pure_eager, strategy)
122
123  def testGatherDiffShapeAtAxis0(self, strategy, pure_eager):
124    """Different `Axis`-th (0) dimension: shape [1, 1], [2, 1] -> [3, 1]."""
125
126    def value_fn(ctx):
127      return constant_op.constant(
128          1, shape=(ctx.replica_id_in_sync_group + 1, 1))
129
130    distributed_values = strategy.experimental_distribute_values_from_function(
131        value_fn)
132    axis = 0
133
134    def run():
135      return strategy.gather(distributed_values, axis=axis)
136
137    if not pure_eager:
138      run = def_function.function(run)
139
140    expected_result = constant_op.constant(
141        1, shape=(sum(range(strategy.num_replicas_in_sync + 1)), 1))
142
143    self.assertAllEqual(expected_result, run().numpy())
144
145  def testGatherDiffShapeAtAxis1(self, strategy, pure_eager):
146    """Different `Axis`-th (non-0) dimension: shape [1, 1], [1, 2] -> [1, 3]."""
147
148    def value_fn(ctx):
149      return constant_op.constant(
150          1, shape=(1, ctx.replica_id_in_sync_group + 1))
151
152    distributed_values = strategy.experimental_distribute_values_from_function(
153        value_fn)
154    axis = 1
155
156    def run():
157      return strategy.gather(distributed_values, axis=axis)
158
159    if not pure_eager:
160      run = def_function.function(run)
161
162    expected_result = constant_op.constant(
163        1, shape=(1, sum(range(strategy.num_replicas_in_sync + 1))))
164
165    self.assertAllEqual(expected_result, run().numpy())
166
167  def testGatherRaiseDiffShapeAtNonAxis(self, strategy, pure_eager):
168    """Different at non-`axis`-th dimension : [1, 1], [1, 2], 0th -> raise error."""
169    if isinstance(strategy, CollectiveAllReduceStrategy
170                 ) and _get_num_replicas_per_client(strategy) > 1:
171      self.skipTest('b/167331966')
172
173    if strategy.num_replicas_in_sync <= 1:
174      self.skipTest('Test for more than 1 replica only.')
175
176    def value_fn(ctx):
177      return constant_op.constant(
178          1, shape=(1, ctx.replica_id_in_sync_group + 1))
179
180    distributed_values = strategy.experimental_distribute_values_from_function(
181        value_fn)
182    axis = 0
183
184    def run():
185      return strategy.gather(distributed_values, axis=axis)
186
187    if not pure_eager:
188      run = def_function.function(run)
189
190    if isinstance(strategy, CollectiveAllReduceStrategy):
191      with self.assertRaisesRegex(errors.InvalidArgumentError,
192                                  r'Shape mismatch'):
193        run()
194    elif isinstance(strategy,
195                    (mirrored_strategy.MirroredStrategy,
196                     central_storage_strategy.CentralStorageStrategy)):
197      with self.assertRaisesRegex((errors.InvalidArgumentError, ValueError),
198                                  r'Dimension \d in both shapes must be equal'):
199        run()
200
201  def testGatherRaiseSparse(self, strategy, pure_eager):
202    dense_shape = [5, 2]
203    t0 = _make_indexed_slices(
204        values=[[1., 2.]], indices=[2], dense_shape=dense_shape)
205
206    def run(value):
207      return strategy.gather(value, axis=0)
208
209    with self.assertRaisesRegex(
210        NotImplementedError,
211        r'gather does not support IndexedSlices'):
212      if pure_eager:
213        run(t0)
214      else:
215        def_function.function(run)(t0)
216
217  def testGatherRaiseDifferentRank(self, strategy, pure_eager):
218    """Different rank: [1,], [1, 2] -> raise error."""
219    if strategy.num_replicas_in_sync <= 1:
220      self.skipTest('Test for more than 1 replicas.')
221    if isinstance(strategy, CollectiveAllReduceStrategy
222                 ) and _get_num_replicas_per_client(strategy) > 1:
223      self.skipTest('b/167331966')
224    def value_fn(ctx):
225      return array_ops.ones(shape=(range(1, ctx.replica_id_in_sync_group + 2)))
226
227    distributed_values = strategy.experimental_distribute_values_from_function(
228        value_fn)
229    axis = 0
230
231    def run():
232      return strategy.gather(distributed_values, axis=axis)
233
234    if not pure_eager:
235      run = def_function.function(run)
236
237    if isinstance(strategy, CollectiveAllReduceStrategy):
238      with self.assertRaisesRegex(errors.InvalidArgumentError,
239                                  r'Shape mismatch'):
240        run()
241    elif isinstance(
242        strategy,
243        (mirrored_strategy.MirroredStrategy,
244         central_storage_strategy.CentralStorageStrategy)):
245      if pure_eager:
246        with self.assertRaises(errors.InvalidArgumentError) as e:
247          run()
248        # Different error message depending on whether collective ops is used.
249        self.assertRegexMatch(
250            str(e.exception),
251            ['Ranks of all input tensors should match', 'Shape mismatch'])
252      else:
253        with self.assertRaises((errors.InvalidArgumentError, ValueError)) as e:
254          run()
255        self.assertRegexMatch(
256            str(e.exception),
257            [r'Shape must be rank \d but is rank \d', 'Shape mismatch'])
258    elif _is_tpu_strategy(strategy) and pure_eager:
259      with self.assertRaisesRegex(ValueError,
260                                  r'Dimension \d in both shapes must be equal'):
261        run()
262    else:
263      with self.assertRaisesRegex(ValueError,
264                                  r'Shape must be rank \d but is rank \d'):
265        run()
266
267  # Ideally, here we should split them into another test class, AllGatherTest.
268  # But doing that makes two initialize_tpu_system() calls and one of them times
269  # out, on Kokoro. Integrating two into one avoids it.
270  def _all_gather_same_shape_and_verify(self, value_on_replica, axis,
271                                        pure_eager, strategy):
272    per_replica_value = strategy.experimental_distribute_values_from_function(
273        lambda _: array_ops.identity(value_on_replica))
274
275    def replica_fn(per_replica_value):
276      ctx = ds_context.get_replica_context()
277      local_value = array_ops.identity(per_replica_value)
278      return ctx.all_gather(local_value, axis=axis)
279
280    if not pure_eager:
281      replica_fn = def_function.function(replica_fn)
282
283    result = strategy.experimental_local_results(
284        strategy.run(replica_fn, args=(per_replica_value,)))
285
286    all_value = [value_on_replica for _ in range(strategy.num_replicas_in_sync)]
287    expect = array_ops.concat(all_value, axis=axis)
288    expected_result = [expect] * _get_num_replicas_per_client(strategy)
289
290    self.assertAllClose(expected_result, result)
291
292  def testAllGatherPerReplicaDense1D0Axis(self, strategy, pure_eager):
293    """all_gather(..., axis=0,...) a DistributedValues with a Tensor of shape (3,) on two replica returns a PerReplica of tensor(s) with shape (6,)."""
294    single_value = constant_op.constant([1, 2, 3], dtype=dtypes.float32)
295    axis = 0
296    self._all_gather_same_shape_and_verify(single_value, axis, pure_eager,
297                                           strategy)
298
299  def testAllGatherPerReplicaDense2D0Axis(self, strategy, pure_eager):
300    """all_gather(..., axis=0,...) a DistributedValues with a Tensor of shape (1,3) on two replica returns PerReplica of tensor(s) with shape (2,3)."""
301    single_value = constant_op.constant([[1, 2, 3]])
302    axis = 0
303    self._all_gather_same_shape_and_verify(single_value, axis, pure_eager,
304                                           strategy)
305
306  def testAllGatherPerReplicaDense2D1Axis(self, strategy, pure_eager):
307    """all_gather(..., axis=1,...) a DistributedValues with a Tensor of shape (1,3) on two replica returns PerReplica of tensor(s) with shape (1,6)."""
308    single_value = constant_op.constant([[1, 2, 3]])
309    axis = 1
310    self._all_gather_same_shape_and_verify(single_value, axis, pure_eager,
311                                           strategy)
312
313  def testAllGatherPerReplicaDense3D0Axis(self, strategy, pure_eager):
314    """all_gather(..., axis=0,...) a DistributedValues with a Tensor of shape (1,2,2) on two replica returns PerReplica of tensor(s) with shape (2,2,2)."""
315    single_value = constant_op.constant([[[1, 2], [1, 2]]])
316    axis = 0
317    self._all_gather_same_shape_and_verify(single_value, axis, pure_eager,
318                                           strategy)
319
320  def testAllGatherPerReplicaDense3D1Axis(self, strategy, pure_eager):
321    """all_gather(..., axis=1,...) a DistributedValues with a Tensor of shape (1,2,2) on two replica returns PerReplica of tensor(s) with shape (1,4,2)."""
322    single_value = constant_op.constant([[[1, 2], [1, 2]]])
323    axis = 1
324    self._all_gather_same_shape_and_verify(single_value, axis, pure_eager,
325                                           strategy)
326
327  def testAllGatherPerReplicaDense3D2Axis(self, strategy, pure_eager):
328    """all_gather(..., axis=2,...) a DistributedValues with a Tensor of shape (1,2,2) on two replica returns PerReplica of tensor(s) with shape (1,2,4)."""
329    single_value = constant_op.constant([[[1, 2], [1, 2]]])
330    axis = 2
331    self._all_gather_same_shape_and_verify(single_value, axis, pure_eager,
332                                           strategy)
333
334  def testAllGatherDiffValueTPU(self, strategy, pure_eager):
335    # Test for TPU only since it can't be tested via testAllGatherDiffShape*
336    if not _is_tpu_strategy(strategy):
337      self.skipTest('Test for TPU only. For other strategies case already'
338                    ' covered in other tests')
339
340    data = [[1], [2], [3], [4], [5], [6], [7], [8]]
341
342    axis = 0
343    dataset = dataset_ops.DatasetV2.from_tensor_slices(data).batch(8)
344    input_iterator = iter(strategy.experimental_distribute_dataset(dataset))
345
346    @def_function.function
347    def replica_fn(per_replica_value):
348      ctx = ds_context.get_replica_context()
349      return ctx.all_gather(array_ops.identity(per_replica_value), axis=axis)
350
351    result = strategy.experimental_local_results(
352        strategy.run(replica_fn, args=(next(input_iterator),)))
353
354    expected_result = [data] * _get_num_replicas_per_client(strategy)
355    self.assertAllClose(expected_result, result)
356
357  def testAllGatherDiffShapeAtAxis0(self, strategy, pure_eager):
358    """Different `Axis==0`-th dimension: shape [1, 1], [2, 1] -> [3, 1]."""
359
360    if _is_tpu_strategy(strategy):
361      self.skipTest('TPU does not support all_gather different shapes')
362
363    def value_fn(ctx):
364      return constant_op.constant(
365          1, shape=(ctx.replica_id_in_sync_group + 1, 1))
366
367    per_replica_value = strategy.experimental_distribute_values_from_function(
368        value_fn)
369
370    expect = constant_op.constant(
371        1, shape=(sum(range(strategy.num_replicas_in_sync + 1)), 1))
372
373    def run(value):
374      value_identity = array_ops.identity(value)
375      ctx = ds_context.get_replica_context()
376      return ctx.all_gather(value_identity, axis=0)
377
378    if not pure_eager:
379      run = def_function.function(run)
380
381    expected_result = [expect] * _get_num_replicas_per_client(strategy)
382    result = strategy.experimental_local_results(
383        strategy.run(run, args=(per_replica_value,)))
384    self.assertAllEqual(expected_result, result)
385
386  def testAllGatherDiffShapeAtAxis1(self, strategy, pure_eager):
387    """Different `Axis`-th (not 0th) dimension: shape [1, 1], [1, 2] -> [1, 3]."""
388    if _is_tpu_strategy(strategy):
389      self.skipTest('TPU does not support all_gather different shapes')
390
391    def value_fn(ctx):
392      return constant_op.constant(
393          1, shape=(1, ctx.replica_id_in_sync_group + 1))
394
395    per_replica_value = strategy.experimental_distribute_values_from_function(
396        value_fn)
397
398    expect = constant_op.constant(
399        1, shape=(1, sum(range(strategy.num_replicas_in_sync + 1))))
400
401    def run(value):
402      value_identity = array_ops.identity(value)
403      ctx = ds_context.get_replica_context()
404      return ctx.all_gather(value_identity, axis=1)
405
406    if not pure_eager:
407      run = def_function.function(run)
408
409    expected_result = [expect] * _get_num_replicas_per_client(strategy)
410    result = strategy.experimental_local_results(
411        strategy.run(run, args=(per_replica_value,)))
412    self.assertAllEqual(expected_result, result)
413
414  def testAllGatherNest(self, strategy, pure_eager):
415    if _is_tpu_strategy(strategy):
416      self.skipTest('TPU does not support all_gather different shapes')
417
418    axis = 1
419
420    def value_fn(ctx):
421      value = constant_op.constant(
422          1, shape=(1, ctx.replica_id_in_sync_group + 1))
423      return value
424    per_replica_value = strategy.experimental_distribute_values_from_function(
425        value_fn)
426
427    expect_1 = constant_op.constant(
428        1, shape=(1, sum(range(strategy.num_replicas_in_sync + 1))))
429
430    expected_per_replica_1 = [expect_1] * _get_num_replicas_per_client(strategy)
431
432    value_2 = constant_op.constant([[[1, 2], [1, 2]]])
433
434    expect_2 = array_ops.concat(
435        [value_2 for _ in range(strategy.num_replicas_in_sync)], axis=axis)
436
437    expected_per_replica_2 = [expect_2] * _get_num_replicas_per_client(strategy)
438
439    def run(value):
440      value_1 = array_ops.identity(value)
441      value_3 = array_ops.identity(value_2)
442      ctx = ds_context.get_replica_context()
443      return ctx.all_gather([value_1, value_3], axis=axis)
444
445    if not pure_eager:
446      run = def_function.function(run)
447
448    result = strategy.run(run, args=(per_replica_value,))
449    self.assertAllEqual(expected_per_replica_1,
450                        strategy.experimental_local_results(result[0]))
451    self.assertAllEqual(expected_per_replica_2,
452                        strategy.experimental_local_results(result[1]))
453
454  def testAllGatherNest1D0Axis(self, strategy, pure_eager):
455    """all_gather(..., axis=0,...) a nest of DistributedValues."""
456    single_value = constant_op.constant([1, 2, 3])
457    axis = 0
458
459    def run():
460      value_identity = array_ops.identity(single_value)
461      ctx = ds_context.get_replica_context()
462      return ctx.all_gather([value_identity, value_identity], axis=axis)
463
464    if not pure_eager:
465      run = def_function.function(run)
466
467    all_value = [single_value for _ in range(strategy.num_replicas_in_sync)]
468    expect = array_ops.concat(all_value, axis=axis)
469    expected_per_replica = [expect] * _get_num_replicas_per_client(strategy)
470
471    result = strategy.run(run)
472    for gathered_result in result:
473      self.assertAllEqual(expected_per_replica,
474                          strategy.experimental_local_results(gathered_result))
475
476  def testAllGatherRaiseDiffShapeAtNonAxis(self, strategy, pure_eager):
477    """Different at non-`axis`-th dimension : [2, 1], [1, 1], all_gather(...axis=1...) -> raise error."""
478    if _is_tpu_strategy(strategy):
479      self.skipTest('TODO(b/169108777): raise a clear error message in xla.')
480
481    if isinstance(strategy, CollectiveAllReduceStrategy
482                 ) and _get_num_replicas_per_client(strategy) > 1:
483      self.skipTest('b/167331966')
484
485    if strategy.num_replicas_in_sync <= 1:
486      self.skipTest('Test for more than 1 replica only.')
487
488    def value_fn(ctx):
489      return constant_op.constant(
490          1, shape=(1, ctx.replica_id_in_sync_group + 1))
491
492    per_replica_value = strategy.experimental_distribute_values_from_function(
493        value_fn)
494
495    def run(value):
496      value_identity = array_ops.identity(value)
497      ctx = ds_context.get_replica_context()
498      return ctx.all_gather(value_identity, axis=0)
499
500    if not pure_eager:
501      run = def_function.function(run)
502
503    if isinstance(strategy, CollectiveAllReduceStrategy):
504      with self.assertRaisesRegex(errors.InvalidArgumentError,
505                                  r'Shape mismatch'):
506        strategy.run(run, args=(per_replica_value,))
507    elif isinstance(strategy,
508                    (mirrored_strategy.MirroredStrategy,
509                     central_storage_strategy.CentralStorageStrategy)):
510      with self.assertRaisesRegex((errors.InvalidArgumentError, ValueError),
511                                  r'Dimension \d in both shapes must be equal'):
512        strategy.run(run, args=(per_replica_value,))
513
514  def testAllGatherRaiseSparse(self, strategy, pure_eager):
515    dense_shape = [5, 2]
516    t0 = _make_indexed_slices(
517        values=[[1., 2.]], indices=[2], dense_shape=dense_shape)
518
519    def replica_fn(value):
520      ctx = ds_context.get_replica_context()
521      return ctx.all_gather(value, axis=0)
522
523    with self.assertRaisesRegex(
524        NotImplementedError,
525        r'all_gather does not support IndexedSlices'):
526      if not pure_eager:
527        strategy.run(def_function.function(replica_fn), args=(t0,))
528      else:
529        strategy.run(replica_fn, args=(t0,))
530
531  def testAllGatherRaiseDifferentRank(self, strategy, pure_eager):
532    """Different rank: [1,], [1, 2] -> raise error."""
533    if _is_tpu_strategy(strategy):
534      self.skipTest('TODO(b/169108777): raise a clear error message in xla.')
535
536    if strategy.num_replicas_in_sync <= 1:
537      self.skipTest('Test for more than 1 replicas.')
538    if isinstance(strategy, CollectiveAllReduceStrategy
539                 ) and _get_num_replicas_per_client(strategy) > 1:
540      self.skipTest('b/167331966')
541    def value_fn(ctx):
542      return array_ops.ones(shape=(range(1, ctx.replica_id_in_sync_group + 2)))
543
544    per_replica_value = strategy.experimental_distribute_values_from_function(
545        value_fn)
546
547    def run(value):
548      value_identity = array_ops.identity(value)
549      ctx = ds_context.get_replica_context()
550      return ctx.all_gather(value_identity, axis=0)
551
552    if not pure_eager:
553      run = def_function.function(run)
554
555    if isinstance(strategy, CollectiveAllReduceStrategy):
556      with self.assertRaisesRegex(errors.InvalidArgumentError,
557                                  r'Shape mismatch'):
558        strategy.run(run, args=(per_replica_value,))
559    elif isinstance(strategy,
560                    (mirrored_strategy.MirroredStrategy,
561                     central_storage_strategy.CentralStorageStrategy)):
562      if pure_eager:
563        with self.assertRaises(errors.InvalidArgumentError) as e:
564          strategy.run(run, args=(per_replica_value,))
565        # Different error message depending on whether collective ops is used.
566        self.assertRegexMatch(
567            str(e.exception),
568            ['Ranks of all input tensors should match', 'Shape mismatch'])
569      else:
570        with self.assertRaises((errors.InvalidArgumentError, ValueError)) as e:
571          strategy.run(run, args=(per_replica_value,))
572        self.assertRegexMatch(
573            str(e.exception),
574            [r'Shape must be rank \d but is rank \d', 'Shape mismatch'])
575    else:
576      with self.assertRaisesRegex(ValueError,
577                                  r'Dimension \d in both shapes must be equal'):
578        strategy.run(run, args=(per_replica_value,))
579
580  def testAllGatherGradient(self, strategy, pure_eager):
581    if pure_eager:
582      self.skipTest('`tf.gradients` is not supported with eager execution '
583                    'without using tf.functions.')
584
585    def all_gather_fn(value):
586      axis = 1
587      ctx = ds_context.get_replica_context()
588      return ctx.all_gather(array_ops.identity(value), axis)
589
590    gradient_comp = sum(range(1, strategy.num_replicas_in_sync + 1))
591    gradient = [[gradient_comp], [gradient_comp]]
592    grads_for_all_replicas = [gradient] * _get_num_replicas_per_client(strategy)
593
594    @def_function.function
595    def step(c):
596      x = constant_op.constant([[3.], [5.]])
597      mid = all_gather_fn(x)
598      y = mid * c
599      return gradients_impl.gradients_v2(y, [x])[0]
600
601    def value_fn(ctx):
602      x = [1., 2., 3., 4., 5., 6., 7., 8.]
603      return array_ops.constant([x[ctx.replica_id_in_sync_group]])
604
605    per_replica_value = strategy.experimental_distribute_values_from_function(
606        value_fn)
607    result = strategy.experimental_local_results(
608        strategy.run(step, args=(per_replica_value,)))
609
610    self.assertAllEqual(grads_for_all_replicas, result)
611
612  def testAllGatherGradientNest(self, strategy, pure_eager):
613    if pure_eager:
614      self.skipTest('`tf.gradients` is not supported with eager execution '
615                    'without using tf.functions.')
616
617    def all_gather_fn(value):
618      axis = 1
619      ctx = ds_context.get_replica_context()
620      return ctx.all_gather(array_ops.identity(value), axis)
621
622    gradient_comp = sum(range(1, strategy.num_replicas_in_sync + 1))
623    gradient = [[gradient_comp], [gradient_comp]]
624    grads_for_all_replicas = [gradient] * _get_num_replicas_per_client(strategy)
625
626    @def_function.function
627    def step(c):
628      x = constant_op.constant([[3.], [5.]])
629      y = constant_op.constant([[2.], [4.]])
630      mid = all_gather_fn([x, y])
631      y = mid * c
632      return gradients_impl.gradients_v2(y, [x])[0]
633
634    def value_fn(ctx):
635      x = [1., 2., 3., 4., 5., 6., 7., 8.]
636      return array_ops.constant([x[ctx.replica_id_in_sync_group]])
637
638    per_replica_value = strategy.experimental_distribute_values_from_function(
639        value_fn)
640    result = strategy.experimental_local_results(
641        strategy.run(step, args=(per_replica_value,)))
642
643    self.assertAllEqual(grads_for_all_replicas, result)
644
645
646def _make_indexed_slices(values, indices, dense_shape):
647  tensor = indexed_slices.IndexedSlices(
648      values=constant_op.constant(values),
649      indices=constant_op.constant(indices),
650      dense_shape=constant_op.constant(dense_shape))
651  return tensor
652
653
654def _get_num_replicas_per_client(strategy):
655  if isinstance(strategy, CollectiveAllReduceStrategy):
656    resolver = strategy.cluster_resolver
657    return max(nest.flatten(resolver.num_accelerators())[0], 1)
658  else:
659    return strategy.num_replicas_in_sync
660
661
662def _is_tpu_strategy(strategy):
663  return isinstance(strategy,
664                    (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1,
665                     tpu_strategy.TPUStrategyV2))
666
667
668if __name__ == '__main__':
669  test_util.main()
670