xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/shared_variable_creator_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for SharedVariableCreator."""
16
17from tensorflow.python.distribute import shared_variable_creator
18from tensorflow.python.eager import test
19from tensorflow.python.framework import test_util
20from tensorflow.python.ops import variable_scope
21
22
23class CanonicalizeVariableNameTest(test.TestCase):
24
25  def _canonicalize(self, name):
26    return shared_variable_creator._canonicalize_variable_name(name)
27
28  def testNoName(self):
29    self.assertEqual("Variable", self._canonicalize(None))
30
31  def testPatternInMiddle(self):
32    self.assertEqual("foo/bar/baz", self._canonicalize("foo_1/bar_1/baz"))
33
34  def testPatternAtEnd(self):
35    self.assertEqual("foo", self._canonicalize("foo_1"))
36
37  def testWrongPatterns(self):
38    self.assertEqual("foo_1:0", self._canonicalize("foo_1:0"))
39    self.assertEqual("foo1", self._canonicalize("foo1"))
40    self.assertEqual("foo_a", self._canonicalize("foo_a"))
41
42
43class SharedVariableCreatorTest(test.TestCase):
44
45  @test_util.run_in_graph_and_eager_modes
46  def testSharedVariable(self):
47
48    shared_variable_store = {}
49    num_devices = 3
50    creator_fns = []
51    for i in range(num_devices):
52      creator_fn = shared_variable_creator.make_fn(shared_variable_store, i)
53      creator_fns.append(creator_fn)
54
55    with variable_scope.variable_creator_scope(creator_fns[0]):
56      v0 = variable_scope.variable(1.0, name="foo")
57
58    with variable_scope.variable_creator_scope(creator_fns[1]):
59      v1 = variable_scope.variable(1.0, name="foo")
60
61    with variable_scope.variable_creator_scope(creator_fns[2]):
62      v2 = variable_scope.variable(1.0, name="foo")
63
64    # v1 and v2 should be same as v0
65    self.assertIs(v1, v0)
66    self.assertIs(v2, v0)
67
68
69if __name__ == "__main__":
70  test.main()
71