1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for custom training loops."""
16
17from absl.testing import parameterized
18
19from tensorflow.python import tf2
20from tensorflow.python.data.ops import dataset_ops
21from tensorflow.python.distribute import combinations
22from tensorflow.python.distribute import strategy_combinations
23from tensorflow.python.eager import backprop
24from tensorflow.python.eager import def_function
25from tensorflow.python.eager import test
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops import variables
28
29
30def get_dataset_from_tensor_slices(inp_array):
31  dataset = dataset_ops.DatasetV2.from_tensor_slices(inp_array)
32  # TODO(b/138326910): Remove Dataset V1 version once bug resolved.
33  if not tf2.enabled():
34    dataset = dataset_ops.Dataset.from_tensor_slices(inp_array)
35  return dataset
36
37
38class AssertFlattenedMixin(object):
39  """Mixin for specialized asserts."""
40
41  def assert_equal_flattened(self, expected_results, actual_results):
42    """Asserts that flattened results are equal.
43
44    Due to the number of replicas in the strategy, the output may have a
45    different structure and needs to be flattened for comparison.
46
47    Args:
48      expected_results: The results expected as a result of a computation.
49      actual_results: The actual results of a computation.
50    """
51    self.assertEqual(len(expected_results), len(actual_results))
52
53    for i, expected_result in enumerate(expected_results):
54      final_result = []
55      actual_result = actual_results[i]
56      for val in actual_result:
57        final_result.extend(val.numpy())
58      self.assertAllEqual(expected_result, final_result)
59
60
61class GradientTapeTest(test.TestCase, parameterized.TestCase,
62                       AssertFlattenedMixin):
63
64  @combinations.generate(
65      combinations.combine(
66          distribution=strategy_combinations.all_strategies,
67          mode=["eager"]
68      ))
69  def testStepInFunctionGradient(self, distribution):
70    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
71
72    @def_function.function
73    def train_step(x):
74      def computation(x):
75        return math_ops.square(x)
76      with backprop.GradientTape() as tape:
77        tape.watch(x)  # Manually watch non-variable tensors.
78        y = computation(x)
79      grads = tape.gradient(y, x)
80      return grads
81
82    dist_dataset = distribution.experimental_distribute_dataset(dataset)
83    results = []
84    for x in dist_dataset:
85      output = distribution.experimental_local_results(
86          distribution.run(train_step, args=(x,)))
87      results.append(output)
88    self.assert_equal_flattened([[10., 12.], [14., 16.]], results)
89
90  @combinations.generate(
91      combinations.combine(
92          distribution=strategy_combinations.all_strategies,
93          mode=["eager"]
94      ))
95  def testRunInFunctionGradient(self, distribution):
96    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
97
98    @def_function.function
99    def run(x):
100      def train_step(x):
101        def computation(x):
102          return math_ops.square(x)
103        with backprop.GradientTape() as tape:
104          tape.watch(x)  # Manually watch non-variable tensors.
105          y = computation(x)
106        grads = tape.gradient(y, x)
107        return grads
108      return distribution.experimental_local_results(
109          distribution.run(train_step, args=(x,)))
110    dist_dataset = distribution.experimental_distribute_dataset(dataset)
111    results = []
112    for x in dist_dataset:
113      output = run(x)
114      results.append(output)
115    self.assert_equal_flattened([[10., 12.], [14., 16.]], results)
116
117  @combinations.generate(
118      combinations.combine(
119          distribution=strategy_combinations.all_strategies,
120          mode=["eager"],
121          model_in_tf_function=[True, False]
122      ))
123  def testNestedFunction(self, distribution, model_in_tf_function):
124    def model(x):
125      return x * x
126
127    if model_in_tf_function:
128      model = def_function.function(model)
129
130    with distribution.scope():
131      x = variables.Variable(1.0)
132
133      @def_function.function
134      def train_step():
135        def replica_step():
136          with backprop.GradientTape() as tape:
137            y = model(x)
138          return tape.gradient(y, x)
139        return distribution.run(replica_step)
140
141      grads = distribution.experimental_local_results(train_step())
142      self.assertLen(grads, distribution.num_replicas_in_sync)
143      self.assertTrue(all(g is not None for g in grads))
144
145
146if __name__ == "__main__":
147  test.main()
148