1# mypy: allow-untyped-defs 2import torch 3import torch.distributed as dist 4from torch.autograd import Function 5 6# The two imports below are not always available depending on the 7# USE_DISTRIBUTED compile flag. Make sure they raise import error 8# if we're trying to use them. 9from torch.distributed import group, ReduceOp 10 11 12def broadcast(tensor, src, group=group.WORLD): 13 """ 14 Broadcasts the tensor to the whole group. 15 16 ``tensor`` must have the same number of elements in all processes 17 participating in the collective. 18 19 Arguments: 20 tensor (Tensor): Data to be sent if ``src`` is the rank of current 21 process. 22 src (int): Source rank. 23 group (ProcessGroup, optional): The process group to work on. 24 25 Returns: 26 Tensor: Received tensor from the broadcast op. 27 28 """ 29 return _Broadcast.apply(src, group, tensor) 30 31 32def gather(tensor, dst=0, group=group.WORLD): 33 """ 34 Gathers a list of tensors in a single process. 35 36 Arguments: 37 tensor (Tensor): Input tensor. 38 dst (int, optional): Destination rank (default is 0). 39 group (ProcessGroup, optional): The process group to work on. 40 41 Returns: 42 tuple[Tensor]: List of appropriately-sized tensors with the gathered data. 43 """ 44 return _Gather.apply(dst, group, tensor) 45 46 47def scatter(tensors, src=0, group=group.WORLD): 48 """ 49 Scatters a list of tensors to all processes in a group. 50 51 Each process will receive exactly one tensor and store its data in the 52 ``tensor`` argument. 53 54 Arguments: 55 tensors (list[Tensor]): List of tensors to scatter on the source rank. 56 Receivers must pass ``None`. 57 src (int, optional): Source rank (default is 0). 58 group (ProcessGroup, optional): The process group to work on. 59 60 Returns: 61 Tensor: Output tensor from the scatter operation. 62 63 """ 64 return _Scatter.apply(src, group, *tensors) 65 66 67def reduce(tensor, dst, op=ReduceOp.SUM, group=group.WORLD): 68 """ 69 Reduces the tensor data across all machines. 70 71 Only the process with rank ``dst`` is going to receive the final result. 72 73 Arguments: 74 tensor (Tensor): Input of the collective. 75 dst (int): Destination rank. 76 op (optional): One of the values from 77 ``torch.distributed.ReduceOp`` 78 enum. Specifies an operation used for element-wise reductions. 79 group (ProcessGroup, optional): The process group to work on. 80 81 Returns: 82 Tensor: Output of the collective. 83 84 """ 85 return _Reduce.apply(dst, op, group, tensor) 86 87 88def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=group.WORLD): 89 """ 90 Reduces, then scatters a list of tensors to all processes in a group. 91 92 Arguments: 93 output (Tensor): Output tensor. 94 input_list (list[Tensor]): List of tensors to reduce and scatter. 95 op (optional): One of the values from 96 ``torch.distributed.ReduceOp`` 97 enum. Specifies an operation used for element-wise reductions. 98 group (ProcessGroup, optional): The process group to work on. 99 100 Returns: 101 Tensor: Output of the collective. 102 103 """ 104 return _Reduce_Scatter.apply(op, group, output, *input_list) 105 106 107def all_gather(tensor, group=group.WORLD): 108 """ 109 Gathers tensors from the whole group in a list. 110 111 Arguments: 112 tensor (Tensor): Tensor to be broadcast from current process. 113 group (ProcessGroup, optional): The process group to work on. 114 115 Returns: 116 tuple([Tensor]): Output of the collective. 117 118 """ 119 return _AllGather.apply(group, tensor) 120 121 122def _all_gather_base(output_tensor, input_tensor, group=group.WORLD): 123 """ 124 Single tensor all gather. Gathers a single tensor from all ranks, and puts them in a single output tensor. 125 126 Args: 127 output_tensor (Tensor): Output tensor. It should contain 128 correctly-sized tensors to be used for output of the collective. 129 input_tensor (Tensor): Tensor to be broadcast from current process. 130 group (ProcessGroup, optional): The process group to work on. If None, 131 the default process group will be used. 132 133 Examples: 134 >>> # All tensors below are of torch.int64 dtype. 135 >>> # We have 2 process groups, 2 ranks. 136 >>> # xdoctest: +SKIP("incorrect want text") 137 >>> output_tensor = torch.zeros(2, dtype=torch.int64) 138 >>> output_tensor 139 [tensor([0, 0])] # Rank 0 and 1 140 >>> tensor = torch.arange(1, dtype=torch.int64) + 1 + rank 141 >>> tensor 142 tensor([1]) # Rank 0 143 tensor([2]) # Rank 1 144 >>> dist.all_gather_base(output_tensor, tensor) 145 >>> output_tensor 146 tensor([1,2]) # Rank 0 147 tensor([1,2]) # Rank 1 148 149 .. warning:: 150 `_all_gather_base` is experimental and subject to change. 151 It is the caller's responsibility to ensure the output_tensor 152 is correctly sized. 153 154 """ 155 return _AllGatherBase.apply(output_tensor, input_tensor, group) 156 157 158def all_to_all(output_tensor_list, input_tensor_list, group=group.WORLD): 159 """ 160 Each process scatters list of input tensors to all processes in a group and return gathered list of tensors in output list. 161 162 Arguments: 163 output_tensor_list (list[Tensor]): list of tensors to gather one per rank. 164 input_tensor_list (list[Tensor]): List of tensors to scatter one per rank. 165 group (ProcessGroup, optional): The process group to work on. 166 167 Returns: 168 tuple([Tensor]): Output of the collective. 169 170 """ 171 return _AlltoAll.apply(group, output_tensor_list, *input_tensor_list) 172 173 174def all_to_all_single( 175 output, 176 input, 177 output_split_sizes=None, 178 input_split_sizes=None, 179 group=group.WORLD, 180): 181 """ 182 Each process splits input tensor and then scatters the split list to all processes in a group. 183 184 Then concatenate the received tensors from all the processes in the group and return single output tensor. 185 186 Arguments: 187 output (Tensor): Gathered concatenated output tensor. 188 input (Tensor): Input tensor to scatter. 189 output_split_sizes: (list[Int], optional): Output split sizes for dim 0 190 if specified None or empty, dim 0 of ``output`` tensor must divide 191 equally by ``world_size``. 192 input_split_sizes: (list[Int], optional): Input split sizes for dim 0 193 if specified None or empty, dim 0 of ``input`` tensor must divide 194 equally by ``world_size``. 195 196 Returns: 197 Tensor: Output of the collective. 198 199 """ 200 return _AlltoAllSingle.apply( 201 group, output, output_split_sizes, input_split_sizes, input 202 ) 203 204 205def all_reduce(tensor, op=ReduceOp.SUM, group=group.WORLD): 206 """ 207 Reduces the tensor data across all machines in such a way that all get the final result. 208 209 After the call the returned tensor is going to be bitwise 210 identical in all processes. 211 212 Arguments: 213 tensor (Tensor): Input of the collective. 214 op (optional): One of the values from 215 ``torch.distributed.ReduceOp`` 216 enum. Specifies an operation used for element-wise reductions. 217 group (ProcessGroup, optional): The process group to work on. 218 219 Returns: 220 Tensor: Output of the collective 221 222 """ 223 return _AllReduce.apply(op, group, tensor) 224 225 226class _Broadcast(Function): 227 @staticmethod 228 def forward(ctx, src, group, tensor): 229 ctx.src = src 230 ctx.group = group 231 ctx.rank = dist.get_rank(group=group) 232 # torch.distributed makes all the calls in place 233 # we allocate new tensors to avoid this 234 tensor = tensor.clone() 235 dist.broadcast(tensor, src, group=group) 236 return tensor 237 238 @staticmethod 239 def backward(ctx, grad_output): 240 gx = _Reduce.apply(ctx.src, ReduceOp.SUM, ctx.group, grad_output) 241 if ctx.src != ctx.rank: 242 gx.zero_() 243 return (None, None, gx) 244 245 246class _Gather(Function): 247 @staticmethod 248 def forward(ctx, dst, group, tensor): 249 ctx.dst = dst 250 ctx.group = group 251 # Need to create a list of tensors here to do the 252 # aggregation, get it from the group size 253 # tensor should be correctly sized for the method 254 # gathering 255 tensor_list = [ 256 torch.zeros_like(tensor) for i in range(dist.get_world_size(group=group)) 257 ] 258 259 tensor = tensor.contiguous() 260 if dist.get_rank(group=group) == dst: 261 dist.gather(tensor, tensor_list, dst, group=group) 262 else: 263 dist.gather(tensor, None, dst, group=group) 264 return tuple(tensor_list) 265 266 @staticmethod 267 def backward(ctx, *grad_outputs): 268 return (None, None) + (_Scatter.apply(ctx.dst, ctx.group, *grad_outputs),) 269 270 271class _Scatter(Function): 272 @staticmethod 273 def forward(ctx, src, group, *tensors): 274 ctx.src = src 275 ctx.group = group 276 assert all(t.size() == tensors[0].size() for t in tensors) 277 output = torch.zeros_like(tensors[0]) 278 if dist.get_rank(group=group) == src: 279 dist.scatter(output, list(tensors), src, group=group) 280 else: 281 dist.scatter(output, None, src, group=group) 282 return output 283 284 @staticmethod 285 def backward(ctx, grad_output): 286 return (None, None) + _Gather.apply(ctx.src, ctx.group, grad_output) 287 288 289class _Reduce(Function): 290 @staticmethod 291 def forward(ctx, src, op, group, tensor): 292 ctx.src = src 293 ctx.group = group 294 tensor = tensor.clone() 295 dist.reduce(tensor, src, op=op, group=group) 296 return tensor 297 298 @staticmethod 299 def backward(ctx, grad_output): 300 return (None, None, None) + (_Broadcast.apply(ctx.src, ctx.group, grad_output),) 301 302 303class _Reduce_Scatter(Function): 304 @staticmethod 305 def forward(ctx, op, group, tensor, *input_tensor_list): 306 ctx.group = group 307 # Need contiguous tensors for collectives. 308 tensor = tensor.contiguous() 309 input_tensor_list = tuple(t.contiguous() for t in input_tensor_list) 310 dist.reduce_scatter(tensor, list(input_tensor_list), op=op, group=group) 311 return tensor 312 313 @staticmethod 314 def backward(ctx, grad_output): 315 return (None, None, None) + _AllGather.apply(ctx.group, grad_output) 316 317 318class _AllGather(Function): 319 @staticmethod 320 def forward(ctx, group, tensor): 321 # Need contiguous tensors for collectives. 322 tensor = tensor.contiguous() 323 324 ctx.group = group 325 out_tensor_list = [ 326 torch.empty_like(tensor) for _ in range(dist.get_world_size(group=group)) 327 ] 328 329 dist.all_gather(out_tensor_list, tensor, group=group) 330 return tuple(out_tensor_list) 331 332 @staticmethod 333 def backward(ctx, *grad_outputs): 334 if dist.get_backend(group=ctx.group) is dist.Backend.NCCL: 335 rank = dist.get_rank(group=ctx.group) 336 gx = torch.empty_like(grad_outputs[rank]) 337 gx = _Reduce_Scatter.apply(ReduceOp.SUM, ctx.group, gx, *grad_outputs) 338 else: 339 # As many backends doesn't support ReduceScatter, we use AlltoAll with .sum() 340 # to emulate the ReduceScatter behavior 341 tensor_list = [torch.empty_like(tensor) for tensor in grad_outputs] 342 gxs = _AlltoAll.apply(ctx.group, tensor_list, *grad_outputs) 343 gx = torch.sum(torch.stack(gxs), dim=0) 344 return (None, gx) 345 346 347class _AllGatherBase(Function): 348 @staticmethod 349 def forward(ctx, output_tensor, input_tensor, group): 350 ctx.group = group 351 dist._all_gather_base(output_tensor, input_tensor.contiguous(), group=group) 352 return output_tensor 353 354 @staticmethod 355 def backward(ctx, grad_output): 356 if dist.get_backend(group=ctx.group) is dist.Backend.NCCL: 357 world_size = dist.get_world_size(group=ctx.group) 358 out_size = list(grad_output.size()) 359 if out_size[0] % world_size != 0: 360 raise RuntimeError( 361 f"Tensor with dimensions: {out_size} does " 362 f"not have first dimension divisible by world_size: {world_size}" 363 ) 364 out_size[0] = out_size[0] // dist.get_world_size(group=ctx.group) 365 gx = torch.empty( 366 out_size, device=grad_output.device, dtype=grad_output.dtype 367 ) 368 dist._reduce_scatter_base(gx, grad_output, ReduceOp.SUM, ctx.group) 369 else: 370 raise RuntimeError("Backend not supported!") 371 return (None, gx, None) 372 373 374class _AlltoAll(Function): 375 @staticmethod 376 def forward(ctx, group, out_tensor_list, *tensors): 377 ctx.group = group 378 ctx.input_tensor_size_list = [ 379 tensors[i].size() for i in range(dist.get_world_size(group=group)) 380 ] 381 my_rank = dist.get_rank(group=group) 382 tensors = tuple(t.contiguous() for t in tensors) 383 # Implement it on means of scatter/gather, send/recv async operations have issues 384 if dist.get_backend(group=group) is dist.Backend.GLOO: 385 for i in range(dist.get_world_size(group=group)): 386 to_send = None 387 if i == my_rank: 388 to_send = list(tensors) 389 dist.scatter(out_tensor_list[i], to_send, i, group=group) 390 else: 391 dist.all_to_all( 392 out_tensor_list, 393 list(tensors), 394 group=group, 395 ) 396 return tuple(out_tensor_list) 397 398 @staticmethod 399 def backward(ctx, *grad_outputs): 400 tensor_list = [ 401 torch.empty( 402 size, device=grad_outputs[0].device, dtype=grad_outputs[0].dtype 403 ) 404 for size in ctx.input_tensor_size_list 405 ] 406 return (None, None) + _AlltoAll.apply(ctx.group, tensor_list, *grad_outputs) 407 408 409class _AlltoAllSingle(Function): 410 @staticmethod 411 def forward(ctx, group, output, output_split_sizes, input_split_sizes, input): 412 ctx.group = group 413 ctx.input_size = input.size() 414 ctx.output_split_sizes = input_split_sizes 415 ctx.input_split_sizes = output_split_sizes 416 dist.all_to_all_single( 417 output, 418 input, 419 output_split_sizes=output_split_sizes, 420 input_split_sizes=input_split_sizes, 421 group=group, 422 ) 423 return output 424 425 @staticmethod 426 def backward(ctx, grad_output): 427 tensor = torch.empty( 428 ctx.input_size, device=grad_output.device, dtype=grad_output.dtype 429 ) 430 return (None, None, None, None) + ( 431 _AlltoAllSingle.apply( 432 ctx.group, 433 tensor, 434 ctx.output_split_sizes, 435 ctx.input_split_sizes, 436 grad_output.contiguous(), 437 ), 438 ) 439 440 441class _AllReduce(Function): 442 @staticmethod 443 def forward(ctx, op, group, tensor): 444 ctx.group = group 445 ctx.op = op 446 tensor = tensor.clone() 447 dist.all_reduce(tensor, op=op, group=group) 448 return tensor 449 450 @staticmethod 451 def backward(ctx, grad_output): 452 return (None, None) + (_AllReduce.apply(ctx.op, ctx.group, grad_output),) 453