xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/template_mirrored_strategy_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for make_template used with MirroredStrategy."""
16from tensorflow.python.distribute import distribution_strategy_context as ds_context
17from tensorflow.python.distribute import mirrored_strategy
18from tensorflow.python.framework import ops
19from tensorflow.python.framework import test_util
20from tensorflow.python.ops import init_ops
21from tensorflow.python.ops import template
22from tensorflow.python.ops import variable_scope
23from tensorflow.python.ops import variables
24from tensorflow.python.platform import test
25
26
27class TemplateMirroredStrategyTest(test.TestCase):
28
29  @test_util.disable_tfrt("Strategy not supported yet.")
30  def test_merge_call(self):
31    with ops.Graph().as_default():
32      # The test is testing a v1 only function.
33      if not test.is_gpu_available():
34        self.skipTest("No GPU available")
35
36      def fn():
37        var1 = variable_scope.get_variable(
38            "var1", shape=[], initializer=init_ops.constant_initializer(21.))
39        ds_context.get_replica_context().merge_call(lambda _: ())
40        var2 = variable_scope.get_variable(
41            "var2", shape=[], initializer=init_ops.constant_initializer(2.))
42        return var1 * var2
43
44      temp = template.make_template("my_template", fn)
45
46      strategy = mirrored_strategy.MirroredStrategy(["/cpu:0", "/gpu:0"])
47      out = strategy.experimental_local_results(
48          strategy.run(temp))
49
50      self.evaluate(variables.global_variables_initializer())
51      self.assertAllEqual([42., 42.], self.evaluate(out))
52
53
54if __name__ == "__main__":
55  test.main()
56