xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/step_fn.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""The step function abstraction represents a single training step."""
16
17from tensorflow.python.eager import backprop
18from tensorflow.python.training import optimizer as optimizer_lib
19
20
21class Step(object):
22  """Interface for performing each step of a training algorithm."""
23
24  def __init__(self, distribution):
25    self._distribution = distribution
26
27  @property
28  def distribution(self):
29    return self._distribution
30
31  def initialize(self):
32    return []
33
34  def __call__(self):
35    """Perform one step of this training algorithm."""
36    raise NotImplementedError("must be implemented in descendants")
37
38  # TODO(priyag): Add an method to access initialization and finalize ops.
39
40
41class StandardInputStep(Step):
42  """Step with a standard implementation of input handling.
43
44  Args:
45    dataset_fn: a function that returns a tf.data Dataset that produces the
46      input for the model.
47  """
48
49  def __init__(self, dataset_fn, distribution):
50    super(StandardInputStep, self).__init__(distribution)
51    self._iterator = distribution.make_input_fn_iterator(lambda _: dataset_fn())
52
53  def initialize(self):
54    return self._iterator.initializer
55
56
57class StandardSingleLossStep(StandardInputStep):
58  """A step function that implements a training step for a feed forward network.
59
60  An instance of this class is intended to be used as a callable:
61
62  ```python
63  ...
64  step = step_fn.StandardSingleLossStep(
65      dataset, loss_fn, optimizer, distribution)
66
67  # Run a single training step on a given DistributionStrategy:
68  step(distribution)
69  ...
70  ```
71
72  Args:
73    dataset_fn: a function that returns a tf.data Dataset that produces the
74      input for the model.
75    loss_fn: a function that takes a context and inputs as arguments. It returns
76      the loss for those inputs. `context` is an instance of
77      `values.MultiStepContext` that will be passed when `loss_fn` is run.
78      `context` can be used to specify the outputs to be returned from
79      `loss_fn`, among other things.
80    optimizer: an optimizer that implements an update rule.
81    distribution: a `DistributionStrategy` object.
82  """
83
84  def __init__(self, dataset_fn, loss_fn, optimizer, distribution,
85               iterations_per_step=1):
86    super(StandardSingleLossStep, self).__init__(dataset_fn, distribution)
87    self._loss_fn = loss_fn
88    self._optimizer = optimizer
89    self._iterations_per_step = iterations_per_step
90
91  def __call__(self):
92    with self._distribution.scope():
93      def step_fn(ctx, inputs):
94        """Function to run one iteration with one input."""
95        gradients_fn = backprop.implicit_grad(self._loss_fn)
96        gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn)
97
98        grads_and_vars = self.distribution.extended.call_for_each_replica(
99            gradients_fn, args=(ctx, inputs))
100        # If threads use layers, then we need to run the first step
101        # sequentially, so that layers.build() is not executed in parallel.
102        # Otherwise, multiple sets of mirrored variables are going to be
103        # created.
104        return self._optimizer._distributed_apply(  # pylint: disable=protected-access
105            self.distribution, grads_and_vars)
106
107      # TODO(priyag): Return the outputs, context, etc as well.
108      ctx = self.distribution.extended.experimental_run_steps_on_iterator(
109          step_fn, self._iterator, self._iterations_per_step)
110      return ctx.run_op
111