xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/collective_util.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# coding=utf-8
2# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Utilities for collectives."""
17
18import copy
19import enum
20
21from tensorflow.python.util import deprecation
22from tensorflow.python.util.tf_export import tf_export
23
24
25# TODO(b/170340570): print deprecation warning for CollectiveCommunication.
26@tf_export("distribute.experimental.CommunicationImplementation",
27           "distribute.experimental.CollectiveCommunication")
28class CommunicationImplementation(enum.Enum):
29  """Cross device communication implementation.
30
31  Warning: The alias `tf.distribute.experimental.CollectiveCommunication` is
32  deprecated and will be removed in a future version. Use
33  `tf.distribute.experimental.CommunicationImplementation` instead.
34
35  * `AUTO`: Automatically chosen by Tensorflow.
36  * `RING`: TensorFlow's ring algorithms for all-reduce and
37    all-gather.
38  * `NCCL`: NVIDIA®'s NCCL library. This is now only used for all-reduce on
39    GPUs; all-reduce on CPU, all-gather and broadcast fallbacks to RING.
40  """
41  AUTO = "AUTO"
42  RING = "RING"
43  NCCL = "NCCL"
44  # TODO(ayushd): add ncclAllGather implementation.
45
46
47CollectiveCommunication = CommunicationImplementation
48
49
50@tf_export("distribute.experimental.CommunicationOptions")
51class _OptionsExported(object):
52  """Options for cross device communications like All-reduce.
53
54  This can be passed to methods like
55  `tf.distribute.get_replica_context().all_reduce()` to optimize collective
56  operation performance. Note that these are only hints, which may or may not
57  change the actual behavior. Some options only apply to certain strategy and
58  are ignored by others.
59
60  One common optimization is to break gradients all-reduce into multiple packs
61  so that weight updates can overlap with gradient all-reduce.
62
63  Examples:
64
65  ```python
66  options = tf.distribute.experimental.CommunicationOptions(
67      bytes_per_pack=50 * 1024 * 1024,
68      timeout_seconds=120.0,
69      implementation=tf.distribute.experimental.CommunicationImplementation.NCCL
70  )
71  grads = tf.distribute.get_replica_context().all_reduce(
72      'sum', grads, options=options)
73  optimizer.apply_gradients(zip(grads, vars),
74      experimental_aggregate_gradients=False)
75  ```
76
77  """
78
79  def __new__(cls, *args, **kwargs):
80    # We expose a dummy class so that we can separate internal and public APIs.
81    # Note that __init__ won't be called on the returned object if it's a
82    # different class [1].
83    # [1] https://docs.python.org/3/reference/datamodel.html#object.__new__
84    return Options(*args, **kwargs)
85
86  def __init__(self,
87               bytes_per_pack=0,
88               timeout_seconds=None,
89               implementation=CommunicationImplementation.AUTO):
90    """Creates a CollectiveHints.
91
92    Args:
93      bytes_per_pack: a non-negative integer. Breaks collective operations into
94        packs of certain size. If it's zero, the value is determined
95        automatically. This hint is respected by all multi-replica strategies
96        except `TPUStrategy`.
97      timeout_seconds: a float or None, timeout in seconds. If not None, the
98        collective raises `tf.errors.DeadlineExceededError` if it takes longer
99        than this timeout. Zero disables timeout. This can be useful when
100        debugging hanging issues.  This should only be used for debugging since
101        it creates a new thread for each collective, i.e. an overhead of
102        `timeout_seconds * num_collectives_per_second` more threads. This only
103        works for `tf.distribute.experimental.MultiWorkerMirroredStrategy`.
104      implementation: a
105        `tf.distribute.experimental.CommunicationImplementation`. This is a hint
106        on the preferred communication implementation. Possible values include
107        `AUTO`, `RING`, and `NCCL`. NCCL is generally more performant for GPU,
108        but doesn't work for CPU. This only works for
109        `tf.distribute.experimental.MultiWorkerMirroredStrategy`.
110
111    Raises:
112      ValueError: When arguments have invalid value.
113    """
114    pass
115
116
117class Options(object):
118  """Implementation of OptionsInterface."""
119
120  def __init__(self,
121               bytes_per_pack=0,
122               timeout_seconds=None,
123               implementation=CommunicationImplementation.AUTO):
124    if bytes_per_pack < 0:
125      raise ValueError(
126          f"Argument `bytes_per_pack` must be >=0, Received {bytes_per_pack}.")
127    if isinstance(implementation, str):
128      implementation = CommunicationImplementation(implementation.upper())
129    if not isinstance(implementation, CommunicationImplementation):
130      raise ValueError(
131          "Argument `implementation` must be instance of "
132          "`tf.distribute.experimental.CommunicationImplementation`.")
133    self.bytes_per_pack = bytes_per_pack
134    self.timeout_seconds = timeout_seconds
135    self.implementation = implementation
136
137  __init__.__doc__ = _OptionsExported.__init__.__doc__
138
139  def merge(self, options):
140    """Merges with another options and returns a new one.
141
142    Values specified in the `options` takes precedence if they're not the
143    default.
144
145    Args:
146      options: a `tf.distribute.experimental.CollectiveCommunication`.
147
148    Returns:
149      A new `tf.distribute.experimental.CollectiveCommunication`.
150    """
151    merged = copy.deepcopy(self)
152    if options is None:
153      return merged
154    if options.bytes_per_pack != 0:
155      merged.bytes_per_pack = options.bytes_per_pack
156    if options.timeout_seconds is not None:
157      merged.timeout_seconds = options.timeout_seconds
158    if options.implementation != CommunicationImplementation.AUTO:
159      merged.implementation = options.implementation
160    return merged
161
162  def __str__(self):
163    return (f"Options(bytes_per_pack={self.bytes_per_pack},"
164            f"timeout_seconds={self.timeout_seconds}, "
165            f"implementation={self.implementation})")
166
167
168@tf_export("distribute.experimental.CollectiveHints")
169class Hints(object):
170  """Hints for collective operations like AllReduce.
171
172  This can be passed to methods like
173  `tf.distribute.get_replica_context().all_reduce()` to optimize collective
174  operation performance. Note that these are only hints, which may or may not
175  change the actual behavior. Some options only apply to certain strategy and
176  are ignored by others.
177
178  One common optimization is to break gradients all-reduce into multiple packs
179  so that weight updates can overlap with gradient all-reduce.
180
181  Examples:
182
183  - bytes_per_pack
184
185  ```python
186  hints = tf.distribute.experimental.CollectiveHints(
187      bytes_per_pack=50 * 1024 * 1024)
188  grads = tf.distribute.get_replica_context().all_reduce(
189      'sum', grads, experimental_hints=hints)
190  optimizer.apply_gradients(zip(grads, vars),
191      experimental_aggregate_gradients=False)
192  ```
193
194  - timeout_seconds
195
196  ```python
197  strategy = tf.distribute.MirroredStrategy()
198  hints = tf.distribute.experimental.CollectiveHints(
199      timeout_seconds=120.0)
200  try:
201    strategy.reduce("sum", v, axis=None, experimental_hints=hints)
202  except tf.errors.DeadlineExceededError:
203    do_something()
204  ```
205
206  """
207
208  @deprecation.deprecated(
209      None, "use distribute.experimental.CommunicationOptions instead")
210  def __new__(cls, bytes_per_pack=0, timeout_seconds=None):
211    return Options(
212        bytes_per_pack=bytes_per_pack, timeout_seconds=timeout_seconds)
213
214  def __init__(self, bytes_per_pack=0, timeout_seconds=None):
215    """Creates a CollectiveHints.
216
217    Args:
218      bytes_per_pack: a non-negative integer. Breaks collective operations into
219        packs of certain size. If it's zero, the value is determined
220        automatically. This only applies to all-reduce with
221        `MultiWorkerMirroredStrategy` currently.
222      timeout_seconds: a float or None, timeout in seconds. If not None, the
223        collective raises `tf.errors.DeadlineExceededError` if it takes longer
224        than this timeout. This can be useful when debugging hanging issues.
225        This should only be used for debugging since it creates a new thread for
226        each collective, i.e. an overhead of `timeout_seconds *
227        num_collectives_per_second` more threads.  This only works for
228        `tf.distribute.experimental.MultiWorkerMirroredStrategy`.
229
230    Raises:
231      ValueError: When arguments have invalid value.
232    """
233    pass
234