1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for parameter_server_strategy_v2.py.""" 16 17import contextlib 18import functools 19import os 20 21from absl.testing import parameterized 22import numpy as np 23 24from tensorflow.core.protobuf import saved_model_pb2 25from tensorflow.python.checkpoint import checkpoint as tracking_util 26from tensorflow.python.compat import v2_compat 27from tensorflow.python.data.ops import dataset_ops 28from tensorflow.python.distribute import distribution_strategy_context 29from tensorflow.python.distribute import multi_process_runner 30from tensorflow.python.distribute import multi_worker_test_base 31from tensorflow.python.distribute import parameter_server_strategy_v2 32from tensorflow.python.distribute import ps_values 33from tensorflow.python.distribute import sharded_variable 34from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver 35from tensorflow.python.eager import context 36from tensorflow.python.eager import def_function 37from tensorflow.python.eager import test 38from tensorflow.python.framework import constant_op 39from tensorflow.python.framework import dtypes 40from tensorflow.python.framework import ops 41from tensorflow.python.framework import tensor_spec 42from tensorflow.python.framework import test_util 43from tensorflow.python.module import module 44from tensorflow.python.ops import array_ops 45from tensorflow.python.ops import embedding_ops 46from tensorflow.python.ops import init_ops_v2 47from tensorflow.python.ops import linalg_ops_impl 48from tensorflow.python.ops import math_ops 49from tensorflow.python.ops import variable_scope 50from tensorflow.python.ops import variables 51from tensorflow.python.platform import gfile 52from tensorflow.python.saved_model import save as tf_save 53from tensorflow.python.trackable import autotrackable 54from tensorflow.python.training.server_lib import ClusterSpec 55 56 57class ParameterServerStrategyV2Test(test.TestCase): 58 59 @classmethod 60 def setUpClass(cls): 61 super(ParameterServerStrategyV2Test, cls).setUpClass() 62 cls.cluster = multi_worker_test_base.create_multi_process_cluster( 63 num_workers=2, num_ps=3, rpc_layer="grpc") 64 cls.cluster_resolver = cls.cluster.cluster_resolver 65 66 @classmethod 67 def tearDownClass(cls): 68 super(ParameterServerStrategyV2Test, cls).tearDownClass() 69 cls.cluster.stop() 70 71 def testVariablePlacement(self): 72 73 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 74 self.cluster_resolver) 75 v1 = variables.Variable(initial_value=0.0) 76 with strategy.scope(): 77 v2 = variables.Variable(initial_value=1.0) 78 v3 = variables.Variable(initial_value=2.0) 79 v4 = variables.Variable(initial_value=3.0) 80 v5 = variables.Variable(initial_value=4.0) 81 # v1 was created outside scope so should be on client. 82 gpu_devices = context.num_gpus() 83 if gpu_devices: 84 # For tests with GPUs 85 self.assertEqual(v1.device, "/job:chief/replica:0/task:0/device:GPU:0") 86 else: 87 self.assertEqual(v1.device, "/job:chief/replica:0/task:0/device:CPU:0") 88 # v2 through v5 are created in scope and in a round-robin manner. 89 self.assertEqual(v2.device, "/job:ps/replica:0/task:0/device:CPU:0") 90 self.assertEqual(v3.device, "/job:ps/replica:0/task:1/device:CPU:0") 91 self.assertEqual(v4.device, "/job:ps/replica:0/task:2/device:CPU:0") 92 self.assertEqual(v5.device, "/job:ps/replica:0/task:0/device:CPU:0") 93 94 def testInteractionWithDeviceScope(self): 95 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 96 self.cluster_resolver) 97 98 # The strategy scope always wins. 99 with strategy.scope(): 100 with ops.device("/job:ps/replica:0/task:1"): 101 v0 = variables.Variable(initial_value=0.0) 102 self.assertEqual(v0.device, "/job:ps/replica:0/task:0/device:CPU:0") 103 104 with ops.device("/job:ps/replica:0/task:0"): 105 v1 = variables.Variable(initial_value=0.0) 106 self.assertEqual(v1.device, "/job:ps/replica:0/task:1/device:CPU:0") 107 108 with ops.device("/job:ps/replica:0/task:1"): 109 with strategy.scope(): 110 v2 = variables.Variable(initial_value=0.0) 111 self.assertEqual(v2.device, "/job:ps/replica:0/task:2/device:CPU:0") 112 113 v3 = variables.Variable(initial_value=0.0) 114 self.assertEqual(v3.device, "/job:ps/replica:0/task:0/device:CPU:0") 115 116 def testInteractionWithVariableCreatorScope(self): 117 118 def var_creator(next_creator, **kwargs): 119 if "colocate_with" in kwargs: 120 with ops.device(None): 121 with ops.colocate_with(kwargs["colocate_with"]): 122 return next_creator(**kwargs) 123 124 self.assertIn("ps1", kwargs["name"]) 125 with ops.device("/job:ps/task:1"): 126 return next_creator(**kwargs) 127 128 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 129 self.cluster_resolver) 130 131 # variable_creator_scope itself will work. 132 with variable_scope.variable_creator_scope(var_creator): 133 v0 = variables.Variable(initial_value=0.0, name="ps1_0") 134 self.assertEqual(v0.device, "/job:ps/replica:0/task:1/device:CPU:0") 135 136 # variable_creator_scope inside strategy.scope will not work. 137 with strategy.scope(): 138 with variable_scope.variable_creator_scope(var_creator): 139 v1 = variables.Variable(initial_value=0.0, name="ps1_1") 140 self.assertEqual(v1.device, "/job:ps/replica:0/task:0/device:CPU:0") 141 142 # strategy.scope still assigns variables in a round robin fashion. 143 with strategy.scope(): 144 v2 = variables.Variable(initial_value=0.0, name="ps1_2") 145 self.assertEqual(v2.device, "/job:ps/replica:0/task:1/device:CPU:0") 146 147 with strategy.scope(): 148 v3 = variables.Variable(initial_value=0.0, name="ps1_3") 149 self.assertEqual(v3.device, "/job:ps/replica:0/task:2/device:CPU:0") 150 151 # variable_creator_scope outside strategy.scope will work. 152 with variable_scope.variable_creator_scope(var_creator): 153 with strategy.scope(): 154 v4 = variables.Variable(initial_value=0.0, name="ps1_4") 155 self.assertEqual(v4.device, "/job:ps/replica:0/task:1/device:CPU:0") 156 157 with variable_scope.variable_creator_scope(var_creator): 158 with strategy.scope(): 159 v5 = variables.Variable(initial_value=0.0, name="ps1_5") 160 self.assertEqual(v5.device, "/job:ps/replica:0/task:1/device:CPU:0") 161 162 # variable_creator_scope can be made to respect "colocate_with" as well. 163 with variable_scope.variable_creator_scope(var_creator): 164 with strategy.scope(): 165 with strategy.extended.colocate_vars_with(v1): 166 v6 = variables.Variable(initial_value=0.0, name="ps1_6") 167 self.assertEqual(v6.device, "/job:ps/replica:0/task:0/device:CPU:0") 168 169 @contextlib.contextmanager 170 def _assertRaisesUsageWarningWithSchedule(self): 171 with self.assertLogs(level="WARNING") as logs: 172 yield 173 174 self.assertIn( 175 "A `tf.distribute.experimental.ParameterServerStrategy` method is " 176 "invoked without using `ClusterCoordinator.schedule`. If you are not " 177 "tracing a tf.function, this method is possibly executed on the " 178 "coordinator, which can be slow. To properly dispatch functions to " 179 "run on workers, methods like `run` or `reduce` should be used " 180 "within a function passed to `tf.distribute.experimental.coordinator." 181 "ClusterCoordinator.schedule`.", "".join(logs.output)) 182 183 def testRunNotUsedWithClusterCoordinator(self): 184 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 185 self.cluster_resolver) 186 dataset = dataset_ops.DatasetV2.range(8) 187 with strategy.scope(): 188 v = variables.Variable(1, dtype=dtypes.int64) 189 190 def step_fn(iterator): 191 return next(iterator) + v 192 193 with self._assertRaisesUsageWarningWithSchedule(): 194 strategy.run(step_fn, args=(iter(dataset),)) 195 196 def testRunUsedWithTestOnlyMode(self): 197 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 198 self.cluster_resolver) 199 strategy.extended._allow_run_without_coordinator = True 200 dataset = dataset_ops.DatasetV2.range(15) 201 with strategy.scope(): 202 v = variables.Variable(1, dtype=dtypes.int64) 203 204 def step_fn(iterator): 205 return next(iterator) + v 206 207 strategy.run(step_fn, args=(iter(dataset),)) 208 209 def testReduceNotUsedWithClusterCoordinator(self): 210 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 211 self.cluster_resolver) 212 with self._assertRaisesUsageWarningWithSchedule(): 213 strategy.reduce("SUM", None, axis=None) 214 215 def testDistributeDatasetUsedDirectly(self): 216 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 217 self.cluster_resolver) 218 dataset = dataset_ops.DatasetV2.range(3) 219 distributed_dataset = strategy.experimental_distribute_dataset(dataset) 220 with self.assertRaises(ValueError): 221 iter(distributed_dataset) 222 223 distributed_dataset = strategy.distribute_datasets_from_function( 224 lambda: dataset) 225 with self.assertRaises(ValueError): 226 iter(distributed_dataset) 227 228 def testSparselyReadForEmbeddingLookup(self): 229 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 230 self.cluster_resolver) 231 232 class FakeModel(module.Module): 233 234 def __init__(self): 235 self._var0 = variables.Variable([1.0, 2.0, 3.0, 4.0]) 236 self._var1 = variables.Variable([5.0, 6.0, 7.0, 8.0]) 237 238 @def_function.function(input_signature=[ 239 tensor_spec.TensorSpec(shape=[2], dtype=dtypes.int32, name="inputs") 240 ]) 241 def func(self, x): 242 return embedding_ops.embedding_lookup([self._var0, self._var1], x) 243 244 with strategy.scope(): 245 model = FakeModel() 246 247 # Assert that ResourceGather op exists instead of Gather in training 248 # function. 249 found_resource_gather = False 250 found_gather = False 251 252 for n in model.func.get_concrete_function().graph.as_graph_def().node: 253 if n.op == "ResourceGather": 254 found_resource_gather = True 255 elif n.op == "Gather": 256 found_gather = True 257 self.assertTrue(found_resource_gather) 258 self.assertFalse(found_gather) 259 260 # Assert that ResourceGather op exists instead of Gather in saved_model. 261 found_resource_gather = False 262 found_gather = False 263 264 tmp_dir = self.get_temp_dir() 265 tf_save.save(model, tmp_dir, signatures=model.func) 266 267 with gfile.Open("%s/saved_model.pb" % tmp_dir, "rb") as f: 268 saved_model_proto = saved_model_pb2.SavedModel().FromString(f.read()) 269 270 for function in saved_model_proto.meta_graphs[0].graph_def.library.function: 271 for n in function.node_def: 272 if n.op == "ResourceGather": 273 found_resource_gather = True 274 resource_gather_device = n.device 275 elif n.op == "Gather": 276 found_gather = True 277 self.assertTrue(found_resource_gather) 278 self.assertFalse(found_gather) 279 280 # We also assert that the colocate_with in embedding_ops will not result in 281 # a hard-coded device string. 282 self.assertEmpty(resource_gather_device) 283 284 285class PartitionAwareIdentity(object): 286 287 def __call__(self, shape, dtype, **kwargs): 288 value = linalg_ops_impl.eye(*shape, dtype=dtype) 289 if "partition_shape" in kwargs and "partition_offset" in kwargs: 290 return array_ops.slice(value, kwargs["partition_offset"], 291 kwargs["partition_shape"]) 292 raise AssertionError("PartitionAwareIdentity do not support " 293 "non-partitioned initialization") 294 295 296class VariablePartitioningTest(test.TestCase, parameterized.TestCase): 297 298 @classmethod 299 def setUpClass(cls): 300 super(VariablePartitioningTest, cls).setUpClass() 301 cls.cluster = multi_worker_test_base.create_multi_process_cluster( 302 num_workers=2, num_ps=2, rpc_layer="grpc") 303 cls.cluster_resolver = cls.cluster.cluster_resolver 304 305 @classmethod 306 def tearDownClass(cls): 307 super(VariablePartitioningTest, cls).tearDownClass() 308 cls.cluster.stop() 309 310 def testDefaultNoPartition(self): 311 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 312 self.cluster_resolver) 313 with strategy.scope(): 314 v = variables.Variable([0, 1, 2, 3]) 315 316 self.assertIsInstance(v, variables.Variable) 317 318 def testBasic(self): 319 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 320 self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2)) 321 with strategy.scope(): 322 init1 = init_ops_v2.Constant([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) 323 v1 = variables.Variable( 324 initial_value=lambda: init1(shape=(5, 2), dtype=dtypes.int64), 325 shape=(5, 2), 326 dtype=dtypes.int64) 327 328 init2 = init_ops_v2.Constant([0, 1, 2, 3, 4, 5]) 329 v2 = variables.Variable( 330 initial_value=lambda: init2(shape=(6, 1), dtype=dtypes.int64), 331 shape=(6, 1), 332 dtype=dtypes.int64) 333 334 self.assertIsInstance(v1, sharded_variable.ShardedVariable) 335 self.assertLen(v1.variables, 2) 336 self.assertRegex(v1.variables[0].device, "/job:ps/replica:0/task:0") 337 self.assertRegex(v1.variables[1].device, "/job:ps/replica:0/task:1") 338 self.assertAllEqual(v1.variables[0], [[0, 1], [2, 3], [4, 5]]) 339 self.assertAllEqual(v1.variables[1], [[6, 7], [8, 9]]) 340 341 self.assertIsInstance(v2, sharded_variable.ShardedVariable) 342 self.assertLen(v2.variables, 2) 343 self.assertRegex(v2.variables[0].device, "/job:ps/replica:0/task:0") 344 self.assertRegex(v2.variables[1].device, "/job:ps/replica:0/task:1") 345 self.assertAllEqual(v2.variables[0], [[0], [1], [2]]) 346 self.assertAllEqual(v2.variables[1], [[3], [4], [5]]) 347 348 def testBasicVariableWithAggregation(self): 349 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 350 self.cluster_resolver) 351 strategy.extended._allow_run_without_coordinator = True 352 with strategy.scope(): 353 v = variables.Variable( 354 initial_value=[0, 0, 0, 0, 0, 0, 0, 0], 355 dtype=dtypes.float32, 356 aggregation=variable_scope.VariableAggregation.SUM) 357 358 if strategy.num_replicas_in_sync > 1: 359 self.assertIsInstance(v, ps_values.AggregatingVariable) 360 else: 361 self.assertIsInstance(v, variables.Variable) 362 363 def replica_fn(): 364 replica_id = distribution_strategy_context.get_replica_context( 365 ).replica_id_in_sync_group 366 val = array_ops.reshape( 367 math_ops.cast(replica_id + 10, dtype=v.dtype), [1]) 368 v.assign( 369 array_ops.concat( 370 [val, constant_op.constant([1., 2., 3., 4., 5., 6., 7.])], 0)) 371 372 strategy.run(replica_fn) 373 374 expected_result = np.arange(8.) * strategy.num_replicas_in_sync 375 for i in range(strategy.num_replicas_in_sync): 376 expected_result[0] = expected_result[0] + i + 10 377 self.assertAllEqual(v, expected_result) 378 379 def testBasicShardedVariableWithAggregation(self): 380 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 381 self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2)) 382 strategy.extended._allow_run_without_coordinator = True 383 with strategy.scope(): 384 v = variables.Variable( 385 initial_value=[0, 0, 0, 0, 0, 0, 0, 0], 386 dtype=dtypes.float32, 387 aggregation=variable_scope.VariableAggregation.SUM) 388 389 self.assertIsInstance(v, sharded_variable.ShardedVariable) 390 self.assertLen(v.variables, 2) 391 if strategy.num_replicas_in_sync > 1: 392 self.assertIsInstance(v.variables[0], ps_values.AggregatingVariable) 393 else: 394 self.assertIsInstance(v.variables[0], variables.Variable) 395 396 def replica_fn(): 397 replica_id = distribution_strategy_context.get_replica_context( 398 ).replica_id_in_sync_group 399 val = array_ops.reshape( 400 math_ops.cast(replica_id + 10, dtype=v.dtype), [1]) 401 v.assign( 402 array_ops.concat( 403 [val, constant_op.constant([1., 2., 3., 4., 5., 6., 7.])], 0)) 404 405 strategy.run(replica_fn) 406 407 expected_result = np.arange(8.) * strategy.num_replicas_in_sync 408 for i in range(strategy.num_replicas_in_sync): 409 expected_result[0] = expected_result[0] + i + 10 410 expected_result = np.array_split(expected_result, 2) 411 self.assertAllEqual(expected_result[0], v.variables[0]) 412 self.assertAllEqual(expected_result[1], v.variables[1]) 413 414 def testNonCallableInitialValue(self): 415 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 416 self.cluster_resolver, sharded_variable.FixedShardsPartitioner(4)) 417 with strategy.scope(): 418 v = variables.Variable([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) 419 420 self.assertIsInstance(v, sharded_variable.ShardedVariable) 421 self.assertLen(v.variables, 4) 422 self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0") 423 self.assertRegex(v.variables[1].device, "/job:ps/replica:0/task:1") 424 self.assertRegex(v.variables[2].device, "/job:ps/replica:0/task:0") 425 self.assertRegex(v.variables[3].device, "/job:ps/replica:0/task:1") 426 self.assertAllEqual(v.variables[0], [0, 1, 2]) 427 self.assertAllEqual(v.variables[1], [3, 4, 5]) 428 self.assertAllEqual(v.variables[2], [6, 7]) 429 self.assertAllEqual(v.variables[3], [8, 9]) 430 431 def testNumPartitionsLargerThanSize(self): 432 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 433 self.cluster_resolver, sharded_variable.FixedShardsPartitioner(4)) 434 with strategy.scope(): 435 v = variables.Variable([0, 1, 2]) 436 437 self.assertIsInstance(v, sharded_variable.ShardedVariable) 438 self.assertLen(v.variables, 3) 439 self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0") 440 self.assertRegex(v.variables[1].device, "/job:ps/replica:0/task:1") 441 self.assertRegex(v.variables[2].device, "/job:ps/replica:0/task:0") 442 self.assertAllEqual(v.variables[0], [0]) 443 self.assertAllEqual(v.variables[1], [1]) 444 self.assertAllEqual(v.variables[2], [2]) 445 446 def testPartitionToOne(self): 447 # For small variables there is only one partition. 448 variable_partitioner = sharded_variable.MinSizePartitioner( 449 min_shard_bytes=64 << 20, max_shards=2) 450 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 451 self.cluster_resolver, variable_partitioner) 452 with strategy.scope(): 453 initializer = init_ops_v2.Constant([0] * 10) 454 v1 = variables.Variable( 455 initial_value=lambda: initializer(shape=(10,), dtype=dtypes.int64), 456 shape=(10,), 457 dtype=dtypes.int64) 458 459 v2 = variables.Variable([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) 460 461 self.assertIsInstance(v1, variables.Variable) 462 self.assertNotIsInstance(v1, sharded_variable.ShardedVariable) 463 self.assertRegex(v1.device, "/job:ps/replica:0/task:0") 464 self.assertAllEqual(v1, [0] * 10) 465 466 self.assertIsInstance(v2, variables.Variable) 467 self.assertNotIsInstance(v2, sharded_variable.ShardedVariable) 468 self.assertRegex(v2.device, "/job:ps/replica:0/task:1") 469 self.assertAllEqual(v2, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) 470 471 def testColocateWith(self): 472 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 473 self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2)) 474 with strategy.scope(): 475 v1 = variables.Variable([0, 1, 2, 3]) 476 477 with strategy.extended.colocate_vars_with(v1.variables[0]): 478 v2 = variables.Variable([4, 5]) 479 480 self.assertIsInstance(v1, sharded_variable.ShardedVariable) 481 482 self.assertIsInstance(v2, variables.Variable) 483 self.assertNotIsInstance(v2, sharded_variable.ShardedVariable) 484 self.assertEqual(v2.device, v1.variables[0].device) 485 self.assertAllEqual(v2, [4, 5]) 486 487 def testCustomPartitionAwareInitializer(self): 488 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 489 self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2)) 490 with strategy.scope(): 491 initializer = PartitionAwareIdentity() 492 initial_value = functools.partial( 493 initializer, shape=(4, 4), dtype=dtypes.int64) 494 v = variables.Variable( 495 initial_value=initial_value, shape=(4, 4), dtype=dtypes.int64) 496 497 self.assertIsInstance(v, sharded_variable.ShardedVariable) 498 self.assertLen(v.variables, 2) 499 self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0") 500 self.assertRegex(v.variables[1].device, "/job:ps/replica:0/task:1") 501 self.assertAllEqual(v.variables[0], [[1, 0, 0, 0], [0, 1, 0, 0]]) 502 self.assertAllEqual(v.variables[1], [[0, 0, 1, 0], [0, 0, 0, 1]]) 503 504 def testPartitionWhenLackOfInfo(self): 505 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 506 self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2)) 507 with strategy.scope(): 508 initializer = init_ops_v2.Constant([0, 1, 2, 3]) 509 # Shape is not explicitly specified. 510 v1 = variables.Variable( 511 initial_value=lambda: initializer(shape=(4,), dtype=dtypes.int64), 512 dtype=dtypes.int64) 513 # Dtype is not explicitly specified. 514 v2 = variables.Variable( 515 initial_value=lambda: initializer(shape=(4,), dtype=dtypes.int64), 516 shape=(4,)) 517 # Neither shape nor dtype is explicitly specified. 518 v3 = variables.Variable( 519 initial_value=lambda: initializer(shape=(4,), dtype=dtypes.int64)) 520 521 for v in [v1, v2, v3]: 522 self.assertIsInstance(v, sharded_variable.ShardedVariable) 523 self.assertLen(v.variables, 2) 524 self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0") 525 self.assertRegex(v.variables[1].device, "/job:ps/replica:0/task:1") 526 self.assertAllEqual(v.variables[0], [0, 1]) 527 self.assertAllEqual(v.variables[1], [2, 3]) 528 529 def testInvalidPartitioner(self): 530 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 531 self.cluster_resolver, lambda shape, dtype: None) 532 with self.assertRaisesRegex(ValueError, "variable_partitioner"): 533 with strategy.scope(): 534 variables.Variable([[[0, 1], [2, 3]], [[0, 1], [2, 3]]]) 535 536 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 537 self.cluster_resolver, lambda shape, dtype: []) 538 with self.assertRaisesRegex(ValueError, "variable_partitioner"): 539 with strategy.scope(): 540 variables.Variable([[[0, 1], [2, 3]], [[0, 1], [2, 3]]]) 541 542 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 543 self.cluster_resolver, lambda shape, dtype: [0, 1, 1]) 544 with self.assertRaisesRegex(ValueError, "variable_partitioner"): 545 with strategy.scope(): 546 variables.Variable([[[0, 1], [2, 3]], [[0, 1], [2, 3]]]) 547 548 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 549 self.cluster_resolver, lambda shape, dtype: [2, 2, 1]) 550 with self.assertRaisesRegex(ValueError, "variable_partitioner"): 551 with strategy.scope(): 552 variables.Variable([[[0, 1], [2, 3]], [[0, 1], [2, 3]]]) 553 554 def testCreateInsideTFFunction(self): 555 if test_util.is_xla_enabled(): 556 self.skipTest("TODO(b/202760274): Would raise an error that is to be " 557 "investigated.") 558 559 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 560 self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2)) 561 562 collection = [] 563 564 @def_function.function 565 def create_vars(): 566 if not collection: 567 identity = init_ops_v2.Identity() 568 v1 = variables.Variable([[1., 0.], [0., 1.]], dtype=dtypes.float32) 569 v2 = variables.Variable(lambda: identity((2, 2), dtypes.float32)) 570 v3 = variables.Variable( 571 lambda: identity((2, 2), dtypes.float32), 572 dtype=dtypes.float32, 573 shape=(2, 2)) 574 collection.extend([v1, v2, v3]) 575 576 with strategy.scope(): 577 create_vars() 578 for v in collection: 579 self.assertIsInstance(v, sharded_variable.ShardedVariable) 580 self.assertLen(v.variables, 2) 581 self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0") 582 self.assertRegex(v.variables[1].device, "/job:ps/replica:0/task:1") 583 self.assertAllEqual(v.variables[0], [[1., 0.]]) 584 self.assertAllEqual(v.variables[1], [[0., 1.]]) 585 586 @parameterized.named_parameters( 587 ("Restore", False, 2), 588 ("RestoreDiffShards", False, 4), 589 ("DelayedRestore", True, 2), 590 ("DelayedRestoreDiffShards", True, 4), 591 ) 592 def testCheckpoint(self, delayed, restore_shards): 593 594 if test_util.is_xla_enabled() and not delayed and restore_shards == 4: 595 self.skipTest("TODO(b/202760274): Would raise an error that is to be " 596 "investigated.") 597 598 def make_variable(name, shape, dtype, initializer): 599 initial_value = functools.partial(initializer, shape, dtype=dtype) 600 return variables.Variable( 601 name=name, initial_value=initial_value, shape=shape, dtype=dtype) 602 603 class Model(autotrackable.AutoTrackable): 604 605 def build(self): 606 self.w = self._add_variable_with_custom_getter( 607 "w", 608 shape=(4,), 609 initializer=init_ops_v2.Ones(), 610 getter=make_variable) 611 612 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 613 self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2)) 614 ckpt_dir = os.path.join(self.get_temp_dir(), "checkpoint") 615 616 with strategy.scope(): 617 model1 = Model() 618 model1.build() 619 self.assertIsInstance(model1.w, sharded_variable.ShardedVariable) 620 self.assertLen(model1.w.variables, 2) 621 model1.w.assign([1., 2., 3., 4.]) 622 623 cp1 = tracking_util.Checkpoint(model=model1) 624 cp1.write(ckpt_dir) 625 626 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 627 self.cluster_resolver, 628 sharded_variable.FixedShardsPartitioner(restore_shards)) 629 630 with strategy.scope(): 631 model2 = Model() 632 cp2 = tracking_util.Checkpoint(model=model2) 633 if delayed: 634 cp2.restore(ckpt_dir) 635 model2.build() 636 else: 637 model2.build() 638 cp2.restore(ckpt_dir) 639 self.assertIsInstance(model2.w, sharded_variable.ShardedVariable) 640 self.assertLen(model2.w.variables, restore_shards) 641 if restore_shards == 2: 642 self.assertAllEqual(model2.w.variables[0], [1., 2.]) 643 self.assertAllEqual(model2.w.variables[1], [3., 4.]) 644 elif restore_shards == 4: 645 self.assertAllEqual(model2.w.variables[0], [1.]) 646 self.assertAllEqual(model2.w.variables[1], [2.]) 647 self.assertAllEqual(model2.w.variables[2], [3.]) 648 self.assertAllEqual(model2.w.variables[3], [4.]) 649 650 651class ClusterTypeNameTest(test.TestCase): 652 653 def testArbitraryJobName(self): 654 cluster_def = multi_worker_test_base.create_cluster_spec( 655 num_workers=1, num_ps=1, has_chief=True) 656 cluster_def["some_arbitrary_name"] = [ 657 "localhost:%d" % multi_worker_test_base.pick_unused_port() 658 ] 659 cluster_resolver = SimpleClusterResolver( 660 ClusterSpec(cluster_def), rpc_layer="grpc") 661 with self.assertRaisesRegexp(ValueError, "Disallowed task type found in"): 662 parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver) 663 664 def testArbitraryCurrentTaskType(self): 665 cluster_def = multi_worker_test_base.create_cluster_spec( 666 num_workers=1, num_ps=1, has_chief=True) 667 cluster_resolver = SimpleClusterResolver( 668 ClusterSpec(cluster_def), rpc_layer="grpc", task_type="foobar") 669 with self.assertRaisesRegexp(ValueError, "Unrecognized task_type: foobar"): 670 parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver) 671 672 def testMoreThanOneChief(self): 673 cluster_def = multi_worker_test_base.create_cluster_spec( 674 num_workers=1, num_ps=1) 675 chief_ports = [multi_worker_test_base.pick_unused_port() for _ in range(3)] 676 cluster_def["chief"] = ["localhost:%s" % port for port in chief_ports] 677 cluster_resolver = SimpleClusterResolver( 678 ClusterSpec(cluster_def), 679 rpc_layer="grpc", 680 task_type="chief", 681 task_id=1) 682 with self.assertRaisesRegexp(ValueError, 683 "There must be at most one 'chief' job."): 684 parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver) 685 686 def testLessThanOneWorker(self): 687 cluster_def = multi_worker_test_base.create_cluster_spec( 688 num_workers=0, num_ps=1, has_chief=True) 689 cluster_resolver = SimpleClusterResolver( 690 ClusterSpec(cluster_def), rpc_layer="grpc", task_type="ps", task_id=0) 691 with self.assertRaisesRegexp(ValueError, 692 "There must be at least one worker."): 693 parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver) 694 695 def testLessThanOnePs(self): 696 cluster_def = multi_worker_test_base.create_cluster_spec( 697 num_workers=1, num_ps=0, has_chief=True) 698 cluster_resolver = SimpleClusterResolver( 699 ClusterSpec(cluster_def), 700 rpc_layer="grpc", 701 task_type="worker", 702 task_id=0) 703 with self.assertRaisesRegexp(ValueError, "There must be at least one ps."): 704 parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver) 705 706 707if __name__ == "__main__": 708 v2_compat.enable_v2_behavior() 709 multi_process_runner.test_main() 710