1# Owner(s): ["module: c10d"] 2import threading 3import unittest 4from typing import List 5 6import torch 7import torch.distributed as dist 8import torch.distributed._functional_collectives as funcol 9from torch._C import FileCheck 10from torch._inductor.utils import fresh_inductor_cache, run_and_get_triton_code 11from torch.distributed._functional_collectives import ( 12 all_gather_into_tensor_coalesced, 13 all_gather_tensor, 14 all_reduce, 15 all_reduce_coalesced, 16 all_to_all_single, 17 AsyncCollectiveTensor, 18 reduce_scatter_tensor, 19 reduce_scatter_tensor_coalesced, 20) 21from torch.testing._internal.common_distributed import ( 22 MultiProcessTestCase, 23 requires_nccl, 24 skip_if_lt_x_gpu, 25) 26from torch.testing._internal.common_utils import ( # type: ignore[attr-defined] 27 run_tests, 28 TestCase, 29) 30from torch.testing._internal.distributed.fake_pg import FakeStore 31from torch.utils._triton import has_triton 32 33 34def load_test_module(name): 35 import sys 36 from importlib.machinery import SourceFileLoader 37 from pathlib import Path 38 from unittest import mock 39 40 testdir = Path(__file__).absolute().parent.parent 41 with mock.patch("sys.path", [*sys.path, str(testdir)]): 42 return SourceFileLoader( 43 name, str(testdir / f"{name.replace('.', '/')}.py") 44 ).load_module() 45 46 47AOTIRunnerUtil = load_test_module("inductor.test_aot_inductor_utils").AOTIRunnerUtil 48 49import sys 50 51 52if not dist.is_available(): 53 print("distributed package not available, skipping tests", file=sys.stderr) 54 sys.exit(0) 55 56 57@requires_nccl() 58class TestWithNCCL(MultiProcessTestCase): 59 def setUp(self) -> None: 60 super().setUp() 61 self._spawn_processes() 62 63 @property 64 def world_size(self) -> int: 65 return 2 66 67 @property 68 def ranks(self) -> List[int]: 69 return list(range(self.world_size)) 70 71 @property 72 def device(self) -> torch.device: 73 return torch.device(f"cuda:{self.rank}") 74 75 def _init_process_group(self) -> None: 76 # Allow testing aoti after torch.compile 77 torch._inductor.config.triton.store_cubin = True 78 torch._inductor.config.debug = True 79 80 torch.cuda.set_device(self.device) 81 store = dist.FileStore(self.file_name, self.world_size) 82 dist.init_process_group( 83 backend="nccl", 84 world_size=self.world_size, 85 rank=self.rank, 86 store=store, 87 ) 88 torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD) 89 90 @skip_if_lt_x_gpu(2) 91 def test_all_reduce_single(self) -> None: 92 self._init_process_group() 93 94 input = torch.full((10, 10), float(self.rank), device=self.device) 95 output = torch.ops._c10d_functional.all_reduce( 96 input, 97 "avg", 98 "default", 99 ) 100 output = torch.ops._c10d_functional.wait_tensor(output) 101 assert id(output) != id(input) 102 expect = sum(self.ranks) / self.world_size 103 assert output.eq(expect).all() 104 105 # Test Python API and AsyncCollectiveTensor 106 output = all_reduce( 107 input, 108 "avg", 109 "default", 110 ) 111 assert isinstance(output, AsyncCollectiveTensor) 112 assert not output.completed 113 assert output.eq(expect).all() 114 assert output.completed 115 116 @skip_if_lt_x_gpu(2) 117 def test_all_reduce_single_(self) -> None: 118 self._init_process_group() 119 120 input = torch.full((10, 10), float(self.rank), device=self.device) 121 output = torch.ops._c10d_functional.all_reduce_( 122 input, 123 "avg", 124 "default", 125 ) 126 output = torch.ops._c10d_functional.wait_tensor(output) 127 assert id(output) == id(input) 128 expect = sum(self.ranks) / self.world_size 129 assert output.eq(expect).all() 130 131 @skip_if_lt_x_gpu(2) 132 def test_all_reduce_coalesced(self) -> None: 133 self._init_process_group() 134 135 inputs = [ 136 torch.full((i, i), float(self.rank * i), device=self.device) 137 for i in range(10) 138 ] 139 outputs = torch.ops._c10d_functional.all_reduce_coalesced( 140 inputs, 141 "avg", 142 "default", 143 ) 144 for i, (output, input) in enumerate(zip(outputs, inputs)): 145 output = torch.ops._c10d_functional.wait_tensor(output) 146 assert id(output) != id(input) 147 assert output.eq(sum(self.ranks) / self.world_size * i).all() 148 149 # Test Python API and AsyncCollectiveTensor 150 outputs = all_reduce_coalesced( 151 inputs, 152 "avg", 153 "default", 154 ) 155 for i, (output, input) in enumerate(zip(outputs, inputs)): 156 assert not output.completed 157 assert output.eq(sum(self.ranks) / self.world_size * i).all() 158 assert output.completed 159 160 @skip_if_lt_x_gpu(2) 161 def test_all_reduce_coalesced_(self) -> None: 162 self._init_process_group() 163 164 inputs = [ 165 torch.full((i, i), float(self.rank * i), device=self.device) 166 for i in range(10) 167 ] 168 outputs = torch.ops._c10d_functional.all_reduce_coalesced_( 169 inputs, 170 "avg", 171 "default", 172 ) 173 for i, (output, input) in enumerate(zip(outputs, inputs)): 174 output = torch.ops._c10d_functional.wait_tensor(output) 175 assert id(output) == id(input) 176 assert output.eq(sum(self.ranks) / self.world_size * i).all() 177 178 @skip_if_lt_x_gpu(2) 179 def test_all_gather_into_tensor_single(self) -> None: 180 self._init_process_group() 181 182 input = torch.full((10, 10), float(self.rank), device=self.device) 183 output = torch.ops._c10d_functional.all_gather_into_tensor( 184 input, 185 self.world_size, 186 "default", 187 ) 188 output = torch.ops._c10d_functional.wait_tensor(output) 189 expect = torch.cat( 190 [ 191 torch.full((10, 10), float(rank), device=self.device) 192 for rank in self.ranks 193 ] 194 ) 195 assert torch.allclose(output, expect) 196 assert output.eq(expect).all() 197 198 # Test out-variant of all_gather_into_tensor 199 output = torch.empty(expect.shape, device=self.device) 200 output = torch.ops._c10d_functional.all_gather_into_tensor_out( 201 input, 202 self.world_size, 203 "default", 204 out=output, 205 ) 206 output = torch.ops._c10d_functional.wait_tensor(output) 207 assert torch.allclose(output, expect) 208 assert output.eq(expect).all() 209 210 # Test Python API and AsyncCollectiveTensor 211 output = all_gather_tensor( 212 input, 213 0, 214 "default", 215 ) 216 assert isinstance(output, AsyncCollectiveTensor) 217 assert not output.completed 218 assert output.eq(expect).all() 219 assert output.completed 220 221 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 222 @skip_if_lt_x_gpu(2) 223 # https://github.com/pytorch/pytorch/issues/126338 224 def test_inductor_dtypeview_memory_leak(self): 225 self._init_process_group() 226 227 def func(arg: torch.Tensor) -> torch.Tensor: 228 ag0 = torch.ops._c10d_functional.all_gather_into_tensor.default( 229 arg, 230 self.world_size, 231 "default", 232 ) 233 ag0_view = torch.ops.aten.view.dtype(ag0, torch.int32) 234 return funcol.wait_tensor(ag0_view) 235 236 arg = torch.full( 237 (10, 10), 238 float(self.rank), 239 device=self.device, 240 dtype=torch.float32, 241 ) 242 compiled = torch.compile(func) 243 mem_usage = {} 244 # check if the aten.view.dtype is compiled to aten.view.dtype 245 code = run_and_get_triton_code(compiled, arg) 246 ( 247 FileCheck() 248 .check("torch.ops._c10d_functional.wait_tensor.default(aten.view.dtype") 249 .run(code) 250 ) 251 # check memory leak 252 for i in range(1, 10): 253 mem_usage[i] = torch.cuda.max_memory_allocated() 254 compiled(arg) 255 256 assert mem_usage[9] == mem_usage[8] 257 258 @skip_if_lt_x_gpu(2) 259 def test_all_gather_into_tensor_coalesced(self) -> None: 260 self._init_process_group() 261 262 inputs = [ 263 torch.full((10, 10), float(self.rank * i), device=self.device) 264 for i in range(10) 265 ] 266 outputs = torch.ops._c10d_functional.all_gather_into_tensor_coalesced( 267 inputs, 268 self.world_size, 269 "default", 270 ) 271 expect = [ 272 torch.cat( 273 [ 274 torch.full((10, 10), float(rank) * i, device=self.device) 275 for rank in self.ranks 276 ] 277 ) 278 for i in range(10) 279 ] 280 for i, output in enumerate(outputs): 281 output = torch.ops._c10d_functional.wait_tensor(output) 282 assert output.eq(expect[i]).all() 283 284 # Test Python API and AsyncCollectiveTensor 285 outputs = all_gather_into_tensor_coalesced( 286 inputs, 287 "default", 288 ) 289 for i, output in enumerate(outputs): 290 assert not output.completed 291 assert output.eq(expect[i]).all() 292 assert output.completed 293 294 @skip_if_lt_x_gpu(2) 295 def test_reduce_scatter_tensor_single(self) -> None: 296 self._init_process_group() 297 298 input = torch.tensor(self.ranks, device=self.device) 299 output = torch.ops._c10d_functional.reduce_scatter_tensor( 300 input, 301 "avg", 302 self.world_size, 303 "default", 304 ) 305 output = torch.ops._c10d_functional.wait_tensor(output) 306 assert output.eq(self.rank).all() 307 308 # Test Python API and AsyncCollectiveTensor 309 output = reduce_scatter_tensor( 310 input, 311 "avg", 312 0, 313 "default", 314 ) 315 assert isinstance(output, AsyncCollectiveTensor) 316 assert not output.completed 317 assert output.eq(self.rank).all() 318 assert output.completed 319 320 @skip_if_lt_x_gpu(2) 321 def test_reduce_scatter_tensor_coalesced(self) -> None: 322 self._init_process_group() 323 324 inputs = [torch.tensor(self.ranks, device=self.device) * i for i in range(10)] 325 outputs = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced( 326 inputs, 327 "avg", 328 self.world_size, 329 "default", 330 ) 331 for i, output in enumerate(outputs): 332 output = torch.ops._c10d_functional.wait_tensor(output) 333 assert output.eq(self.rank * i).all() 334 335 # Test Python API and AsyncCollectiveTensor 336 outputs = reduce_scatter_tensor_coalesced( 337 inputs, 338 "avg", 339 [0] * 10, 340 "default", 341 ) 342 for i, output in enumerate(outputs): 343 assert not output.completed 344 assert output.eq(self.rank * i).all() 345 assert output.completed 346 347 @skip_if_lt_x_gpu(2) 348 def test_all_to_all_single(self) -> None: 349 self._init_process_group() 350 torch.cuda.set_device(self.device) 351 352 torch.manual_seed(42) 353 send_sz_matrix = torch.randint(0, 20, (self.world_size, self.world_size)) 354 355 input_split_sizes = send_sz_matrix[self.rank].tolist() 356 output_split_sizes = send_sz_matrix[:, self.rank].tolist() 357 input = torch.full((sum(input_split_sizes),), float(self.rank)).cuda() 358 359 output = torch.ops._c10d_functional.all_to_all_single( 360 input, 361 output_split_sizes, 362 input_split_sizes, 363 "default", 364 ) 365 output = torch.ops._c10d_functional.wait_tensor(output) 366 expect = torch.cat( 367 [ 368 torch.full((sz,), float(rank)).cuda() 369 for rank, sz in enumerate(output_split_sizes) 370 ] 371 ) 372 assert output.eq(expect).all() 373 374 # Test Python API and AsyncCollectiveTensor 375 output = all_to_all_single( 376 input, output_split_sizes, input_split_sizes, "default" 377 ) 378 assert not output.completed 379 assert output.eq(expect).all() 380 assert output.completed 381 382 @skip_if_lt_x_gpu(2) 383 def test_broadcast(self) -> None: 384 self._init_process_group() 385 386 input = torch.full((10, 10), float(self.rank), device=self.device) 387 output = torch.ops._c10d_functional.broadcast( 388 input, 389 1, 390 "default", 391 ) 392 output = torch.ops._c10d_functional.wait_tensor(output) 393 assert id(output) != id(input) 394 expect = 1 395 assert output.eq(expect).all() 396 397 # Test Python API and AsyncCollectiveTensor 398 output = funcol.broadcast( 399 input, 400 1, 401 "default", 402 ) 403 assert isinstance(output, AsyncCollectiveTensor) 404 assert not output.completed 405 assert output.eq(expect).all() 406 assert output.completed 407 408 @skip_if_lt_x_gpu(2) 409 def test_unwaited(self) -> None: 410 # Verify that the process can terminate gracefully 411 # even with unwaited tensors 412 self._init_process_group() 413 414 input = torch.full((10, 10), float(self.rank), device=self.device) 415 output = torch.ops._c10d_functional.all_reduce( 416 input, 417 "avg", 418 "default", 419 ) 420 421 @skip_if_lt_x_gpu(2) 422 def test_py_work(self) -> None: 423 self._init_process_group() 424 425 wait_called = False 426 427 class MyWork(dist.Work): 428 def wait(self, _): 429 nonlocal wait_called 430 wait_called = True 431 432 tensor = torch.rand(2, 2) 433 torch._C._distributed_c10d._register_work(tensor, MyWork()) 434 torch.ops._c10d_functional.wait_tensor(tensor) 435 self.assertTrue(wait_called) 436 437 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 438 @skip_if_lt_x_gpu(2) 439 @fresh_inductor_cache() 440 def test_threading(self): 441 self._init_process_group() 442 device = torch.device(f"cuda:{self.rank}") 443 444 def func(arg: torch.Tensor) -> torch.Tensor: 445 buf0 = arg + 42 446 ar0 = funcol.all_reduce(buf0, "avg", "0") 447 ar0 = funcol.wait_tensor(ar0) 448 return ar0 + 1 449 450 arg = torch.rand(4, 4, device=device) 451 func(arg) 452 453 compiled = torch.compile(func, fullgraph=True) 454 code = run_and_get_triton_code(compiled, arg) 455 FileCheck().check("all_reduce_.default(buf0, 'avg', '0')").run(code) 456 457 # Unless explicitly specified (e.g. in a custom runtime), the process 458 # group registry is shared among all threads in a process. Here we 459 # verify that a process group registered in main thread can be resolved 460 # in a different thread. 461 class TestThread(threading.Thread): 462 def run(self): 463 self.exc = None 464 try: 465 func(arg) 466 compiled(arg) 467 except BaseException as exc: 468 self.exc = exc 469 470 def join(self): 471 threading.Thread.join(self) 472 if self.exc: 473 raise self.exc 474 475 t = TestThread() 476 t.start() 477 t.join() 478 479 480class CompileTest(TestCase): 481 def setUp(self): 482 # Allow testing aoti after torch.compile 483 torch._inductor.config.triton.store_cubin = True 484 torch._inductor.config.debug = True 485 486 self.rank = 0 487 self.world_size = 2 488 torch.cuda.set_device("cuda:0") 489 490 store = FakeStore() 491 dist.init_process_group( 492 backend="fake", 493 world_size=self.world_size, 494 rank=self.rank, 495 store=store, 496 ) 497 498 def tearDown(self): 499 dist.destroy_process_group() 500 501 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 502 @fresh_inductor_cache() 503 def test_inductor_all_reduce_single(self): 504 def func(arg: torch.Tensor) -> torch.Tensor: 505 buf0 = arg + 42 506 # Expect in-place with inductor allocated buf 507 ar0 = funcol.all_reduce(buf0, "avg", "0") 508 ar0 = funcol.wait_tensor(ar0) 509 # Expect no in-place with graph input 510 ar1 = funcol.all_reduce(arg, "avg", "0") 511 ar1 = funcol.wait_tensor(ar1) 512 return ar0, ar1 513 514 arg = torch.rand(4, 4, device="cuda") 515 compiled = torch.compile(func) 516 517 code = run_and_get_triton_code(compiled, arg) 518 ( 519 FileCheck() 520 .check("buf0 = empty") 521 .check("buf7 = empty") 522 # Expect in-place with inductor allocated buf 523 .check("torch.ops._c10d_functional.all_reduce_.default(buf0") 524 .check("torch.ops._c10d_functional.wait_tensor.default(buf0") 525 # Expect no in-place with graph input (buf5 is a clone) 526 .check("torch.ops._c10d_functional.all_reduce_.default(buf7") 527 .check("torch.ops._c10d_functional.wait_tensor.default(buf7") 528 # Expect no extra copy on return 529 .check("return (buf0, buf7, )") 530 .run(code) 531 ) 532 assert "= torch.ops._c10d_functional.wait_tensor.default" not in code 533 534 # Test aoti 535 out = AOTIRunnerUtil.run("cuda", func, (arg,)) 536 torch.cuda.synchronize() 537 538 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 539 @fresh_inductor_cache() 540 def test_inductor_all_reduce_coalesced(self): 541 def func(args: List[torch.Tensor]) -> torch.Tensor: 542 bufs = [arg + 42 for arg in args] 543 # Expect in-place with inductor allocated buf 544 ar0 = funcol.all_reduce_coalesced(bufs, "avg", "0") 545 ar0 = [funcol.wait_tensor(out) for out in ar0] 546 # Expect no in-place with graph input 547 ar1 = funcol.all_reduce_coalesced(args, "avg", "0") 548 ar1 = [funcol.wait_tensor(out) for out in ar1] 549 return ar0, ar1 550 551 args = [torch.rand(4, 4, device="cuda") for _ in range(2)] 552 compiled = torch.compile(func) 553 code = run_and_get_triton_code(compiled, args) 554 ( 555 FileCheck() 556 .check("buf0 = empty") 557 .check("buf5 = empty") 558 .check("buf1 = empty") 559 .check("buf6 = empty") 560 # Expect in-place with inductor allocated buf 561 .check( 562 "torch.ops._c10d_functional.all_reduce_coalesced_" 563 ".default([buf0, buf1]" 564 ) 565 # Expect no in-place with graph input (buf5, buf6 are clones) 566 .check( 567 "torch.ops._c10d_functional.all_reduce_coalesced_" 568 ".default([buf5, buf6]" 569 ) 570 .check("torch.ops._c10d_functional.wait_tensor.default(buf0") 571 .check("torch.ops._c10d_functional.wait_tensor.default(buf1") 572 .check("torch.ops._c10d_functional.wait_tensor.default(buf5") 573 .check("torch.ops._c10d_functional.wait_tensor.default(buf6") 574 # Expect no extra copy on return 575 .check("return (buf0, buf1, buf5, buf6, )") 576 .run(code) 577 ) 578 assert "= torch.ops._c10d_functional.wait_tensor.default" not in code 579 580 # Test aoti 581 out = AOTIRunnerUtil.run("cuda", func, (args,)) 582 torch.cuda.synchronize() 583 584 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 585 @fresh_inductor_cache() 586 def test_inductor_inplace_op_on_view(self): 587 def func(arg: torch.Tensor) -> torch.Tensor: 588 buf0 = (arg + 10)[:2] 589 ar0 = funcol.all_reduce(buf0, "avg", "0") 590 ar0 = funcol.wait_tensor(ar0) 591 return ar0 592 593 arg = torch.rand(4, 4, device="cuda") 594 compiled = torch.compile(func) 595 596 code = run_and_get_triton_code(compiled, arg) 597 ( 598 FileCheck() 599 .check("buf0 = empty") 600 # Ensure the all_reduce_ input is a view 601 .check( 602 "torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf0" 603 ) 604 .check( 605 "torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf0" 606 ) 607 .check("return (reinterpret_tensor(buf0") 608 .run(code) 609 ) 610 611 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 612 @fresh_inductor_cache() 613 def test_inductor_reuse_buffer_after_inplace_collective(self): 614 def func(arg: torch.Tensor) -> torch.Tensor: 615 # Expect allocation 616 buf0 = arg + 42 617 ar0 = funcol.all_reduce(buf0, "avg", "0") 618 ar0 = funcol.wait_tensor(ar0) 619 # Expect allocation 620 buf1 = torch.mm(arg, ar0) 621 # Expect buf0 to be reused 622 buf2 = torch.mm(arg, buf1) 623 return buf1, buf2 624 625 arg = torch.rand(4, 4, device="cuda") 626 compiled = torch.compile(func) 627 code = run_and_get_triton_code(compiled, arg) 628 ( 629 FileCheck() 630 # Expect allocation 631 .check("buf0 = empty") 632 .check("torch.ops._c10d_functional.all_reduce_.default(buf0") 633 .check("torch.ops._c10d_functional.wait_tensor.default(buf0") 634 # Expect allocation 635 .check("buf7 = empty") 636 .check("extern_kernels.mm(arg0_1, buf0, out=buf7") 637 # Expect buf0 to be reused 638 .check("buf8 = buf0; del buf0 # reuse") 639 .check("extern_kernels.mm(arg0_1, buf7, out=buf8") 640 # Expect no extra copy on return 641 .check("return (buf7, buf8, )") 642 .run(code) 643 ) 644 assert "= torch.ops._c10d_functional.wait_tensor.default" not in code 645 646 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 647 @fresh_inductor_cache() 648 def test_inductor_all_gather_into_tensor_single(self): 649 def func(arg: torch.Tensor) -> torch.Tensor: 650 ag0 = funcol.all_gather_tensor(arg, 0, "0") 651 ag0 = funcol.wait_tensor(ag0) 652 return ag0 653 654 arg = torch.rand(4, 4, device="cuda") 655 compiled = torch.compile(func) 656 code = run_and_get_triton_code(compiled, arg) 657 ( 658 FileCheck() 659 .check( 660 "buf0 = torch.ops._c10d_functional.all_gather_into_tensor.default(arg0_1" 661 ) 662 .check("torch.ops._c10d_functional.wait_tensor.default(buf0") 663 # Expect no extra copy on return 664 .check("return (buf0, )") 665 .run(code) 666 ) 667 assert "= torch.ops._c10d_functional.wait_tensor.default" not in code 668 669 # Test aoti 670 out = AOTIRunnerUtil.run("cuda", func, (arg,)) 671 torch.cuda.synchronize() 672 673 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 674 @fresh_inductor_cache() 675 def test_inductor_all_gather_into_tensor_coalesced(self): 676 def func(args: List[torch.Tensor]) -> torch.Tensor: 677 ag0 = funcol.all_gather_into_tensor_coalesced(args, "0") 678 ag0 = [funcol.wait_tensor(out) for out in ag0] 679 return ag0 680 681 args = [torch.rand(4, 4, device="cuda") for _ in range(4)] 682 compiled = torch.compile(func) 683 code = run_and_get_triton_code(compiled, args) 684 ( 685 FileCheck() 686 .check( 687 "buf0 = torch.ops._c10d_functional.all_gather_into_tensor_coalesced" 688 ".default([arg0_1, arg1_1, arg2_1, arg3_1]" 689 ) 690 .check("buf1 = buf0[0]") 691 .check("buf2 = buf0[1]") 692 .check("buf3 = buf0[2]") 693 .check("buf4 = buf0[3]") 694 .check("torch.ops._c10d_functional.wait_tensor.default(buf1") 695 .check("torch.ops._c10d_functional.wait_tensor.default(buf2") 696 .check("torch.ops._c10d_functional.wait_tensor.default(buf3") 697 .check("torch.ops._c10d_functional.wait_tensor.default(buf4") 698 # Expect no extra copy on return 699 .check("return (buf1, buf2, buf3, buf4, )") 700 .run(code) 701 ) 702 703 # Test aoti 704 out = AOTIRunnerUtil.run("cuda", func, (args,)) 705 torch.cuda.synchronize() 706 707 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 708 @fresh_inductor_cache() 709 def test_inductor_reduce_scatter_tensor_single(self): 710 def func(arg: torch.Tensor) -> torch.Tensor: 711 rs0 = funcol.reduce_scatter_tensor(arg, "avg", 0, "0") 712 rs0 = funcol.wait_tensor(rs0) 713 return rs0 714 715 arg = torch.rand(4, 4, device="cuda") 716 compiled = torch.compile(func) 717 code = run_and_get_triton_code(compiled, arg) 718 ( 719 FileCheck() 720 .check( 721 "buf0 = torch.ops._c10d_functional.reduce_scatter_tensor.default(arg0_1" 722 ) 723 .check("torch.ops._c10d_functional.wait_tensor.default(buf0") 724 # Expect no extra copy on return 725 .check("return (buf0, )") 726 .run(code) 727 ) 728 729 # Test aoti 730 out = AOTIRunnerUtil.run("cuda", func, (arg,)) 731 torch.cuda.synchronize() 732 733 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 734 @fresh_inductor_cache() 735 def test_inductor_reduce_scatter_tensor_coalesced(self): 736 def func(args: List[torch.Tensor]) -> torch.Tensor: 737 rs0 = funcol.reduce_scatter_tensor_coalesced( 738 args, "avg", [0] * len(args), "0" 739 ) 740 rs0 = [funcol.wait_tensor(out) for out in rs0] 741 return rs0 742 743 args = [torch.rand(4, 4, device="cuda") for _ in range(4)] 744 compiled = torch.compile(func) 745 code = run_and_get_triton_code(compiled, args) 746 ( 747 FileCheck() 748 .check( 749 "buf0 = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced" 750 ".default([arg0_1, arg1_1, arg2_1, arg3_1]" 751 ) 752 .check("buf1 = buf0[0]") 753 .check("buf2 = buf0[1]") 754 .check("buf3 = buf0[2]") 755 .check("buf4 = buf0[3]") 756 .check("torch.ops._c10d_functional.wait_tensor.default(buf1") 757 .check("torch.ops._c10d_functional.wait_tensor.default(buf2") 758 .check("torch.ops._c10d_functional.wait_tensor.default(buf3") 759 .check("torch.ops._c10d_functional.wait_tensor.default(buf4") 760 # Expect no extra copy on return 761 .check("return (buf1, buf2, buf3, buf4, )") 762 .run(code) 763 ) 764 765 # Test aoti 766 AOTIRunnerUtil.run("cuda", func, (args,)) 767 torch.cuda.synchronize() 768 769 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 770 @fresh_inductor_cache() 771 def test_inductor_all_to_all_single(self): 772 def _tolist_with_constrain_as_size(tensor): 773 lst = tensor.tolist() 774 for elem in lst: 775 torch._check_is_size(elem) 776 return lst 777 778 def func( 779 input: torch.Tensor, 780 output_split_sizes: torch.Tensor, 781 input_split_sizes: torch.Tensor, 782 ) -> torch.Tensor: 783 output = funcol.all_to_all_single( 784 input, 785 _tolist_with_constrain_as_size(output_split_sizes), 786 _tolist_with_constrain_as_size(input_split_sizes), 787 "0", 788 ) 789 return funcol.wait_tensor(output) 790 791 torch.manual_seed(42) 792 send_sz_matrix = torch.randint(0, 20, (self.world_size, self.world_size)) 793 794 input_split_sizes = send_sz_matrix[self.rank] 795 output_split_sizes = send_sz_matrix[:, self.rank].contiguous() 796 input = torch.full((input_split_sizes.sum().item(),), float(self.rank)).cuda() 797 798 with torch._dynamo.config.patch( 799 dynamic_shapes=True, 800 capture_dynamic_output_shape_ops=True, 801 capture_scalar_outputs=True, 802 ): 803 compiled = torch.compile(func, dynamic=True) 804 code = run_and_get_triton_code( 805 compiled, input, output_split_sizes, input_split_sizes 806 ) 807 ( 808 FileCheck() 809 .check_regex( 810 "torch.ops._c10d_functional.all_to_all_single.default\\(" 811 "arg\\d+_\\d+, \\[u\\d+, u\\d+\\], \\[u\\d+, u\\d+\\]" 812 ) 813 .check("torch.ops._c10d_functional.wait_tensor.default(") 814 .run(code) 815 ) 816 817 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 818 @fresh_inductor_cache() 819 def test_inductor_broadcast(self): 820 def func(arg: torch.Tensor) -> torch.Tensor: 821 buf0 = arg + 42 822 # Expect in-place with inductor allocated buf 823 br0 = funcol.broadcast(buf0, 1, "0") 824 br0 = funcol.wait_tensor(br0) 825 # Expect no in-place with graph input 826 br1 = funcol.broadcast(arg, 0, "0") 827 br1 = funcol.wait_tensor(br1) 828 return br0, br1 829 830 arg = torch.rand(4, 4, device="cuda") 831 compiled = torch.compile(func) 832 833 code = run_and_get_triton_code(compiled, arg) 834 ( 835 FileCheck() 836 .check("buf0 = empty") 837 .check("buf7 = empty") 838 # Expect in-place with inductor allocated buf 839 .check("torch.ops._c10d_functional.broadcast_.default(buf0") 840 .check("torch.ops._c10d_functional.wait_tensor.default(buf0") 841 # Expect no in-place with graph input (buf5 is a clone) 842 .check("torch.ops._c10d_functional.broadcast_.default(buf7") 843 .check("torch.ops._c10d_functional.wait_tensor.default(buf7") 844 # Expect no extra copy on return 845 .check("return (buf0, buf7, )") 846 .run(code) 847 ) 848 849 # Test aoti 850 out = AOTIRunnerUtil.run("cuda", func, (arg,)) 851 torch.cuda.synchronize() 852 853 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 854 @fresh_inductor_cache() 855 def test_ranks_and_tag(self): 856 def func(arg: torch.Tensor) -> torch.Tensor: 857 buf0 = arg + 42 858 # Expect in-place with inductor allocated buf 859 ar0 = funcol.all_reduce(buf0, "avg", [0, 1], "") 860 ar0 = funcol.wait_tensor(ar0) 861 # Expect no in-place with graph input 862 ar1 = funcol.all_reduce(arg, "avg", [0, 1], "") 863 ar1 = funcol.wait_tensor(ar1) 864 return ar0, ar1 865 866 arg = torch.rand(4, 4, device="cuda") 867 compiled = torch.compile(func, fullgraph=True) 868 869 code = run_and_get_triton_code(compiled, arg) 870 (FileCheck().check("all_reduce_.default(buf0, 'avg', '0')").run(code)) 871 872 873if __name__ == "__main__": 874 run_tests() 875