xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/random_generator_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests tf.random.Generator with distribution strategies."""
16
17import functools
18import os
19
20from absl.testing import parameterized
21from tensorflow.python.checkpoint import checkpoint as tracking_util
22from tensorflow.python.compat import v2_compat
23from tensorflow.python.distribute import combinations as ds_combinations
24from tensorflow.python.distribute import multi_process_runner
25from tensorflow.python.distribute import sharded_variable
26from tensorflow.python.distribute import strategy_combinations
27from tensorflow.python.distribute import values as dist_values
28from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib
29from tensorflow.python.distribute.mirrored_strategy import MirroredStrategy
30from tensorflow.python.eager import def_function
31from tensorflow.python.framework import constant_op
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import test_combinations as combinations
34from tensorflow.python.framework import test_util
35from tensorflow.python.module import module
36from tensorflow.python.ops import array_ops
37from tensorflow.python.ops import stateful_random_ops as rng
38from tensorflow.python.platform import test
39from tensorflow.python.saved_model import load
40from tensorflow.python.saved_model import save
41from tensorflow.python.util import deprecation
42
43
44def get_num_local_replicas(strat, values=None):
45  strat_name = type(strat).__name__
46  if "MultiWorker" in strat_name or "CollectiveAllReduceStrategy" in strat_name:
47    if values is None:
48      values = strat.run(lambda: constant_op.constant(0))
49      values = strat.experimental_local_results(values)
50    return len(values)
51  else:
52    return strat.num_replicas_in_sync
53
54
55ps_strategies = [
56    strategy_combinations.parameter_server_strategy_3worker_2ps_cpu,
57    strategy_combinations.parameter_server_strategy_1worker_2ps_cpu,
58    strategy_combinations.parameter_server_strategy_3worker_2ps_1gpu,
59    strategy_combinations.parameter_server_strategy_1worker_2ps_1gpu,
60]
61all_strategies = (strategy_combinations.all_strategies +
62                  strategy_combinations.multiworker_strategies +
63                  ps_strategies)
64
65
66def run_on_strategy(replica_fn, strat, coord):
67  def distributed_fn():
68    return strat.run(replica_fn)
69  if coord is not None:
70    results = coord.schedule(
71        def_function.function(distributed_fn)).fetch()
72  else:
73    results = distributed_fn()
74  return results
75
76
77class GeneratorTest(test.TestCase, parameterized.TestCase):
78
79  def setUp(self):
80    super(GeneratorTest, self).setUp()
81    v2_compat.enable_v2_behavior()
82
83  def assertAllDifferent(self, tensors):
84    """Checks that there are no duplicate elements anywhere among the tensors.
85
86    Args:
87      tensors: a list of tensors. They can have different shapes.
88    """
89    values = [array_ops.reshape(t, shape=[-1]) for t in tensors]
90    values = array_ops.concat(values, axis=0)
91    values = self.evaluate(values)
92    values = values.tolist()
93    self.assertAllEqual(len(values), len(set(values)))
94
95  @test_util.run_v2_only
96  def testCreateOutsideMirroredStrat(self):
97    """Tests RNG/MirrorStrategy interaction #1.
98
99    If an RNG is created outside a DS scope, all replicas will access the
100    same RNG object, and accesses are serialized.
101    """
102    shape = [3, 4]
103    dtype = dtypes.int32
104    gen = rng.Generator.from_seed(1234)
105    strat = MirroredStrategy(devices=["cpu:0", "cpu:1"])
106    with strat.scope():
107
108      def f():
109        t1 = gen.uniform_full_int(shape=shape, dtype=dtype)
110        t2 = gen.uniform_full_int(shape=shape, dtype=dtype)
111        t = array_ops.stack([t1, t2])
112        return t
113
114      results = strat.extended.call_for_each_replica(fn=f)
115      values = results.values
116      self.assertAllEqual(2, len(values))
117      self.assertAllDifferent(values)
118
119  @test_util.run_v2_only
120  def testMirroredStratParaAsync(self):
121    """Tests RNG/MirrorStrategy interaction #2.
122
123    The user can create n independent RNGs outside strategy.scope(), where n
124    is the number of replicas, and give one to each replica. The replicas can
125    thus get different random-number streams.
126    """
127    shape = [3, 4]
128    dtype = dtypes.int32
129    gens = rng.get_global_generator().split(count=2)
130    devices = ["cpu:0", "cpu:1"]
131    strat = MirroredStrategy(devices=devices)
132    # Use `PerReplica` to specify which `gen` is sent to which replica
133    gens = dist_values.PerReplica([[g] for g in gens])
134    with strat.scope():
135
136      def f(gen):
137        t1 = gen.uniform_full_int(shape=shape, dtype=dtype)
138        t2 = gen.uniform_full_int(shape=shape, dtype=dtype)
139        t = array_ops.stack([t1, t2])
140        return t
141
142      results = strat.extended.call_for_each_replica(fn=f, args=gens)
143      local_results = strat.experimental_local_results(results)
144      self.assertAllEqual(2, len(local_results))
145      self.assertAllDifferent(local_results)
146
147  @ds_combinations.generate(
148      combinations.combine(
149          strat=all_strategies,
150          mode=["eager"]))
151  def testCrossReplica(self, strat):
152    """Tests that RNG can be properly advanced in cross-replica context."""
153    def read_values(dv):
154      return [v.read_value() for v in strat.experimental_local_results(dv)]
155    with strat.scope():
156      g = rng.Generator.from_seed(1)
157      s1 = read_values(g.state)
158      g.normal([3])
159      g.skip(4)
160      s2 = read_values(g.state)
161    self.assertNotAllEqual(s1[0], s2[0])
162    self.assertEqual(len(s1), len(s2))
163    for i in range(1, len(s1)):
164      self.assertAllEqual(s1[0], s1[i])
165      self.assertAllEqual(s2[0], s2[i])
166
167  @ds_combinations.generate(
168      combinations.combine(
169          strat=all_strategies,
170          mode=["eager"],
171          jit_replica_fn=[False, True],
172          seeded=[True, False],))
173  def testDistStrat(self, strat, jit_replica_fn, seeded):
174    """Tests RNG with distribution strategies."""
175    strat_name = type(strat).__name__
176    if "TPU" in strat_name and not jit_replica_fn:
177      self.skipTest(
178          "TPUStrategy requires the replica function (the function passed to "
179          "strategy.run) to be decorated with tf.function")
180    coord = None
181    if "ParameterServer" in strat_name:
182      coord = coordinator_lib.ClusterCoordinator(strat)
183    creators = {
184        True: functools.partial(rng.Generator.from_seed, 1234),
185        False: rng.Generator.from_non_deterministic_state,
186    }
187    shape = [3, 4]
188    dtype = dtypes.int32
189    creator = creators[seeded]
190    with strat.scope():
191      gen = creator()
192      def f():
193        t1 = gen.uniform_full_int(shape=shape, dtype=dtype)
194        t2 = gen.uniform_full_int(shape=shape, dtype=dtype)
195        t = array_ops.stack([t1, t2])
196        return t
197      replica_fn = def_function.function(f) if jit_replica_fn else f
198      results = run_on_strategy(replica_fn, strat, coord)
199      values = strat.experimental_local_results(results)
200      n = get_num_local_replicas(strat, values)
201      self.assertAllEqual(n, len(values))
202      self.assertAllDifferent(values)
203
204  @ds_combinations.generate(
205      combinations.combine(
206          strat=[
207              strategy_combinations.parameter_server_strategy_fn(
208                  "ParameterServer1Worker2PSCPUFixedShards",
209                  num_workers=1, num_ps=2,
210                  variable_partitioner=(
211                      sharded_variable.FixedShardsPartitioner(2)))
212          ],
213          mode=["eager"]))
214  def testShardedError(self, strat):
215    """Tests error about sharding is raised."""
216    with strat.scope():
217      with self.assertRaisesRegex(
218          ValueError, "state is sharded, which is not allowed"):
219        rng.Generator.from_seed(1234)
220
221  @ds_combinations.generate(
222      combinations.combine(
223          strat=all_strategies,
224          mode=["eager"],
225          jit_replica_fn=[False, True]))
226  def testDistVarAsTFFunArg(self, strat, jit_replica_fn):
227    """Tests that RNG with dist variables can be used as tf.function's arg."""
228    strat_name = type(strat).__name__
229    if "CentralStorage" in strat_name:
230      self.skipTest(
231          "CentralStorageStrategy wraps variable updates in merge_call which "
232          "can't be called inside a tf.function that doesn't cover the entire "
233          "replica function (the function passed to strategy.run).")
234    if "TPU" in strat_name and not jit_replica_fn:
235      self.skipTest(
236          "TPUStrategy requires the replica function (the function passed to "
237          "strategy.run) to be decorated with tf.function")
238    coord = None
239    if "ParameterServer" in strat_name:
240      coord = coordinator_lib.ClusterCoordinator(strat)
241    shape = [3, 4]
242    dtype = dtypes.int32
243    with strat.scope():
244      gen = rng.Generator.from_seed(1234)
245      @def_function.function
246      def f(gen):  # the main focus
247        t1 = gen.uniform_full_int(shape=shape, dtype=dtype)
248        t2 = gen.uniform_full_int(shape=shape, dtype=dtype)
249        t = array_ops.stack([t1, t2])
250        return t
251      def g():
252        return f(gen)
253      replica_fn = def_function.function(g) if jit_replica_fn else g
254      for _ in range(2):
255        results = run_on_strategy(replica_fn, strat, coord)
256        values = strat.experimental_local_results(results)
257        n = get_num_local_replicas(strat, values)
258        self.assertAllEqual(n, len(values))
259        self.assertAllDifferent(values)
260
261  @ds_combinations.generate(
262      combinations.combine(
263          strat1=strategy_combinations.all_strategies,
264          strat2=strategy_combinations.all_strategies,
265          jit_replica_fn=[False, True],
266          mode=["eager"]) +
267      combinations.combine(
268          strat1=strategy_combinations.multiworker_strategies + ps_strategies,
269          strat2=[None],
270          jit_replica_fn=[False, True],
271          mode=["eager"]))
272  def testDistStratRestore(self, strat1, strat2, jit_replica_fn):
273    """Tests checkpointing and restoring (to possibly different #replicas)."""
274    if strat2 is None:
275      strat2 = strat1
276    strat1_name = type(strat1).__name__
277    strat2_name = type(strat2).__name__
278    if "Default" in strat1_name or "Default" in strat2_name:
279      self.skipTest(
280          "We don't guarantee consistency between strategy and no-strategy.")
281    if ("TPU" in strat1_name or "TPU" in strat2_name) and not jit_replica_fn:
282      self.skipTest(
283          "TPUStrategy requires the replica function (the function passed to "
284          "strategy.run) to be decorated with tf.function")
285    coord1 = None
286    if "ParameterServer" in strat1_name:
287      coord1 = coordinator_lib.ClusterCoordinator(strat1)
288    coord2 = None
289    if "ParameterServer" in strat2_name:
290      coord2 = coordinator_lib.ClusterCoordinator(strat2)
291    fname = os.path.join(self.get_temp_dir(), "checkpoint")
292    def uniform(strat, coord, g):
293      def f():
294        return g.uniform_full_int([3], dtype=dtypes.int32)
295      replica_fn = def_function.function(f) if jit_replica_fn else f
296      result = run_on_strategy(replica_fn, strat, coord)
297      return strat.experimental_local_results(result)
298    with strat1.scope():
299      g1 = rng.Generator.from_seed(1)
300    with strat2.scope():
301      g2 = rng.Generator.from_seed(10)
302    cp1 = tracking_util.Checkpoint(g=g1)
303    cp2 = tracking_util.Checkpoint(g=g2)
304    def write_restore_compare():
305      cp1.write(fname)
306      r1 = uniform(strat1, coord1, g1)
307      cp2.restore(fname)
308      r2 = uniform(strat2, coord2, g2)
309      # Tests that overlapping replicas are properly restored.
310      n1 = get_num_local_replicas(strat1)
311      n2 = get_num_local_replicas(strat2)
312      n = min(n1, n2)
313      self.assertAllEqual(r1[:n], r2[:n])
314    # Run multiple times so that cp1.write is called in various RNG states
315    for _ in range(2):
316      write_restore_compare()
317
318  @ds_combinations.generate(
319      combinations.combine(
320          strat=all_strategies,
321          mode=["eager"],
322          is_save_in_scope=[True, False]))
323  def testSavedModel(self, strat, is_save_in_scope):
324
325    class CustomModule(module.Module):
326
327      def __init__(self):
328        super(CustomModule, self).__init__()
329        self.g = rng.Generator.from_seed(0)
330
331      @def_function.function
332      def __call__(self):
333        return self.g.state
334
335      @def_function.function
336      def mutate(self):
337        self.g.normal([])
338
339    with strat.scope():
340      m = CustomModule()
341      m.mutate()
342      state_before = m()
343      path = os.path.join(self.get_temp_dir(), "saved_model")
344    if is_save_in_scope:
345      with strat.scope():
346        save.save(m, path)
347    else:
348      save.save(m, path)
349    with strat.scope():
350      m.mutate()
351      state_before_2 = m()
352
353    imported = load.load(path)
354    state_after = imported()
355    self.assertAllEqual(state_before, state_after)
356    imported.mutate()
357    state_after_2 = imported()
358    self.assertAllEqual(state_before_2, state_after_2)
359
360
361if __name__ == "__main__":
362  with deprecation.silence():
363    multi_process_runner.test_main()
364