1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for the distributed values library.""" 16 17from absl.testing import parameterized 18 19from tensorflow.python.distribute import combinations 20from tensorflow.python.distribute import ps_values 21from tensorflow.python.distribute import strategy_combinations 22from tensorflow.python.eager import def_function 23from tensorflow.python.eager import test 24from tensorflow.python.ops import variable_scope 25from tensorflow.python.ops import variables as variables_lib 26 27 28@combinations.generate( 29 combinations.combine( 30 distribution=[ 31 strategy_combinations.central_storage_strategy_with_two_gpus 32 ], 33 mode=["graph", "eager"])) 34class AggregatingVariableTest(test.TestCase, parameterized.TestCase): 35 36 def testAssignOutOfScope(self, distribution): 37 with distribution.scope(): 38 aggregating = variables_lib.Variable(1.) 39 self.assertIsInstance(aggregating, ps_values.AggregatingVariable) 40 self.evaluate(aggregating.assign(3.)) 41 self.assertEqual(self.evaluate(aggregating.read_value()), 3.) 42 self.assertEqual(self.evaluate(aggregating._v.read_value()), 3.) 43 44 def testAssignAdd(self, distribution): 45 with distribution.scope(): 46 v = variable_scope.variable( 47 1, aggregation=variables_lib.VariableAggregation.MEAN) 48 self.evaluate(variables_lib.global_variables_initializer()) 49 50 @def_function.function 51 def assign(): 52 return v.assign_add(2) 53 54 per_replica_results = self.evaluate( 55 distribution.experimental_local_results( 56 distribution.run(assign))) 57 self.assertAllEqual([3], per_replica_results) 58 59 60if __name__ == "__main__": 61 test.main() 62