1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for parameter_server_strategy_v2.py."""
16
17import contextlib
18import functools
19import os
20
21from absl.testing import parameterized
22import numpy as np
23
24from tensorflow.core.protobuf import saved_model_pb2
25from tensorflow.python.checkpoint import checkpoint as tracking_util
26from tensorflow.python.compat import v2_compat
27from tensorflow.python.data.ops import dataset_ops
28from tensorflow.python.distribute import distribution_strategy_context
29from tensorflow.python.distribute import multi_process_runner
30from tensorflow.python.distribute import multi_worker_test_base
31from tensorflow.python.distribute import parameter_server_strategy_v2
32from tensorflow.python.distribute import ps_values
33from tensorflow.python.distribute import sharded_variable
34from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
35from tensorflow.python.eager import context
36from tensorflow.python.eager import def_function
37from tensorflow.python.eager import test
38from tensorflow.python.framework import constant_op
39from tensorflow.python.framework import dtypes
40from tensorflow.python.framework import ops
41from tensorflow.python.framework import tensor_spec
42from tensorflow.python.framework import test_util
43from tensorflow.python.module import module
44from tensorflow.python.ops import array_ops
45from tensorflow.python.ops import embedding_ops
46from tensorflow.python.ops import init_ops_v2
47from tensorflow.python.ops import linalg_ops_impl
48from tensorflow.python.ops import math_ops
49from tensorflow.python.ops import variable_scope
50from tensorflow.python.ops import variables
51from tensorflow.python.platform import gfile
52from tensorflow.python.saved_model import save as tf_save
53from tensorflow.python.trackable import autotrackable
54from tensorflow.python.training.server_lib import ClusterSpec
55
56
57class ParameterServerStrategyV2Test(test.TestCase):
58
59  @classmethod
60  def setUpClass(cls):
61    super(ParameterServerStrategyV2Test, cls).setUpClass()
62    cls.cluster = multi_worker_test_base.create_multi_process_cluster(
63        num_workers=2, num_ps=3, rpc_layer="grpc")
64    cls.cluster_resolver = cls.cluster.cluster_resolver
65
66  @classmethod
67  def tearDownClass(cls):
68    super(ParameterServerStrategyV2Test, cls).tearDownClass()
69    cls.cluster.stop()
70
71  def testVariablePlacement(self):
72
73    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
74        self.cluster_resolver)
75    v1 = variables.Variable(initial_value=0.0)
76    with strategy.scope():
77      v2 = variables.Variable(initial_value=1.0)
78      v3 = variables.Variable(initial_value=2.0)
79      v4 = variables.Variable(initial_value=3.0)
80      v5 = variables.Variable(initial_value=4.0)
81    # v1 was created outside scope so should be on client.
82    gpu_devices = context.num_gpus()
83    if gpu_devices:
84      # For tests with GPUs
85      self.assertEqual(v1.device, "/job:chief/replica:0/task:0/device:GPU:0")
86    else:
87      self.assertEqual(v1.device, "/job:chief/replica:0/task:0/device:CPU:0")
88    # v2 through v5 are created in scope and in a round-robin manner.
89    self.assertEqual(v2.device, "/job:ps/replica:0/task:0/device:CPU:0")
90    self.assertEqual(v3.device, "/job:ps/replica:0/task:1/device:CPU:0")
91    self.assertEqual(v4.device, "/job:ps/replica:0/task:2/device:CPU:0")
92    self.assertEqual(v5.device, "/job:ps/replica:0/task:0/device:CPU:0")
93
94  def testInteractionWithDeviceScope(self):
95    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
96        self.cluster_resolver)
97
98    # The strategy scope always wins.
99    with strategy.scope():
100      with ops.device("/job:ps/replica:0/task:1"):
101        v0 = variables.Variable(initial_value=0.0)
102      self.assertEqual(v0.device, "/job:ps/replica:0/task:0/device:CPU:0")
103
104      with ops.device("/job:ps/replica:0/task:0"):
105        v1 = variables.Variable(initial_value=0.0)
106      self.assertEqual(v1.device, "/job:ps/replica:0/task:1/device:CPU:0")
107
108    with ops.device("/job:ps/replica:0/task:1"):
109      with strategy.scope():
110        v2 = variables.Variable(initial_value=0.0)
111        self.assertEqual(v2.device, "/job:ps/replica:0/task:2/device:CPU:0")
112
113        v3 = variables.Variable(initial_value=0.0)
114        self.assertEqual(v3.device, "/job:ps/replica:0/task:0/device:CPU:0")
115
116  def testInteractionWithVariableCreatorScope(self):
117
118    def var_creator(next_creator, **kwargs):
119      if "colocate_with" in kwargs:
120        with ops.device(None):
121          with ops.colocate_with(kwargs["colocate_with"]):
122            return next_creator(**kwargs)
123
124      self.assertIn("ps1", kwargs["name"])
125      with ops.device("/job:ps/task:1"):
126        return next_creator(**kwargs)
127
128    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
129        self.cluster_resolver)
130
131    # variable_creator_scope itself will work.
132    with variable_scope.variable_creator_scope(var_creator):
133      v0 = variables.Variable(initial_value=0.0, name="ps1_0")
134    self.assertEqual(v0.device, "/job:ps/replica:0/task:1/device:CPU:0")
135
136    # variable_creator_scope inside strategy.scope will not work.
137    with strategy.scope():
138      with variable_scope.variable_creator_scope(var_creator):
139        v1 = variables.Variable(initial_value=0.0, name="ps1_1")
140    self.assertEqual(v1.device, "/job:ps/replica:0/task:0/device:CPU:0")
141
142    # strategy.scope still assigns variables in a round robin fashion.
143    with strategy.scope():
144      v2 = variables.Variable(initial_value=0.0, name="ps1_2")
145    self.assertEqual(v2.device, "/job:ps/replica:0/task:1/device:CPU:0")
146
147    with strategy.scope():
148      v3 = variables.Variable(initial_value=0.0, name="ps1_3")
149    self.assertEqual(v3.device, "/job:ps/replica:0/task:2/device:CPU:0")
150
151    # variable_creator_scope outside strategy.scope will work.
152    with variable_scope.variable_creator_scope(var_creator):
153      with strategy.scope():
154        v4 = variables.Variable(initial_value=0.0, name="ps1_4")
155    self.assertEqual(v4.device, "/job:ps/replica:0/task:1/device:CPU:0")
156
157    with variable_scope.variable_creator_scope(var_creator):
158      with strategy.scope():
159        v5 = variables.Variable(initial_value=0.0, name="ps1_5")
160    self.assertEqual(v5.device, "/job:ps/replica:0/task:1/device:CPU:0")
161
162    # variable_creator_scope can be made to respect "colocate_with" as well.
163    with variable_scope.variable_creator_scope(var_creator):
164      with strategy.scope():
165        with strategy.extended.colocate_vars_with(v1):
166          v6 = variables.Variable(initial_value=0.0, name="ps1_6")
167    self.assertEqual(v6.device, "/job:ps/replica:0/task:0/device:CPU:0")
168
169  @contextlib.contextmanager
170  def _assertRaisesUsageWarningWithSchedule(self):
171    with self.assertLogs(level="WARNING") as logs:
172      yield
173
174    self.assertIn(
175        "A `tf.distribute.experimental.ParameterServerStrategy` method is "
176        "invoked without using `ClusterCoordinator.schedule`. If you are not "
177        "tracing a tf.function, this method is possibly executed on the "
178        "coordinator, which can be slow. To properly dispatch functions to "
179        "run on workers, methods like `run` or `reduce` should be used "
180        "within a function passed to `tf.distribute.experimental.coordinator."
181        "ClusterCoordinator.schedule`.", "".join(logs.output))
182
183  def testRunNotUsedWithClusterCoordinator(self):
184    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
185        self.cluster_resolver)
186    dataset = dataset_ops.DatasetV2.range(8)
187    with strategy.scope():
188      v = variables.Variable(1, dtype=dtypes.int64)
189
190    def step_fn(iterator):
191      return next(iterator) + v
192
193    with self._assertRaisesUsageWarningWithSchedule():
194      strategy.run(step_fn, args=(iter(dataset),))
195
196  def testRunUsedWithTestOnlyMode(self):
197    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
198        self.cluster_resolver)
199    strategy.extended._allow_run_without_coordinator = True
200    dataset = dataset_ops.DatasetV2.range(15)
201    with strategy.scope():
202      v = variables.Variable(1, dtype=dtypes.int64)
203
204    def step_fn(iterator):
205      return next(iterator) + v
206
207    strategy.run(step_fn, args=(iter(dataset),))
208
209  def testReduceNotUsedWithClusterCoordinator(self):
210    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
211        self.cluster_resolver)
212    with self._assertRaisesUsageWarningWithSchedule():
213      strategy.reduce("SUM", None, axis=None)
214
215  def testDistributeDatasetUsedDirectly(self):
216    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
217        self.cluster_resolver)
218    dataset = dataset_ops.DatasetV2.range(3)
219    distributed_dataset = strategy.experimental_distribute_dataset(dataset)
220    with self.assertRaises(ValueError):
221      iter(distributed_dataset)
222
223    distributed_dataset = strategy.distribute_datasets_from_function(
224        lambda: dataset)
225    with self.assertRaises(ValueError):
226      iter(distributed_dataset)
227
228  def testSparselyReadForEmbeddingLookup(self):
229    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
230        self.cluster_resolver)
231
232    class FakeModel(module.Module):
233
234      def __init__(self):
235        self._var0 = variables.Variable([1.0, 2.0, 3.0, 4.0])
236        self._var1 = variables.Variable([5.0, 6.0, 7.0, 8.0])
237
238      @def_function.function(input_signature=[
239          tensor_spec.TensorSpec(shape=[2], dtype=dtypes.int32, name="inputs")
240      ])
241      def func(self, x):
242        return embedding_ops.embedding_lookup([self._var0, self._var1], x)
243
244    with strategy.scope():
245      model = FakeModel()
246
247    # Assert that ResourceGather op exists instead of Gather in training
248    # function.
249    found_resource_gather = False
250    found_gather = False
251
252    for n in model.func.get_concrete_function().graph.as_graph_def().node:
253      if n.op == "ResourceGather":
254        found_resource_gather = True
255      elif n.op == "Gather":
256        found_gather = True
257    self.assertTrue(found_resource_gather)
258    self.assertFalse(found_gather)
259
260    # Assert that ResourceGather op exists instead of Gather in saved_model.
261    found_resource_gather = False
262    found_gather = False
263
264    tmp_dir = self.get_temp_dir()
265    tf_save.save(model, tmp_dir, signatures=model.func)
266
267    with gfile.Open("%s/saved_model.pb" % tmp_dir, "rb") as f:
268      saved_model_proto = saved_model_pb2.SavedModel().FromString(f.read())
269
270    for function in saved_model_proto.meta_graphs[0].graph_def.library.function:
271      for n in function.node_def:
272        if n.op == "ResourceGather":
273          found_resource_gather = True
274          resource_gather_device = n.device
275        elif n.op == "Gather":
276          found_gather = True
277    self.assertTrue(found_resource_gather)
278    self.assertFalse(found_gather)
279
280    # We also assert that the colocate_with in embedding_ops will not result in
281    # a hard-coded device string.
282    self.assertEmpty(resource_gather_device)
283
284
285class PartitionAwareIdentity(object):
286
287  def __call__(self, shape, dtype, **kwargs):
288    value = linalg_ops_impl.eye(*shape, dtype=dtype)
289    if "partition_shape" in kwargs and "partition_offset" in kwargs:
290      return array_ops.slice(value, kwargs["partition_offset"],
291                             kwargs["partition_shape"])
292    raise AssertionError("PartitionAwareIdentity do not support "
293                         "non-partitioned initialization")
294
295
296class VariablePartitioningTest(test.TestCase, parameterized.TestCase):
297
298  @classmethod
299  def setUpClass(cls):
300    super(VariablePartitioningTest, cls).setUpClass()
301    cls.cluster = multi_worker_test_base.create_multi_process_cluster(
302        num_workers=2, num_ps=2, rpc_layer="grpc")
303    cls.cluster_resolver = cls.cluster.cluster_resolver
304
305  @classmethod
306  def tearDownClass(cls):
307    super(VariablePartitioningTest, cls).tearDownClass()
308    cls.cluster.stop()
309
310  def testDefaultNoPartition(self):
311    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
312        self.cluster_resolver)
313    with strategy.scope():
314      v = variables.Variable([0, 1, 2, 3])
315
316    self.assertIsInstance(v, variables.Variable)
317
318  def testBasic(self):
319    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
320        self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2))
321    with strategy.scope():
322      init1 = init_ops_v2.Constant([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
323      v1 = variables.Variable(
324          initial_value=lambda: init1(shape=(5, 2), dtype=dtypes.int64),
325          shape=(5, 2),
326          dtype=dtypes.int64)
327
328      init2 = init_ops_v2.Constant([0, 1, 2, 3, 4, 5])
329      v2 = variables.Variable(
330          initial_value=lambda: init2(shape=(6, 1), dtype=dtypes.int64),
331          shape=(6, 1),
332          dtype=dtypes.int64)
333
334    self.assertIsInstance(v1, sharded_variable.ShardedVariable)
335    self.assertLen(v1.variables, 2)
336    self.assertRegex(v1.variables[0].device, "/job:ps/replica:0/task:0")
337    self.assertRegex(v1.variables[1].device, "/job:ps/replica:0/task:1")
338    self.assertAllEqual(v1.variables[0], [[0, 1], [2, 3], [4, 5]])
339    self.assertAllEqual(v1.variables[1], [[6, 7], [8, 9]])
340
341    self.assertIsInstance(v2, sharded_variable.ShardedVariable)
342    self.assertLen(v2.variables, 2)
343    self.assertRegex(v2.variables[0].device, "/job:ps/replica:0/task:0")
344    self.assertRegex(v2.variables[1].device, "/job:ps/replica:0/task:1")
345    self.assertAllEqual(v2.variables[0], [[0], [1], [2]])
346    self.assertAllEqual(v2.variables[1], [[3], [4], [5]])
347
348  def testBasicVariableWithAggregation(self):
349    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
350        self.cluster_resolver)
351    strategy.extended._allow_run_without_coordinator = True
352    with strategy.scope():
353      v = variables.Variable(
354          initial_value=[0, 0, 0, 0, 0, 0, 0, 0],
355          dtype=dtypes.float32,
356          aggregation=variable_scope.VariableAggregation.SUM)
357
358    if strategy.num_replicas_in_sync > 1:
359      self.assertIsInstance(v, ps_values.AggregatingVariable)
360    else:
361      self.assertIsInstance(v, variables.Variable)
362
363    def replica_fn():
364      replica_id = distribution_strategy_context.get_replica_context(
365      ).replica_id_in_sync_group
366      val = array_ops.reshape(
367          math_ops.cast(replica_id + 10, dtype=v.dtype), [1])
368      v.assign(
369          array_ops.concat(
370              [val, constant_op.constant([1., 2., 3., 4., 5., 6., 7.])], 0))
371
372    strategy.run(replica_fn)
373
374    expected_result = np.arange(8.) * strategy.num_replicas_in_sync
375    for i in range(strategy.num_replicas_in_sync):
376      expected_result[0] = expected_result[0] + i + 10
377    self.assertAllEqual(v, expected_result)
378
379  def testBasicShardedVariableWithAggregation(self):
380    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
381        self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2))
382    strategy.extended._allow_run_without_coordinator = True
383    with strategy.scope():
384      v = variables.Variable(
385          initial_value=[0, 0, 0, 0, 0, 0, 0, 0],
386          dtype=dtypes.float32,
387          aggregation=variable_scope.VariableAggregation.SUM)
388
389    self.assertIsInstance(v, sharded_variable.ShardedVariable)
390    self.assertLen(v.variables, 2)
391    if strategy.num_replicas_in_sync > 1:
392      self.assertIsInstance(v.variables[0], ps_values.AggregatingVariable)
393    else:
394      self.assertIsInstance(v.variables[0], variables.Variable)
395
396    def replica_fn():
397      replica_id = distribution_strategy_context.get_replica_context(
398      ).replica_id_in_sync_group
399      val = array_ops.reshape(
400          math_ops.cast(replica_id + 10, dtype=v.dtype), [1])
401      v.assign(
402          array_ops.concat(
403              [val, constant_op.constant([1., 2., 3., 4., 5., 6., 7.])], 0))
404
405    strategy.run(replica_fn)
406
407    expected_result = np.arange(8.) * strategy.num_replicas_in_sync
408    for i in range(strategy.num_replicas_in_sync):
409      expected_result[0] = expected_result[0] + i + 10
410    expected_result = np.array_split(expected_result, 2)
411    self.assertAllEqual(expected_result[0], v.variables[0])
412    self.assertAllEqual(expected_result[1], v.variables[1])
413
414  def testNonCallableInitialValue(self):
415    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
416        self.cluster_resolver, sharded_variable.FixedShardsPartitioner(4))
417    with strategy.scope():
418      v = variables.Variable([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
419
420    self.assertIsInstance(v, sharded_variable.ShardedVariable)
421    self.assertLen(v.variables, 4)
422    self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0")
423    self.assertRegex(v.variables[1].device, "/job:ps/replica:0/task:1")
424    self.assertRegex(v.variables[2].device, "/job:ps/replica:0/task:0")
425    self.assertRegex(v.variables[3].device, "/job:ps/replica:0/task:1")
426    self.assertAllEqual(v.variables[0], [0, 1, 2])
427    self.assertAllEqual(v.variables[1], [3, 4, 5])
428    self.assertAllEqual(v.variables[2], [6, 7])
429    self.assertAllEqual(v.variables[3], [8, 9])
430
431  def testNumPartitionsLargerThanSize(self):
432    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
433        self.cluster_resolver, sharded_variable.FixedShardsPartitioner(4))
434    with strategy.scope():
435      v = variables.Variable([0, 1, 2])
436
437    self.assertIsInstance(v, sharded_variable.ShardedVariable)
438    self.assertLen(v.variables, 3)
439    self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0")
440    self.assertRegex(v.variables[1].device, "/job:ps/replica:0/task:1")
441    self.assertRegex(v.variables[2].device, "/job:ps/replica:0/task:0")
442    self.assertAllEqual(v.variables[0], [0])
443    self.assertAllEqual(v.variables[1], [1])
444    self.assertAllEqual(v.variables[2], [2])
445
446  def testPartitionToOne(self):
447    # For small variables there is only one partition.
448    variable_partitioner = sharded_variable.MinSizePartitioner(
449        min_shard_bytes=64 << 20, max_shards=2)
450    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
451        self.cluster_resolver, variable_partitioner)
452    with strategy.scope():
453      initializer = init_ops_v2.Constant([0] * 10)
454      v1 = variables.Variable(
455          initial_value=lambda: initializer(shape=(10,), dtype=dtypes.int64),
456          shape=(10,),
457          dtype=dtypes.int64)
458
459      v2 = variables.Variable([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
460
461    self.assertIsInstance(v1, variables.Variable)
462    self.assertNotIsInstance(v1, sharded_variable.ShardedVariable)
463    self.assertRegex(v1.device, "/job:ps/replica:0/task:0")
464    self.assertAllEqual(v1, [0] * 10)
465
466    self.assertIsInstance(v2, variables.Variable)
467    self.assertNotIsInstance(v2, sharded_variable.ShardedVariable)
468    self.assertRegex(v2.device, "/job:ps/replica:0/task:1")
469    self.assertAllEqual(v2, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
470
471  def testColocateWith(self):
472    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
473        self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2))
474    with strategy.scope():
475      v1 = variables.Variable([0, 1, 2, 3])
476
477      with strategy.extended.colocate_vars_with(v1.variables[0]):
478        v2 = variables.Variable([4, 5])
479
480    self.assertIsInstance(v1, sharded_variable.ShardedVariable)
481
482    self.assertIsInstance(v2, variables.Variable)
483    self.assertNotIsInstance(v2, sharded_variable.ShardedVariable)
484    self.assertEqual(v2.device, v1.variables[0].device)
485    self.assertAllEqual(v2, [4, 5])
486
487  def testCustomPartitionAwareInitializer(self):
488    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
489        self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2))
490    with strategy.scope():
491      initializer = PartitionAwareIdentity()
492      initial_value = functools.partial(
493          initializer, shape=(4, 4), dtype=dtypes.int64)
494      v = variables.Variable(
495          initial_value=initial_value, shape=(4, 4), dtype=dtypes.int64)
496
497    self.assertIsInstance(v, sharded_variable.ShardedVariable)
498    self.assertLen(v.variables, 2)
499    self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0")
500    self.assertRegex(v.variables[1].device, "/job:ps/replica:0/task:1")
501    self.assertAllEqual(v.variables[0], [[1, 0, 0, 0], [0, 1, 0, 0]])
502    self.assertAllEqual(v.variables[1], [[0, 0, 1, 0], [0, 0, 0, 1]])
503
504  def testPartitionWhenLackOfInfo(self):
505    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
506        self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2))
507    with strategy.scope():
508      initializer = init_ops_v2.Constant([0, 1, 2, 3])
509      # Shape is not explicitly specified.
510      v1 = variables.Variable(
511          initial_value=lambda: initializer(shape=(4,), dtype=dtypes.int64),
512          dtype=dtypes.int64)
513      # Dtype is not explicitly specified.
514      v2 = variables.Variable(
515          initial_value=lambda: initializer(shape=(4,), dtype=dtypes.int64),
516          shape=(4,))
517      # Neither shape nor dtype is explicitly specified.
518      v3 = variables.Variable(
519          initial_value=lambda: initializer(shape=(4,), dtype=dtypes.int64))
520
521    for v in [v1, v2, v3]:
522      self.assertIsInstance(v, sharded_variable.ShardedVariable)
523      self.assertLen(v.variables, 2)
524      self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0")
525      self.assertRegex(v.variables[1].device, "/job:ps/replica:0/task:1")
526      self.assertAllEqual(v.variables[0], [0, 1])
527      self.assertAllEqual(v.variables[1], [2, 3])
528
529  def testInvalidPartitioner(self):
530    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
531        self.cluster_resolver, lambda shape, dtype: None)
532    with self.assertRaisesRegex(ValueError, "variable_partitioner"):
533      with strategy.scope():
534        variables.Variable([[[0, 1], [2, 3]], [[0, 1], [2, 3]]])
535
536    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
537        self.cluster_resolver, lambda shape, dtype: [])
538    with self.assertRaisesRegex(ValueError, "variable_partitioner"):
539      with strategy.scope():
540        variables.Variable([[[0, 1], [2, 3]], [[0, 1], [2, 3]]])
541
542    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
543        self.cluster_resolver, lambda shape, dtype: [0, 1, 1])
544    with self.assertRaisesRegex(ValueError, "variable_partitioner"):
545      with strategy.scope():
546        variables.Variable([[[0, 1], [2, 3]], [[0, 1], [2, 3]]])
547
548    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
549        self.cluster_resolver, lambda shape, dtype: [2, 2, 1])
550    with self.assertRaisesRegex(ValueError, "variable_partitioner"):
551      with strategy.scope():
552        variables.Variable([[[0, 1], [2, 3]], [[0, 1], [2, 3]]])
553
554  def testCreateInsideTFFunction(self):
555    if test_util.is_xla_enabled():
556      self.skipTest("TODO(b/202760274): Would raise an error that is to be "
557                    "investigated.")
558
559    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
560        self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2))
561
562    collection = []
563
564    @def_function.function
565    def create_vars():
566      if not collection:
567        identity = init_ops_v2.Identity()
568        v1 = variables.Variable([[1., 0.], [0., 1.]], dtype=dtypes.float32)
569        v2 = variables.Variable(lambda: identity((2, 2), dtypes.float32))
570        v3 = variables.Variable(
571            lambda: identity((2, 2), dtypes.float32),
572            dtype=dtypes.float32,
573            shape=(2, 2))
574        collection.extend([v1, v2, v3])
575
576    with strategy.scope():
577      create_vars()
578      for v in collection:
579        self.assertIsInstance(v, sharded_variable.ShardedVariable)
580        self.assertLen(v.variables, 2)
581        self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0")
582        self.assertRegex(v.variables[1].device, "/job:ps/replica:0/task:1")
583        self.assertAllEqual(v.variables[0], [[1., 0.]])
584        self.assertAllEqual(v.variables[1], [[0., 1.]])
585
586  @parameterized.named_parameters(
587      ("Restore", False, 2),
588      ("RestoreDiffShards", False, 4),
589      ("DelayedRestore", True, 2),
590      ("DelayedRestoreDiffShards", True, 4),
591  )
592  def testCheckpoint(self, delayed, restore_shards):
593
594    if test_util.is_xla_enabled() and not delayed and restore_shards == 4:
595      self.skipTest("TODO(b/202760274): Would raise an error that is to be "
596                    "investigated.")
597
598    def make_variable(name, shape, dtype, initializer):
599      initial_value = functools.partial(initializer, shape, dtype=dtype)
600      return variables.Variable(
601          name=name, initial_value=initial_value, shape=shape, dtype=dtype)
602
603    class Model(autotrackable.AutoTrackable):
604
605      def build(self):
606        self.w = self._add_variable_with_custom_getter(
607            "w",
608            shape=(4,),
609            initializer=init_ops_v2.Ones(),
610            getter=make_variable)
611
612    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
613        self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2))
614    ckpt_dir = os.path.join(self.get_temp_dir(), "checkpoint")
615
616    with strategy.scope():
617      model1 = Model()
618      model1.build()
619      self.assertIsInstance(model1.w, sharded_variable.ShardedVariable)
620      self.assertLen(model1.w.variables, 2)
621      model1.w.assign([1., 2., 3., 4.])
622
623      cp1 = tracking_util.Checkpoint(model=model1)
624      cp1.write(ckpt_dir)
625
626    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
627        self.cluster_resolver,
628        sharded_variable.FixedShardsPartitioner(restore_shards))
629
630    with strategy.scope():
631      model2 = Model()
632      cp2 = tracking_util.Checkpoint(model=model2)
633      if delayed:
634        cp2.restore(ckpt_dir)
635        model2.build()
636      else:
637        model2.build()
638        cp2.restore(ckpt_dir)
639      self.assertIsInstance(model2.w, sharded_variable.ShardedVariable)
640      self.assertLen(model2.w.variables, restore_shards)
641      if restore_shards == 2:
642        self.assertAllEqual(model2.w.variables[0], [1., 2.])
643        self.assertAllEqual(model2.w.variables[1], [3., 4.])
644      elif restore_shards == 4:
645        self.assertAllEqual(model2.w.variables[0], [1.])
646        self.assertAllEqual(model2.w.variables[1], [2.])
647        self.assertAllEqual(model2.w.variables[2], [3.])
648        self.assertAllEqual(model2.w.variables[3], [4.])
649
650
651class ClusterTypeNameTest(test.TestCase):
652
653  def testArbitraryJobName(self):
654    cluster_def = multi_worker_test_base.create_cluster_spec(
655        num_workers=1, num_ps=1, has_chief=True)
656    cluster_def["some_arbitrary_name"] = [
657        "localhost:%d" % multi_worker_test_base.pick_unused_port()
658    ]
659    cluster_resolver = SimpleClusterResolver(
660        ClusterSpec(cluster_def), rpc_layer="grpc")
661    with self.assertRaisesRegexp(ValueError, "Disallowed task type found in"):
662      parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
663
664  def testArbitraryCurrentTaskType(self):
665    cluster_def = multi_worker_test_base.create_cluster_spec(
666        num_workers=1, num_ps=1, has_chief=True)
667    cluster_resolver = SimpleClusterResolver(
668        ClusterSpec(cluster_def), rpc_layer="grpc", task_type="foobar")
669    with self.assertRaisesRegexp(ValueError, "Unrecognized task_type: foobar"):
670      parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
671
672  def testMoreThanOneChief(self):
673    cluster_def = multi_worker_test_base.create_cluster_spec(
674        num_workers=1, num_ps=1)
675    chief_ports = [multi_worker_test_base.pick_unused_port() for _ in range(3)]
676    cluster_def["chief"] = ["localhost:%s" % port for port in chief_ports]
677    cluster_resolver = SimpleClusterResolver(
678        ClusterSpec(cluster_def),
679        rpc_layer="grpc",
680        task_type="chief",
681        task_id=1)
682    with self.assertRaisesRegexp(ValueError,
683                                 "There must be at most one 'chief' job."):
684      parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
685
686  def testLessThanOneWorker(self):
687    cluster_def = multi_worker_test_base.create_cluster_spec(
688        num_workers=0, num_ps=1, has_chief=True)
689    cluster_resolver = SimpleClusterResolver(
690        ClusterSpec(cluster_def), rpc_layer="grpc", task_type="ps", task_id=0)
691    with self.assertRaisesRegexp(ValueError,
692                                 "There must be at least one worker."):
693      parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
694
695  def testLessThanOnePs(self):
696    cluster_def = multi_worker_test_base.create_cluster_spec(
697        num_workers=1, num_ps=0, has_chief=True)
698    cluster_resolver = SimpleClusterResolver(
699        ClusterSpec(cluster_def),
700        rpc_layer="grpc",
701        task_type="worker",
702        task_id=0)
703    with self.assertRaisesRegexp(ValueError, "There must be at least one ps."):
704      parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
705
706
707if __name__ == "__main__":
708  v2_compat.enable_v2_behavior()
709  multi_process_runner.test_main()
710