xref: /aosp_15_r20/external/pytorch/test/distributed/test_c10d_functional_native.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: c10d"]
2import threading
3import unittest
4from typing import List
5
6import torch
7import torch.distributed as dist
8import torch.distributed._functional_collectives as funcol
9from torch._C import FileCheck
10from torch._inductor.utils import fresh_inductor_cache, run_and_get_triton_code
11from torch.distributed._functional_collectives import (
12    all_gather_into_tensor_coalesced,
13    all_gather_tensor,
14    all_reduce,
15    all_reduce_coalesced,
16    all_to_all_single,
17    AsyncCollectiveTensor,
18    reduce_scatter_tensor,
19    reduce_scatter_tensor_coalesced,
20)
21from torch.testing._internal.common_distributed import (
22    MultiProcessTestCase,
23    requires_nccl,
24    skip_if_lt_x_gpu,
25)
26from torch.testing._internal.common_utils import (  # type: ignore[attr-defined]
27    run_tests,
28    TestCase,
29)
30from torch.testing._internal.distributed.fake_pg import FakeStore
31from torch.utils._triton import has_triton
32
33
34def load_test_module(name):
35    import sys
36    from importlib.machinery import SourceFileLoader
37    from pathlib import Path
38    from unittest import mock
39
40    testdir = Path(__file__).absolute().parent.parent
41    with mock.patch("sys.path", [*sys.path, str(testdir)]):
42        return SourceFileLoader(
43            name, str(testdir / f"{name.replace('.', '/')}.py")
44        ).load_module()
45
46
47AOTIRunnerUtil = load_test_module("inductor.test_aot_inductor_utils").AOTIRunnerUtil
48
49import sys
50
51
52if not dist.is_available():
53    print("distributed package not available, skipping tests", file=sys.stderr)
54    sys.exit(0)
55
56
57@requires_nccl()
58class TestWithNCCL(MultiProcessTestCase):
59    def setUp(self) -> None:
60        super().setUp()
61        self._spawn_processes()
62
63    @property
64    def world_size(self) -> int:
65        return 2
66
67    @property
68    def ranks(self) -> List[int]:
69        return list(range(self.world_size))
70
71    @property
72    def device(self) -> torch.device:
73        return torch.device(f"cuda:{self.rank}")
74
75    def _init_process_group(self) -> None:
76        # Allow testing aoti after torch.compile
77        torch._inductor.config.triton.store_cubin = True
78        torch._inductor.config.debug = True
79
80        torch.cuda.set_device(self.device)
81        store = dist.FileStore(self.file_name, self.world_size)
82        dist.init_process_group(
83            backend="nccl",
84            world_size=self.world_size,
85            rank=self.rank,
86            store=store,
87        )
88        torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD)
89
90    @skip_if_lt_x_gpu(2)
91    def test_all_reduce_single(self) -> None:
92        self._init_process_group()
93
94        input = torch.full((10, 10), float(self.rank), device=self.device)
95        output = torch.ops._c10d_functional.all_reduce(
96            input,
97            "avg",
98            "default",
99        )
100        output = torch.ops._c10d_functional.wait_tensor(output)
101        assert id(output) != id(input)
102        expect = sum(self.ranks) / self.world_size
103        assert output.eq(expect).all()
104
105        # Test Python API and AsyncCollectiveTensor
106        output = all_reduce(
107            input,
108            "avg",
109            "default",
110        )
111        assert isinstance(output, AsyncCollectiveTensor)
112        assert not output.completed
113        assert output.eq(expect).all()
114        assert output.completed
115
116    @skip_if_lt_x_gpu(2)
117    def test_all_reduce_single_(self) -> None:
118        self._init_process_group()
119
120        input = torch.full((10, 10), float(self.rank), device=self.device)
121        output = torch.ops._c10d_functional.all_reduce_(
122            input,
123            "avg",
124            "default",
125        )
126        output = torch.ops._c10d_functional.wait_tensor(output)
127        assert id(output) == id(input)
128        expect = sum(self.ranks) / self.world_size
129        assert output.eq(expect).all()
130
131    @skip_if_lt_x_gpu(2)
132    def test_all_reduce_coalesced(self) -> None:
133        self._init_process_group()
134
135        inputs = [
136            torch.full((i, i), float(self.rank * i), device=self.device)
137            for i in range(10)
138        ]
139        outputs = torch.ops._c10d_functional.all_reduce_coalesced(
140            inputs,
141            "avg",
142            "default",
143        )
144        for i, (output, input) in enumerate(zip(outputs, inputs)):
145            output = torch.ops._c10d_functional.wait_tensor(output)
146            assert id(output) != id(input)
147            assert output.eq(sum(self.ranks) / self.world_size * i).all()
148
149        # Test Python API and AsyncCollectiveTensor
150        outputs = all_reduce_coalesced(
151            inputs,
152            "avg",
153            "default",
154        )
155        for i, (output, input) in enumerate(zip(outputs, inputs)):
156            assert not output.completed
157            assert output.eq(sum(self.ranks) / self.world_size * i).all()
158            assert output.completed
159
160    @skip_if_lt_x_gpu(2)
161    def test_all_reduce_coalesced_(self) -> None:
162        self._init_process_group()
163
164        inputs = [
165            torch.full((i, i), float(self.rank * i), device=self.device)
166            for i in range(10)
167        ]
168        outputs = torch.ops._c10d_functional.all_reduce_coalesced_(
169            inputs,
170            "avg",
171            "default",
172        )
173        for i, (output, input) in enumerate(zip(outputs, inputs)):
174            output = torch.ops._c10d_functional.wait_tensor(output)
175            assert id(output) == id(input)
176            assert output.eq(sum(self.ranks) / self.world_size * i).all()
177
178    @skip_if_lt_x_gpu(2)
179    def test_all_gather_into_tensor_single(self) -> None:
180        self._init_process_group()
181
182        input = torch.full((10, 10), float(self.rank), device=self.device)
183        output = torch.ops._c10d_functional.all_gather_into_tensor(
184            input,
185            self.world_size,
186            "default",
187        )
188        output = torch.ops._c10d_functional.wait_tensor(output)
189        expect = torch.cat(
190            [
191                torch.full((10, 10), float(rank), device=self.device)
192                for rank in self.ranks
193            ]
194        )
195        assert torch.allclose(output, expect)
196        assert output.eq(expect).all()
197
198        # Test out-variant of all_gather_into_tensor
199        output = torch.empty(expect.shape, device=self.device)
200        output = torch.ops._c10d_functional.all_gather_into_tensor_out(
201            input,
202            self.world_size,
203            "default",
204            out=output,
205        )
206        output = torch.ops._c10d_functional.wait_tensor(output)
207        assert torch.allclose(output, expect)
208        assert output.eq(expect).all()
209
210        # Test Python API and AsyncCollectiveTensor
211        output = all_gather_tensor(
212            input,
213            0,
214            "default",
215        )
216        assert isinstance(output, AsyncCollectiveTensor)
217        assert not output.completed
218        assert output.eq(expect).all()
219        assert output.completed
220
221    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
222    @skip_if_lt_x_gpu(2)
223    # https://github.com/pytorch/pytorch/issues/126338
224    def test_inductor_dtypeview_memory_leak(self):
225        self._init_process_group()
226
227        def func(arg: torch.Tensor) -> torch.Tensor:
228            ag0 = torch.ops._c10d_functional.all_gather_into_tensor.default(
229                arg,
230                self.world_size,
231                "default",
232            )
233            ag0_view = torch.ops.aten.view.dtype(ag0, torch.int32)
234            return funcol.wait_tensor(ag0_view)
235
236        arg = torch.full(
237            (10, 10),
238            float(self.rank),
239            device=self.device,
240            dtype=torch.float32,
241        )
242        compiled = torch.compile(func)
243        mem_usage = {}
244        # check if the aten.view.dtype is compiled to aten.view.dtype
245        code = run_and_get_triton_code(compiled, arg)
246        (
247            FileCheck()
248            .check("torch.ops._c10d_functional.wait_tensor.default(aten.view.dtype")
249            .run(code)
250        )
251        # check memory leak
252        for i in range(1, 10):
253            mem_usage[i] = torch.cuda.max_memory_allocated()
254            compiled(arg)
255
256        assert mem_usage[9] == mem_usage[8]
257
258    @skip_if_lt_x_gpu(2)
259    def test_all_gather_into_tensor_coalesced(self) -> None:
260        self._init_process_group()
261
262        inputs = [
263            torch.full((10, 10), float(self.rank * i), device=self.device)
264            for i in range(10)
265        ]
266        outputs = torch.ops._c10d_functional.all_gather_into_tensor_coalesced(
267            inputs,
268            self.world_size,
269            "default",
270        )
271        expect = [
272            torch.cat(
273                [
274                    torch.full((10, 10), float(rank) * i, device=self.device)
275                    for rank in self.ranks
276                ]
277            )
278            for i in range(10)
279        ]
280        for i, output in enumerate(outputs):
281            output = torch.ops._c10d_functional.wait_tensor(output)
282            assert output.eq(expect[i]).all()
283
284        # Test Python API and AsyncCollectiveTensor
285        outputs = all_gather_into_tensor_coalesced(
286            inputs,
287            "default",
288        )
289        for i, output in enumerate(outputs):
290            assert not output.completed
291            assert output.eq(expect[i]).all()
292            assert output.completed
293
294    @skip_if_lt_x_gpu(2)
295    def test_reduce_scatter_tensor_single(self) -> None:
296        self._init_process_group()
297
298        input = torch.tensor(self.ranks, device=self.device)
299        output = torch.ops._c10d_functional.reduce_scatter_tensor(
300            input,
301            "avg",
302            self.world_size,
303            "default",
304        )
305        output = torch.ops._c10d_functional.wait_tensor(output)
306        assert output.eq(self.rank).all()
307
308        # Test Python API and AsyncCollectiveTensor
309        output = reduce_scatter_tensor(
310            input,
311            "avg",
312            0,
313            "default",
314        )
315        assert isinstance(output, AsyncCollectiveTensor)
316        assert not output.completed
317        assert output.eq(self.rank).all()
318        assert output.completed
319
320    @skip_if_lt_x_gpu(2)
321    def test_reduce_scatter_tensor_coalesced(self) -> None:
322        self._init_process_group()
323
324        inputs = [torch.tensor(self.ranks, device=self.device) * i for i in range(10)]
325        outputs = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced(
326            inputs,
327            "avg",
328            self.world_size,
329            "default",
330        )
331        for i, output in enumerate(outputs):
332            output = torch.ops._c10d_functional.wait_tensor(output)
333            assert output.eq(self.rank * i).all()
334
335        # Test Python API and AsyncCollectiveTensor
336        outputs = reduce_scatter_tensor_coalesced(
337            inputs,
338            "avg",
339            [0] * 10,
340            "default",
341        )
342        for i, output in enumerate(outputs):
343            assert not output.completed
344            assert output.eq(self.rank * i).all()
345            assert output.completed
346
347    @skip_if_lt_x_gpu(2)
348    def test_all_to_all_single(self) -> None:
349        self._init_process_group()
350        torch.cuda.set_device(self.device)
351
352        torch.manual_seed(42)
353        send_sz_matrix = torch.randint(0, 20, (self.world_size, self.world_size))
354
355        input_split_sizes = send_sz_matrix[self.rank].tolist()
356        output_split_sizes = send_sz_matrix[:, self.rank].tolist()
357        input = torch.full((sum(input_split_sizes),), float(self.rank)).cuda()
358
359        output = torch.ops._c10d_functional.all_to_all_single(
360            input,
361            output_split_sizes,
362            input_split_sizes,
363            "default",
364        )
365        output = torch.ops._c10d_functional.wait_tensor(output)
366        expect = torch.cat(
367            [
368                torch.full((sz,), float(rank)).cuda()
369                for rank, sz in enumerate(output_split_sizes)
370            ]
371        )
372        assert output.eq(expect).all()
373
374        # Test Python API and AsyncCollectiveTensor
375        output = all_to_all_single(
376            input, output_split_sizes, input_split_sizes, "default"
377        )
378        assert not output.completed
379        assert output.eq(expect).all()
380        assert output.completed
381
382    @skip_if_lt_x_gpu(2)
383    def test_broadcast(self) -> None:
384        self._init_process_group()
385
386        input = torch.full((10, 10), float(self.rank), device=self.device)
387        output = torch.ops._c10d_functional.broadcast(
388            input,
389            1,
390            "default",
391        )
392        output = torch.ops._c10d_functional.wait_tensor(output)
393        assert id(output) != id(input)
394        expect = 1
395        assert output.eq(expect).all()
396
397        # Test Python API and AsyncCollectiveTensor
398        output = funcol.broadcast(
399            input,
400            1,
401            "default",
402        )
403        assert isinstance(output, AsyncCollectiveTensor)
404        assert not output.completed
405        assert output.eq(expect).all()
406        assert output.completed
407
408    @skip_if_lt_x_gpu(2)
409    def test_unwaited(self) -> None:
410        # Verify that the process can terminate gracefully
411        # even with unwaited tensors
412        self._init_process_group()
413
414        input = torch.full((10, 10), float(self.rank), device=self.device)
415        output = torch.ops._c10d_functional.all_reduce(
416            input,
417            "avg",
418            "default",
419        )
420
421    @skip_if_lt_x_gpu(2)
422    def test_py_work(self) -> None:
423        self._init_process_group()
424
425        wait_called = False
426
427        class MyWork(dist.Work):
428            def wait(self, _):
429                nonlocal wait_called
430                wait_called = True
431
432        tensor = torch.rand(2, 2)
433        torch._C._distributed_c10d._register_work(tensor, MyWork())
434        torch.ops._c10d_functional.wait_tensor(tensor)
435        self.assertTrue(wait_called)
436
437    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
438    @skip_if_lt_x_gpu(2)
439    @fresh_inductor_cache()
440    def test_threading(self):
441        self._init_process_group()
442        device = torch.device(f"cuda:{self.rank}")
443
444        def func(arg: torch.Tensor) -> torch.Tensor:
445            buf0 = arg + 42
446            ar0 = funcol.all_reduce(buf0, "avg", "0")
447            ar0 = funcol.wait_tensor(ar0)
448            return ar0 + 1
449
450        arg = torch.rand(4, 4, device=device)
451        func(arg)
452
453        compiled = torch.compile(func, fullgraph=True)
454        code = run_and_get_triton_code(compiled, arg)
455        FileCheck().check("all_reduce_.default(buf0, 'avg', '0')").run(code)
456
457        # Unless explicitly specified (e.g. in a custom runtime), the process
458        # group registry is shared among all threads in a process. Here we
459        # verify that a process group registered in main thread can be resolved
460        # in a different thread.
461        class TestThread(threading.Thread):
462            def run(self):
463                self.exc = None
464                try:
465                    func(arg)
466                    compiled(arg)
467                except BaseException as exc:
468                    self.exc = exc
469
470            def join(self):
471                threading.Thread.join(self)
472                if self.exc:
473                    raise self.exc
474
475        t = TestThread()
476        t.start()
477        t.join()
478
479
480class CompileTest(TestCase):
481    def setUp(self):
482        # Allow testing aoti after torch.compile
483        torch._inductor.config.triton.store_cubin = True
484        torch._inductor.config.debug = True
485
486        self.rank = 0
487        self.world_size = 2
488        torch.cuda.set_device("cuda:0")
489
490        store = FakeStore()
491        dist.init_process_group(
492            backend="fake",
493            world_size=self.world_size,
494            rank=self.rank,
495            store=store,
496        )
497
498    def tearDown(self):
499        dist.destroy_process_group()
500
501    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
502    @fresh_inductor_cache()
503    def test_inductor_all_reduce_single(self):
504        def func(arg: torch.Tensor) -> torch.Tensor:
505            buf0 = arg + 42
506            # Expect in-place with inductor allocated buf
507            ar0 = funcol.all_reduce(buf0, "avg", "0")
508            ar0 = funcol.wait_tensor(ar0)
509            # Expect no in-place with graph input
510            ar1 = funcol.all_reduce(arg, "avg", "0")
511            ar1 = funcol.wait_tensor(ar1)
512            return ar0, ar1
513
514        arg = torch.rand(4, 4, device="cuda")
515        compiled = torch.compile(func)
516
517        code = run_and_get_triton_code(compiled, arg)
518        (
519            FileCheck()
520            .check("buf0 = empty")
521            .check("buf7 = empty")
522            # Expect in-place with inductor allocated buf
523            .check("torch.ops._c10d_functional.all_reduce_.default(buf0")
524            .check("torch.ops._c10d_functional.wait_tensor.default(buf0")
525            # Expect no in-place with graph input (buf5 is a clone)
526            .check("torch.ops._c10d_functional.all_reduce_.default(buf7")
527            .check("torch.ops._c10d_functional.wait_tensor.default(buf7")
528            # Expect no extra copy on return
529            .check("return (buf0, buf7, )")
530            .run(code)
531        )
532        assert "= torch.ops._c10d_functional.wait_tensor.default" not in code
533
534        # Test aoti
535        out = AOTIRunnerUtil.run("cuda", func, (arg,))
536        torch.cuda.synchronize()
537
538    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
539    @fresh_inductor_cache()
540    def test_inductor_all_reduce_coalesced(self):
541        def func(args: List[torch.Tensor]) -> torch.Tensor:
542            bufs = [arg + 42 for arg in args]
543            # Expect in-place with inductor allocated buf
544            ar0 = funcol.all_reduce_coalesced(bufs, "avg", "0")
545            ar0 = [funcol.wait_tensor(out) for out in ar0]
546            # Expect no in-place with graph input
547            ar1 = funcol.all_reduce_coalesced(args, "avg", "0")
548            ar1 = [funcol.wait_tensor(out) for out in ar1]
549            return ar0, ar1
550
551        args = [torch.rand(4, 4, device="cuda") for _ in range(2)]
552        compiled = torch.compile(func)
553        code = run_and_get_triton_code(compiled, args)
554        (
555            FileCheck()
556            .check("buf0 = empty")
557            .check("buf5 = empty")
558            .check("buf1 = empty")
559            .check("buf6 = empty")
560            # Expect in-place with inductor allocated buf
561            .check(
562                "torch.ops._c10d_functional.all_reduce_coalesced_"
563                ".default([buf0, buf1]"
564            )
565            # Expect no in-place with graph input (buf5, buf6 are clones)
566            .check(
567                "torch.ops._c10d_functional.all_reduce_coalesced_"
568                ".default([buf5, buf6]"
569            )
570            .check("torch.ops._c10d_functional.wait_tensor.default(buf0")
571            .check("torch.ops._c10d_functional.wait_tensor.default(buf1")
572            .check("torch.ops._c10d_functional.wait_tensor.default(buf5")
573            .check("torch.ops._c10d_functional.wait_tensor.default(buf6")
574            # Expect no extra copy on return
575            .check("return (buf0, buf1, buf5, buf6, )")
576            .run(code)
577        )
578        assert "= torch.ops._c10d_functional.wait_tensor.default" not in code
579
580        # Test aoti
581        out = AOTIRunnerUtil.run("cuda", func, (args,))
582        torch.cuda.synchronize()
583
584    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
585    @fresh_inductor_cache()
586    def test_inductor_inplace_op_on_view(self):
587        def func(arg: torch.Tensor) -> torch.Tensor:
588            buf0 = (arg + 10)[:2]
589            ar0 = funcol.all_reduce(buf0, "avg", "0")
590            ar0 = funcol.wait_tensor(ar0)
591            return ar0
592
593        arg = torch.rand(4, 4, device="cuda")
594        compiled = torch.compile(func)
595
596        code = run_and_get_triton_code(compiled, arg)
597        (
598            FileCheck()
599            .check("buf0 = empty")
600            # Ensure the all_reduce_ input is a view
601            .check(
602                "torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf0"
603            )
604            .check(
605                "torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf0"
606            )
607            .check("return (reinterpret_tensor(buf0")
608            .run(code)
609        )
610
611    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
612    @fresh_inductor_cache()
613    def test_inductor_reuse_buffer_after_inplace_collective(self):
614        def func(arg: torch.Tensor) -> torch.Tensor:
615            # Expect allocation
616            buf0 = arg + 42
617            ar0 = funcol.all_reduce(buf0, "avg", "0")
618            ar0 = funcol.wait_tensor(ar0)
619            # Expect allocation
620            buf1 = torch.mm(arg, ar0)
621            # Expect buf0 to be reused
622            buf2 = torch.mm(arg, buf1)
623            return buf1, buf2
624
625        arg = torch.rand(4, 4, device="cuda")
626        compiled = torch.compile(func)
627        code = run_and_get_triton_code(compiled, arg)
628        (
629            FileCheck()
630            # Expect allocation
631            .check("buf0 = empty")
632            .check("torch.ops._c10d_functional.all_reduce_.default(buf0")
633            .check("torch.ops._c10d_functional.wait_tensor.default(buf0")
634            # Expect allocation
635            .check("buf7 = empty")
636            .check("extern_kernels.mm(arg0_1, buf0, out=buf7")
637            # Expect buf0 to be reused
638            .check("buf8 = buf0; del buf0  # reuse")
639            .check("extern_kernels.mm(arg0_1, buf7, out=buf8")
640            # Expect no extra copy on return
641            .check("return (buf7, buf8, )")
642            .run(code)
643        )
644        assert "= torch.ops._c10d_functional.wait_tensor.default" not in code
645
646    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
647    @fresh_inductor_cache()
648    def test_inductor_all_gather_into_tensor_single(self):
649        def func(arg: torch.Tensor) -> torch.Tensor:
650            ag0 = funcol.all_gather_tensor(arg, 0, "0")
651            ag0 = funcol.wait_tensor(ag0)
652            return ag0
653
654        arg = torch.rand(4, 4, device="cuda")
655        compiled = torch.compile(func)
656        code = run_and_get_triton_code(compiled, arg)
657        (
658            FileCheck()
659            .check(
660                "buf0 = torch.ops._c10d_functional.all_gather_into_tensor.default(arg0_1"
661            )
662            .check("torch.ops._c10d_functional.wait_tensor.default(buf0")
663            # Expect no extra copy on return
664            .check("return (buf0, )")
665            .run(code)
666        )
667        assert "= torch.ops._c10d_functional.wait_tensor.default" not in code
668
669        # Test aoti
670        out = AOTIRunnerUtil.run("cuda", func, (arg,))
671        torch.cuda.synchronize()
672
673    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
674    @fresh_inductor_cache()
675    def test_inductor_all_gather_into_tensor_coalesced(self):
676        def func(args: List[torch.Tensor]) -> torch.Tensor:
677            ag0 = funcol.all_gather_into_tensor_coalesced(args, "0")
678            ag0 = [funcol.wait_tensor(out) for out in ag0]
679            return ag0
680
681        args = [torch.rand(4, 4, device="cuda") for _ in range(4)]
682        compiled = torch.compile(func)
683        code = run_and_get_triton_code(compiled, args)
684        (
685            FileCheck()
686            .check(
687                "buf0 = torch.ops._c10d_functional.all_gather_into_tensor_coalesced"
688                ".default([arg0_1, arg1_1, arg2_1, arg3_1]"
689            )
690            .check("buf1 = buf0[0]")
691            .check("buf2 = buf0[1]")
692            .check("buf3 = buf0[2]")
693            .check("buf4 = buf0[3]")
694            .check("torch.ops._c10d_functional.wait_tensor.default(buf1")
695            .check("torch.ops._c10d_functional.wait_tensor.default(buf2")
696            .check("torch.ops._c10d_functional.wait_tensor.default(buf3")
697            .check("torch.ops._c10d_functional.wait_tensor.default(buf4")
698            # Expect no extra copy on return
699            .check("return (buf1, buf2, buf3, buf4, )")
700            .run(code)
701        )
702
703        # Test aoti
704        out = AOTIRunnerUtil.run("cuda", func, (args,))
705        torch.cuda.synchronize()
706
707    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
708    @fresh_inductor_cache()
709    def test_inductor_reduce_scatter_tensor_single(self):
710        def func(arg: torch.Tensor) -> torch.Tensor:
711            rs0 = funcol.reduce_scatter_tensor(arg, "avg", 0, "0")
712            rs0 = funcol.wait_tensor(rs0)
713            return rs0
714
715        arg = torch.rand(4, 4, device="cuda")
716        compiled = torch.compile(func)
717        code = run_and_get_triton_code(compiled, arg)
718        (
719            FileCheck()
720            .check(
721                "buf0 = torch.ops._c10d_functional.reduce_scatter_tensor.default(arg0_1"
722            )
723            .check("torch.ops._c10d_functional.wait_tensor.default(buf0")
724            # Expect no extra copy on return
725            .check("return (buf0, )")
726            .run(code)
727        )
728
729        # Test aoti
730        out = AOTIRunnerUtil.run("cuda", func, (arg,))
731        torch.cuda.synchronize()
732
733    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
734    @fresh_inductor_cache()
735    def test_inductor_reduce_scatter_tensor_coalesced(self):
736        def func(args: List[torch.Tensor]) -> torch.Tensor:
737            rs0 = funcol.reduce_scatter_tensor_coalesced(
738                args, "avg", [0] * len(args), "0"
739            )
740            rs0 = [funcol.wait_tensor(out) for out in rs0]
741            return rs0
742
743        args = [torch.rand(4, 4, device="cuda") for _ in range(4)]
744        compiled = torch.compile(func)
745        code = run_and_get_triton_code(compiled, args)
746        (
747            FileCheck()
748            .check(
749                "buf0 = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced"
750                ".default([arg0_1, arg1_1, arg2_1, arg3_1]"
751            )
752            .check("buf1 = buf0[0]")
753            .check("buf2 = buf0[1]")
754            .check("buf3 = buf0[2]")
755            .check("buf4 = buf0[3]")
756            .check("torch.ops._c10d_functional.wait_tensor.default(buf1")
757            .check("torch.ops._c10d_functional.wait_tensor.default(buf2")
758            .check("torch.ops._c10d_functional.wait_tensor.default(buf3")
759            .check("torch.ops._c10d_functional.wait_tensor.default(buf4")
760            # Expect no extra copy on return
761            .check("return (buf1, buf2, buf3, buf4, )")
762            .run(code)
763        )
764
765        # Test aoti
766        AOTIRunnerUtil.run("cuda", func, (args,))
767        torch.cuda.synchronize()
768
769    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
770    @fresh_inductor_cache()
771    def test_inductor_all_to_all_single(self):
772        def _tolist_with_constrain_as_size(tensor):
773            lst = tensor.tolist()
774            for elem in lst:
775                torch._check_is_size(elem)
776            return lst
777
778        def func(
779            input: torch.Tensor,
780            output_split_sizes: torch.Tensor,
781            input_split_sizes: torch.Tensor,
782        ) -> torch.Tensor:
783            output = funcol.all_to_all_single(
784                input,
785                _tolist_with_constrain_as_size(output_split_sizes),
786                _tolist_with_constrain_as_size(input_split_sizes),
787                "0",
788            )
789            return funcol.wait_tensor(output)
790
791        torch.manual_seed(42)
792        send_sz_matrix = torch.randint(0, 20, (self.world_size, self.world_size))
793
794        input_split_sizes = send_sz_matrix[self.rank]
795        output_split_sizes = send_sz_matrix[:, self.rank].contiguous()
796        input = torch.full((input_split_sizes.sum().item(),), float(self.rank)).cuda()
797
798        with torch._dynamo.config.patch(
799            dynamic_shapes=True,
800            capture_dynamic_output_shape_ops=True,
801            capture_scalar_outputs=True,
802        ):
803            compiled = torch.compile(func, dynamic=True)
804            code = run_and_get_triton_code(
805                compiled, input, output_split_sizes, input_split_sizes
806            )
807        (
808            FileCheck()
809            .check_regex(
810                "torch.ops._c10d_functional.all_to_all_single.default\\("
811                "arg\\d+_\\d+, \\[u\\d+, u\\d+\\], \\[u\\d+, u\\d+\\]"
812            )
813            .check("torch.ops._c10d_functional.wait_tensor.default(")
814            .run(code)
815        )
816
817    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
818    @fresh_inductor_cache()
819    def test_inductor_broadcast(self):
820        def func(arg: torch.Tensor) -> torch.Tensor:
821            buf0 = arg + 42
822            # Expect in-place with inductor allocated buf
823            br0 = funcol.broadcast(buf0, 1, "0")
824            br0 = funcol.wait_tensor(br0)
825            # Expect no in-place with graph input
826            br1 = funcol.broadcast(arg, 0, "0")
827            br1 = funcol.wait_tensor(br1)
828            return br0, br1
829
830        arg = torch.rand(4, 4, device="cuda")
831        compiled = torch.compile(func)
832
833        code = run_and_get_triton_code(compiled, arg)
834        (
835            FileCheck()
836            .check("buf0 = empty")
837            .check("buf7 = empty")
838            # Expect in-place with inductor allocated buf
839            .check("torch.ops._c10d_functional.broadcast_.default(buf0")
840            .check("torch.ops._c10d_functional.wait_tensor.default(buf0")
841            # Expect no in-place with graph input (buf5 is a clone)
842            .check("torch.ops._c10d_functional.broadcast_.default(buf7")
843            .check("torch.ops._c10d_functional.wait_tensor.default(buf7")
844            # Expect no extra copy on return
845            .check("return (buf0, buf7, )")
846            .run(code)
847        )
848
849        # Test aoti
850        out = AOTIRunnerUtil.run("cuda", func, (arg,))
851        torch.cuda.synchronize()
852
853    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
854    @fresh_inductor_cache()
855    def test_ranks_and_tag(self):
856        def func(arg: torch.Tensor) -> torch.Tensor:
857            buf0 = arg + 42
858            # Expect in-place with inductor allocated buf
859            ar0 = funcol.all_reduce(buf0, "avg", [0, 1], "")
860            ar0 = funcol.wait_tensor(ar0)
861            # Expect no in-place with graph input
862            ar1 = funcol.all_reduce(arg, "avg", [0, 1], "")
863            ar1 = funcol.wait_tensor(ar1)
864            return ar0, ar1
865
866        arg = torch.rand(4, 4, device="cuda")
867        compiled = torch.compile(func, fullgraph=True)
868
869        code = run_and_get_triton_code(compiled, arg)
870        (FileCheck().check("all_reduce_.default(buf0, 'avg', '0')").run(code))
871
872
873if __name__ == "__main__":
874    run_tests()
875