xref: /aosp_15_r20/external/pytorch/test/distributed/test_inductor_collectives.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2import functools
3import unittest
4from unittest.mock import patch
5
6import torch
7import torch._dynamo
8import torch._dynamo.logging
9import torch._dynamo.test_case
10
11# for some reason importing functional collectives after dynamo breaks collectives handling!
12import torch.distributed._functional_collectives as _functional_collectives
13from torch._C import FileCheck
14from torch._dynamo.testing import CompileCounter
15from torch._dynamo.utils import same
16from torch._inductor.compile_fx import compile_fx as inductor_compile_fx
17from torch._inductor.utils import run_and_get_triton_code
18from torch.distributed.distributed_c10d import GroupMember
19from torch.fx.experimental.proxy_tensor import make_fx
20from torch.testing._internal.common_distributed import (
21    _dynamo_dist_per_rank_init,
22    DynamoDistributedMultiProcTestCase,
23    DynamoDistributedSingleProcTestCase,
24    requires_nccl,
25    skip_if_lt_x_gpu,
26)
27from torch.testing._internal.common_utils import (
28    instantiate_parametrized_tests,
29    parametrize,
30    requires_cuda,
31)
32from torch.utils._triton import has_triton
33
34
35def _tolist_with_constrain_as_size(tensor):
36    lst = tensor.tolist()
37    for elem in lst:
38        torch._check_is_size(elem)
39    return lst
40
41
42@requires_nccl()
43class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
44    """
45    Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under
46    """
47
48    def get_world_trs(self):
49        return {
50            "tag": "",
51            "ranks": list(range(self.world_size)),
52            "group_size": self.world_size,
53        }
54
55    @property
56    def world_size(self) -> int:
57        # hack: no matter whether we have 2 or 3 or 4 gpus, just run on 2
58        # works around issue with skipif<2 and workers with unpredictable #s gpu
59        return 2
60
61    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
62    @skip_if_lt_x_gpu(2)
63    def test_broadcast_inductor(self):
64        """
65        Testing if broadcast works correctly when using inductor
66        """
67
68        def example(tensor, src, *, tag, ranks, group_size):
69            res = torch.ops.c10d_functional.broadcast(
70                tensor, src, tag, ranks, group_size
71            )
72            res = torch.ops.c10d_functional.wait_tensor(res)
73            return res
74
75        def compile(func, example_inputs):
76            graph = make_fx(func)(*example_inputs)
77            return inductor_compile_fx(graph, example_inputs)
78
79        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
80            example = functools.partial(
81                example,
82                **self.get_world_trs(),
83            )
84            t = torch.randn(4, 4, device="cuda")
85            inputs = (t if self.rank == 0 else torch.zeros(4, 4, device="cuda"), 0)
86            eager_out = example(*inputs)
87            self.assertTrue(same(t, eager_out))
88
89            compiled_func = compile(example, inputs)
90            compiled_out = compiled_func(*inputs)
91            self.assertTrue(same(eager_out, compiled_out))
92
93    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
94    @skip_if_lt_x_gpu(2)
95    def test_allreduce_inductor(self):
96        """
97        This is matmul/cat/allreduce is a pattern we aim to optimize.
98        """
99
100        def matmul_cat_col(a, b, c, d, e, f, *, tag, ranks, group_size):
101            x = torch.matmul(a, b)
102            y = torch.matmul(c, d)
103            z = torch.cat((x, y))
104            ar = torch.ops.c10d_functional.all_reduce(z, "sum", tag, ranks, group_size)
105            g = torch.matmul(e, f)
106            ar = torch.ops.c10d_functional.wait_tensor(ar)
107            out = torch.add(ar, g.repeat(2, 1))
108            return (out,)
109
110        def compile(func, example_inputs):
111            graph = make_fx(func)(*example_inputs)
112            return inductor_compile_fx(graph, example_inputs)
113
114        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
115            matmul_cat_col = functools.partial(
116                matmul_cat_col,
117                **self.get_world_trs(),
118            )
119            inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 6
120
121            eager_out = matmul_cat_col(*inputs)
122            compiled_matmul_cat_col = compile(matmul_cat_col, inputs)
123            inductor_out = compiled_matmul_cat_col(*inputs)
124            self.assertTrue(same(eager_out, inductor_out, tol=0.001))
125
126    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
127    @skip_if_lt_x_gpu(2)
128    def test_allreduce_inductor_cudagraph_trees(self):
129        """
130        Tests whether cudagraph trees support all_reduce from nccl
131        """
132        import torch.distributed as dist
133
134        # dist.all_reduce is an inplace op in eager mode but a functionanlized op in compiled mode.
135        # so we define eager_func and func separately for the same semantic.
136        def eager_func(x):
137            y = x * x
138            dist.all_reduce(y, op=dist.ReduceOp.SUM)
139            x = torch.nn.functional.silu(x)
140            return x * y
141
142        def func(x):
143            y = x * x
144            y = dist.all_reduce(y, op=dist.ReduceOp.SUM)
145            x = torch.nn.functional.silu(x)
146            return x * y
147
148        options = {
149            "triton.cudagraphs": True,
150            "triton.cudagraph_trees": True,
151        }
152
153        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
154            compiled_func = torch.compile(
155                func, backend="inductor", fullgraph=True, options=options, dynamic=None
156            )
157
158            for nelem in [1024, 2048, 4096]:
159                x = torch.randn(nelem, device="cuda", dtype=torch.bfloat16)
160                golden_out = eager_func(x)
161
162                for _ in range(3):
163                    compiled_out = compiled_func(x)
164                    self.assertEqual(golden_out, compiled_out)
165
166    def test_c10d_functional_tagged_pt2_compliant(self):
167        op = torch.ops._c10d_functional.all_reduce.default
168        self.assertIn(torch.Tag.pt2_compliant_tag, op.tags)
169        op = torch.ops.c10d_functional.all_reduce.default
170        self.assertIn(torch.Tag.pt2_compliant_tag, op.tags)
171
172    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
173    @skip_if_lt_x_gpu(2)
174    def test_eager_allreduce_inductor_wait(self):
175        def eager_func(a, b, c, d, *, tag, ranks, group_size):
176            x = torch.matmul(a, b)
177            y = torch.matmul(c, d)
178            z = torch.cat((x, y))
179            ar = torch.ops.c10d_functional.all_reduce(z, "sum", tag, ranks, group_size)
180            return ar
181
182        def inductor_func(ar, e, f):
183            g = torch.matmul(e, f)
184            ar = torch.ops.c10d_functional.wait_tensor(ar)
185            out = torch.add(ar, g.repeat(2, 1))
186            return (out,)
187
188        def compile(func, example_inputs):
189            graph = make_fx(func)(*example_inputs)
190            return inductor_compile_fx(graph, example_inputs)
191
192        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
193            eager_func = functools.partial(
194                eager_func,
195                **self.get_world_trs(),
196            )
197            eager_inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 4
198            inductor_inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 2
199
200            eager_out = inductor_func(eager_func(*eager_inputs), *inductor_inputs)
201            compiled_inductor_func = compile(
202                inductor_func, [eager_func(*eager_inputs)] + list(inductor_inputs)
203            )
204            inductor_out = compiled_inductor_func(
205                eager_func(*eager_inputs), *inductor_inputs
206            )
207            print(f"eager_out, {eager_out}")
208            print(f"inductor_out, {inductor_out}")
209            self.assertTrue(same(eager_out, inductor_out, tol=0.001))
210
211    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
212    @skip_if_lt_x_gpu(2)
213    def test_inductor_allreduce_eager_wait(self):
214        def inductor_func(a, b, c, d, *, tag, ranks, group_size):
215            x = torch.matmul(a, b)
216            y = torch.matmul(c, d)
217            z = torch.cat((x, y))
218            ar = torch.ops.c10d_functional.all_reduce(z, "sum", tag, ranks, group_size)
219            return ar
220
221        def eager_func(ar, e, f):
222            g = torch.matmul(e, f)
223            ar = torch.ops.c10d_functional.wait_tensor(ar)
224            out = torch.add(ar, g.repeat(2, 1))
225            return (out,)
226
227        def compile(func, example_inputs):
228            graph = make_fx(func)(*example_inputs)
229            return inductor_compile_fx(graph, example_inputs)
230
231        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
232            inductor_func = functools.partial(
233                inductor_func,
234                **self.get_world_trs(),
235            )
236            inductor_inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 4
237            eager_inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 2
238
239            eager_out = eager_func(inductor_func(*inductor_inputs), *eager_inputs)
240            compiled_inductor_func = compile(inductor_func, inductor_inputs)
241            inductor_out = eager_func(
242                compiled_inductor_func(*inductor_inputs), *eager_inputs
243            )
244            self.assertTrue(same(eager_out, inductor_out, tol=0.001))
245
246    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
247    @skip_if_lt_x_gpu(2)
248    @patch.object(torch._inductor.config, "allow_buffer_reuse", True)
249    def test_allreduce_input_buffer_reuse(self):
250        def func(a, *, tag, ranks, group_size):
251            ar = _functional_collectives.all_reduce(a, "sum", ranks, tag)
252            c = torch.relu(a)
253            d = torch.matmul(c, c)
254            e = d + ar
255            return (e,)
256
257        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
258            inputs = torch.ones(4, 4, device="cuda") + self.rank
259            compiled = torch.compile(func)
260            out = compiled(inputs, **self.get_world_trs())
261            correct = func(inputs, **self.get_world_trs())
262            self.assertTrue(same(out, correct))
263
264    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
265    @skip_if_lt_x_gpu(2)
266    def test_permute_tensor(self):
267        def func(tensor, src_dst_pairs, *, tag, ranks, group_size):
268            return _functional_collectives.permute_tensor(
269                tensor, src_dst_pairs, ranks, tag
270            )
271
272        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
273            inputs = (
274                # rank0: [0., 1.], rank1: [2., 3.]
275                torch.arange(2, dtype=torch.float32, device="cuda") + 2 * self.rank,
276                [1, 0],
277            )
278            compiled = torch.compile(func)
279            out = compiled(*inputs, **self.get_world_trs())
280            correct = func(*inputs, **self.get_world_trs())
281            self.assertTrue(same(out, correct))
282
283            # rank0: [2., 3.], rank1: [0., 1.]
284            expected = torch.arange(2, dtype=torch.float32, device="cuda") + 2 * (
285                (self.rank - 1 + self.world_size) % self.world_size
286            )
287            self.assertEqual(out, expected)
288            self.assertEqual(correct, expected)
289
290    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
291    @skip_if_lt_x_gpu(2)
292    @patch.object(torch._inductor.config, "allow_buffer_reuse", True)
293    def test_allgather_output_buffer_reuse(self):
294        class Model(torch.nn.Module):
295            def __init__(self, *args, **kwargs) -> None:
296                super().__init__(*args, **kwargs)
297                self.emb = torch.nn.Embedding(4, 4)
298
299            def forward(self, x, world_size, tag, ranks, group_size):
300                y = self.emb(x)
301                last_dim = y.dim() - 1
302                res = _functional_collectives.all_gather_tensor(y, 0, ranks, tag)
303                out = torch.cat(torch.chunk(res, world_size, dim=0), dim=last_dim)
304                return out
305
306        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
307            model = Model().cuda()
308            model_compiled = torch.compile(model)
309            inp = torch.tensor([[2, 1, 3, 0]], dtype=torch.long, device="cuda")
310            out = model_compiled(inp, self.world_size, **self.get_world_trs())
311            correct = model(inp, self.world_size, **self.get_world_trs())
312            self.assertTrue(same(out, correct))
313
314    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
315    @skip_if_lt_x_gpu(2)
316    def test_allgather_contiguous_input(self):
317        class Model(torch.nn.Module):
318            def __init__(self, *args, **kwargs) -> None:
319                super().__init__(*args, **kwargs)
320                self.emb = torch.nn.Embedding(4, 4)
321
322            def forward(self, x, world_size, tag, ranks, group_size):
323                y = self.emb(x)
324                last_dim = y.dim() - 1
325                y = y.transpose_(0, last_dim).contiguous()
326                res = _functional_collectives.all_gather_tensor(y, 0, ranks, tag)
327                out = y.transpose_(0, last_dim).contiguous()
328                return out
329
330        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
331            model = Model().cuda()
332            model_compiled = torch.compile(model)
333            inp = torch.tensor([[2, 1, 3, 0]], dtype=torch.long, device="cuda")
334            out = model_compiled(inp, self.world_size, **self.get_world_trs())
335            correct = model(inp, self.world_size, **self.get_world_trs())
336            self.assertTrue(same(out, correct))
337
338    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
339    @skip_if_lt_x_gpu(2)
340    def test_allgather_into_tensor_inductor(self):
341        """
342        This is matmul/cat/allreduce is a pattern we aim to optimize.
343        """
344
345        def example(a, b, *, tag, ranks, group_size):
346            c = torch.matmul(a, b)
347            ag = torch.ops.c10d_functional.all_gather_into_tensor(
348                c, tag, ranks, group_size
349            )
350            ag = torch.ops.c10d_functional.wait_tensor(ag)
351            return (ag,)
352
353        def compile(func, example_inputs):
354            graph = make_fx(func)(*example_inputs)
355            return inductor_compile_fx(graph, example_inputs)
356
357        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
358            example = functools.partial(
359                example,
360                **self.get_world_trs(),
361            )
362            inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 2
363
364            eager_out = example(*inputs)
365            compiled_matmul_cat_col = compile(example, inputs)
366            inductor_out = compiled_matmul_cat_col(*inputs)
367            self.assertTrue(same(eager_out, inductor_out, tol=0.001))
368
369    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
370    @skip_if_lt_x_gpu(2)
371    def test_reduce_scatter_tensor_inductor(self):
372        def example(a, b, *, tag, ranks, group_size):
373            c = torch.matmul(a, b)
374            ag = torch.ops.c10d_functional.reduce_scatter_tensor(
375                c, "sum", tag, ranks, group_size
376            )
377            ag = torch.ops.c10d_functional.wait_tensor(ag)
378            return (ag,)
379
380        def compile(func, example_inputs):
381            graph = make_fx(func)(*example_inputs)
382            return inductor_compile_fx(graph, example_inputs)
383
384        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
385            example = functools.partial(
386                example,
387                **self.get_world_trs(),
388            )
389            inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 2
390
391            eager_out = example(*inputs)
392            compiled_fn = compile(example, inputs)
393            inductor_out = compiled_fn(*inputs)
394            self.assertTrue(same(eager_out, inductor_out, tol=0.001))
395
396    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
397    @skip_if_lt_x_gpu(2)
398    @patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
399    def test_all_to_all_single_inductor(self):
400        def example(
401            inp,
402            input_split_sizes_tensor,
403            output_split_sizes_tensor,
404            *,
405            tag,
406            ranks,
407            group_size,
408        ):
409            input_split_sizes = _tolist_with_constrain_as_size(input_split_sizes_tensor)
410            output_split_sizes = _tolist_with_constrain_as_size(
411                output_split_sizes_tensor
412            )
413            a2a = torch.ops.c10d_functional.all_to_all_single(
414                inp,
415                output_split_sizes,
416                input_split_sizes,
417                tag,
418                ranks,
419                group_size,
420            )
421            a2a = torch.ops.c10d_functional.wait_tensor(a2a)
422            out = a2a / a2a.sum(dim=0)
423            return out
424
425        with _dynamo_dist_per_rank_init(
426            self.rank, self.world_size
427        ), torch._dynamo.config.patch(
428            dynamic_shapes=True,
429            capture_dynamic_output_shape_ops=True,
430            capture_scalar_outputs=True,
431        ):
432            row = self.world_size * (self.rank + 1) * (self.world_size + 1) / 2
433            input_split_sizes_tensor = torch.tensor(
434                [(i + 1) * (self.rank + 1) for i in range(self.world_size)],
435                dtype=torch.int64,
436            )
437            output_split_sizes_tensor = torch.tensor(
438                [(i + 1) * (self.rank + 1) for i in range(self.world_size)],
439                dtype=torch.int64,
440            )
441            inputs = (
442                torch.ones(int(row), 5, device="cuda") * (self.rank + 1),
443                input_split_sizes_tensor,
444                output_split_sizes_tensor,
445            )
446            trs = self.get_world_trs()
447
448            compiled_fn = torch.compile(example, fullgraph=True, dynamic=True)
449            code = run_and_get_triton_code(compiled_fn, *inputs, **trs)
450            (
451                FileCheck()
452                .check_regex(
453                    "torch.ops._c10d_functional.all_to_all_single.default\\("
454                    "arg\\d+_\\d+, "
455                    "\\[u\\d+, u\\d+\\], "
456                    "\\[u\\d+, u\\d+\\]"
457                )
458                .run(code)
459            )
460
461            eager_out = example(*inputs, **trs)
462            inductor_out = compiled_fn(*inputs, **trs)
463            self.assertTrue(same(eager_out, inductor_out, tol=0.001))
464
465    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
466    @skip_if_lt_x_gpu(2)
467    def test_all_to_all_single_inductor_split_sizes_none(self):
468        def example(inp, *, tag, ranks, group_size):
469            a2a = torch.ops.c10d_functional.all_to_all_single(
470                inp,
471                None,
472                None,
473                tag,
474                ranks,
475                group_size,
476            )
477            a2a = torch.ops.c10d_functional.wait_tensor(a2a)
478            out = a2a / a2a.sum(dim=0)
479            return out
480
481        with _dynamo_dist_per_rank_init(self.rank, self.world_size):
482            inputs = (
483                torch.ones(self.world_size, self.world_size, device="cuda")
484                * (self.rank + 1),
485            )
486            trs = self.get_world_trs()
487
488            compiled_fn = torch.compile(example, fullgraph=True, dynamic=True)
489            code = run_and_get_triton_code(compiled_fn, *inputs, **trs)
490            (
491                FileCheck()
492                .check_regex(
493                    "torch.ops._c10d_functional.all_to_all_single.default\\("
494                    "arg\\d+_\\d+, "
495                    "\\[\\(s\\d+ // \\d\\), \\(s\\d+ // \\d\\)\\], "
496                    "\\[\\(s\\d+ // \\d\\), \\(s\\d+ // \\d\\)\\]"
497                )
498                .run(code)
499            )
500
501            eager_out = example(*inputs, **trs)
502            inductor_out = compiled_fn(*inputs, **trs)
503            self.assertTrue(same(eager_out, inductor_out, tol=0.001))
504
505
506@instantiate_parametrized_tests
507@requires_nccl()
508@requires_cuda
509class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
510    """
511    Prefer single-proc test runner for basic tests as it is easier to work with.
512    """
513
514    def get_world_trs(self, world_size=1):
515        return {
516            "tag": "",
517            "ranks": list(range(world_size)),
518            "group_size": world_size,
519        }
520
521    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
522    @torch._inductor.config.patch(debug=True)
523    def test_inductor_single_op(self):
524        def func(inp, *, tag, ranks, group_size):
525            ar = torch.ops.c10d_functional.all_reduce(
526                inp, "sum", tag, ranks, group_size
527            )
528            ar = torch.ops.c10d_functional.wait_tensor(ar)
529            return ar
530
531        inputs = torch.ones(4, 4, device="cuda")
532
533        compiled = torch.compile(func)
534        out = compiled(inputs, **self.get_world_trs())
535        code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
536        # NOTE: Make sure we are not unneccessarily copying the outputs of
537        # wait_tensors before they are returned from the graph.
538        (
539            FileCheck()
540            .check("buf0 = empty_strided")
541            .check(".run(arg0_1, buf0, 16")
542            .check("torch.ops._c10d_functional.all_reduce_.default(buf0")
543            .check("torch.ops._c10d_functional.wait_tensor.default(buf0")
544            .check("return (buf0")
545            .run(code)
546        )
547        correct = func(inputs, **self.get_world_trs())
548        self.assertTrue(same(out, correct))
549
550    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
551    @torch._inductor.config.patch(debug=True)
552    def test_inductor_steal_buffer(self):
553        """
554        it's ok and optimal if inductor allreduce mutates the buffer of an intermediate
555        that isn't going to be used again
556        """
557
558        def func(inp, *, tag, ranks, group_size):
559            x = inp + 1
560            ar = torch.ops.c10d_functional.all_reduce(x, "sum", tag, ranks, group_size)
561            ar = torch.ops.c10d_functional.wait_tensor(ar)
562            # ensure other is not incorrectly aliasing ar's buffer
563            other = torch.ones_like(inp) + 22
564            return ar, other
565
566        inputs = torch.ones(4, 4, device="cuda")
567
568        compiled = torch.compile(func)
569        code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
570        (
571            FileCheck()
572            .check("buf0 = empty_strided")
573            .check(".run(arg0_1, buf0")
574            .check("torch.ops._c10d_functional.all_reduce_.default(buf0")
575            .check("torch.ops._c10d_functional.wait_tensor.default(buf0")
576            .check("buf5 = empty_strided")
577            .check(".run(buf5, 16")
578            .check("return (buf0, buf5")
579            .run(code)
580        )
581        out = compiled(inputs, **self.get_world_trs())
582        correct = func(inputs, **self.get_world_trs())
583        self.assertTrue(same(out, correct))
584
585    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
586    @torch._inductor.config.patch({"debug": True, "triton.descriptive_names": False})
587    def test_inductor_doesnt_mutate_shared(self):
588        """
589        make sure that an intermediate that's going to be reuse isn't mutated unless copied
590        """
591
592        def func(inp, *, tag, ranks, group_size):
593            x = inp + 1
594            ar = torch.ops.c10d_functional.all_reduce(x, "sum", tag, ranks, group_size)
595            y = x + 2
596            ar = torch.ops.c10d_functional.wait_tensor(ar)
597            # ensure other is not incorrectly aliasing ar's buffer
598            other = torch.ones_like(inp) + 22
599            return ar, y, other
600
601        inputs = torch.ones(4, 4, device="cuda")
602
603        compiled = torch.compile(func)
604        code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
605        # NOTE: Make sure we are not unneccessarily copying the outputs of
606        # wait_tensors before they are returned from the graph.
607        (
608            FileCheck()
609            .check("buf0 = empty_strided")
610            .check("buf5 = empty_strided")
611            .check(".run(arg0_1, buf0, buf5, 16")
612            .check("torch.ops._c10d_functional.all_reduce_.default(buf0")
613            .check("torch.ops._c10d_functional.wait_tensor.default(buf0")
614            .check("buf6 = empty_strided")
615            .check(".run(buf6, 16")
616            .check("return (buf0, buf5, buf6")
617            .run(code)
618        )
619        out = compiled(inputs, **self.get_world_trs())
620        correct = func(inputs, **self.get_world_trs())
621        self.assertTrue(same(out, correct))
622
623    def test_dynamo_trace_allreduce(self):
624        def func(inp):
625            ar = _functional_collectives.all_reduce(inp, "sum", "0")
626            return ar
627
628        inputs = torch.ones(4, 4, device="cuda")
629        counter = CompileCounter()
630        compiled = torch.compile(func, backend=counter)
631        out = compiled(inputs)
632        correct = func(inputs)
633        self.assertEqual(counter.frame_count, 1)
634
635        # should test more precisely, but the 2 is supposed to be (all_reduce, wait)
636        self.assertEqual(counter.op_count, 2)
637        self.assertTrue(same(out, correct))
638
639    def test_dynamo_trace_all_gather_tensor(self):
640        def func(inp):
641            ar = _functional_collectives.all_gather_tensor(inp, 0, "0")
642            return ar
643
644        inputs = torch.ones(4, 4, device="cuda")
645        counter = CompileCounter()
646        compiled = torch.compile(func, backend=counter)
647        out = compiled(inputs)
648        correct = func(inputs)
649        self.assertEqual(counter.frame_count, 1)
650
651        # should test more precisely, but the 2 is supposed to be (all_gather, wait)
652        self.assertEqual(counter.op_count, 2)
653        self.assertTrue(same(out, correct))
654
655    def test_dynamo_trace_all_gather_tensor_pg(self):
656        def func(inp, *, pg):
657            ar = _functional_collectives.all_gather_tensor(inp, 0, pg)
658            return ar
659
660        inputs = torch.ones(4, 4, device=self.device)
661        counter = CompileCounter()
662        compiled = torch.compile(func, backend=counter, fullgraph=True)
663        out = compiled(inputs, pg=GroupMember.WORLD)
664        correct = func(inputs, pg=GroupMember.WORLD)
665        self.assertEqual(counter.frame_count, 1)
666
667        # should test more precisely, but the 2 is supposed to be (all_gather, wait)
668        self.assertEqual(counter.op_count, 2)
669        self.assertTrue(same(out, correct))
670
671    def test_dynamo_rewrite_dist_all_gather(self):
672        def func(inp, out, *, pg):
673            torch.distributed.all_gather_into_tensor(
674                out,
675                inp,
676                pg,
677            )
678
679        local_size = [4, 4]
680        # single-proc test
681        global_size = local_size
682
683        inputs = torch.ones(local_size, device=self.device)
684        outputs = torch.empty(global_size, device=self.device)
685        correct_outputs = torch.empty(global_size, device=self.device)
686        counter = CompileCounter()
687        compiled = torch.compile(func, backend=counter, fullgraph=True)
688        compiled(inputs, outputs, pg=GroupMember.WORLD)
689        func(inputs, correct_outputs, pg=GroupMember.WORLD)
690        assert counter.frame_count == 1
691
692        # should test more precisely, but the 3 is supposed to be (all_gather, wait, copy_)
693        assert counter.op_count == 3
694        assert same(outputs, correct_outputs)
695
696    def test_dynamo_rewrite_dist_all_gather_list(self):
697        def func(inp, out, *, pg):
698            torch.distributed.all_gather(
699                out,
700                inp,
701                pg,
702            )
703
704        local_size = [4, 4]
705        # single-proc test
706        global_size = local_size
707
708        inputs = torch.ones(local_size, device=self.device)
709        outputs = [torch.empty(global_size, device=self.device)]
710        correct_outputs = [torch.empty(global_size, device=self.device)]
711        counter = CompileCounter()
712        compiled = torch.compile(func, backend=counter, fullgraph=True)
713        compiled(inputs, outputs, pg=GroupMember.WORLD)
714        func(inputs, correct_outputs, pg=GroupMember.WORLD)
715        assert counter.frame_count == 1
716        assert same(outputs, correct_outputs)
717
718    def test_dynamo_rewrite_dist_all_gather_args_match(self):
719        # Duplicated most of the structure from test_dynamo_rewrite_dist_all_gather
720        # except uses kwargs to ensure rewrite has matching arg names
721        def func(inp, out, *, pg):
722            torch.distributed.all_gather_into_tensor(
723                output_tensor=out,
724                input_tensor=inp,
725                group=pg,
726                async_op=False,
727            )
728
729        local_size = [4, 4]
730        # single-proc test
731        global_size = local_size
732
733        inputs = torch.ones(local_size, device=self.device)
734        outputs = torch.empty(global_size, device=self.device)
735        correct_outputs = torch.empty(global_size, device=self.device)
736        counter = CompileCounter()
737        compiled = torch.compile(func, backend=counter, fullgraph=True)
738        compiled(inputs, outputs, pg=GroupMember.WORLD)
739        func(inputs, correct_outputs, pg=GroupMember.WORLD)
740        assert counter.frame_count == 1
741
742        # should test more precisely, but the 3 is supposed to be (all_gather, wait, copy_)
743        assert counter.op_count == 3
744        assert same(outputs, correct_outputs)
745
746    def test_dynamo_rewrite_dist_reduce_scatter(self):
747        def func(inp, out, *, pg):
748            torch.distributed.reduce_scatter_tensor(
749                out,
750                inp,
751                group=pg,
752            )
753
754        local_size = [4, 4]
755        # single-proc test
756        global_size = local_size
757
758        inputs = torch.ones(local_size, device=self.device)
759        outputs = torch.empty(global_size, device=self.device)
760        correct_outputs = torch.empty(global_size, device=self.device)
761        counter = CompileCounter()
762        compiled = torch.compile(func, backend=counter, fullgraph=True)
763        compiled(inputs, outputs, pg=GroupMember.WORLD)
764        func(inputs, correct_outputs, pg=GroupMember.WORLD)
765        assert counter.frame_count == 1
766
767        # should test more precisely, but the 3 is supposed to be (reduce_scatter, wait, copy_)
768        assert counter.op_count == 3
769        assert same(outputs, correct_outputs)
770
771    @parametrize(
772        "pg_mode",
773        [
774            "positional",
775            "positional_none",
776            "kwargs",
777            "kwargs_none",
778            "unspecified",
779        ],
780    )
781    def test_dynamo_rewrite_dist_allreduce(self, pg_mode):
782        def func(tensor, *args, **kwargs):
783            torch.distributed.all_reduce(
784                tensor,
785                *args,
786                **kwargs,
787            )
788
789        counter = CompileCounter()
790        compiled = torch.compile(func, backend=counter, fullgraph=True)
791
792        args = []
793        kwargs = {}
794
795        if pg_mode == "positional":
796            args.append(torch.distributed.ReduceOp.MAX)
797            args.append(GroupMember.WORLD)
798        elif pg_mode == "positional_none":
799            args.append(torch.distributed.ReduceOp.MAX)
800            args.append(None)
801        elif pg_mode == "kwargs":
802            kwargs["group"] = GroupMember.WORLD
803        elif pg_mode == "kwargs_none":
804            kwargs["group"] = None
805        else:
806            assert pg_mode == "unspecified"
807
808        inputs_compiled = torch.ones(2, device=self.device)
809        inputs_eager = torch.ones(2, device=self.device)
810
811        compiled(inputs_compiled, *args, **kwargs)
812        func(inputs_eager, *args, **kwargs)
813
814        assert counter.frame_count == 1
815        # should test more precisely, but the 3 is supposed to be (all_reduce, wait, copy_)
816        assert counter.op_count == 3
817        assert same(inputs_compiled, inputs_eager)
818
819    def test_dynamo_rewrite_dist_all_to_all_single(self):
820        def func(output, input, pg):
821            torch.distributed.all_to_all_single(output, input, group=pg)
822
823        counter = CompileCounter()
824        compiled = torch.compile(func, backend=counter, fullgraph=True)
825
826        input_compiled = torch.ones(2, device=self.device)
827        input_eager = torch.ones(2, device=self.device)
828        output_compiled = torch.empty(2, device=self.device)
829        output_eager = torch.empty(2, device=self.device)
830
831        compiled(output_compiled, input_compiled, GroupMember.WORLD)
832        func(output_eager, input_eager, GroupMember.WORLD)
833
834        assert counter.frame_count == 1
835        assert same(output_compiled, output_eager)
836
837    @parametrize(
838        "reduce_op",
839        [
840            torch.distributed.ReduceOp.SUM,
841            torch.distributed.ReduceOp.AVG,
842            torch.distributed.ReduceOp.PRODUCT,
843            torch.distributed.ReduceOp.MIN,
844            torch.distributed.ReduceOp.MAX,
845        ],
846    )
847    def test_dynamo_rewrite_dist_allreduce_reduce_op(self, reduce_op):
848        from torch.distributed._functional_collectives import REDUCE_OP_TO_STR
849
850        def verify_rewrite(gm, _):
851            ar_nodes = []
852            for node in gm.graph.nodes:
853                if node.target in [
854                    torch.ops.c10d_functional.all_reduce,
855                    torch.ops._c10d_functional.all_reduce,
856                ]:
857                    ar_nodes.append(node)
858            self.assertEqual(len(ar_nodes), 1)
859            reduce_op_str = ar_nodes[0].args[1]
860            self.assertEqual(REDUCE_OP_TO_STR[reduce_op], reduce_op_str)
861            return gm
862
863        compiled = torch.compile(
864            torch.distributed.all_reduce,
865            backend=verify_rewrite,
866            fullgraph=True,
867        )
868        inputs = (
869            torch.ones(2, device=self.device),
870            reduce_op,
871            GroupMember.WORLD,
872        )
873        compiled(*inputs)
874
875    @parametrize(
876        "source",
877        [
878            "GroupMember.WORLD",
879            "group.WORLD",
880            "_get_default_group",
881        ],
882    )
883    def test_dynamo_get_world_group(self, source):
884        def func(tensor):
885            if source == "GroupMember.WORLD":
886                group = torch.distributed.GroupMember.WORLD
887            elif source == "group.WORLD":
888                group = torch.distributed.group.WORLD
889            else:
890                assert source == "_get_default_group"
891                group = torch.distributed.distributed_c10d._get_default_group()
892
893            torch.distributed.all_reduce(
894                tensor,
895                group=group,
896            )
897
898        def verify(gm, _):
899            ar_nodes = []
900            for node in gm.graph.nodes:
901                if node.target in [
902                    torch.ops.c10d_functional.all_reduce,
903                    torch.ops._c10d_functional.all_reduce,
904                ]:
905                    ar_nodes.append(node)
906            self.assertEqual(len(ar_nodes), 1)
907            return gm
908
909        compiled = torch.compile(func, backend=verify, fullgraph=True)
910        input = torch.ones(2, device=self.device)
911        compiled(input)
912
913    def test_dynamo_support_collective_op_with_async_op_False(self):
914        def func(inp, out, *, pg):
915            # user explicitly set the attribute `async_op` to False,
916            # there should be no graph break
917            torch.distributed.reduce_scatter_tensor(out, inp, group=pg, async_op=False)
918
919        local_size = [4, 4]
920        # single-proc test
921        global_size = local_size
922
923        inputs = torch.ones(local_size, device=self.device)
924        outputs = torch.empty(global_size, device=self.device)
925        correct_outputs = torch.empty(global_size, device=self.device)
926        counter = CompileCounter()
927        compiled = torch.compile(func, backend=counter)
928        compiled(inputs, outputs, pg=GroupMember.WORLD)
929        func(inputs, correct_outputs, pg=GroupMember.WORLD)
930        assert counter.frame_count == 1
931        assert counter.op_count == 3
932        assert same(outputs, correct_outputs)
933
934    def test_dynamo_graphbreaks_unsupported_async_op(self):
935        def func(inp, out, *, pg):
936            work = torch.distributed.reduce_scatter_tensor(
937                out, inp, group=pg, async_op=True
938            )
939            work.wait()
940
941        local_size = [4, 4]
942        # single-proc test
943        global_size = local_size
944
945        inputs = torch.ones(local_size, device=self.device)
946        outputs = torch.empty(global_size, device=self.device)
947        correct_outputs = torch.empty(global_size, device=self.device)
948        counter = CompileCounter()
949        compiled = torch.compile(func, backend=counter)
950        compiled(inputs, outputs, pg=GroupMember.WORLD)
951        func(inputs, correct_outputs, pg=GroupMember.WORLD)
952        assert counter.frame_count == 0
953        assert counter.op_count == 0
954        assert same(outputs, correct_outputs)
955
956    def test_dynamo_pg_var(self):
957        def func(inp, *, pg):
958            x = pg.rank() + 1 % pg.size()
959            return inp + x
960
961        local_size = [4, 4]
962        inputs = torch.ones(local_size, device=self.device)
963        correct_outputs = torch.empty(local_size, device=self.device)
964        counter = CompileCounter()
965        compiled = torch.compile(func, backend=counter, fullgraph=True)
966        outputs = compiled(inputs, pg=GroupMember.WORLD)
967        correct_outputs = func(inputs, pg=GroupMember.WORLD)
968        assert counter.frame_count == 1
969        assert counter.op_count == 1
970        assert same(outputs, correct_outputs)
971
972    def test_dynamo_trace_reduce_scatter_tensor(self):
973        def func(inp):
974            ar = _functional_collectives.reduce_scatter_tensor(inp, "sum", 0, "0")
975            return ar
976
977        inputs = torch.ones(4, 4, device="cuda")
978        counter = CompileCounter()
979        compiled = torch.compile(func, backend=counter)
980        out = compiled(inputs)
981        correct = func(inputs)
982        self.assertEqual(counter.frame_count, 1)
983
984        # should test more precisely, but the 2 is supposed to be (reduce_scatter, wait)
985        self.assertEqual(counter.op_count, 2)
986        self.assertTrue(same(out, correct))
987
988    def test_dynamo_trace_allgather_coalesced(self):
989        def func(inp, *, tag, ranks, group_size):
990            ar = torch.ops.c10d_functional.all_gather_into_tensor_coalesced(
991                inp, tag, ranks, group_size
992            )
993            return ar
994
995        inputs = [torch.ones(4, 4, device="cuda"), torch.ones(6, 6, device="cuda")]
996        counter = CompileCounter()
997        compiled = torch.compile(func, backend=counter)
998        out = compiled(inputs, **self.get_world_trs())
999        correct = func(inputs, **self.get_world_trs())
1000        assert counter.frame_count == 1
1001        assert counter.op_count == 3  # It generates 2 getattr to unpack the array
1002        assert same(out, correct)
1003
1004    def test_backwards(self):
1005        """
1006        It's probably not that common to need backwards support for collectives.
1007
1008        However, I wanted to at least see if it was possible to support it as a design goal.
1009        """
1010
1011        def func(inp):
1012            ar = _functional_collectives.all_reduce(inp, "sum", "0")
1013            return ar
1014
1015        input = torch.ones(4, 4, device="cuda", requires_grad=True)
1016        compiled = torch.compile(
1017            func, backend="aot_eager"
1018        )  # inductor bug with single-op allreduce graph
1019        out = compiled(input)
1020        out.sum().backward()
1021
1022        correct_input = input.clone().detach().requires_grad_()
1023        correct = func(correct_input)
1024        correct.sum().backward()
1025        self.assertTrue(same(out, correct))
1026        self.assertTrue(same(input.grad, correct_input.grad))
1027
1028    def test_meta(self):
1029        x = torch.rand((2, 3, 4), device="meta")
1030        out = torch.ops.c10d_functional.all_reduce(x, "sum", **self.get_world_trs())
1031        self.assertEqual(x.size(), out.size())
1032
1033    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
1034    @torch._inductor.config.patch({"debug": True, "triton.descriptive_names": False})
1035    def test_inductor_all_gather_coalesced(self):
1036        """
1037        make sure that an intermediate that's going to be reuse isn't mutated unless copied
1038        """
1039
1040        def func(inp, *, tag, ranks, group_size):
1041            x = inp + 1
1042            tensor_list = torch.ops.c10d_functional.all_gather_into_tensor_coalesced(
1043                [x, inp], tag, ranks, group_size
1044            )
1045            y = x + 2
1046            ar0 = torch.ops.c10d_functional.wait_tensor(tensor_list[0])
1047            ar1 = torch.ops.c10d_functional.wait_tensor(tensor_list[1])
1048            # ensure other is not incorrectly aliasing ar's buffer
1049            other = torch.ones_like(inp) + 22
1050            return ar0, y, other, ar1
1051
1052        inputs = torch.ones(4, 4, device="cuda")
1053
1054        compiled = torch.compile(func)
1055        code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
1056        # NOTE: Make sure we are not unneccessarily copying the outputs of
1057        # wait_tensors before they are returned from the graph.
1058        (
1059            FileCheck()
1060            .check("buf0 = empty_strided")
1061            .check("buf6 = empty_strided")
1062            .check(".run(arg0_1, buf0, buf6, 16")
1063            .check(
1064                "buf1 = torch.ops._c10d_functional.all_gather_into_tensor_coalesced.default([buf0, arg0_1]"
1065            )
1066            .check("buf2 = buf1[0]")
1067            .check("buf3 = buf1[1]")
1068            .check("torch.ops._c10d_functional.wait_tensor.default(buf2")
1069            .check("buf7 = buf0; del buf0  # reuse")
1070            .check(".run(buf7, 16")
1071            .check("torch.ops._c10d_functional.wait_tensor.default(buf3")
1072            .check("return (buf2, buf6, buf7, buf3")
1073            .run(code)
1074        )
1075        out = compiled(inputs, **self.get_world_trs())
1076        correct = func(inputs, **self.get_world_trs())
1077        assert same(out, correct), f"{out} va {correct}"
1078
1079    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
1080    @torch._inductor.config.patch({"debug": True, "triton.descriptive_names": False})
1081    def test_inductor_reduce_scatter_coalesced(self):
1082        """
1083        make sure that an intermediate that's going to be reuse isn't mutated unless copied
1084        """
1085
1086        def func(inp, *, tag, ranks, group_size):
1087            x = inp + 1
1088            tensor_list = torch.ops.c10d_functional.reduce_scatter_tensor_coalesced(
1089                [x, inp], "sum", tag, ranks, group_size
1090            )
1091            y = x + 2
1092            ar0 = torch.ops.c10d_functional.wait_tensor(tensor_list[0])
1093            ar1 = torch.ops.c10d_functional.wait_tensor(tensor_list[1])
1094            # ensure other is not incorrectly aliasing ar's buffer
1095            other = torch.ones_like(inp) + 22
1096            return ar0, y, other, ar1
1097
1098        inputs = torch.ones(4, 4, device="cuda")
1099
1100        compiled = torch.compile(func)
1101        code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
1102        # NOTE: The first return value should be the output of the first wait_tensor.
1103        # We want to make sure no unneccessary copy is made.
1104        (
1105            FileCheck()
1106            .check("buf0 = empty_strided")
1107            .check("buf6 = empty_strided")
1108            .check(".run(arg0_1, buf0, buf6, 16")
1109            .check(
1110                "buf1 = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced.default([buf0, arg0_1]"
1111            )
1112            .check("buf2 = buf1[0]")
1113            .check("buf3 = buf1[1]")
1114            .check("torch.ops._c10d_functional.wait_tensor.default(buf2")
1115            .check("buf7 = buf0; del buf0  # reuse")
1116            .check(".run(buf7, 16")
1117            .check("torch.ops._c10d_functional.wait_tensor.default(buf3")
1118            .check("return (buf2, buf6, buf7, buf3")
1119            .run(code)
1120        )
1121        out = compiled(inputs, **self.get_world_trs())
1122        correct = func(inputs, **self.get_world_trs())
1123        assert same(out, correct), f"{out} va {correct}"
1124
1125
1126if __name__ == "__main__":
1127    from torch._dynamo.test_case import run_tests
1128
1129    run_tests()
1130