1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for TPUStrategy.""" 16 17import os 18 19from tensorflow.python.checkpoint import checkpoint as util 20from tensorflow.python.checkpoint import checkpoint_management 21from tensorflow.python.distribute import distribution_strategy_context 22from tensorflow.python.distribute import strategy_test_lib 23from tensorflow.python.distribute import tpu_strategy as tpu_lib 24from tensorflow.python.distribute import tpu_values 25from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver 26from tensorflow.python.eager import def_function 27from tensorflow.python.eager import remote 28from tensorflow.python.eager import test 29from tensorflow.python.framework import config 30from tensorflow.python.framework import constant_op 31from tensorflow.python.framework import dtypes 32from tensorflow.python.framework import ops 33from tensorflow.python.framework import test_util 34from tensorflow.python.module import module 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import control_flow_ops 37from tensorflow.python.ops import math_ops 38from tensorflow.python.ops import random_ops 39from tensorflow.python.ops import summary_ops_v2 as summary_ops 40from tensorflow.python.ops import variables 41from tensorflow.python.platform import flags 42from tensorflow.python.tpu import device_assignment as device_assignment_lib 43from tensorflow.python.tpu import tpu 44from tensorflow.python.tpu import tpu_strategy_util 45 46FLAGS = flags.FLAGS 47flags.DEFINE_string("tpu", "", "Name of TPU to connect to.") 48flags.DEFINE_string("project", None, "Name of GCP project with TPU.") 49flags.DEFINE_string("zone", None, "Name of GCP zone with TPU.") 50 51 52def get_tpu_cluster_resolver(): 53 resolver = tpu_cluster_resolver.TPUClusterResolver( 54 tpu=FLAGS.tpu, 55 zone=FLAGS.zone, 56 project=FLAGS.project, 57 ) 58 return resolver 59 60 61def get_tpu_strategy(enable_spmd=False): 62 resolver = get_tpu_cluster_resolver() 63 remote.connect_to_cluster(resolver) 64 topology = tpu_strategy_util.initialize_tpu_system(resolver) 65 num_replicas = resolver.get_tpu_system_metadata().num_cores // 2 66 device_assignment = device_assignment_lib.DeviceAssignment.build( 67 topology, num_replicas=num_replicas, computation_shape=[1, 1, 1, 2]) 68 strategy = tpu_lib.TPUStrategyV2( 69 resolver, 70 experimental_device_assignment=device_assignment, 71 experimental_spmd_xla_partitioning=enable_spmd) 72 return strategy, num_replicas 73 74 75class TPUStrategyModelParallelismTest( 76 strategy_test_lib.DistributionTestBase, 77 strategy_test_lib.TwoDeviceDistributionTestBase): 78 79 def test_logical_device_assignment(self): 80 if test_util.is_mlir_bridge_enabled(): 81 self.skipTest("TODO(b/238811067): fix MLIR bridge") 82 strategy, num_replicas = get_tpu_strategy() 83 with strategy.scope(): 84 v = variables.Variable(2.) 85 with strategy.extended.experimental_logical_device(1): 86 w = variables.Variable(3.) 87 88 self.assertLen(strategy.experimental_local_results(v), num_replicas) 89 self.assertLen(strategy.experimental_local_results(w), num_replicas) 90 self.assertEqual("/job:localhost/replica:0/task:0/device:TPU:0", 91 strategy.experimental_local_results(v)[0].device) 92 self.assertEqual("/job:localhost/replica:0/task:0/device:TPU:1", 93 strategy.experimental_local_results(w)[0].device) 94 95 logical_devices = [] 96 97 @def_function.function 98 def f(x): 99 replica_ctx = distribution_strategy_context.get_replica_context() 100 with replica_ctx.experimental_logical_device(0): 101 y = v * x 102 with replica_ctx.experimental_logical_device(1): 103 z = w * y 104 logical_devices.append((y.device, z.device)) 105 return z 106 107 result = strategy.run(f, args=(5.,)) 108 109 self.assertEqual( 110 [("/device:TPU_REPLICATED_CORE:0", "/device:TPU_REPLICATED_CORE:1")], 111 logical_devices) 112 113 with self.cached_session(): 114 self.evaluate(variables.global_variables_initializer()) 115 self.assertEqual(30. * num_replicas, 116 self.evaluate(strategy.reduce("SUM", result, axis=None))) 117 118 def test_paritioned_model_checkpointing(self): 119 if test_util.is_mlir_bridge_enabled(): 120 self.skipTest("TODO(b/238811067): fix MLIR bridge") 121 122 class PartitionedModel(module.Module): 123 124 def __init__(self, v, w): 125 super(PartitionedModel, self).__init__() 126 127 assert distribution_strategy_context.has_strategy() 128 strategy = distribution_strategy_context.get_strategy() 129 130 with strategy.extended.experimental_logical_device(0): 131 self.v = variables.Variable(v) 132 with strategy.extended.experimental_logical_device(1): 133 self.w = variables.Variable(w) 134 135 def __call__(self, x): 136 replica_ctx = distribution_strategy_context.get_replica_context() 137 with replica_ctx.experimental_logical_device(0): 138 y = self.v * x 139 with replica_ctx.experimental_logical_device(1): 140 z = self.w * y 141 return z 142 143 def change_weights_op(self, v_new, w_new): 144 return control_flow_ops.group( 145 [self.v.assign(v_new), self.w.assign(w_new)]) 146 147 strategy, num_replicas = get_tpu_strategy() 148 with strategy.scope(): 149 model = PartitionedModel(2., 3.) 150 151 checkpoint_dir = self.get_temp_dir() 152 checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt") 153 checkpoint = util.Checkpoint(model=model) 154 155 with self.cached_session() as sess: 156 self.evaluate(variables.global_variables_initializer()) 157 checkpoint.save(file_prefix=checkpoint_prefix) 158 159 self.evaluate(model.change_weights_op(1., 4.)) 160 result = strategy.run(def_function.function(model), args=(5.0,)) 161 self.assertEqual(20. * num_replicas, 162 self.evaluate(strategy.reduce("SUM", result, axis=None))) 163 164 status = checkpoint.restore( 165 checkpoint_management.latest_checkpoint(checkpoint_dir)) 166 status.run_restore_ops(sess) # must run restore op in non-eager mode. 167 status.assert_consumed() 168 status.assert_existing_objects_matched() 169 result = strategy.run(def_function.function(model), args=(5.0,)) 170 self.assertEqual(30. * num_replicas, 171 self.evaluate(strategy.reduce("SUM", result, axis=None))) 172 173 def test_spmd_cannot_assign_tensor_to_logical_device(self): 174 strategy, _ = get_tpu_strategy(enable_spmd=True) 175 x = constant_op.constant([0, 1]) 176 with self.assertRaises(ValueError): 177 strategy.experimental_assign_to_logical_device(x, 0) 178 179 def test_spmd_variable_created_from_callable(self): 180 initilizer = lambda: random_ops.random_normal(shape=(16, 16)) 181 strategy, _ = get_tpu_strategy(enable_spmd=True) 182 with strategy.scope(): 183 w = variables.Variable(initilizer) 184 value0 = w.values[0] 185 for v in value0.variables: 186 self.assertAllEqual(v, value0.variables[0]) 187 188 def test_spmd_variable_read(self): 189 batch_size = 32 190 num_feature_in = 16 191 num_feature_out = 8 192 193 x = random_ops.random_uniform((batch_size, num_feature_in), 194 dtype=dtypes.float32) 195 w_init = random_ops.random_uniform((num_feature_in, num_feature_out), 196 dtype=dtypes.float32) 197 198 strategy, num_replicas = get_tpu_strategy(enable_spmd=True) 199 with strategy.scope(): 200 w = variables.Variable(w_init, dtype=dtypes.float32) 201 202 self.assertEqual(w.values[0].variables[0].shape.as_list(), 203 [num_feature_in, num_feature_out]) 204 205 self.assertEqual(w.shape.as_list(), [num_feature_in, num_feature_out]) 206 207 def step_fn(batch_features): 208 predict = math_ops.matmul(batch_features, w) 209 return predict 210 211 @def_function.function 212 def train_fn(batch_features): 213 return strategy.run(step_fn, args=(batch_features,)) 214 215 result = train_fn(x) 216 self.assertAllClose( 217 strategy.reduce("SUM", result, axis=None), 218 math_ops.matmul(x, w_init) * num_replicas, 219 rtol=5e-03, 220 atol=5e-03) 221 222 def test_spmd_variable_read_init_scope(self): 223 strategy, _ = get_tpu_strategy(enable_spmd=True) 224 with strategy.scope(): 225 v = variables.Variable(array_ops.ones((4, 4), dtype=dtypes.float32)) 226 227 @def_function.function 228 def read_v(): 229 with ops.init_scope(): 230 return v.read_value() 231 232 result = strategy.reduce("MEAN", strategy.run(read_v), axis=None) 233 self.assertAllClose(result, v.read_value()) 234 235 def test_spmd_variable_update(self): 236 batch_size = 1024 237 num_feature_in = 256 238 239 x = random_ops.random_uniform((batch_size, num_feature_in), 240 dtype=dtypes.float32) 241 w_init = random_ops.random_uniform((batch_size, num_feature_in), 242 dtype=dtypes.float32) 243 244 strategy, num_replicas = get_tpu_strategy(enable_spmd=True) 245 with strategy.scope(): 246 w = variables.Variable(w_init, dtype=dtypes.float32) 247 248 self.assertIsInstance(w, tpu_values.TPUMirroredVariable) 249 self.assertTrue(w._is_replicated_or_sharded_to_logical_cores()) 250 251 def make_strategy_run(fn): 252 253 def run(value): 254 return strategy.run(fn, args=(value,)) 255 256 return def_function.function(run) 257 258 result = make_strategy_run(w.assign)(x) 259 self.assertAllClose( 260 strategy.reduce("SUM", result, axis=None), x * num_replicas) 261 262 delta = random_ops.random_uniform((batch_size, num_feature_in), 263 dtype=dtypes.float32) 264 result = make_strategy_run(w.assign_sub)(delta) 265 x -= delta 266 self.assertAllClose( 267 strategy.reduce("SUM", result, axis=None), x * num_replicas) 268 269 delta = random_ops.random_uniform((batch_size, num_feature_in), 270 dtype=dtypes.float32) 271 result = make_strategy_run(w.assign_add)(delta) 272 x += delta 273 self.assertAllClose( 274 strategy.reduce("SUM", result, axis=None), x * num_replicas) 275 276 def test_spmd_variable_eager_update(self): 277 batch_size = 32 278 num_feature_in = 16 279 280 x = random_ops.random_uniform((batch_size, num_feature_in), 281 dtype=dtypes.float32) 282 w_init = random_ops.random_uniform((batch_size, num_feature_in), 283 dtype=dtypes.float32) 284 285 strategy, _ = get_tpu_strategy(enable_spmd=True) 286 with strategy.scope(): 287 w = variables.Variable(w_init, dtype=dtypes.float32) 288 289 w.assign(x) 290 result = w.numpy() 291 self.assertAllClose(result, x) 292 293 x1 = random_ops.random_uniform((batch_size, num_feature_in), 294 dtype=dtypes.float32) 295 w.assign_sub(x1) 296 result = w.numpy() 297 self.assertAllClose(result, x - x1) 298 299 x2 = random_ops.random_uniform((batch_size, num_feature_in), 300 dtype=dtypes.float32) 301 w.assign(x) 302 w.assign_add(x2) 303 result = w.numpy() 304 self.assertAllClose(result, x + x2) 305 306 def test_spmd_model_checkpointing(self): 307 308 class LinearModel(module.Module): 309 310 def __init__(self, w): 311 super(LinearModel, self).__init__() 312 self.w = variables.Variable(w) 313 314 def __call__(self, x): 315 return math_ops.matmul(x, self.w) 316 317 def change_weights_op(self, w_new): 318 return self.w.assign(w_new) 319 320 batch_size = 32 321 num_feature_in = 16 322 num_feature_out = 8 323 w1 = random_ops.random_uniform((num_feature_in, num_feature_out), 324 dtype=dtypes.float32) 325 w2 = random_ops.random_uniform((num_feature_in, num_feature_out), 326 dtype=dtypes.float32) 327 x = random_ops.random_uniform((batch_size, num_feature_in), 328 dtype=dtypes.float32) 329 330 strategy, num_replicas = get_tpu_strategy(enable_spmd=True) 331 with strategy.scope(): 332 model = LinearModel(w1) 333 334 checkpoint_dir = self.get_temp_dir() 335 checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt") 336 checkpoint = util.Checkpoint(model=model) 337 338 @def_function.function 339 def step_fn(x): 340 x = strategy.experimental_split_to_logical_devices(x, [1, 2]) 341 return model(x) 342 343 with self.cached_session() as sess: 344 self.evaluate(variables.global_variables_initializer()) 345 checkpoint.save(file_prefix=checkpoint_prefix) 346 347 self.evaluate(model.change_weights_op(w2)) 348 result = strategy.run(step_fn, args=(x,)) 349 self.assertAllClose( 350 math_ops.matmul(x, w2) * num_replicas, 351 self.evaluate(strategy.reduce("SUM", result, axis=None)), 352 rtol=5e-3, 353 atol=5e-3) 354 355 status = checkpoint.restore( 356 checkpoint_management.latest_checkpoint(checkpoint_dir)) 357 status.run_restore_ops(sess) # must run restore op in non-eager mode. 358 status.assert_consumed() 359 status.assert_existing_objects_matched() 360 result = strategy.run(step_fn, args=(x,)) 361 self.assertAllClose( 362 math_ops.matmul(x, w1) * num_replicas, 363 self.evaluate(strategy.reduce("SUM", result, axis=None)), 364 rtol=5e-3, 365 atol=5e-3) 366 367 def test_spmd_with_summary(self): 368 if test_util.is_mlir_bridge_enabled(): 369 self.skipTest("TODO(b/232580663): fix MLIR bridge") 370 original_device_placement = config.get_soft_device_placement() 371 config.set_soft_device_placement(True) 372 373 strategy, _ = get_tpu_strategy(enable_spmd=True) 374 summary_dir = self.get_temp_dir() 375 writer = summary_ops.create_file_writer_v2(summary_dir) 376 377 with strategy.scope(): 378 step = variables.Variable(0, dtype=dtypes.int64) 379 380 @def_function.function 381 def run(): 382 with writer.as_default(): 383 summary_ops.scalar("result", step * 2, step=step) 384 step.assign_add(1) 385 386 for _ in range(10): 387 strategy.run(run, args=()) 388 389 for val in step.values: 390 for var in val.variables: 391 self.assertAllEqual(10, var) 392 393 config.set_soft_device_placement(original_device_placement) 394 395 def test_spmd_with_outside_comp(self): 396 strategy, num_replicas = get_tpu_strategy(enable_spmd=True) 397 398 def host_inc(x): 399 return x + 1 400 401 @def_function.function 402 def fn(x): 403 y = x + 1 404 z = tpu.outside_compilation(host_inc, y) 405 a = z + 1 406 return a 407 408 arg = constant_op.constant(0, shape=(), dtype=dtypes.int64) 409 result = strategy.run(fn, args=(arg,)) 410 self.assertEqual(3 * num_replicas, 411 self.evaluate(strategy.reduce("SUM", result, axis=None))) 412 413if __name__ == "__main__": 414 test.main() 415