xref: /aosp_15_r20/external/pytorch/torch/distributed/nn/functional.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3import torch.distributed as dist
4from torch.autograd import Function
5
6# The two imports below are not always available depending on the
7# USE_DISTRIBUTED compile flag. Make sure they raise import error
8# if we're trying to use them.
9from torch.distributed import group, ReduceOp
10
11
12def broadcast(tensor, src, group=group.WORLD):
13    """
14    Broadcasts the tensor to the whole group.
15
16    ``tensor`` must have the same number of elements in all processes
17    participating in the collective.
18
19    Arguments:
20        tensor (Tensor): Data to be sent if ``src`` is the rank of current
21            process.
22        src (int): Source rank.
23        group (ProcessGroup, optional): The process group to work on.
24
25    Returns:
26        Tensor: Received tensor from the broadcast op.
27
28    """
29    return _Broadcast.apply(src, group, tensor)
30
31
32def gather(tensor, dst=0, group=group.WORLD):
33    """
34    Gathers a list of tensors in a single process.
35
36    Arguments:
37        tensor (Tensor): Input tensor.
38        dst (int, optional): Destination rank (default is 0).
39        group (ProcessGroup, optional): The process group to work on.
40
41    Returns:
42        tuple[Tensor]: List of appropriately-sized tensors with the gathered data.
43    """
44    return _Gather.apply(dst, group, tensor)
45
46
47def scatter(tensors, src=0, group=group.WORLD):
48    """
49    Scatters a list of tensors to all processes in a group.
50
51    Each process will receive exactly one tensor and store its data in the
52    ``tensor`` argument.
53
54    Arguments:
55        tensors (list[Tensor]): List of tensors to scatter on the source rank.
56            Receivers must pass ``None`.
57        src (int, optional): Source rank (default is 0).
58        group (ProcessGroup, optional): The process group to work on.
59
60    Returns:
61        Tensor: Output tensor from the scatter operation.
62
63    """
64    return _Scatter.apply(src, group, *tensors)
65
66
67def reduce(tensor, dst, op=ReduceOp.SUM, group=group.WORLD):
68    """
69    Reduces the tensor data across all machines.
70
71    Only the process with rank ``dst`` is going to receive the final result.
72
73    Arguments:
74        tensor (Tensor): Input of the collective.
75        dst (int): Destination rank.
76        op (optional): One of the values from
77            ``torch.distributed.ReduceOp``
78            enum.  Specifies an operation used for element-wise reductions.
79        group (ProcessGroup, optional): The process group to work on.
80
81    Returns:
82        Tensor: Output of the collective.
83
84    """
85    return _Reduce.apply(dst, op, group, tensor)
86
87
88def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=group.WORLD):
89    """
90    Reduces, then scatters a list of tensors to all processes in a group.
91
92    Arguments:
93        output (Tensor): Output tensor.
94        input_list (list[Tensor]): List of tensors to reduce and scatter.
95        op (optional): One of the values from
96            ``torch.distributed.ReduceOp``
97            enum.  Specifies an operation used for element-wise reductions.
98        group (ProcessGroup, optional): The process group to work on.
99
100    Returns:
101        Tensor: Output of the collective.
102
103    """
104    return _Reduce_Scatter.apply(op, group, output, *input_list)
105
106
107def all_gather(tensor, group=group.WORLD):
108    """
109    Gathers tensors from the whole group in a list.
110
111    Arguments:
112        tensor (Tensor): Tensor to be broadcast from current process.
113        group (ProcessGroup, optional): The process group to work on.
114
115    Returns:
116        tuple([Tensor]): Output of the collective.
117
118    """
119    return _AllGather.apply(group, tensor)
120
121
122def _all_gather_base(output_tensor, input_tensor, group=group.WORLD):
123    """
124    Single tensor all gather. Gathers a single tensor from all ranks, and puts them in a single output tensor.
125
126    Args:
127        output_tensor (Tensor): Output tensor. It should contain
128            correctly-sized tensors to be used for output of the collective.
129        input_tensor (Tensor): Tensor to be broadcast from current process.
130        group (ProcessGroup, optional): The process group to work on. If None,
131            the default process group will be used.
132
133    Examples:
134        >>> # All tensors below are of torch.int64 dtype.
135        >>> # We have 2 process groups, 2 ranks.
136        >>> # xdoctest: +SKIP("incorrect want text")
137        >>> output_tensor = torch.zeros(2, dtype=torch.int64)
138        >>> output_tensor
139        [tensor([0, 0])] # Rank 0 and 1
140        >>> tensor = torch.arange(1, dtype=torch.int64) + 1 + rank
141        >>> tensor
142        tensor([1]) # Rank 0
143        tensor([2]) # Rank 1
144        >>> dist.all_gather_base(output_tensor, tensor)
145        >>> output_tensor
146        tensor([1,2]) # Rank 0
147        tensor([1,2]) # Rank 1
148
149    .. warning::
150        `_all_gather_base` is experimental and subject to change.
151        It is the caller's responsibility to ensure the output_tensor
152        is correctly sized.
153
154    """
155    return _AllGatherBase.apply(output_tensor, input_tensor, group)
156
157
158def all_to_all(output_tensor_list, input_tensor_list, group=group.WORLD):
159    """
160    Each process scatters list of input tensors to all processes in a group and return gathered list of tensors in output list.
161
162    Arguments:
163        output_tensor_list (list[Tensor]): list of tensors to gather one per rank.
164        input_tensor_list (list[Tensor]): List of tensors to scatter one per rank.
165        group (ProcessGroup, optional): The process group to work on.
166
167    Returns:
168        tuple([Tensor]): Output of the collective.
169
170    """
171    return _AlltoAll.apply(group, output_tensor_list, *input_tensor_list)
172
173
174def all_to_all_single(
175    output,
176    input,
177    output_split_sizes=None,
178    input_split_sizes=None,
179    group=group.WORLD,
180):
181    """
182    Each process splits input tensor and then scatters the split list to all processes in a group.
183
184    Then concatenate the received tensors from all the processes in the group and return single output tensor.
185
186    Arguments:
187        output (Tensor): Gathered concatenated output tensor.
188        input (Tensor): Input tensor to scatter.
189        output_split_sizes: (list[Int], optional): Output split sizes for dim 0
190            if specified None or empty, dim 0 of ``output`` tensor must divide
191            equally by ``world_size``.
192        input_split_sizes: (list[Int], optional): Input split sizes for dim 0
193            if specified None or empty, dim 0 of ``input`` tensor must divide
194            equally by ``world_size``.
195
196    Returns:
197        Tensor: Output of the collective.
198
199    """
200    return _AlltoAllSingle.apply(
201        group, output, output_split_sizes, input_split_sizes, input
202    )
203
204
205def all_reduce(tensor, op=ReduceOp.SUM, group=group.WORLD):
206    """
207    Reduces the tensor data across all machines in such a way that all get the final result.
208
209    After the call the returned tensor is going to be bitwise
210    identical in all processes.
211
212    Arguments:
213        tensor (Tensor): Input of the collective.
214        op (optional): One of the values from
215            ``torch.distributed.ReduceOp``
216            enum.  Specifies an operation used for element-wise reductions.
217        group (ProcessGroup, optional): The process group to work on.
218
219    Returns:
220        Tensor: Output of the collective
221
222    """
223    return _AllReduce.apply(op, group, tensor)
224
225
226class _Broadcast(Function):
227    @staticmethod
228    def forward(ctx, src, group, tensor):
229        ctx.src = src
230        ctx.group = group
231        ctx.rank = dist.get_rank(group=group)
232        # torch.distributed makes all the calls in place
233        # we allocate new tensors to avoid this
234        tensor = tensor.clone()
235        dist.broadcast(tensor, src, group=group)
236        return tensor
237
238    @staticmethod
239    def backward(ctx, grad_output):
240        gx = _Reduce.apply(ctx.src, ReduceOp.SUM, ctx.group, grad_output)
241        if ctx.src != ctx.rank:
242            gx.zero_()
243        return (None, None, gx)
244
245
246class _Gather(Function):
247    @staticmethod
248    def forward(ctx, dst, group, tensor):
249        ctx.dst = dst
250        ctx.group = group
251        # Need to create a list of tensors here to do the
252        # aggregation, get it from the group size
253        # tensor should be correctly sized for the method
254        # gathering
255        tensor_list = [
256            torch.zeros_like(tensor) for i in range(dist.get_world_size(group=group))
257        ]
258
259        tensor = tensor.contiguous()
260        if dist.get_rank(group=group) == dst:
261            dist.gather(tensor, tensor_list, dst, group=group)
262        else:
263            dist.gather(tensor, None, dst, group=group)
264        return tuple(tensor_list)
265
266    @staticmethod
267    def backward(ctx, *grad_outputs):
268        return (None, None) + (_Scatter.apply(ctx.dst, ctx.group, *grad_outputs),)
269
270
271class _Scatter(Function):
272    @staticmethod
273    def forward(ctx, src, group, *tensors):
274        ctx.src = src
275        ctx.group = group
276        assert all(t.size() == tensors[0].size() for t in tensors)
277        output = torch.zeros_like(tensors[0])
278        if dist.get_rank(group=group) == src:
279            dist.scatter(output, list(tensors), src, group=group)
280        else:
281            dist.scatter(output, None, src, group=group)
282        return output
283
284    @staticmethod
285    def backward(ctx, grad_output):
286        return (None, None) + _Gather.apply(ctx.src, ctx.group, grad_output)
287
288
289class _Reduce(Function):
290    @staticmethod
291    def forward(ctx, src, op, group, tensor):
292        ctx.src = src
293        ctx.group = group
294        tensor = tensor.clone()
295        dist.reduce(tensor, src, op=op, group=group)
296        return tensor
297
298    @staticmethod
299    def backward(ctx, grad_output):
300        return (None, None, None) + (_Broadcast.apply(ctx.src, ctx.group, grad_output),)
301
302
303class _Reduce_Scatter(Function):
304    @staticmethod
305    def forward(ctx, op, group, tensor, *input_tensor_list):
306        ctx.group = group
307        # Need contiguous tensors for collectives.
308        tensor = tensor.contiguous()
309        input_tensor_list = tuple(t.contiguous() for t in input_tensor_list)
310        dist.reduce_scatter(tensor, list(input_tensor_list), op=op, group=group)
311        return tensor
312
313    @staticmethod
314    def backward(ctx, grad_output):
315        return (None, None, None) + _AllGather.apply(ctx.group, grad_output)
316
317
318class _AllGather(Function):
319    @staticmethod
320    def forward(ctx, group, tensor):
321        # Need contiguous tensors for collectives.
322        tensor = tensor.contiguous()
323
324        ctx.group = group
325        out_tensor_list = [
326            torch.empty_like(tensor) for _ in range(dist.get_world_size(group=group))
327        ]
328
329        dist.all_gather(out_tensor_list, tensor, group=group)
330        return tuple(out_tensor_list)
331
332    @staticmethod
333    def backward(ctx, *grad_outputs):
334        if dist.get_backend(group=ctx.group) is dist.Backend.NCCL:
335            rank = dist.get_rank(group=ctx.group)
336            gx = torch.empty_like(grad_outputs[rank])
337            gx = _Reduce_Scatter.apply(ReduceOp.SUM, ctx.group, gx, *grad_outputs)
338        else:
339            # As many backends doesn't support ReduceScatter, we use AlltoAll with .sum()
340            # to emulate the ReduceScatter behavior
341            tensor_list = [torch.empty_like(tensor) for tensor in grad_outputs]
342            gxs = _AlltoAll.apply(ctx.group, tensor_list, *grad_outputs)
343            gx = torch.sum(torch.stack(gxs), dim=0)
344        return (None, gx)
345
346
347class _AllGatherBase(Function):
348    @staticmethod
349    def forward(ctx, output_tensor, input_tensor, group):
350        ctx.group = group
351        dist._all_gather_base(output_tensor, input_tensor.contiguous(), group=group)
352        return output_tensor
353
354    @staticmethod
355    def backward(ctx, grad_output):
356        if dist.get_backend(group=ctx.group) is dist.Backend.NCCL:
357            world_size = dist.get_world_size(group=ctx.group)
358            out_size = list(grad_output.size())
359            if out_size[0] % world_size != 0:
360                raise RuntimeError(
361                    f"Tensor with dimensions: {out_size} does "
362                    f"not have first dimension divisible by world_size: {world_size}"
363                )
364            out_size[0] = out_size[0] // dist.get_world_size(group=ctx.group)
365            gx = torch.empty(
366                out_size, device=grad_output.device, dtype=grad_output.dtype
367            )
368            dist._reduce_scatter_base(gx, grad_output, ReduceOp.SUM, ctx.group)
369        else:
370            raise RuntimeError("Backend not supported!")
371        return (None, gx, None)
372
373
374class _AlltoAll(Function):
375    @staticmethod
376    def forward(ctx, group, out_tensor_list, *tensors):
377        ctx.group = group
378        ctx.input_tensor_size_list = [
379            tensors[i].size() for i in range(dist.get_world_size(group=group))
380        ]
381        my_rank = dist.get_rank(group=group)
382        tensors = tuple(t.contiguous() for t in tensors)
383        # Implement it on means of scatter/gather, send/recv async operations have issues
384        if dist.get_backend(group=group) is dist.Backend.GLOO:
385            for i in range(dist.get_world_size(group=group)):
386                to_send = None
387                if i == my_rank:
388                    to_send = list(tensors)
389                dist.scatter(out_tensor_list[i], to_send, i, group=group)
390        else:
391            dist.all_to_all(
392                out_tensor_list,
393                list(tensors),
394                group=group,
395            )
396        return tuple(out_tensor_list)
397
398    @staticmethod
399    def backward(ctx, *grad_outputs):
400        tensor_list = [
401            torch.empty(
402                size, device=grad_outputs[0].device, dtype=grad_outputs[0].dtype
403            )
404            for size in ctx.input_tensor_size_list
405        ]
406        return (None, None) + _AlltoAll.apply(ctx.group, tensor_list, *grad_outputs)
407
408
409class _AlltoAllSingle(Function):
410    @staticmethod
411    def forward(ctx, group, output, output_split_sizes, input_split_sizes, input):
412        ctx.group = group
413        ctx.input_size = input.size()
414        ctx.output_split_sizes = input_split_sizes
415        ctx.input_split_sizes = output_split_sizes
416        dist.all_to_all_single(
417            output,
418            input,
419            output_split_sizes=output_split_sizes,
420            input_split_sizes=input_split_sizes,
421            group=group,
422        )
423        return output
424
425    @staticmethod
426    def backward(ctx, grad_output):
427        tensor = torch.empty(
428            ctx.input_size, device=grad_output.device, dtype=grad_output.dtype
429        )
430        return (None, None, None, None) + (
431            _AlltoAllSingle.apply(
432                ctx.group,
433                tensor,
434                ctx.output_split_sizes,
435                ctx.input_split_sizes,
436                grad_output.contiguous(),
437            ),
438        )
439
440
441class _AllReduce(Function):
442    @staticmethod
443    def forward(ctx, op, group, tensor):
444        ctx.group = group
445        ctx.op = op
446        tensor = tensor.clone()
447        dist.all_reduce(tensor, op=op, group=group)
448        return tensor
449
450    @staticmethod
451    def backward(ctx, grad_output):
452        return (None, None) + (_AllReduce.apply(ctx.op, ctx.group, grad_output),)
453