1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for SharedVariableCreator.""" 16 17from tensorflow.python.distribute import shared_variable_creator 18from tensorflow.python.eager import test 19from tensorflow.python.framework import test_util 20from tensorflow.python.ops import variable_scope 21 22 23class CanonicalizeVariableNameTest(test.TestCase): 24 25 def _canonicalize(self, name): 26 return shared_variable_creator._canonicalize_variable_name(name) 27 28 def testNoName(self): 29 self.assertEqual("Variable", self._canonicalize(None)) 30 31 def testPatternInMiddle(self): 32 self.assertEqual("foo/bar/baz", self._canonicalize("foo_1/bar_1/baz")) 33 34 def testPatternAtEnd(self): 35 self.assertEqual("foo", self._canonicalize("foo_1")) 36 37 def testWrongPatterns(self): 38 self.assertEqual("foo_1:0", self._canonicalize("foo_1:0")) 39 self.assertEqual("foo1", self._canonicalize("foo1")) 40 self.assertEqual("foo_a", self._canonicalize("foo_a")) 41 42 43class SharedVariableCreatorTest(test.TestCase): 44 45 @test_util.run_in_graph_and_eager_modes 46 def testSharedVariable(self): 47 48 shared_variable_store = {} 49 num_devices = 3 50 creator_fns = [] 51 for i in range(num_devices): 52 creator_fn = shared_variable_creator.make_fn(shared_variable_store, i) 53 creator_fns.append(creator_fn) 54 55 with variable_scope.variable_creator_scope(creator_fns[0]): 56 v0 = variable_scope.variable(1.0, name="foo") 57 58 with variable_scope.variable_creator_scope(creator_fns[1]): 59 v1 = variable_scope.variable(1.0, name="foo") 60 61 with variable_scope.variable_creator_scope(creator_fns[2]): 62 v2 = variable_scope.variable(1.0, name="foo") 63 64 # v1 and v2 should be same as v0 65 self.assertIs(v1, v0) 66 self.assertIs(v2, v0) 67 68 69if __name__ == "__main__": 70 test.main() 71