1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for make_template used with MirroredStrategy.""" 16from tensorflow.python.distribute import distribution_strategy_context as ds_context 17from tensorflow.python.distribute import mirrored_strategy 18from tensorflow.python.framework import ops 19from tensorflow.python.framework import test_util 20from tensorflow.python.ops import init_ops 21from tensorflow.python.ops import template 22from tensorflow.python.ops import variable_scope 23from tensorflow.python.ops import variables 24from tensorflow.python.platform import test 25 26 27class TemplateMirroredStrategyTest(test.TestCase): 28 29 @test_util.disable_tfrt("Strategy not supported yet.") 30 def test_merge_call(self): 31 with ops.Graph().as_default(): 32 # The test is testing a v1 only function. 33 if not test.is_gpu_available(): 34 self.skipTest("No GPU available") 35 36 def fn(): 37 var1 = variable_scope.get_variable( 38 "var1", shape=[], initializer=init_ops.constant_initializer(21.)) 39 ds_context.get_replica_context().merge_call(lambda _: ()) 40 var2 = variable_scope.get_variable( 41 "var2", shape=[], initializer=init_ops.constant_initializer(2.)) 42 return var1 * var2 43 44 temp = template.make_template("my_template", fn) 45 46 strategy = mirrored_strategy.MirroredStrategy(["/cpu:0", "/gpu:0"]) 47 out = strategy.experimental_local_results( 48 strategy.run(temp)) 49 50 self.evaluate(variables.global_variables_initializer()) 51 self.assertAllEqual([42., 42.], self.evaluate(out)) 52 53 54if __name__ == "__main__": 55 test.main() 56