1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for TPUStrategy."""
16
17import os
18
19from tensorflow.python.checkpoint import checkpoint as util
20from tensorflow.python.checkpoint import checkpoint_management
21from tensorflow.python.distribute import distribution_strategy_context
22from tensorflow.python.distribute import strategy_test_lib
23from tensorflow.python.distribute import tpu_strategy as tpu_lib
24from tensorflow.python.distribute import tpu_values
25from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
26from tensorflow.python.eager import def_function
27from tensorflow.python.eager import remote
28from tensorflow.python.eager import test
29from tensorflow.python.framework import config
30from tensorflow.python.framework import constant_op
31from tensorflow.python.framework import dtypes
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import test_util
34from tensorflow.python.module import module
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import control_flow_ops
37from tensorflow.python.ops import math_ops
38from tensorflow.python.ops import random_ops
39from tensorflow.python.ops import summary_ops_v2 as summary_ops
40from tensorflow.python.ops import variables
41from tensorflow.python.platform import flags
42from tensorflow.python.tpu import device_assignment as device_assignment_lib
43from tensorflow.python.tpu import tpu
44from tensorflow.python.tpu import tpu_strategy_util
45
46FLAGS = flags.FLAGS
47flags.DEFINE_string("tpu", "", "Name of TPU to connect to.")
48flags.DEFINE_string("project", None, "Name of GCP project with TPU.")
49flags.DEFINE_string("zone", None, "Name of GCP zone with TPU.")
50
51
52def get_tpu_cluster_resolver():
53  resolver = tpu_cluster_resolver.TPUClusterResolver(
54      tpu=FLAGS.tpu,
55      zone=FLAGS.zone,
56      project=FLAGS.project,
57  )
58  return resolver
59
60
61def get_tpu_strategy(enable_spmd=False):
62  resolver = get_tpu_cluster_resolver()
63  remote.connect_to_cluster(resolver)
64  topology = tpu_strategy_util.initialize_tpu_system(resolver)
65  num_replicas = resolver.get_tpu_system_metadata().num_cores // 2
66  device_assignment = device_assignment_lib.DeviceAssignment.build(
67      topology, num_replicas=num_replicas, computation_shape=[1, 1, 1, 2])
68  strategy = tpu_lib.TPUStrategyV2(
69      resolver,
70      experimental_device_assignment=device_assignment,
71      experimental_spmd_xla_partitioning=enable_spmd)
72  return strategy, num_replicas
73
74
75class TPUStrategyModelParallelismTest(
76    strategy_test_lib.DistributionTestBase,
77    strategy_test_lib.TwoDeviceDistributionTestBase):
78
79  def test_logical_device_assignment(self):
80    if test_util.is_mlir_bridge_enabled():
81      self.skipTest("TODO(b/238811067): fix MLIR bridge")
82    strategy, num_replicas = get_tpu_strategy()
83    with strategy.scope():
84      v = variables.Variable(2.)
85      with strategy.extended.experimental_logical_device(1):
86        w = variables.Variable(3.)
87
88    self.assertLen(strategy.experimental_local_results(v), num_replicas)
89    self.assertLen(strategy.experimental_local_results(w), num_replicas)
90    self.assertEqual("/job:localhost/replica:0/task:0/device:TPU:0",
91                     strategy.experimental_local_results(v)[0].device)
92    self.assertEqual("/job:localhost/replica:0/task:0/device:TPU:1",
93                     strategy.experimental_local_results(w)[0].device)
94
95    logical_devices = []
96
97    @def_function.function
98    def f(x):
99      replica_ctx = distribution_strategy_context.get_replica_context()
100      with replica_ctx.experimental_logical_device(0):
101        y = v * x
102      with replica_ctx.experimental_logical_device(1):
103        z = w * y
104      logical_devices.append((y.device, z.device))
105      return z
106
107    result = strategy.run(f, args=(5.,))
108
109    self.assertEqual(
110        [("/device:TPU_REPLICATED_CORE:0", "/device:TPU_REPLICATED_CORE:1")],
111        logical_devices)
112
113    with self.cached_session():
114      self.evaluate(variables.global_variables_initializer())
115      self.assertEqual(30. * num_replicas,
116                       self.evaluate(strategy.reduce("SUM", result, axis=None)))
117
118  def test_paritioned_model_checkpointing(self):
119    if test_util.is_mlir_bridge_enabled():
120      self.skipTest("TODO(b/238811067): fix MLIR bridge")
121
122    class PartitionedModel(module.Module):
123
124      def __init__(self, v, w):
125        super(PartitionedModel, self).__init__()
126
127        assert distribution_strategy_context.has_strategy()
128        strategy = distribution_strategy_context.get_strategy()
129
130        with strategy.extended.experimental_logical_device(0):
131          self.v = variables.Variable(v)
132        with strategy.extended.experimental_logical_device(1):
133          self.w = variables.Variable(w)
134
135      def __call__(self, x):
136        replica_ctx = distribution_strategy_context.get_replica_context()
137        with replica_ctx.experimental_logical_device(0):
138          y = self.v * x
139        with replica_ctx.experimental_logical_device(1):
140          z = self.w * y
141        return z
142
143      def change_weights_op(self, v_new, w_new):
144        return control_flow_ops.group(
145            [self.v.assign(v_new), self.w.assign(w_new)])
146
147    strategy, num_replicas = get_tpu_strategy()
148    with strategy.scope():
149      model = PartitionedModel(2., 3.)
150
151    checkpoint_dir = self.get_temp_dir()
152    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
153    checkpoint = util.Checkpoint(model=model)
154
155    with self.cached_session() as sess:
156      self.evaluate(variables.global_variables_initializer())
157      checkpoint.save(file_prefix=checkpoint_prefix)
158
159      self.evaluate(model.change_weights_op(1., 4.))
160      result = strategy.run(def_function.function(model), args=(5.0,))
161      self.assertEqual(20. * num_replicas,
162                       self.evaluate(strategy.reduce("SUM", result, axis=None)))
163
164      status = checkpoint.restore(
165          checkpoint_management.latest_checkpoint(checkpoint_dir))
166      status.run_restore_ops(sess)  # must run restore op in non-eager mode.
167      status.assert_consumed()
168      status.assert_existing_objects_matched()
169      result = strategy.run(def_function.function(model), args=(5.0,))
170      self.assertEqual(30. * num_replicas,
171                       self.evaluate(strategy.reduce("SUM", result, axis=None)))
172
173  def test_spmd_cannot_assign_tensor_to_logical_device(self):
174    strategy, _ = get_tpu_strategy(enable_spmd=True)
175    x = constant_op.constant([0, 1])
176    with self.assertRaises(ValueError):
177      strategy.experimental_assign_to_logical_device(x, 0)
178
179  def test_spmd_variable_created_from_callable(self):
180    initilizer = lambda: random_ops.random_normal(shape=(16, 16))
181    strategy, _ = get_tpu_strategy(enable_spmd=True)
182    with strategy.scope():
183      w = variables.Variable(initilizer)
184    value0 = w.values[0]
185    for v in value0.variables:
186      self.assertAllEqual(v, value0.variables[0])
187
188  def test_spmd_variable_read(self):
189    batch_size = 32
190    num_feature_in = 16
191    num_feature_out = 8
192
193    x = random_ops.random_uniform((batch_size, num_feature_in),
194                                  dtype=dtypes.float32)
195    w_init = random_ops.random_uniform((num_feature_in, num_feature_out),
196                                       dtype=dtypes.float32)
197
198    strategy, num_replicas = get_tpu_strategy(enable_spmd=True)
199    with strategy.scope():
200      w = variables.Variable(w_init, dtype=dtypes.float32)
201
202    self.assertEqual(w.values[0].variables[0].shape.as_list(),
203                     [num_feature_in, num_feature_out])
204
205    self.assertEqual(w.shape.as_list(), [num_feature_in, num_feature_out])
206
207    def step_fn(batch_features):
208      predict = math_ops.matmul(batch_features, w)
209      return predict
210
211    @def_function.function
212    def train_fn(batch_features):
213      return strategy.run(step_fn, args=(batch_features,))
214
215    result = train_fn(x)
216    self.assertAllClose(
217        strategy.reduce("SUM", result, axis=None),
218        math_ops.matmul(x, w_init) * num_replicas,
219        rtol=5e-03,
220        atol=5e-03)
221
222  def test_spmd_variable_read_init_scope(self):
223    strategy, _ = get_tpu_strategy(enable_spmd=True)
224    with strategy.scope():
225      v = variables.Variable(array_ops.ones((4, 4), dtype=dtypes.float32))
226
227    @def_function.function
228    def read_v():
229      with ops.init_scope():
230        return v.read_value()
231
232    result = strategy.reduce("MEAN", strategy.run(read_v), axis=None)
233    self.assertAllClose(result, v.read_value())
234
235  def test_spmd_variable_update(self):
236    batch_size = 1024
237    num_feature_in = 256
238
239    x = random_ops.random_uniform((batch_size, num_feature_in),
240                                  dtype=dtypes.float32)
241    w_init = random_ops.random_uniform((batch_size, num_feature_in),
242                                       dtype=dtypes.float32)
243
244    strategy, num_replicas = get_tpu_strategy(enable_spmd=True)
245    with strategy.scope():
246      w = variables.Variable(w_init, dtype=dtypes.float32)
247
248    self.assertIsInstance(w, tpu_values.TPUMirroredVariable)
249    self.assertTrue(w._is_replicated_or_sharded_to_logical_cores())
250
251    def make_strategy_run(fn):
252
253      def run(value):
254        return strategy.run(fn, args=(value,))
255
256      return def_function.function(run)
257
258    result = make_strategy_run(w.assign)(x)
259    self.assertAllClose(
260        strategy.reduce("SUM", result, axis=None), x * num_replicas)
261
262    delta = random_ops.random_uniform((batch_size, num_feature_in),
263                                      dtype=dtypes.float32)
264    result = make_strategy_run(w.assign_sub)(delta)
265    x -= delta
266    self.assertAllClose(
267        strategy.reduce("SUM", result, axis=None), x * num_replicas)
268
269    delta = random_ops.random_uniform((batch_size, num_feature_in),
270                                      dtype=dtypes.float32)
271    result = make_strategy_run(w.assign_add)(delta)
272    x += delta
273    self.assertAllClose(
274        strategy.reduce("SUM", result, axis=None), x * num_replicas)
275
276  def test_spmd_variable_eager_update(self):
277    batch_size = 32
278    num_feature_in = 16
279
280    x = random_ops.random_uniform((batch_size, num_feature_in),
281                                  dtype=dtypes.float32)
282    w_init = random_ops.random_uniform((batch_size, num_feature_in),
283                                       dtype=dtypes.float32)
284
285    strategy, _ = get_tpu_strategy(enable_spmd=True)
286    with strategy.scope():
287      w = variables.Variable(w_init, dtype=dtypes.float32)
288
289    w.assign(x)
290    result = w.numpy()
291    self.assertAllClose(result, x)
292
293    x1 = random_ops.random_uniform((batch_size, num_feature_in),
294                                   dtype=dtypes.float32)
295    w.assign_sub(x1)
296    result = w.numpy()
297    self.assertAllClose(result, x - x1)
298
299    x2 = random_ops.random_uniform((batch_size, num_feature_in),
300                                   dtype=dtypes.float32)
301    w.assign(x)
302    w.assign_add(x2)
303    result = w.numpy()
304    self.assertAllClose(result, x + x2)
305
306  def test_spmd_model_checkpointing(self):
307
308    class LinearModel(module.Module):
309
310      def __init__(self, w):
311        super(LinearModel, self).__init__()
312        self.w = variables.Variable(w)
313
314      def __call__(self, x):
315        return math_ops.matmul(x, self.w)
316
317      def change_weights_op(self, w_new):
318        return self.w.assign(w_new)
319
320    batch_size = 32
321    num_feature_in = 16
322    num_feature_out = 8
323    w1 = random_ops.random_uniform((num_feature_in, num_feature_out),
324                                   dtype=dtypes.float32)
325    w2 = random_ops.random_uniform((num_feature_in, num_feature_out),
326                                   dtype=dtypes.float32)
327    x = random_ops.random_uniform((batch_size, num_feature_in),
328                                  dtype=dtypes.float32)
329
330    strategy, num_replicas = get_tpu_strategy(enable_spmd=True)
331    with strategy.scope():
332      model = LinearModel(w1)
333
334    checkpoint_dir = self.get_temp_dir()
335    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
336    checkpoint = util.Checkpoint(model=model)
337
338    @def_function.function
339    def step_fn(x):
340      x = strategy.experimental_split_to_logical_devices(x, [1, 2])
341      return model(x)
342
343    with self.cached_session() as sess:
344      self.evaluate(variables.global_variables_initializer())
345      checkpoint.save(file_prefix=checkpoint_prefix)
346
347      self.evaluate(model.change_weights_op(w2))
348      result = strategy.run(step_fn, args=(x,))
349      self.assertAllClose(
350          math_ops.matmul(x, w2) * num_replicas,
351          self.evaluate(strategy.reduce("SUM", result, axis=None)),
352          rtol=5e-3,
353          atol=5e-3)
354
355      status = checkpoint.restore(
356          checkpoint_management.latest_checkpoint(checkpoint_dir))
357      status.run_restore_ops(sess)  # must run restore op in non-eager mode.
358      status.assert_consumed()
359      status.assert_existing_objects_matched()
360      result = strategy.run(step_fn, args=(x,))
361      self.assertAllClose(
362          math_ops.matmul(x, w1) * num_replicas,
363          self.evaluate(strategy.reduce("SUM", result, axis=None)),
364          rtol=5e-3,
365          atol=5e-3)
366
367  def test_spmd_with_summary(self):
368    if test_util.is_mlir_bridge_enabled():
369      self.skipTest("TODO(b/232580663): fix MLIR bridge")
370    original_device_placement = config.get_soft_device_placement()
371    config.set_soft_device_placement(True)
372
373    strategy, _ = get_tpu_strategy(enable_spmd=True)
374    summary_dir = self.get_temp_dir()
375    writer = summary_ops.create_file_writer_v2(summary_dir)
376
377    with strategy.scope():
378      step = variables.Variable(0, dtype=dtypes.int64)
379
380    @def_function.function
381    def run():
382      with writer.as_default():
383        summary_ops.scalar("result", step * 2, step=step)
384        step.assign_add(1)
385
386    for _ in range(10):
387      strategy.run(run, args=())
388
389    for val in step.values:
390      for var in val.variables:
391        self.assertAllEqual(10, var)
392
393    config.set_soft_device_placement(original_device_placement)
394
395  def test_spmd_with_outside_comp(self):
396    strategy, num_replicas = get_tpu_strategy(enable_spmd=True)
397
398    def host_inc(x):
399      return x + 1
400
401    @def_function.function
402    def fn(x):
403      y = x + 1
404      z = tpu.outside_compilation(host_inc, y)
405      a = z + 1
406      return a
407
408    arg = constant_op.constant(0, shape=(), dtype=dtypes.int64)
409    result = strategy.run(fn, args=(arg,))
410    self.assertEqual(3 * num_replicas,
411                     self.evaluate(strategy.reduce("SUM", result, axis=None)))
412
413if __name__ == "__main__":
414  test.main()
415