1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""The step function abstraction represents a single training step.""" 16 17from tensorflow.python.eager import backprop 18from tensorflow.python.training import optimizer as optimizer_lib 19 20 21class Step(object): 22 """Interface for performing each step of a training algorithm.""" 23 24 def __init__(self, distribution): 25 self._distribution = distribution 26 27 @property 28 def distribution(self): 29 return self._distribution 30 31 def initialize(self): 32 return [] 33 34 def __call__(self): 35 """Perform one step of this training algorithm.""" 36 raise NotImplementedError("must be implemented in descendants") 37 38 # TODO(priyag): Add an method to access initialization and finalize ops. 39 40 41class StandardInputStep(Step): 42 """Step with a standard implementation of input handling. 43 44 Args: 45 dataset_fn: a function that returns a tf.data Dataset that produces the 46 input for the model. 47 """ 48 49 def __init__(self, dataset_fn, distribution): 50 super(StandardInputStep, self).__init__(distribution) 51 self._iterator = distribution.make_input_fn_iterator(lambda _: dataset_fn()) 52 53 def initialize(self): 54 return self._iterator.initializer 55 56 57class StandardSingleLossStep(StandardInputStep): 58 """A step function that implements a training step for a feed forward network. 59 60 An instance of this class is intended to be used as a callable: 61 62 ```python 63 ... 64 step = step_fn.StandardSingleLossStep( 65 dataset, loss_fn, optimizer, distribution) 66 67 # Run a single training step on a given DistributionStrategy: 68 step(distribution) 69 ... 70 ``` 71 72 Args: 73 dataset_fn: a function that returns a tf.data Dataset that produces the 74 input for the model. 75 loss_fn: a function that takes a context and inputs as arguments. It returns 76 the loss for those inputs. `context` is an instance of 77 `values.MultiStepContext` that will be passed when `loss_fn` is run. 78 `context` can be used to specify the outputs to be returned from 79 `loss_fn`, among other things. 80 optimizer: an optimizer that implements an update rule. 81 distribution: a `DistributionStrategy` object. 82 """ 83 84 def __init__(self, dataset_fn, loss_fn, optimizer, distribution, 85 iterations_per_step=1): 86 super(StandardSingleLossStep, self).__init__(dataset_fn, distribution) 87 self._loss_fn = loss_fn 88 self._optimizer = optimizer 89 self._iterations_per_step = iterations_per_step 90 91 def __call__(self): 92 with self._distribution.scope(): 93 def step_fn(ctx, inputs): 94 """Function to run one iteration with one input.""" 95 gradients_fn = backprop.implicit_grad(self._loss_fn) 96 gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn) 97 98 grads_and_vars = self.distribution.extended.call_for_each_replica( 99 gradients_fn, args=(ctx, inputs)) 100 # If threads use layers, then we need to run the first step 101 # sequentially, so that layers.build() is not executed in parallel. 102 # Otherwise, multiple sets of mirrored variables are going to be 103 # created. 104 return self._optimizer._distributed_apply( # pylint: disable=protected-access 105 self.distribution, grads_and_vars) 106 107 # TODO(priyag): Return the outputs, context, etc as well. 108 ctx = self.distribution.extended.experimental_run_steps_on_iterator( 109 step_fn, self._iterator, self._iterations_per_step) 110 return ctx.run_op 111