xref: /aosp_15_r20/external/pytorch/torch/distributed/_functional_collectives.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 # mypy: allow-untyped-defs
2 import sys
3 import warnings
4 from typing import cast, List, Optional, Tuple, TYPE_CHECKING, Union
5 
6 import torch
7 import torch.distributed as dist
8 import torch.distributed.distributed_c10d as c10d
9 from torch.distributed.device_mesh import DeviceMesh
10 from torch.fx.experimental.proxy_tensor import get_proxy_mode
11 
12 from . import _functional_collectives_impl as fun_col_impl
13 
14 
15 try:
16     from torch.utils._cxx_pytree import tree_map_only
17 except ImportError:
18     from torch.utils._pytree import tree_map_only  # type: ignore[no-redef]
19 
20 
21 if torch._running_with_deploy():
22 
23     def is_torchdynamo_compiling():
24         """Can't import torchdynamo in torchdeploy builds currently."""
25         return False
26 
27 else:
28     try:
29         from torch.compiler import is_dynamo_compiling as is_torchdynamo_compiling
30     except Exception:
31         warnings.warn(
32             "Unable to import torchdynamo util `is_torchdynamo_compiling`, so won't support torchdynamo correctly"
33         )
34 
35         def is_torchdynamo_compiling():
36             return False
37 
38 
39 """
40 New traceable, functional collectives.
41 RFC: https://github.com/pytorch/pytorch/issues/93173
42 
43   compiler: trace these ops with plain-old-data schemas, then choose how to lower them.
44   eager: execute these 'functional' ops which in eager return AsyncCollectiveTensor subclasses,
45          automatically calling .wait() on underlying/hidden async 'work' obj only when fed to
46          a downstream op.
47 
48 Issues:
49 * Where should these ops live? Couldn't `import torch` if putting these ops in existing torch.distributed files
50 * Proper support for eager requires inplace ops. We should explore having it as an option for the API.
51 """
52 
53 """
54 Functional collectives are asynchronous only and we perform implicit stream synchronization
55 on behalf of the user.
56 
57 We use AsyncCollectiveTensor to wrap the result tensor of a collective and it lets us witness
58 first usage of the tensor and insert cross stream sync at the right place.
59 
60 The above are the easy bits, the hard one is how we match the Work object returned by
61 c10d and the tensor AsyncCollectiveTensor wraps. We alloc the tensor inside the collective
62 op implementation (see ``clone()`` call in ``_all_reduce``) and then it's handled by the
63 dispatcher which might call other implementations that are allowed to change the returned
64 tensor - even return a tensor with a different shape (see ``torch.vmap``).
65 
66 This means the caller of our ops receives a Tensor that is not guaranteed to be the same
67 allocated by our implementations and that makes pairing The AsyncTensor to the original
68 tensor a lot harder. This pairing is needed so we can lookup the Work object to use.
69 
70 Originally, we tried WeakKeyDictionary to map from Tensor to Work, but because Tensor's
71 identity is not stable across dispatch, the op caller would end up with a different Tensor
72 instance that would not match any in the dictionary.
73 
74 With Tensor identity out of the question, we decided use the tensor data pointer, which
75 should be stable across all the Tensor changes done during dispatch.
76 
77 We have a dictionary of tensor::data_ptr -> Work that we insert right after we call into c10d.
78 
79 We use this dictionary when AsyncCollectiveTensor is used to invoke Work::wait()
80 
81 Finally, we setup a finalizer against the tensor wrapper to observe it getting collected so we
82 can clean up stale entries in the dictionary.
83 
84 To eliminate the possibility of races we have a global version counter that is used by the finalizer.
85 
86 As a wise man said once: Don't cross the streams (https://www.youtube.com/watch?v=wyKQe_i9yyo)
87 
88 """
89 
90 """
91 Functional collectives can accept any of these types to describe the ranks participating in collectives.
92 
93 The different types will be desugared to a canonical format
94 """
95 RANK_TYPES = Union[
96     List[int],
97     List[List[int]],
98     dist.ProcessGroup,
99     DeviceMesh,
100     Tuple["dist.tensor.DeviceMesh", int],
101     str,
102 ]
103 
104 
105 """
106 User facing APIs for functional collectives
107 -------------------------------------------
108 
109 These apis are called by user code and expected to work both in eager execution and compilation,
110 but there are significant differences to how the two modes are implemented underneath.
111 
112 Eager execution is 'optimized' using a tensor subclass that schedules the synchronization (via wait_tensor() op)
113 just before the tensor is first used.  Compiled tracing currently relies on the compiler to perform this optimization,
114 and cannot yet correctly trace the AsyncTensor wrapper class.  In the future, these paths may be unified
115 if sufficient subclass support is added in dynamo.
116 
117 Example: all_reduce is an entrypoint API, and other collectives follow a similar pattern.
118 
119 Here's how it works under torch.compile/dynamo:
120 all_reduce(...)
121   |--> _expand_group(...)               - desugars processgroup into canonical/traceable format
122   |--> c10d_functional.all_reduce(...)  - dynamo captures this op call, doesn't trace deeper
123   |--> _maybe_wrap_tensor(...)          - wait_tensor() op is immediately called, no AsyncTensor subclass needed
124 
125 And under eager execution:
126 all_reduce(...)
127   |--> _expand_group(...)               - same as above, but less critical for eager
128   |--> c10d_functional.all_reduce(...)  - dispatches to real kernel OR records op in trace
129   |--> _maybe_wrap_tensor(...)          - AsyncTensor wrapper applied to returned tensor,
130                                           which issues wait_tensor() at the time of first use
131 """
132 
133 
134 def wait_tensor(tensor):
135     """
136     Wait on a tensor returned by the collectives ops.
137 
138     Waiting follows device semantics, which means blocking on CPU and synchronizing streams on CUDA.
139     """
140     return torch.ops._c10d_functional.wait_tensor(tensor)  # type: ignore[attr-defined]
141 
142 
143 def broadcast(self: torch.Tensor, src: int, group: RANK_TYPES, tag: str = ""):
144     """
145     Broadcasts the tensor to all processes in the given process group.
146 
147     Args:
148         src (int): Source rank
149         group (ProcessGroup or List[int]): The process group to work on.
150         tag (str, optional): A unique identifier for the collective. Default: empty string
151     """
152     group_name = _resolve_group_name(group, tag)
153     tensor = torch.ops._c10d_functional.broadcast(self, src, group_name)
154     return _maybe_wrap_tensor(tensor)
155 
156 
157 def all_reduce(self: torch.Tensor, reduceOp: str, group: RANK_TYPES, tag: str = ""):
158     """
159     Reduces the tensor data across all machines in such a way that all get
160     the final result.
161 
162     The input tensor is left unmodified.
163 
164     Group can be one of:
165         List[int]: ranks participating in the collective.
166         List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
167         ProcessGroup: Will perform a collective using the ranks and tag of the PG.
168         DeviceMesh: Do a SPMD collective over all ranks of the mesh
169         (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
170 
171     :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
172     that information and perform collective algebraic optimization. Use other forms of input for that.
173     """
174     group_name = _resolve_group_name(group, tag)
175     tensor = torch.ops._c10d_functional.all_reduce(self, reduceOp.lower(), group_name)
176     return _maybe_wrap_tensor(tensor)
177 
178 
179 def all_gather_tensor(
180     self: torch.Tensor,
181     gather_dim: int,
182     group: RANK_TYPES,
183     tag: str = "",
184 ):
185     """
186     Gather tensor data across from all machines and concatenate over ``gather_dim``.
187 
188     Note that it currently only supports gather_dim = 0.
189 
190     The input tensor is left unmodified.
191     Group can be one of:
192         List[int]: ranks participating in the collective.
193         List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
194         ProcessGroup: Will perform a collective using the ranks and tag of the PG.
195         DeviceMesh: Do a SPMD collective over all ranks of the mesh
196         (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
197 
198     :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
199     that information and perform collective algebraic optimization. Use other forms of input for that.
200     """
201     assert self.is_contiguous()
202     group_name = _resolve_group_name(group, tag)
203     group_size = c10d._get_group_size_by_name(group_name)
204     tensor = torch.ops._c10d_functional.all_gather_into_tensor(
205         self, group_size, group_name
206     )
207     res = _maybe_wrap_tensor(tensor)
208     # TODO this should be done inside AsyncCollectiveTensor to delay the wait() call
209     if gather_dim != 0:
210         # torch.cat access the data so we already need to wait here, first do wait
211         # and then chunk + cat avoid us going through ACT dispatching logic again
212         if isinstance(res, AsyncCollectiveTensor):
213             res = res.wait()  # type: ignore[attr-defined]
214         res = torch.cat(torch.chunk(res, group_size, dim=0), dim=gather_dim)
215     return res
216 
217 
218 def all_gather_tensor_autograd(
219     self: torch.Tensor,
220     gather_dim: int,
221     group: RANK_TYPES,
222     tag: str = "",
223 ):
224     """
225     Gather tensor data across from all machines and concatenate over ``gather_dim``.
226 
227     Note that it currently only supports gather_dim = 0.
228 
229     This function is the same as all_gather_tensor but will propagate the
230     backwards gradient across workers.
231 
232     See all_gather_tensor for more details on usage.
233     """
234     group_name = _resolve_group_name(group, tag)
235     group_size = c10d._get_group_size_by_name(group_name)
236 
237     tensor = torch.ops._c10d_functional_autograd.all_gather_into_tensor(
238         self, group_size, group_name
239     )
240     res = _FromTorchTensor.apply(tensor)
241     # TODO this should be done inside AsyncCollectiveTensor to delay the wait() call
242     if gather_dim != 0:
243         # torch.cat access the data so we already need to wait here, first do wait
244         # and then chunk + cat avoid us going through ACT dispatching logic again
245         if isinstance(res, AsyncCollectiveTensor):
246             res = res.wait()  # type: ignore[attr-defined]
247         res = torch.cat(torch.chunk(res, group_size, dim=0), dim=gather_dim)
248     return res
249 
250 
251 def reduce_scatter_tensor(
252     self: torch.Tensor,
253     reduceOp: str,
254     scatter_dim: int,
255     group: RANK_TYPES,
256     tag: str = "",
257 ):
258     """
259     Reduces the tensor data across all machines in such a way that all get
260     the final result, then scatter the results to corresponding ranks.
261 
262 
263     The input tensor is left unmodified.
264     Group can be one of:
265         List[int]: ranks participating in the collective.
266         List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
267         ProcessGroup: Will perform a collective using the ranks and tag of the PG.
268         DeviceMesh: Do a SPMD collective over all ranks of the mesh
269         (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
270     :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
271     that information and perform collective algebraic optimization. Use other forms of input for that.
272     """
273     group_name = _resolve_group_name(group, tag)
274     group_size = c10d._get_group_size_by_name(group_name)
275 
276     assert (
277         self.size(scatter_dim) % group_size == 0
278     ), f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}"
279     if scatter_dim != 0:
280         tensor_list = torch.chunk(self, group_size, dim=scatter_dim)
281         self = torch.cat(tensor_list)
282 
283     tensor = torch.ops._c10d_functional.reduce_scatter_tensor(
284         self,
285         reduceOp.lower(),
286         group_size,
287         group_name,  # type: ignore[possibly-undefined]
288     )
289     res = _maybe_wrap_tensor(tensor)
290     return res
291 
292 
293 def reduce_scatter_tensor_autograd(
294     self: torch.Tensor,
295     reduceOp: str,
296     scatter_dim: int,
297     group: RANK_TYPES,
298     tag: str = "",
299 ):
300     """
301     Reduces the tensor data across all machines in such a way that all get
302     the final result, then scatter the results to corresponding ranks.
303 
304     This function is the same as reduce_scatter_tensor but will propagate the
305     backwards gradient across workers.
306 
307     Currently only the "sum" reduceOp is supported.
308 
309     See reduce_scatter_tensor for more details on usage.
310     """
311 
312     group_name = _resolve_group_name(group, tag)
313     group_size = c10d._get_group_size_by_name(group_name)
314 
315     assert (
316         self.size(scatter_dim) % group_size == 0
317     ), f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}"
318     if scatter_dim != 0:
319         tensor_list = torch.chunk(self, group_size, dim=scatter_dim)
320         self = torch.cat(tensor_list)
321 
322     tensor = torch.ops._c10d_functional_autograd.reduce_scatter_tensor(
323         self,
324         reduceOp.lower(),
325         group_size,
326         group_name,  # type: ignore[possibly-undefined]
327     )
328     res = _FromTorchTensor.apply(tensor)
329     return res
330 
331 
332 def all_reduce_coalesced(
333     self: List[torch.Tensor], reduceOp: str, group: RANK_TYPES, tag: str = ""
334 ) -> List[torch.Tensor]:
335     """
336     Reduces a list of tensors across all machines in such a way that all get
337     the final result.
338 
339     The all tensors in the input list are left unmodified.
340 
341     Group can be one of:
342         List[int]: ranks participating in the collective.
343         List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
344         ProcessGroup: Will perform a collective using the ranks and tag of the PG.
345         DeviceMesh: Do a SPMD collective over all ranks of the mesh
346         (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
347 
348     :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
349     that information and perform collective algebraic optimization. Use other forms of input for that.
350     """
351     group_name = _resolve_group_name(group, tag)
352     tensor_list = torch.ops._c10d_functional.all_reduce_coalesced(  # type: ignore[attr-defined]
353         self,
354         reduceOp.lower(),
355         group_name,
356     )
357     return list(map(_maybe_wrap_tensor, tensor_list))
358 
359 
360 def all_gather_into_tensor_coalesced(
361     self: List[torch.Tensor], group: RANK_TYPES, tag: str = ""
362 ) -> List[torch.Tensor]:
363     """
364     Gather a list of tensors across from all machines.
365 
366     Note that it currently only supports gather_dim = 0.
367 
368     The input tensor is left unmodified.
369     Group can be one of:
370         List[int]: ranks participating in the collective.
371         List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
372         ProcessGroup: Will perform a collective using the ranks and tag of the PG.
373         DeviceMesh: Do a SPMD collective over all ranks of the mesh
374         (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
375 
376     :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
377     that information and perform collective algebraic optimization. Use other forms of input for that.
378     """
379     group_name = _resolve_group_name(group, tag)
380     group_size = c10d._get_group_size_by_name(group_name)
381     tensor_list = torch.ops._c10d_functional.all_gather_into_tensor_coalesced(  # type: ignore[attr-defined]
382         self,
383         group_size,
384         group_name,
385     )
386     return list(map(_maybe_wrap_tensor, tensor_list))
387 
388 
389 def reduce_scatter_tensor_coalesced(
390     inputs: List[torch.Tensor],
391     reduceOp: str,
392     scatter_dim: List[int],
393     group: RANK_TYPES,
394     tag: str = "",
395 ) -> List[torch.Tensor]:
396     """
397     Reduces a list of tensors across all machines in such a way that all get
398     the final result, then scatter the results to corresponding ranks.
399 
400     The input tensors are left unmodified.
401     Group can be one of:
402         List[int]: ranks participating in the collective.
403         List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
404         ProcessGroup: Will perform a collective using the ranks and tag of the PG.
405         DeviceMesh: Do a SPMD collective over all ranks of the mesh
406         (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
407 
408     :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
409     that information and perform collective algebraic optimization. Use other forms of input for that.
410     """
411     group_name = _resolve_group_name(group, tag)
412     group_size = c10d._get_group_size_by_name(group_name)
413 
414     assert len(scatter_dim) == len(inputs)
415     for idx, (dim, tensor) in enumerate(zip(scatter_dim, inputs)):
416         assert (
417             tensor.size(dim) % group_size == 0
418         ), f"input dimension {dim} ({tensor.size(dim)} must be a multiple of group_size {group_size} for tensor at index {idx}"
419         if dim != 0:
420             tensor_list = torch.chunk(tensor, group_size, dim=dim)
421             inputs[idx] = torch.cat(tensor_list)
422 
423     tensor_list = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced(  # type: ignore[attr-defined]
424         inputs,
425         reduceOp.lower(),
426         group_size,
427         group_name,  # type: ignore[possibly-undefined]
428     )
429 
430     return list(map(_maybe_wrap_tensor, tensor_list))
431 
432 
433 # This is a bit unsafe: it checks if the first argument in the schema reports as a non-mutable alias.
434 # Today, this maps 1:1 with "aten ops that are views".
435 def _is_view_op(tgt):
436     assert isinstance(tgt, torch._ops.OpOverload)
437     schema = tgt._schema
438     if len(schema.arguments) > 0:
439         first_arg = schema.arguments[0]
440         # check if op is a view
441         return first_arg.alias_info is not None and not first_arg.alias_info.is_write
442 
443 
444 def all_to_all_single(
445     self: torch.Tensor,
446     output_split_sizes: Optional[List[int]],
447     input_split_sizes: Optional[List[int]],
448     group: RANK_TYPES,
449     tag: str = "",
450 ) -> torch.Tensor:
451     """
452     Each process splits input tensor and then scatters the split list
453     to all processes in a group. Then concatenate the received tensors from all
454     the processes in the group and return single output tensor.
455 
456     Group can be one of:
457         List[int]: ranks participating in the collective.
458         List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
459         ProcessGroup: Will perform a collective using the ranks and tag of the PG.
460         DeviceMesh: Do a SPMD collective over all ranks of the mesh
461         (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
462 
463     :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
464     that information and perform collective algebraic optimization. Use other forms of input for that.
465     """
466     if output_split_sizes is not None:
467         assert all(
468             isinstance(size, (int, torch.SymInt)) for size in output_split_sizes
469         ), output_split_sizes
470     if input_split_sizes is not None:
471         assert all(
472             isinstance(size, (int, torch.SymInt)) for size in input_split_sizes
473         ), input_split_sizes
474     group_name = _resolve_group_name(group, tag)
475     group_size = c10d._get_group_size_by_name(group_name)
476     if output_split_sizes is None or input_split_sizes is None:
477         assert output_split_sizes is None and input_split_sizes is None, (
478             "output_split_sizes and input_split_sizes must either be "
479             "specified together or both set to None"
480         )
481         output_split_sizes = [self.shape[0] // group_size] * group_size
482         input_split_sizes = output_split_sizes
483     tensor = torch.ops._c10d_functional.all_to_all_single(  # type: ignore[attr-defined]
484         self,
485         output_split_sizes,
486         input_split_sizes,
487         group_name,
488     )
489     return _maybe_wrap_tensor(tensor)
490 
491 
492 def all_to_all_single_autograd(
493     self: torch.Tensor,
494     output_split_sizes: Optional[List[int]],
495     input_split_sizes: Optional[List[int]],
496     group: RANK_TYPES,
497     tag: str = "",
498 ) -> torch.Tensor:
499     """
500     Same as all_to_all_single but supports autograd.
501     """
502     if output_split_sizes is not None:
503         assert all(
504             isinstance(size, (int, torch.SymInt)) for size in output_split_sizes
505         ), output_split_sizes
506     if input_split_sizes is not None:
507         assert all(
508             isinstance(size, (int, torch.SymInt)) for size in input_split_sizes
509         ), input_split_sizes
510 
511     group_name = _resolve_group_name(group, tag)
512     group_size = c10d._get_group_size_by_name(group_name)
513     if output_split_sizes is None or input_split_sizes is None:
514         assert output_split_sizes is None and input_split_sizes is None, (
515             "output_split_sizes and input_split_sizes must either be "
516             "specified together or both set to None"
517         )
518         output_split_sizes = [self.shape[0] // group_size] * group_size
519         input_split_sizes = output_split_sizes
520     tensor = torch.ops._c10d_functional_autograd.all_to_all_single(  # type: ignore[attr-defined]
521         self,
522         output_split_sizes,
523         input_split_sizes,
524         group_name,
525     )
526     return _FromTorchTensor.apply(tensor)
527 
528 
529 def permute_tensor(
530     self: torch.Tensor,
531     src_dst: List[int],
532     group: RANK_TYPES,
533     tag: str = "",
534 ) -> torch.Tensor:
535     """
536     Permutes the elements of the tensor according to the given source/destination pairs. `src_dst` should
537     be defined such that src_dst[m] == n means m sends to n.
538 
539     Group can be one of:
540         List[int]: ranks participating in the collective.
541         List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
542         ProcessGroup: Will perform a collective using the ranks and tag of the PG.
543         DeviceMesh: Do a SPMD collective over all ranks of the mesh
544         (DeviceMesh, int): Do a MPMD collective over one
545     """
546     t, rankset, group_size = _expand_group(group, tag)
547     local_pg = c10d._find_or_create_pg_by_ranks_and_tag(t, rankset, group_size)
548 
549     output_split_sizes = [0] * group_size
550     input_split_sizes = [0] * group_size
551     for src, dst in enumerate(src_dst):
552         if src == dist.get_rank(local_pg):
553             input_split_sizes[dst] = self.numel()
554         if dst == dist.get_rank(local_pg):
555             output_split_sizes[src] = self.numel()
556 
557     return all_to_all_single(self, output_split_sizes, input_split_sizes, group, tag)
558 
559 
560 class AsyncCollectiveTensor(torch.Tensor):
561     r"""
562     A Tensor wrapper subclass that is used to trigger a call to wait
563     prior to first use of the underlying tensor.
564     Use it inside functional collective pytorch wrappers like the following:
565     def functional_collective(self, group, tag):
566         tag, rankset, group_size = _expand_group(group, tag)
567         tensor = torch.ops.c10d_functional.{collective}(self, tag, rankset, group_size)
568         return _maybe_wrap_tensor(tensor)
569     """
570     elem: torch.Tensor
571     completed: bool
572 
573     __slots__ = ["elem", "completed"]
574 
575     @staticmethod
576     def __new__(cls, elem: torch.Tensor):
577         r = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
578             cls,
579             elem.size(),
580             strides=elem.stride(),
581             storage_offset=elem.storage_offset(),
582             dtype=elem.dtype,
583             layout=elem.layout,
584             device=elem.device,
585             requires_grad=elem.requires_grad,
586         )
587         r.elem = elem
588         r.completed = False
589         return r
590 
591     def __tensor_flatten__(self):
592         return ["elem"], None
593 
594     def tolist(self):
595         return self.trigger_wait().tolist()
596 
597     @staticmethod
598     def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
599         assert meta is None
600         elem = inner_tensors["elem"]
601         return AsyncCollectiveTensor(elem)
602 
603     def __repr__(self):
604         return f"AsyncCollectiveTensor({self.trigger_wait()})"
605 
606     def trigger_wait(self):
607         if not self.completed:
608             out = wait_tensor(self.elem)
609             self.completed = True
610             return out
611         else:
612             return self.elem
613 
614     def wait(self) -> torch.Tensor:
615         return wait_tensor(self.elem)
616 
617     def _get_acs_underlying_tensor(self):
618         """This method enables  _functional_collectives_impl to test if a tensor is an ACS"""
619         return self.elem
620 
621     @classmethod
622     def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
623         if func == torch.ops.aten.view.default:
624             # Fast handle aten.view as a lot of view related op goes to aten.view
625             # eventually, this avoids pytree slowdown
626             res = func(args[0].elem, args[1])
627             wrapper_res = AsyncCollectiveTensor(res)
628             return wrapper_res
629 
630         is_view_op = _is_view_op(func)
631 
632         def unwrap(e: AsyncCollectiveTensor):
633             # wait_tensor is idepotent and will do stream sync only once
634             if not is_view_op:
635                 return e.trigger_wait()
636             return e.elem
637 
638         def wrap(e: torch.Tensor):
639             # wait_tensor is idepotent and will do stream sync only once
640             assert not isinstance(e, AsyncCollectiveTensor)
641             res = AsyncCollectiveTensor(e)
642             return res
643 
644         unwrapped_args = tree_map_only(AsyncCollectiveTensor, unwrap, args)
645         unwrapped_kwargs = tree_map_only(AsyncCollectiveTensor, unwrap, kwargs)
646 
647         # we don't wrap the result as it doesn't need to be waited on.
648         out = func(*unwrapped_args, **unwrapped_kwargs)
649 
650         # View ops dont require a sync, so we should re-wrap the outputs.
651         if is_view_op:
652             out = tree_map_only(torch.Tensor, wrap, out)
653 
654         return out
655 
656     def numpy(self):
657         return self.wait().numpy()
658 
659 
660 """
661 Utils and infrastructure for tracing support
662 """
663 
664 
665 def _expand_group(group: RANK_TYPES, tag: str = "") -> Tuple[str, List[int], int]:
666     """
667     _expand_group desugars the different RANK_TYPES types into a canonical format that is traceable.
668 
669     By having this be part of the explicit eager codepath, we avoid having to specialize behavior inside
670     torchdynamo and can still interoperate with processgroup objects or other untraceable forms.
671     """
672     # had to define this hack _inside_ expand_group to avoid
673     # graph_break [('torch.* op returned non-Tensor int
674     # caused by 'cast_*` functions being treated as 'torch.*' ops (iiuc)
675     if TYPE_CHECKING:
676 
677         def cast_listlistint(x):
678             return cast(List[List[int]], x)
679 
680         def cast_listint(x):
681             return cast(List[int], x)
682 
683     else:
684         # fake cast op for use at runtime since dynamo doesn't support real cast
685         # also, dynamo didn't like encountering 'typing' objects ()
686         # NotImplementedError: argument of type: <class 'typing._GenericAlias'>
687         def cast_listlistint(x):
688             return x
689 
690         def cast_listint(x):
691             return x
692 
693     rankset: List[int]
694     if isinstance(group, list):
695         if isinstance(group[0], list):
696             nested_list = cast_listlistint(group)
697             rankset = []
698             group_size = -1
699             for rs in nested_list:
700                 rankset.extend(rs)
701                 if group_size != -1 and group_size != len(rs):
702                     raise ValueError(
703                         f"group sizes must be identical found {group_size} and {len(rs)}"
704                     )
705                 group_size = len(rs)
706         else:
707             rankset = cast_listint(group)
708             group_size = len(rankset)
709     elif isinstance(group, dist.ProcessGroup):
710         rankset = dist.get_process_group_ranks(group)
711         group_size = len(rankset)
712         tag = tag or c10d._get_group_tag(group)
713     elif isinstance(group, DeviceMesh):
714         assert (
715             group.ndim == 1
716         ), "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
717         # TODO: it should run collective in the whole mesh instead of dim 0
718         tag, rankset, _ = group._dim_group_infos[0]
719         group_size = len(rankset)
720     elif isinstance(group, tuple):
721         if (
722             len(group) == 2
723             and isinstance(group[0], DeviceMesh)
724             and isinstance(group[1], int)
725         ):
726             dmesh = group[0]
727             dim = group[1]
728             tag, rankset, _ = dmesh._dim_group_infos[dim]
729             group_size = len(rankset)
730         else:
731             raise ValueError("Invalid tuple for group must be (DeviceMesh, int)")
732     else:
733         raise ValueError(
734             "Invalid type for group, must be one of List, Processgroup, DeviceMesh or (DeviceMesh, int)."
735         )
736 
737     return (tag, rankset, group_size)
738 
739 
740 def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> str:
741     """
742     Given group in RANK_TYPES, return the group name.
743     """
744     # `tag` will be deprecated. See details in:
745     # https://github.com/pytorch/pytorch/issues/93173#issuecomment-1907095208
746     if isinstance(group, dist.ProcessGroup):
747         return group.group_name
748     elif isinstance(group, str):
749         return group
750     elif isinstance(group, DeviceMesh):
751         assert (
752             group.ndim == 1
753         ), "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
754         return group._dim_group_infos[0][2]
755     elif isinstance(group, tuple):
756         if (
757             len(group) == 2
758             and isinstance(group[0], DeviceMesh)
759             and isinstance(group[1], int)
760         ):
761             dmesh = group[0]
762             dim = group[1]
763             return dmesh._dim_group_infos[dim][2]
764         else:
765             raise ValueError("Invalid tuple for group must be (DeviceMesh, int)")
766     elif isinstance(group, list):
767         if not is_torchdynamo_compiling():
768             warnings.warn(
769                 "The combination of ranks + tag as process group "
770                 "identifier has been deprecated. Please switch to "
771                 "using ProcessGroup, DeviceMesh, or group name instead.",
772                 FutureWarning,
773                 stacklevel=3,
774             )
775         return c10d._resolve_group_name_by_ranks_and_tag(cast(List[int], group), tag)
776     else:
777         raise ValueError(f"Unsupported group type: {type(group)}, {group}")
778 
779 
780 class _FromTorchTensor(torch.autograd.Function):
781     """
782     _FromTorchTensor allows autograd to propagate from a normal Tensor to an
783     AsyncCollectiveTensor.
784     """
785 
786     @staticmethod
787     def forward(  # type: ignore[override]
788         ctx,  # pyre-ignore[2]: Parameter must be annotated.
789         input: torch.Tensor,
790     ) -> torch.Tensor:
791         return _maybe_wrap_tensor(input)
792 
793     @staticmethod
794     def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:  # type: ignore[override]
795         return grad_output
796 
797 
798 def _are_we_tracing() -> bool:
799     if is_torchdynamo_compiling():
800         return True
801     # If functionalization is turned on, we are almost definitely compiling/tracing.
802     # (In particular, AOTAutograd traces a model once with functionalization on
803     #  but proxy tracing turned of, so this is how we detect it).
804     if (
805         torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL)
806         is not None
807     ):
808         return True
809     return get_proxy_mode() is not None
810 
811 
812 def _maybe_wrap_tensor(self) -> torch.Tensor:
813     if _are_we_tracing():
814         return wait_tensor(self)
815     res = AsyncCollectiveTensor(self)
816     return cast(torch.Tensor, res)
817 
818 
819 def _all_gather_into_tensor_coalesced_meta(self, tag, rankset, group_size):
820     def mk_out_tensor(shard):
821         out_size = list(shard.size())
822         out_size[0] *= group_size
823         out_tensor = shard.new_empty(out_size)
824         return out_tensor
825 
826     return [mk_out_tensor(t) for t in self]
827 
828 
829 # We now register meta kernels to deal with tracing
830 def _broadcast_meta(self, *args):
831     return torch.empty_like(self)
832 
833 
834 def _all_reduce_meta(self, *args):
835     return torch.empty_like(self)
836 
837 
838 def _wait_tensor_meta(self, *args):
839     return torch.empty_like(self)
840 
841 
842 def _all_gather_into_tensor_meta(shard, tag, rankset, group_size):
843     out_size = list(shard.size())
844     out_size[0] *= group_size
845     return shard.new_empty(out_size)
846 
847 
848 def _reduce_scatter_tensor_meta(input, reduce_op, tag, rankset, group_size):
849     out_size = list(input.size())
850     out_size[0] //= group_size
851     return input.new_empty(out_size)
852 
853 
854 def _all_reduce_coalesced_meta(self, *args):
855     return [torch.empty_like(t) for t in self]
856 
857 
858 def _all_reduce__meta(inp, *args):
859     return inp
860 
861 
862 def _broadcast__meta(inp, *args):
863     return inp
864 
865 
866 def _all_reduce_coalesced__meta(inputs, *args):
867     return inputs
868 
869 
870 def _reduce_scatter_tensor_coalesced_meta(inputs, reduceOp, tag, rankset, group_size):
871     def mk_out_tensor(input):
872         out_size = list(input.size())
873         out_size[0] //= group_size
874         out_tensor = input.new_empty(out_size)
875         return out_tensor
876 
877     return [mk_out_tensor(t) for t in inputs]
878 
879 
880 # NB: We often say all_to_all has dynamic output size, but this is not
881 # technically true: instead, what typically happens is you manually
882 # communicate the output_split_sizes ahead of time (which is dynamic),
883 # but then you pass those sizes explicitly, and the all to all itself
884 # isn't dynamic, it just follows the specified output splits
885 def _all_to_all_single_meta(
886     input, output_split_sizes, input_split_sizes, *args, **kwargs
887 ):
888     if output_split_sizes is None:
889         return input.new_empty(input.size())
890     else:
891         for s in output_split_sizes:
892             torch._check_is_size(s)
893         out_size = list(input.size())
894         out_size[0] = sum(output_split_sizes)
895         return input.new_empty(out_size)
896 
897 
898 def _all_gather_into_tensor_out_native_meta(input, group_size, group_name, *, out):
899     shape = list(input.size())
900     shape[0] *= group_size
901     return input.new_empty(shape)
902 
903 
904 def _all_gather_into_tensor_native_meta(input, group_size, group_name):
905     shape = list(input.size())
906     shape[0] *= group_size
907     return input.new_empty(shape)
908 
909 
910 def _all_gather_into_tensor_coalesced_native_meta(inputs, group_size, group_name):
911     return [
912         _all_gather_into_tensor_native_meta(input, group_size, group_name)
913         for input in inputs
914     ]
915 
916 
917 def _reduce_scatter_tensor_native_meta(inp, reduce_op, group_size, group_name):
918     shape = list(inp.size())
919     shape[0] //= group_size
920     return inp.new_empty(shape)
921 
922 
923 def _reduce_scatter_tensor_coalesced_native_meta(
924     inputs, reduce_op, group_size, group_name
925 ):
926     return [
927         _reduce_scatter_tensor_native_meta(inp, reduce_op, group_size, group_name)
928         for inp in inputs
929     ]
930 
931 
932 if not torch._running_with_deploy():
933     # Library MUST be defined at module scope or it doesn't work
934     # Creating a "DEF" Library always crashes torch::deploy so we create our
935     # Library instances here guarded against running inside it
936     lib_impl = torch.library.Library("_c10d_functional", "IMPL")
937     lib_impl.impl("all_reduce", _all_reduce_meta, "Meta")
938     lib_impl.impl("all_reduce_", _all_reduce__meta, "Meta")
939     lib_impl.impl("all_reduce_coalesced", _all_reduce_coalesced_meta, "Meta")
940     lib_impl.impl("all_reduce_coalesced_", _all_reduce_coalesced__meta, "Meta")
941     lib_impl.impl("wait_tensor", _wait_tensor_meta, "Meta")
942     lib_impl.impl(
943         "all_gather_into_tensor_out", _all_gather_into_tensor_out_native_meta, "Meta"
944     )
945     lib_impl.impl("all_gather_into_tensor", _all_gather_into_tensor_native_meta, "Meta")
946     lib_impl.impl(
947         "all_gather_into_tensor_coalesced",
948         _all_gather_into_tensor_coalesced_native_meta,
949         "Meta",
950     )
951     lib_impl.impl("reduce_scatter_tensor", _reduce_scatter_tensor_native_meta, "Meta")
952     lib_impl.impl(
953         "reduce_scatter_tensor_coalesced",
954         _reduce_scatter_tensor_coalesced_native_meta,
955         "Meta",
956     )
957     lib_impl.impl("all_to_all_single", _all_to_all_single_meta, "Meta")
958     lib_impl.impl("broadcast", _broadcast_meta, "Meta")
959     lib_impl.impl("broadcast_", _broadcast__meta, "Meta")
960 
961     # mark these ops has side effect so that they won't be removed by DCE
962     torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor.default)
963     torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor)
964 
965     # Register legacy ops for backward compatibility
966     # TODO(yifu): remove these in functional collective beta release
967     legacy_lib = torch.library.Library("c10d_functional", "DEF")
968     legacy_lib_impl = torch.library.Library("c10d_functional", "IMPL")
969     ops_defs = [
970         "broadcast(Tensor self, int src, str tag, int[] ranks, int group_size) -> Tensor",
971         "all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor",
972         "all_reduce_coalesced(Tensor[] self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]",
973         "wait_tensor(Tensor self) -> Tensor",
974         "all_gather_into_tensor(Tensor shard, str tag, int[] ranks, int group_size) -> Tensor",
975         "all_gather_into_tensor_coalesced(Tensor[] input, str tag, int[] ranks, int group_size) -> Tensor[]",
976         "reduce_scatter_tensor(Tensor input, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor",
977         "reduce_scatter_tensor_coalesced(Tensor[] inputs, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]",
978         "all_to_all_single(Tensor input, SymInt[]? output_split_sizes, SymInt[]? input_split_sizes, str tag, int[] ranks, int group_size) -> Tensor",  # noqa: B950
979     ]
980 
981     my_module = sys.modules[__name__]
982     for op_def in ops_defs:
983         op_name = op_def[0 : op_def.index("(")]
984         backend_impl = getattr(fun_col_impl, f"_{op_name}")
985         legacy_lib.define(op_def, tags=torch.Tag.pt2_compliant_tag)
986         legacy_lib_impl.impl(op_name, backend_impl, "CompositeImplicitAutograd")
987 
988 else:
989     warnings.warn(
990         "PyTorch Distributed functional collectives do not work with torch::deploy."
991     )
992 
993 
994 """
995 Dynamo Remappings allow seamless translation from non-functional collectives of supportable form into
996 functional collective calls followed by inplace copy ops, allowing them to be traced into a functional graph.
997 
998 We implement this by writing a decomposition and teaching dynamo how to associate it to a corresponding op via
999 the mapping dict below.
1000 
1001 These schemas intentionally match torch.distributed.distributed_c10d.* ops that we are trying to remap from
1002 """
1003 
1004 
1005 def all_gather_tensor_inplace(
1006     output_tensor: torch.Tensor,
1007     input_tensor: torch.Tensor,
1008     group,  # TODO add a type,
1009     async_op: bool = False,
1010     tag: str = "",
1011     gather_dim: int = 0,
1012 ):
1013     assert (
1014         not async_op
1015     ), "Can't remap async version of inplace op to functional collective"
1016 
1017     group = group or dist.group.WORLD
1018     assert group is not None
1019 
1020     return output_tensor.copy_(all_gather_tensor(input_tensor, gather_dim, group, tag))
1021 
1022 
1023 def reduce_scatter_tensor_inplace(
1024     output: torch.Tensor,
1025     input: torch.Tensor,
1026     op: str = "sum",  # TODO type is actually c10d ReduceOp. is this ok?
1027     group=None,  # TODO add a type
1028     async_op: bool = False,
1029     scatter_dim: int = 0,
1030     tag: str = "",
1031 ):
1032     assert (
1033         not async_op
1034     ), "Can't remap async version of inplace op to functional collective"
1035 
1036     group = group or dist.group.WORLD
1037     assert group is not None
1038 
1039     return output.copy_(reduce_scatter_tensor(input, op, scatter_dim, group, tag))
1040 
1041 
1042 REDUCE_OP_TO_STR = {
1043     dist.ReduceOp.SUM: "sum",
1044     dist.ReduceOp.AVG: "avg",
1045     dist.ReduceOp.PRODUCT: "product",
1046     dist.ReduceOp.MIN: "min",
1047     dist.ReduceOp.MAX: "max",
1048     dist.ReduceOp.BAND: "band",
1049     dist.ReduceOp.BOR: "bor",
1050     dist.ReduceOp.BXOR: "bxor",
1051 }
1052 
1053 
1054 def all_reduce_inplace(
1055     tensor: torch.Tensor,
1056     op: str = "sum",
1057     group=None,
1058     async_op: bool = False,
1059     tag: str = "",
1060 ):
1061     assert (
1062         not async_op
1063     ), "Can't remap async version of inplace op to functional collective"
1064 
1065     group = group or dist.group.WORLD
1066     assert group is not None
1067 
1068     return tensor.copy_(all_reduce(tensor, op, group, tag))
1069 
1070 
1071 def all_to_all_inplace(
1072     output: torch.Tensor,
1073     input: torch.Tensor,
1074     output_split_sizes=None,
1075     input_split_sizes=None,
1076     group=None,
1077     async_op=False,
1078     tag: str = "",
1079 ):
1080     assert (
1081         not async_op
1082     ), "Can't remap async version of inplace op to functional collective"
1083 
1084     group = group or dist.group.WORLD
1085     assert group is not None
1086 
1087     return output.copy_(
1088         all_to_all_single(
1089             input,
1090             output_split_sizes,
1091             input_split_sizes,
1092             group,
1093             tag,
1094         )
1095     )
1096 
1097 
1098 def all_gather_inplace(
1099     tensor_list: List[torch.Tensor],
1100     tensor: torch.Tensor,
1101     group=None,
1102     async_op=False,
1103     tag: str = "",
1104 ):
1105     assert (
1106         not async_op
1107     ), "Can't remap async version of inplace op to functional collective"
1108     assert all(
1109         t.size(0) == tensor.size(0) for t in tensor_list
1110     ), "Remapping variable size all_gather is not yet supported"
1111 
1112     group = group or dist.group.WORLD
1113     assert group is not None
1114 
1115     output = all_gather_tensor(tensor, 0, group, tag)
1116 
1117     # Use aten.slice instead of aten.split because the latter causes
1118     # tensor.shape(0) to be unnecessarily baked in when it's a SymInt.
1119     output_splits = []
1120     offset = 0
1121     for t in tensor_list:
1122         output_splits.append(output[offset : offset + t.size(0)])
1123         offset += t.size(0)
1124     for dst, src in zip(tensor_list, output_splits):
1125         dst.copy_(src)
1126     return tensor_list
1127 
1128 
1129 from torch.distributed.distributed_c10d import (
1130     _all_gather_base as legacy_all_gather_base,
1131     _reduce_scatter_base as legacy_reduce_scatter_base,
1132     all_gather as legacy_all_gather,
1133     all_gather_into_tensor as legacy_allgather,
1134     all_reduce as legacy_allreduce,
1135     all_to_all_single as legacy_all_to_all_single,
1136     reduce_scatter_tensor as legacy_reducescatter,
1137 )
1138 
1139 
1140 # This dict should contain sets of functions that dynamo is allowed to remap.
1141 # Functions in this set should accept the same args/kwargs 1:1 as their mapping.
1142 traceable_collective_remaps = {
1143     legacy_allgather: all_gather_tensor_inplace,
1144     legacy_reducescatter: reduce_scatter_tensor_inplace,
1145     legacy_allreduce: all_reduce_inplace,
1146     legacy_all_to_all_single: all_to_all_inplace,
1147     legacy_all_gather: all_gather_inplace,
1148     legacy_reduce_scatter_base: reduce_scatter_tensor_inplace,
1149     legacy_all_gather_base: all_gather_tensor_inplace,
1150 }
1151