1# coding=utf-8 2# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Utilities for collectives.""" 17 18import copy 19import enum 20 21from tensorflow.python.util import deprecation 22from tensorflow.python.util.tf_export import tf_export 23 24 25# TODO(b/170340570): print deprecation warning for CollectiveCommunication. 26@tf_export("distribute.experimental.CommunicationImplementation", 27 "distribute.experimental.CollectiveCommunication") 28class CommunicationImplementation(enum.Enum): 29 """Cross device communication implementation. 30 31 Warning: The alias `tf.distribute.experimental.CollectiveCommunication` is 32 deprecated and will be removed in a future version. Use 33 `tf.distribute.experimental.CommunicationImplementation` instead. 34 35 * `AUTO`: Automatically chosen by Tensorflow. 36 * `RING`: TensorFlow's ring algorithms for all-reduce and 37 all-gather. 38 * `NCCL`: NVIDIA®'s NCCL library. This is now only used for all-reduce on 39 GPUs; all-reduce on CPU, all-gather and broadcast fallbacks to RING. 40 """ 41 AUTO = "AUTO" 42 RING = "RING" 43 NCCL = "NCCL" 44 # TODO(ayushd): add ncclAllGather implementation. 45 46 47CollectiveCommunication = CommunicationImplementation 48 49 50@tf_export("distribute.experimental.CommunicationOptions") 51class _OptionsExported(object): 52 """Options for cross device communications like All-reduce. 53 54 This can be passed to methods like 55 `tf.distribute.get_replica_context().all_reduce()` to optimize collective 56 operation performance. Note that these are only hints, which may or may not 57 change the actual behavior. Some options only apply to certain strategy and 58 are ignored by others. 59 60 One common optimization is to break gradients all-reduce into multiple packs 61 so that weight updates can overlap with gradient all-reduce. 62 63 Examples: 64 65 ```python 66 options = tf.distribute.experimental.CommunicationOptions( 67 bytes_per_pack=50 * 1024 * 1024, 68 timeout_seconds=120.0, 69 implementation=tf.distribute.experimental.CommunicationImplementation.NCCL 70 ) 71 grads = tf.distribute.get_replica_context().all_reduce( 72 'sum', grads, options=options) 73 optimizer.apply_gradients(zip(grads, vars), 74 experimental_aggregate_gradients=False) 75 ``` 76 77 """ 78 79 def __new__(cls, *args, **kwargs): 80 # We expose a dummy class so that we can separate internal and public APIs. 81 # Note that __init__ won't be called on the returned object if it's a 82 # different class [1]. 83 # [1] https://docs.python.org/3/reference/datamodel.html#object.__new__ 84 return Options(*args, **kwargs) 85 86 def __init__(self, 87 bytes_per_pack=0, 88 timeout_seconds=None, 89 implementation=CommunicationImplementation.AUTO): 90 """Creates a CollectiveHints. 91 92 Args: 93 bytes_per_pack: a non-negative integer. Breaks collective operations into 94 packs of certain size. If it's zero, the value is determined 95 automatically. This hint is respected by all multi-replica strategies 96 except `TPUStrategy`. 97 timeout_seconds: a float or None, timeout in seconds. If not None, the 98 collective raises `tf.errors.DeadlineExceededError` if it takes longer 99 than this timeout. Zero disables timeout. This can be useful when 100 debugging hanging issues. This should only be used for debugging since 101 it creates a new thread for each collective, i.e. an overhead of 102 `timeout_seconds * num_collectives_per_second` more threads. This only 103 works for `tf.distribute.experimental.MultiWorkerMirroredStrategy`. 104 implementation: a 105 `tf.distribute.experimental.CommunicationImplementation`. This is a hint 106 on the preferred communication implementation. Possible values include 107 `AUTO`, `RING`, and `NCCL`. NCCL is generally more performant for GPU, 108 but doesn't work for CPU. This only works for 109 `tf.distribute.experimental.MultiWorkerMirroredStrategy`. 110 111 Raises: 112 ValueError: When arguments have invalid value. 113 """ 114 pass 115 116 117class Options(object): 118 """Implementation of OptionsInterface.""" 119 120 def __init__(self, 121 bytes_per_pack=0, 122 timeout_seconds=None, 123 implementation=CommunicationImplementation.AUTO): 124 if bytes_per_pack < 0: 125 raise ValueError( 126 f"Argument `bytes_per_pack` must be >=0, Received {bytes_per_pack}.") 127 if isinstance(implementation, str): 128 implementation = CommunicationImplementation(implementation.upper()) 129 if not isinstance(implementation, CommunicationImplementation): 130 raise ValueError( 131 "Argument `implementation` must be instance of " 132 "`tf.distribute.experimental.CommunicationImplementation`.") 133 self.bytes_per_pack = bytes_per_pack 134 self.timeout_seconds = timeout_seconds 135 self.implementation = implementation 136 137 __init__.__doc__ = _OptionsExported.__init__.__doc__ 138 139 def merge(self, options): 140 """Merges with another options and returns a new one. 141 142 Values specified in the `options` takes precedence if they're not the 143 default. 144 145 Args: 146 options: a `tf.distribute.experimental.CollectiveCommunication`. 147 148 Returns: 149 A new `tf.distribute.experimental.CollectiveCommunication`. 150 """ 151 merged = copy.deepcopy(self) 152 if options is None: 153 return merged 154 if options.bytes_per_pack != 0: 155 merged.bytes_per_pack = options.bytes_per_pack 156 if options.timeout_seconds is not None: 157 merged.timeout_seconds = options.timeout_seconds 158 if options.implementation != CommunicationImplementation.AUTO: 159 merged.implementation = options.implementation 160 return merged 161 162 def __str__(self): 163 return (f"Options(bytes_per_pack={self.bytes_per_pack}," 164 f"timeout_seconds={self.timeout_seconds}, " 165 f"implementation={self.implementation})") 166 167 168@tf_export("distribute.experimental.CollectiveHints") 169class Hints(object): 170 """Hints for collective operations like AllReduce. 171 172 This can be passed to methods like 173 `tf.distribute.get_replica_context().all_reduce()` to optimize collective 174 operation performance. Note that these are only hints, which may or may not 175 change the actual behavior. Some options only apply to certain strategy and 176 are ignored by others. 177 178 One common optimization is to break gradients all-reduce into multiple packs 179 so that weight updates can overlap with gradient all-reduce. 180 181 Examples: 182 183 - bytes_per_pack 184 185 ```python 186 hints = tf.distribute.experimental.CollectiveHints( 187 bytes_per_pack=50 * 1024 * 1024) 188 grads = tf.distribute.get_replica_context().all_reduce( 189 'sum', grads, experimental_hints=hints) 190 optimizer.apply_gradients(zip(grads, vars), 191 experimental_aggregate_gradients=False) 192 ``` 193 194 - timeout_seconds 195 196 ```python 197 strategy = tf.distribute.MirroredStrategy() 198 hints = tf.distribute.experimental.CollectiveHints( 199 timeout_seconds=120.0) 200 try: 201 strategy.reduce("sum", v, axis=None, experimental_hints=hints) 202 except tf.errors.DeadlineExceededError: 203 do_something() 204 ``` 205 206 """ 207 208 @deprecation.deprecated( 209 None, "use distribute.experimental.CommunicationOptions instead") 210 def __new__(cls, bytes_per_pack=0, timeout_seconds=None): 211 return Options( 212 bytes_per_pack=bytes_per_pack, timeout_seconds=timeout_seconds) 213 214 def __init__(self, bytes_per_pack=0, timeout_seconds=None): 215 """Creates a CollectiveHints. 216 217 Args: 218 bytes_per_pack: a non-negative integer. Breaks collective operations into 219 packs of certain size. If it's zero, the value is determined 220 automatically. This only applies to all-reduce with 221 `MultiWorkerMirroredStrategy` currently. 222 timeout_seconds: a float or None, timeout in seconds. If not None, the 223 collective raises `tf.errors.DeadlineExceededError` if it takes longer 224 than this timeout. This can be useful when debugging hanging issues. 225 This should only be used for debugging since it creates a new thread for 226 each collective, i.e. an overhead of `timeout_seconds * 227 num_collectives_per_second` more threads. This only works for 228 `tf.distribute.experimental.MultiWorkerMirroredStrategy`. 229 230 Raises: 231 ValueError: When arguments have invalid value. 232 """ 233 pass 234