1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for custom training loops.""" 16 17from absl.testing import parameterized 18 19from tensorflow.python import tf2 20from tensorflow.python.data.ops import dataset_ops 21from tensorflow.python.distribute import combinations 22from tensorflow.python.distribute import strategy_combinations 23from tensorflow.python.eager import backprop 24from tensorflow.python.eager import def_function 25from tensorflow.python.eager import test 26from tensorflow.python.ops import math_ops 27from tensorflow.python.ops import variables 28 29 30def get_dataset_from_tensor_slices(inp_array): 31 dataset = dataset_ops.DatasetV2.from_tensor_slices(inp_array) 32 # TODO(b/138326910): Remove Dataset V1 version once bug resolved. 33 if not tf2.enabled(): 34 dataset = dataset_ops.Dataset.from_tensor_slices(inp_array) 35 return dataset 36 37 38class AssertFlattenedMixin(object): 39 """Mixin for specialized asserts.""" 40 41 def assert_equal_flattened(self, expected_results, actual_results): 42 """Asserts that flattened results are equal. 43 44 Due to the number of replicas in the strategy, the output may have a 45 different structure and needs to be flattened for comparison. 46 47 Args: 48 expected_results: The results expected as a result of a computation. 49 actual_results: The actual results of a computation. 50 """ 51 self.assertEqual(len(expected_results), len(actual_results)) 52 53 for i, expected_result in enumerate(expected_results): 54 final_result = [] 55 actual_result = actual_results[i] 56 for val in actual_result: 57 final_result.extend(val.numpy()) 58 self.assertAllEqual(expected_result, final_result) 59 60 61class GradientTapeTest(test.TestCase, parameterized.TestCase, 62 AssertFlattenedMixin): 63 64 @combinations.generate( 65 combinations.combine( 66 distribution=strategy_combinations.all_strategies, 67 mode=["eager"] 68 )) 69 def testStepInFunctionGradient(self, distribution): 70 dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) 71 72 @def_function.function 73 def train_step(x): 74 def computation(x): 75 return math_ops.square(x) 76 with backprop.GradientTape() as tape: 77 tape.watch(x) # Manually watch non-variable tensors. 78 y = computation(x) 79 grads = tape.gradient(y, x) 80 return grads 81 82 dist_dataset = distribution.experimental_distribute_dataset(dataset) 83 results = [] 84 for x in dist_dataset: 85 output = distribution.experimental_local_results( 86 distribution.run(train_step, args=(x,))) 87 results.append(output) 88 self.assert_equal_flattened([[10., 12.], [14., 16.]], results) 89 90 @combinations.generate( 91 combinations.combine( 92 distribution=strategy_combinations.all_strategies, 93 mode=["eager"] 94 )) 95 def testRunInFunctionGradient(self, distribution): 96 dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) 97 98 @def_function.function 99 def run(x): 100 def train_step(x): 101 def computation(x): 102 return math_ops.square(x) 103 with backprop.GradientTape() as tape: 104 tape.watch(x) # Manually watch non-variable tensors. 105 y = computation(x) 106 grads = tape.gradient(y, x) 107 return grads 108 return distribution.experimental_local_results( 109 distribution.run(train_step, args=(x,))) 110 dist_dataset = distribution.experimental_distribute_dataset(dataset) 111 results = [] 112 for x in dist_dataset: 113 output = run(x) 114 results.append(output) 115 self.assert_equal_flattened([[10., 12.], [14., 16.]], results) 116 117 @combinations.generate( 118 combinations.combine( 119 distribution=strategy_combinations.all_strategies, 120 mode=["eager"], 121 model_in_tf_function=[True, False] 122 )) 123 def testNestedFunction(self, distribution, model_in_tf_function): 124 def model(x): 125 return x * x 126 127 if model_in_tf_function: 128 model = def_function.function(model) 129 130 with distribution.scope(): 131 x = variables.Variable(1.0) 132 133 @def_function.function 134 def train_step(): 135 def replica_step(): 136 with backprop.GradientTape() as tape: 137 y = model(x) 138 return tape.gradient(y, x) 139 return distribution.run(replica_step) 140 141 grads = distribution.experimental_local_results(train_step()) 142 self.assertLen(grads, distribution.num_replicas_in_sync) 143 self.assertTrue(all(g is not None for g in grads)) 144 145 146if __name__ == "__main__": 147 test.main() 148