1# Owner(s): ["module: dynamo"] 2import functools 3import unittest 4from unittest.mock import patch 5 6import torch 7import torch._dynamo 8import torch._dynamo.logging 9import torch._dynamo.test_case 10 11# for some reason importing functional collectives after dynamo breaks collectives handling! 12import torch.distributed._functional_collectives as _functional_collectives 13from torch._C import FileCheck 14from torch._dynamo.testing import CompileCounter 15from torch._dynamo.utils import same 16from torch._inductor.compile_fx import compile_fx as inductor_compile_fx 17from torch._inductor.utils import run_and_get_triton_code 18from torch.distributed.distributed_c10d import GroupMember 19from torch.fx.experimental.proxy_tensor import make_fx 20from torch.testing._internal.common_distributed import ( 21 _dynamo_dist_per_rank_init, 22 DynamoDistributedMultiProcTestCase, 23 DynamoDistributedSingleProcTestCase, 24 requires_nccl, 25 skip_if_lt_x_gpu, 26) 27from torch.testing._internal.common_utils import ( 28 instantiate_parametrized_tests, 29 parametrize, 30 requires_cuda, 31) 32from torch.utils._triton import has_triton 33 34 35def _tolist_with_constrain_as_size(tensor): 36 lst = tensor.tolist() 37 for elem in lst: 38 torch._check_is_size(elem) 39 return lst 40 41 42@requires_nccl() 43class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase): 44 """ 45 Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under 46 """ 47 48 def get_world_trs(self): 49 return { 50 "tag": "", 51 "ranks": list(range(self.world_size)), 52 "group_size": self.world_size, 53 } 54 55 @property 56 def world_size(self) -> int: 57 # hack: no matter whether we have 2 or 3 or 4 gpus, just run on 2 58 # works around issue with skipif<2 and workers with unpredictable #s gpu 59 return 2 60 61 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 62 @skip_if_lt_x_gpu(2) 63 def test_broadcast_inductor(self): 64 """ 65 Testing if broadcast works correctly when using inductor 66 """ 67 68 def example(tensor, src, *, tag, ranks, group_size): 69 res = torch.ops.c10d_functional.broadcast( 70 tensor, src, tag, ranks, group_size 71 ) 72 res = torch.ops.c10d_functional.wait_tensor(res) 73 return res 74 75 def compile(func, example_inputs): 76 graph = make_fx(func)(*example_inputs) 77 return inductor_compile_fx(graph, example_inputs) 78 79 with _dynamo_dist_per_rank_init(self.rank, self.world_size): 80 example = functools.partial( 81 example, 82 **self.get_world_trs(), 83 ) 84 t = torch.randn(4, 4, device="cuda") 85 inputs = (t if self.rank == 0 else torch.zeros(4, 4, device="cuda"), 0) 86 eager_out = example(*inputs) 87 self.assertTrue(same(t, eager_out)) 88 89 compiled_func = compile(example, inputs) 90 compiled_out = compiled_func(*inputs) 91 self.assertTrue(same(eager_out, compiled_out)) 92 93 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 94 @skip_if_lt_x_gpu(2) 95 def test_allreduce_inductor(self): 96 """ 97 This is matmul/cat/allreduce is a pattern we aim to optimize. 98 """ 99 100 def matmul_cat_col(a, b, c, d, e, f, *, tag, ranks, group_size): 101 x = torch.matmul(a, b) 102 y = torch.matmul(c, d) 103 z = torch.cat((x, y)) 104 ar = torch.ops.c10d_functional.all_reduce(z, "sum", tag, ranks, group_size) 105 g = torch.matmul(e, f) 106 ar = torch.ops.c10d_functional.wait_tensor(ar) 107 out = torch.add(ar, g.repeat(2, 1)) 108 return (out,) 109 110 def compile(func, example_inputs): 111 graph = make_fx(func)(*example_inputs) 112 return inductor_compile_fx(graph, example_inputs) 113 114 with _dynamo_dist_per_rank_init(self.rank, self.world_size): 115 matmul_cat_col = functools.partial( 116 matmul_cat_col, 117 **self.get_world_trs(), 118 ) 119 inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 6 120 121 eager_out = matmul_cat_col(*inputs) 122 compiled_matmul_cat_col = compile(matmul_cat_col, inputs) 123 inductor_out = compiled_matmul_cat_col(*inputs) 124 self.assertTrue(same(eager_out, inductor_out, tol=0.001)) 125 126 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 127 @skip_if_lt_x_gpu(2) 128 def test_allreduce_inductor_cudagraph_trees(self): 129 """ 130 Tests whether cudagraph trees support all_reduce from nccl 131 """ 132 import torch.distributed as dist 133 134 # dist.all_reduce is an inplace op in eager mode but a functionanlized op in compiled mode. 135 # so we define eager_func and func separately for the same semantic. 136 def eager_func(x): 137 y = x * x 138 dist.all_reduce(y, op=dist.ReduceOp.SUM) 139 x = torch.nn.functional.silu(x) 140 return x * y 141 142 def func(x): 143 y = x * x 144 y = dist.all_reduce(y, op=dist.ReduceOp.SUM) 145 x = torch.nn.functional.silu(x) 146 return x * y 147 148 options = { 149 "triton.cudagraphs": True, 150 "triton.cudagraph_trees": True, 151 } 152 153 with _dynamo_dist_per_rank_init(self.rank, self.world_size): 154 compiled_func = torch.compile( 155 func, backend="inductor", fullgraph=True, options=options, dynamic=None 156 ) 157 158 for nelem in [1024, 2048, 4096]: 159 x = torch.randn(nelem, device="cuda", dtype=torch.bfloat16) 160 golden_out = eager_func(x) 161 162 for _ in range(3): 163 compiled_out = compiled_func(x) 164 self.assertEqual(golden_out, compiled_out) 165 166 def test_c10d_functional_tagged_pt2_compliant(self): 167 op = torch.ops._c10d_functional.all_reduce.default 168 self.assertIn(torch.Tag.pt2_compliant_tag, op.tags) 169 op = torch.ops.c10d_functional.all_reduce.default 170 self.assertIn(torch.Tag.pt2_compliant_tag, op.tags) 171 172 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 173 @skip_if_lt_x_gpu(2) 174 def test_eager_allreduce_inductor_wait(self): 175 def eager_func(a, b, c, d, *, tag, ranks, group_size): 176 x = torch.matmul(a, b) 177 y = torch.matmul(c, d) 178 z = torch.cat((x, y)) 179 ar = torch.ops.c10d_functional.all_reduce(z, "sum", tag, ranks, group_size) 180 return ar 181 182 def inductor_func(ar, e, f): 183 g = torch.matmul(e, f) 184 ar = torch.ops.c10d_functional.wait_tensor(ar) 185 out = torch.add(ar, g.repeat(2, 1)) 186 return (out,) 187 188 def compile(func, example_inputs): 189 graph = make_fx(func)(*example_inputs) 190 return inductor_compile_fx(graph, example_inputs) 191 192 with _dynamo_dist_per_rank_init(self.rank, self.world_size): 193 eager_func = functools.partial( 194 eager_func, 195 **self.get_world_trs(), 196 ) 197 eager_inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 4 198 inductor_inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 2 199 200 eager_out = inductor_func(eager_func(*eager_inputs), *inductor_inputs) 201 compiled_inductor_func = compile( 202 inductor_func, [eager_func(*eager_inputs)] + list(inductor_inputs) 203 ) 204 inductor_out = compiled_inductor_func( 205 eager_func(*eager_inputs), *inductor_inputs 206 ) 207 print(f"eager_out, {eager_out}") 208 print(f"inductor_out, {inductor_out}") 209 self.assertTrue(same(eager_out, inductor_out, tol=0.001)) 210 211 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 212 @skip_if_lt_x_gpu(2) 213 def test_inductor_allreduce_eager_wait(self): 214 def inductor_func(a, b, c, d, *, tag, ranks, group_size): 215 x = torch.matmul(a, b) 216 y = torch.matmul(c, d) 217 z = torch.cat((x, y)) 218 ar = torch.ops.c10d_functional.all_reduce(z, "sum", tag, ranks, group_size) 219 return ar 220 221 def eager_func(ar, e, f): 222 g = torch.matmul(e, f) 223 ar = torch.ops.c10d_functional.wait_tensor(ar) 224 out = torch.add(ar, g.repeat(2, 1)) 225 return (out,) 226 227 def compile(func, example_inputs): 228 graph = make_fx(func)(*example_inputs) 229 return inductor_compile_fx(graph, example_inputs) 230 231 with _dynamo_dist_per_rank_init(self.rank, self.world_size): 232 inductor_func = functools.partial( 233 inductor_func, 234 **self.get_world_trs(), 235 ) 236 inductor_inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 4 237 eager_inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 2 238 239 eager_out = eager_func(inductor_func(*inductor_inputs), *eager_inputs) 240 compiled_inductor_func = compile(inductor_func, inductor_inputs) 241 inductor_out = eager_func( 242 compiled_inductor_func(*inductor_inputs), *eager_inputs 243 ) 244 self.assertTrue(same(eager_out, inductor_out, tol=0.001)) 245 246 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 247 @skip_if_lt_x_gpu(2) 248 @patch.object(torch._inductor.config, "allow_buffer_reuse", True) 249 def test_allreduce_input_buffer_reuse(self): 250 def func(a, *, tag, ranks, group_size): 251 ar = _functional_collectives.all_reduce(a, "sum", ranks, tag) 252 c = torch.relu(a) 253 d = torch.matmul(c, c) 254 e = d + ar 255 return (e,) 256 257 with _dynamo_dist_per_rank_init(self.rank, self.world_size): 258 inputs = torch.ones(4, 4, device="cuda") + self.rank 259 compiled = torch.compile(func) 260 out = compiled(inputs, **self.get_world_trs()) 261 correct = func(inputs, **self.get_world_trs()) 262 self.assertTrue(same(out, correct)) 263 264 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 265 @skip_if_lt_x_gpu(2) 266 def test_permute_tensor(self): 267 def func(tensor, src_dst_pairs, *, tag, ranks, group_size): 268 return _functional_collectives.permute_tensor( 269 tensor, src_dst_pairs, ranks, tag 270 ) 271 272 with _dynamo_dist_per_rank_init(self.rank, self.world_size): 273 inputs = ( 274 # rank0: [0., 1.], rank1: [2., 3.] 275 torch.arange(2, dtype=torch.float32, device="cuda") + 2 * self.rank, 276 [1, 0], 277 ) 278 compiled = torch.compile(func) 279 out = compiled(*inputs, **self.get_world_trs()) 280 correct = func(*inputs, **self.get_world_trs()) 281 self.assertTrue(same(out, correct)) 282 283 # rank0: [2., 3.], rank1: [0., 1.] 284 expected = torch.arange(2, dtype=torch.float32, device="cuda") + 2 * ( 285 (self.rank - 1 + self.world_size) % self.world_size 286 ) 287 self.assertEqual(out, expected) 288 self.assertEqual(correct, expected) 289 290 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 291 @skip_if_lt_x_gpu(2) 292 @patch.object(torch._inductor.config, "allow_buffer_reuse", True) 293 def test_allgather_output_buffer_reuse(self): 294 class Model(torch.nn.Module): 295 def __init__(self, *args, **kwargs) -> None: 296 super().__init__(*args, **kwargs) 297 self.emb = torch.nn.Embedding(4, 4) 298 299 def forward(self, x, world_size, tag, ranks, group_size): 300 y = self.emb(x) 301 last_dim = y.dim() - 1 302 res = _functional_collectives.all_gather_tensor(y, 0, ranks, tag) 303 out = torch.cat(torch.chunk(res, world_size, dim=0), dim=last_dim) 304 return out 305 306 with _dynamo_dist_per_rank_init(self.rank, self.world_size): 307 model = Model().cuda() 308 model_compiled = torch.compile(model) 309 inp = torch.tensor([[2, 1, 3, 0]], dtype=torch.long, device="cuda") 310 out = model_compiled(inp, self.world_size, **self.get_world_trs()) 311 correct = model(inp, self.world_size, **self.get_world_trs()) 312 self.assertTrue(same(out, correct)) 313 314 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 315 @skip_if_lt_x_gpu(2) 316 def test_allgather_contiguous_input(self): 317 class Model(torch.nn.Module): 318 def __init__(self, *args, **kwargs) -> None: 319 super().__init__(*args, **kwargs) 320 self.emb = torch.nn.Embedding(4, 4) 321 322 def forward(self, x, world_size, tag, ranks, group_size): 323 y = self.emb(x) 324 last_dim = y.dim() - 1 325 y = y.transpose_(0, last_dim).contiguous() 326 res = _functional_collectives.all_gather_tensor(y, 0, ranks, tag) 327 out = y.transpose_(0, last_dim).contiguous() 328 return out 329 330 with _dynamo_dist_per_rank_init(self.rank, self.world_size): 331 model = Model().cuda() 332 model_compiled = torch.compile(model) 333 inp = torch.tensor([[2, 1, 3, 0]], dtype=torch.long, device="cuda") 334 out = model_compiled(inp, self.world_size, **self.get_world_trs()) 335 correct = model(inp, self.world_size, **self.get_world_trs()) 336 self.assertTrue(same(out, correct)) 337 338 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 339 @skip_if_lt_x_gpu(2) 340 def test_allgather_into_tensor_inductor(self): 341 """ 342 This is matmul/cat/allreduce is a pattern we aim to optimize. 343 """ 344 345 def example(a, b, *, tag, ranks, group_size): 346 c = torch.matmul(a, b) 347 ag = torch.ops.c10d_functional.all_gather_into_tensor( 348 c, tag, ranks, group_size 349 ) 350 ag = torch.ops.c10d_functional.wait_tensor(ag) 351 return (ag,) 352 353 def compile(func, example_inputs): 354 graph = make_fx(func)(*example_inputs) 355 return inductor_compile_fx(graph, example_inputs) 356 357 with _dynamo_dist_per_rank_init(self.rank, self.world_size): 358 example = functools.partial( 359 example, 360 **self.get_world_trs(), 361 ) 362 inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 2 363 364 eager_out = example(*inputs) 365 compiled_matmul_cat_col = compile(example, inputs) 366 inductor_out = compiled_matmul_cat_col(*inputs) 367 self.assertTrue(same(eager_out, inductor_out, tol=0.001)) 368 369 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 370 @skip_if_lt_x_gpu(2) 371 def test_reduce_scatter_tensor_inductor(self): 372 def example(a, b, *, tag, ranks, group_size): 373 c = torch.matmul(a, b) 374 ag = torch.ops.c10d_functional.reduce_scatter_tensor( 375 c, "sum", tag, ranks, group_size 376 ) 377 ag = torch.ops.c10d_functional.wait_tensor(ag) 378 return (ag,) 379 380 def compile(func, example_inputs): 381 graph = make_fx(func)(*example_inputs) 382 return inductor_compile_fx(graph, example_inputs) 383 384 with _dynamo_dist_per_rank_init(self.rank, self.world_size): 385 example = functools.partial( 386 example, 387 **self.get_world_trs(), 388 ) 389 inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 2 390 391 eager_out = example(*inputs) 392 compiled_fn = compile(example, inputs) 393 inductor_out = compiled_fn(*inputs) 394 self.assertTrue(same(eager_out, inductor_out, tol=0.001)) 395 396 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 397 @skip_if_lt_x_gpu(2) 398 @patch.object(torch._dynamo.config, "capture_scalar_outputs", True) 399 def test_all_to_all_single_inductor(self): 400 def example( 401 inp, 402 input_split_sizes_tensor, 403 output_split_sizes_tensor, 404 *, 405 tag, 406 ranks, 407 group_size, 408 ): 409 input_split_sizes = _tolist_with_constrain_as_size(input_split_sizes_tensor) 410 output_split_sizes = _tolist_with_constrain_as_size( 411 output_split_sizes_tensor 412 ) 413 a2a = torch.ops.c10d_functional.all_to_all_single( 414 inp, 415 output_split_sizes, 416 input_split_sizes, 417 tag, 418 ranks, 419 group_size, 420 ) 421 a2a = torch.ops.c10d_functional.wait_tensor(a2a) 422 out = a2a / a2a.sum(dim=0) 423 return out 424 425 with _dynamo_dist_per_rank_init( 426 self.rank, self.world_size 427 ), torch._dynamo.config.patch( 428 dynamic_shapes=True, 429 capture_dynamic_output_shape_ops=True, 430 capture_scalar_outputs=True, 431 ): 432 row = self.world_size * (self.rank + 1) * (self.world_size + 1) / 2 433 input_split_sizes_tensor = torch.tensor( 434 [(i + 1) * (self.rank + 1) for i in range(self.world_size)], 435 dtype=torch.int64, 436 ) 437 output_split_sizes_tensor = torch.tensor( 438 [(i + 1) * (self.rank + 1) for i in range(self.world_size)], 439 dtype=torch.int64, 440 ) 441 inputs = ( 442 torch.ones(int(row), 5, device="cuda") * (self.rank + 1), 443 input_split_sizes_tensor, 444 output_split_sizes_tensor, 445 ) 446 trs = self.get_world_trs() 447 448 compiled_fn = torch.compile(example, fullgraph=True, dynamic=True) 449 code = run_and_get_triton_code(compiled_fn, *inputs, **trs) 450 ( 451 FileCheck() 452 .check_regex( 453 "torch.ops._c10d_functional.all_to_all_single.default\\(" 454 "arg\\d+_\\d+, " 455 "\\[u\\d+, u\\d+\\], " 456 "\\[u\\d+, u\\d+\\]" 457 ) 458 .run(code) 459 ) 460 461 eager_out = example(*inputs, **trs) 462 inductor_out = compiled_fn(*inputs, **trs) 463 self.assertTrue(same(eager_out, inductor_out, tol=0.001)) 464 465 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 466 @skip_if_lt_x_gpu(2) 467 def test_all_to_all_single_inductor_split_sizes_none(self): 468 def example(inp, *, tag, ranks, group_size): 469 a2a = torch.ops.c10d_functional.all_to_all_single( 470 inp, 471 None, 472 None, 473 tag, 474 ranks, 475 group_size, 476 ) 477 a2a = torch.ops.c10d_functional.wait_tensor(a2a) 478 out = a2a / a2a.sum(dim=0) 479 return out 480 481 with _dynamo_dist_per_rank_init(self.rank, self.world_size): 482 inputs = ( 483 torch.ones(self.world_size, self.world_size, device="cuda") 484 * (self.rank + 1), 485 ) 486 trs = self.get_world_trs() 487 488 compiled_fn = torch.compile(example, fullgraph=True, dynamic=True) 489 code = run_and_get_triton_code(compiled_fn, *inputs, **trs) 490 ( 491 FileCheck() 492 .check_regex( 493 "torch.ops._c10d_functional.all_to_all_single.default\\(" 494 "arg\\d+_\\d+, " 495 "\\[\\(s\\d+ // \\d\\), \\(s\\d+ // \\d\\)\\], " 496 "\\[\\(s\\d+ // \\d\\), \\(s\\d+ // \\d\\)\\]" 497 ) 498 .run(code) 499 ) 500 501 eager_out = example(*inputs, **trs) 502 inductor_out = compiled_fn(*inputs, **trs) 503 self.assertTrue(same(eager_out, inductor_out, tol=0.001)) 504 505 506@instantiate_parametrized_tests 507@requires_nccl() 508@requires_cuda 509class TestCollectivesInductor(DynamoDistributedSingleProcTestCase): 510 """ 511 Prefer single-proc test runner for basic tests as it is easier to work with. 512 """ 513 514 def get_world_trs(self, world_size=1): 515 return { 516 "tag": "", 517 "ranks": list(range(world_size)), 518 "group_size": world_size, 519 } 520 521 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 522 @torch._inductor.config.patch(debug=True) 523 def test_inductor_single_op(self): 524 def func(inp, *, tag, ranks, group_size): 525 ar = torch.ops.c10d_functional.all_reduce( 526 inp, "sum", tag, ranks, group_size 527 ) 528 ar = torch.ops.c10d_functional.wait_tensor(ar) 529 return ar 530 531 inputs = torch.ones(4, 4, device="cuda") 532 533 compiled = torch.compile(func) 534 out = compiled(inputs, **self.get_world_trs()) 535 code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) 536 # NOTE: Make sure we are not unneccessarily copying the outputs of 537 # wait_tensors before they are returned from the graph. 538 ( 539 FileCheck() 540 .check("buf0 = empty_strided") 541 .check(".run(arg0_1, buf0, 16") 542 .check("torch.ops._c10d_functional.all_reduce_.default(buf0") 543 .check("torch.ops._c10d_functional.wait_tensor.default(buf0") 544 .check("return (buf0") 545 .run(code) 546 ) 547 correct = func(inputs, **self.get_world_trs()) 548 self.assertTrue(same(out, correct)) 549 550 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 551 @torch._inductor.config.patch(debug=True) 552 def test_inductor_steal_buffer(self): 553 """ 554 it's ok and optimal if inductor allreduce mutates the buffer of an intermediate 555 that isn't going to be used again 556 """ 557 558 def func(inp, *, tag, ranks, group_size): 559 x = inp + 1 560 ar = torch.ops.c10d_functional.all_reduce(x, "sum", tag, ranks, group_size) 561 ar = torch.ops.c10d_functional.wait_tensor(ar) 562 # ensure other is not incorrectly aliasing ar's buffer 563 other = torch.ones_like(inp) + 22 564 return ar, other 565 566 inputs = torch.ones(4, 4, device="cuda") 567 568 compiled = torch.compile(func) 569 code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) 570 ( 571 FileCheck() 572 .check("buf0 = empty_strided") 573 .check(".run(arg0_1, buf0") 574 .check("torch.ops._c10d_functional.all_reduce_.default(buf0") 575 .check("torch.ops._c10d_functional.wait_tensor.default(buf0") 576 .check("buf5 = empty_strided") 577 .check(".run(buf5, 16") 578 .check("return (buf0, buf5") 579 .run(code) 580 ) 581 out = compiled(inputs, **self.get_world_trs()) 582 correct = func(inputs, **self.get_world_trs()) 583 self.assertTrue(same(out, correct)) 584 585 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 586 @torch._inductor.config.patch({"debug": True, "triton.descriptive_names": False}) 587 def test_inductor_doesnt_mutate_shared(self): 588 """ 589 make sure that an intermediate that's going to be reuse isn't mutated unless copied 590 """ 591 592 def func(inp, *, tag, ranks, group_size): 593 x = inp + 1 594 ar = torch.ops.c10d_functional.all_reduce(x, "sum", tag, ranks, group_size) 595 y = x + 2 596 ar = torch.ops.c10d_functional.wait_tensor(ar) 597 # ensure other is not incorrectly aliasing ar's buffer 598 other = torch.ones_like(inp) + 22 599 return ar, y, other 600 601 inputs = torch.ones(4, 4, device="cuda") 602 603 compiled = torch.compile(func) 604 code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) 605 # NOTE: Make sure we are not unneccessarily copying the outputs of 606 # wait_tensors before they are returned from the graph. 607 ( 608 FileCheck() 609 .check("buf0 = empty_strided") 610 .check("buf5 = empty_strided") 611 .check(".run(arg0_1, buf0, buf5, 16") 612 .check("torch.ops._c10d_functional.all_reduce_.default(buf0") 613 .check("torch.ops._c10d_functional.wait_tensor.default(buf0") 614 .check("buf6 = empty_strided") 615 .check(".run(buf6, 16") 616 .check("return (buf0, buf5, buf6") 617 .run(code) 618 ) 619 out = compiled(inputs, **self.get_world_trs()) 620 correct = func(inputs, **self.get_world_trs()) 621 self.assertTrue(same(out, correct)) 622 623 def test_dynamo_trace_allreduce(self): 624 def func(inp): 625 ar = _functional_collectives.all_reduce(inp, "sum", "0") 626 return ar 627 628 inputs = torch.ones(4, 4, device="cuda") 629 counter = CompileCounter() 630 compiled = torch.compile(func, backend=counter) 631 out = compiled(inputs) 632 correct = func(inputs) 633 self.assertEqual(counter.frame_count, 1) 634 635 # should test more precisely, but the 2 is supposed to be (all_reduce, wait) 636 self.assertEqual(counter.op_count, 2) 637 self.assertTrue(same(out, correct)) 638 639 def test_dynamo_trace_all_gather_tensor(self): 640 def func(inp): 641 ar = _functional_collectives.all_gather_tensor(inp, 0, "0") 642 return ar 643 644 inputs = torch.ones(4, 4, device="cuda") 645 counter = CompileCounter() 646 compiled = torch.compile(func, backend=counter) 647 out = compiled(inputs) 648 correct = func(inputs) 649 self.assertEqual(counter.frame_count, 1) 650 651 # should test more precisely, but the 2 is supposed to be (all_gather, wait) 652 self.assertEqual(counter.op_count, 2) 653 self.assertTrue(same(out, correct)) 654 655 def test_dynamo_trace_all_gather_tensor_pg(self): 656 def func(inp, *, pg): 657 ar = _functional_collectives.all_gather_tensor(inp, 0, pg) 658 return ar 659 660 inputs = torch.ones(4, 4, device=self.device) 661 counter = CompileCounter() 662 compiled = torch.compile(func, backend=counter, fullgraph=True) 663 out = compiled(inputs, pg=GroupMember.WORLD) 664 correct = func(inputs, pg=GroupMember.WORLD) 665 self.assertEqual(counter.frame_count, 1) 666 667 # should test more precisely, but the 2 is supposed to be (all_gather, wait) 668 self.assertEqual(counter.op_count, 2) 669 self.assertTrue(same(out, correct)) 670 671 def test_dynamo_rewrite_dist_all_gather(self): 672 def func(inp, out, *, pg): 673 torch.distributed.all_gather_into_tensor( 674 out, 675 inp, 676 pg, 677 ) 678 679 local_size = [4, 4] 680 # single-proc test 681 global_size = local_size 682 683 inputs = torch.ones(local_size, device=self.device) 684 outputs = torch.empty(global_size, device=self.device) 685 correct_outputs = torch.empty(global_size, device=self.device) 686 counter = CompileCounter() 687 compiled = torch.compile(func, backend=counter, fullgraph=True) 688 compiled(inputs, outputs, pg=GroupMember.WORLD) 689 func(inputs, correct_outputs, pg=GroupMember.WORLD) 690 assert counter.frame_count == 1 691 692 # should test more precisely, but the 3 is supposed to be (all_gather, wait, copy_) 693 assert counter.op_count == 3 694 assert same(outputs, correct_outputs) 695 696 def test_dynamo_rewrite_dist_all_gather_list(self): 697 def func(inp, out, *, pg): 698 torch.distributed.all_gather( 699 out, 700 inp, 701 pg, 702 ) 703 704 local_size = [4, 4] 705 # single-proc test 706 global_size = local_size 707 708 inputs = torch.ones(local_size, device=self.device) 709 outputs = [torch.empty(global_size, device=self.device)] 710 correct_outputs = [torch.empty(global_size, device=self.device)] 711 counter = CompileCounter() 712 compiled = torch.compile(func, backend=counter, fullgraph=True) 713 compiled(inputs, outputs, pg=GroupMember.WORLD) 714 func(inputs, correct_outputs, pg=GroupMember.WORLD) 715 assert counter.frame_count == 1 716 assert same(outputs, correct_outputs) 717 718 def test_dynamo_rewrite_dist_all_gather_args_match(self): 719 # Duplicated most of the structure from test_dynamo_rewrite_dist_all_gather 720 # except uses kwargs to ensure rewrite has matching arg names 721 def func(inp, out, *, pg): 722 torch.distributed.all_gather_into_tensor( 723 output_tensor=out, 724 input_tensor=inp, 725 group=pg, 726 async_op=False, 727 ) 728 729 local_size = [4, 4] 730 # single-proc test 731 global_size = local_size 732 733 inputs = torch.ones(local_size, device=self.device) 734 outputs = torch.empty(global_size, device=self.device) 735 correct_outputs = torch.empty(global_size, device=self.device) 736 counter = CompileCounter() 737 compiled = torch.compile(func, backend=counter, fullgraph=True) 738 compiled(inputs, outputs, pg=GroupMember.WORLD) 739 func(inputs, correct_outputs, pg=GroupMember.WORLD) 740 assert counter.frame_count == 1 741 742 # should test more precisely, but the 3 is supposed to be (all_gather, wait, copy_) 743 assert counter.op_count == 3 744 assert same(outputs, correct_outputs) 745 746 def test_dynamo_rewrite_dist_reduce_scatter(self): 747 def func(inp, out, *, pg): 748 torch.distributed.reduce_scatter_tensor( 749 out, 750 inp, 751 group=pg, 752 ) 753 754 local_size = [4, 4] 755 # single-proc test 756 global_size = local_size 757 758 inputs = torch.ones(local_size, device=self.device) 759 outputs = torch.empty(global_size, device=self.device) 760 correct_outputs = torch.empty(global_size, device=self.device) 761 counter = CompileCounter() 762 compiled = torch.compile(func, backend=counter, fullgraph=True) 763 compiled(inputs, outputs, pg=GroupMember.WORLD) 764 func(inputs, correct_outputs, pg=GroupMember.WORLD) 765 assert counter.frame_count == 1 766 767 # should test more precisely, but the 3 is supposed to be (reduce_scatter, wait, copy_) 768 assert counter.op_count == 3 769 assert same(outputs, correct_outputs) 770 771 @parametrize( 772 "pg_mode", 773 [ 774 "positional", 775 "positional_none", 776 "kwargs", 777 "kwargs_none", 778 "unspecified", 779 ], 780 ) 781 def test_dynamo_rewrite_dist_allreduce(self, pg_mode): 782 def func(tensor, *args, **kwargs): 783 torch.distributed.all_reduce( 784 tensor, 785 *args, 786 **kwargs, 787 ) 788 789 counter = CompileCounter() 790 compiled = torch.compile(func, backend=counter, fullgraph=True) 791 792 args = [] 793 kwargs = {} 794 795 if pg_mode == "positional": 796 args.append(torch.distributed.ReduceOp.MAX) 797 args.append(GroupMember.WORLD) 798 elif pg_mode == "positional_none": 799 args.append(torch.distributed.ReduceOp.MAX) 800 args.append(None) 801 elif pg_mode == "kwargs": 802 kwargs["group"] = GroupMember.WORLD 803 elif pg_mode == "kwargs_none": 804 kwargs["group"] = None 805 else: 806 assert pg_mode == "unspecified" 807 808 inputs_compiled = torch.ones(2, device=self.device) 809 inputs_eager = torch.ones(2, device=self.device) 810 811 compiled(inputs_compiled, *args, **kwargs) 812 func(inputs_eager, *args, **kwargs) 813 814 assert counter.frame_count == 1 815 # should test more precisely, but the 3 is supposed to be (all_reduce, wait, copy_) 816 assert counter.op_count == 3 817 assert same(inputs_compiled, inputs_eager) 818 819 def test_dynamo_rewrite_dist_all_to_all_single(self): 820 def func(output, input, pg): 821 torch.distributed.all_to_all_single(output, input, group=pg) 822 823 counter = CompileCounter() 824 compiled = torch.compile(func, backend=counter, fullgraph=True) 825 826 input_compiled = torch.ones(2, device=self.device) 827 input_eager = torch.ones(2, device=self.device) 828 output_compiled = torch.empty(2, device=self.device) 829 output_eager = torch.empty(2, device=self.device) 830 831 compiled(output_compiled, input_compiled, GroupMember.WORLD) 832 func(output_eager, input_eager, GroupMember.WORLD) 833 834 assert counter.frame_count == 1 835 assert same(output_compiled, output_eager) 836 837 @parametrize( 838 "reduce_op", 839 [ 840 torch.distributed.ReduceOp.SUM, 841 torch.distributed.ReduceOp.AVG, 842 torch.distributed.ReduceOp.PRODUCT, 843 torch.distributed.ReduceOp.MIN, 844 torch.distributed.ReduceOp.MAX, 845 ], 846 ) 847 def test_dynamo_rewrite_dist_allreduce_reduce_op(self, reduce_op): 848 from torch.distributed._functional_collectives import REDUCE_OP_TO_STR 849 850 def verify_rewrite(gm, _): 851 ar_nodes = [] 852 for node in gm.graph.nodes: 853 if node.target in [ 854 torch.ops.c10d_functional.all_reduce, 855 torch.ops._c10d_functional.all_reduce, 856 ]: 857 ar_nodes.append(node) 858 self.assertEqual(len(ar_nodes), 1) 859 reduce_op_str = ar_nodes[0].args[1] 860 self.assertEqual(REDUCE_OP_TO_STR[reduce_op], reduce_op_str) 861 return gm 862 863 compiled = torch.compile( 864 torch.distributed.all_reduce, 865 backend=verify_rewrite, 866 fullgraph=True, 867 ) 868 inputs = ( 869 torch.ones(2, device=self.device), 870 reduce_op, 871 GroupMember.WORLD, 872 ) 873 compiled(*inputs) 874 875 @parametrize( 876 "source", 877 [ 878 "GroupMember.WORLD", 879 "group.WORLD", 880 "_get_default_group", 881 ], 882 ) 883 def test_dynamo_get_world_group(self, source): 884 def func(tensor): 885 if source == "GroupMember.WORLD": 886 group = torch.distributed.GroupMember.WORLD 887 elif source == "group.WORLD": 888 group = torch.distributed.group.WORLD 889 else: 890 assert source == "_get_default_group" 891 group = torch.distributed.distributed_c10d._get_default_group() 892 893 torch.distributed.all_reduce( 894 tensor, 895 group=group, 896 ) 897 898 def verify(gm, _): 899 ar_nodes = [] 900 for node in gm.graph.nodes: 901 if node.target in [ 902 torch.ops.c10d_functional.all_reduce, 903 torch.ops._c10d_functional.all_reduce, 904 ]: 905 ar_nodes.append(node) 906 self.assertEqual(len(ar_nodes), 1) 907 return gm 908 909 compiled = torch.compile(func, backend=verify, fullgraph=True) 910 input = torch.ones(2, device=self.device) 911 compiled(input) 912 913 def test_dynamo_support_collective_op_with_async_op_False(self): 914 def func(inp, out, *, pg): 915 # user explicitly set the attribute `async_op` to False, 916 # there should be no graph break 917 torch.distributed.reduce_scatter_tensor(out, inp, group=pg, async_op=False) 918 919 local_size = [4, 4] 920 # single-proc test 921 global_size = local_size 922 923 inputs = torch.ones(local_size, device=self.device) 924 outputs = torch.empty(global_size, device=self.device) 925 correct_outputs = torch.empty(global_size, device=self.device) 926 counter = CompileCounter() 927 compiled = torch.compile(func, backend=counter) 928 compiled(inputs, outputs, pg=GroupMember.WORLD) 929 func(inputs, correct_outputs, pg=GroupMember.WORLD) 930 assert counter.frame_count == 1 931 assert counter.op_count == 3 932 assert same(outputs, correct_outputs) 933 934 def test_dynamo_graphbreaks_unsupported_async_op(self): 935 def func(inp, out, *, pg): 936 work = torch.distributed.reduce_scatter_tensor( 937 out, inp, group=pg, async_op=True 938 ) 939 work.wait() 940 941 local_size = [4, 4] 942 # single-proc test 943 global_size = local_size 944 945 inputs = torch.ones(local_size, device=self.device) 946 outputs = torch.empty(global_size, device=self.device) 947 correct_outputs = torch.empty(global_size, device=self.device) 948 counter = CompileCounter() 949 compiled = torch.compile(func, backend=counter) 950 compiled(inputs, outputs, pg=GroupMember.WORLD) 951 func(inputs, correct_outputs, pg=GroupMember.WORLD) 952 assert counter.frame_count == 0 953 assert counter.op_count == 0 954 assert same(outputs, correct_outputs) 955 956 def test_dynamo_pg_var(self): 957 def func(inp, *, pg): 958 x = pg.rank() + 1 % pg.size() 959 return inp + x 960 961 local_size = [4, 4] 962 inputs = torch.ones(local_size, device=self.device) 963 correct_outputs = torch.empty(local_size, device=self.device) 964 counter = CompileCounter() 965 compiled = torch.compile(func, backend=counter, fullgraph=True) 966 outputs = compiled(inputs, pg=GroupMember.WORLD) 967 correct_outputs = func(inputs, pg=GroupMember.WORLD) 968 assert counter.frame_count == 1 969 assert counter.op_count == 1 970 assert same(outputs, correct_outputs) 971 972 def test_dynamo_trace_reduce_scatter_tensor(self): 973 def func(inp): 974 ar = _functional_collectives.reduce_scatter_tensor(inp, "sum", 0, "0") 975 return ar 976 977 inputs = torch.ones(4, 4, device="cuda") 978 counter = CompileCounter() 979 compiled = torch.compile(func, backend=counter) 980 out = compiled(inputs) 981 correct = func(inputs) 982 self.assertEqual(counter.frame_count, 1) 983 984 # should test more precisely, but the 2 is supposed to be (reduce_scatter, wait) 985 self.assertEqual(counter.op_count, 2) 986 self.assertTrue(same(out, correct)) 987 988 def test_dynamo_trace_allgather_coalesced(self): 989 def func(inp, *, tag, ranks, group_size): 990 ar = torch.ops.c10d_functional.all_gather_into_tensor_coalesced( 991 inp, tag, ranks, group_size 992 ) 993 return ar 994 995 inputs = [torch.ones(4, 4, device="cuda"), torch.ones(6, 6, device="cuda")] 996 counter = CompileCounter() 997 compiled = torch.compile(func, backend=counter) 998 out = compiled(inputs, **self.get_world_trs()) 999 correct = func(inputs, **self.get_world_trs()) 1000 assert counter.frame_count == 1 1001 assert counter.op_count == 3 # It generates 2 getattr to unpack the array 1002 assert same(out, correct) 1003 1004 def test_backwards(self): 1005 """ 1006 It's probably not that common to need backwards support for collectives. 1007 1008 However, I wanted to at least see if it was possible to support it as a design goal. 1009 """ 1010 1011 def func(inp): 1012 ar = _functional_collectives.all_reduce(inp, "sum", "0") 1013 return ar 1014 1015 input = torch.ones(4, 4, device="cuda", requires_grad=True) 1016 compiled = torch.compile( 1017 func, backend="aot_eager" 1018 ) # inductor bug with single-op allreduce graph 1019 out = compiled(input) 1020 out.sum().backward() 1021 1022 correct_input = input.clone().detach().requires_grad_() 1023 correct = func(correct_input) 1024 correct.sum().backward() 1025 self.assertTrue(same(out, correct)) 1026 self.assertTrue(same(input.grad, correct_input.grad)) 1027 1028 def test_meta(self): 1029 x = torch.rand((2, 3, 4), device="meta") 1030 out = torch.ops.c10d_functional.all_reduce(x, "sum", **self.get_world_trs()) 1031 self.assertEqual(x.size(), out.size()) 1032 1033 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 1034 @torch._inductor.config.patch({"debug": True, "triton.descriptive_names": False}) 1035 def test_inductor_all_gather_coalesced(self): 1036 """ 1037 make sure that an intermediate that's going to be reuse isn't mutated unless copied 1038 """ 1039 1040 def func(inp, *, tag, ranks, group_size): 1041 x = inp + 1 1042 tensor_list = torch.ops.c10d_functional.all_gather_into_tensor_coalesced( 1043 [x, inp], tag, ranks, group_size 1044 ) 1045 y = x + 2 1046 ar0 = torch.ops.c10d_functional.wait_tensor(tensor_list[0]) 1047 ar1 = torch.ops.c10d_functional.wait_tensor(tensor_list[1]) 1048 # ensure other is not incorrectly aliasing ar's buffer 1049 other = torch.ones_like(inp) + 22 1050 return ar0, y, other, ar1 1051 1052 inputs = torch.ones(4, 4, device="cuda") 1053 1054 compiled = torch.compile(func) 1055 code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) 1056 # NOTE: Make sure we are not unneccessarily copying the outputs of 1057 # wait_tensors before they are returned from the graph. 1058 ( 1059 FileCheck() 1060 .check("buf0 = empty_strided") 1061 .check("buf6 = empty_strided") 1062 .check(".run(arg0_1, buf0, buf6, 16") 1063 .check( 1064 "buf1 = torch.ops._c10d_functional.all_gather_into_tensor_coalesced.default([buf0, arg0_1]" 1065 ) 1066 .check("buf2 = buf1[0]") 1067 .check("buf3 = buf1[1]") 1068 .check("torch.ops._c10d_functional.wait_tensor.default(buf2") 1069 .check("buf7 = buf0; del buf0 # reuse") 1070 .check(".run(buf7, 16") 1071 .check("torch.ops._c10d_functional.wait_tensor.default(buf3") 1072 .check("return (buf2, buf6, buf7, buf3") 1073 .run(code) 1074 ) 1075 out = compiled(inputs, **self.get_world_trs()) 1076 correct = func(inputs, **self.get_world_trs()) 1077 assert same(out, correct), f"{out} va {correct}" 1078 1079 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 1080 @torch._inductor.config.patch({"debug": True, "triton.descriptive_names": False}) 1081 def test_inductor_reduce_scatter_coalesced(self): 1082 """ 1083 make sure that an intermediate that's going to be reuse isn't mutated unless copied 1084 """ 1085 1086 def func(inp, *, tag, ranks, group_size): 1087 x = inp + 1 1088 tensor_list = torch.ops.c10d_functional.reduce_scatter_tensor_coalesced( 1089 [x, inp], "sum", tag, ranks, group_size 1090 ) 1091 y = x + 2 1092 ar0 = torch.ops.c10d_functional.wait_tensor(tensor_list[0]) 1093 ar1 = torch.ops.c10d_functional.wait_tensor(tensor_list[1]) 1094 # ensure other is not incorrectly aliasing ar's buffer 1095 other = torch.ones_like(inp) + 22 1096 return ar0, y, other, ar1 1097 1098 inputs = torch.ones(4, 4, device="cuda") 1099 1100 compiled = torch.compile(func) 1101 code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) 1102 # NOTE: The first return value should be the output of the first wait_tensor. 1103 # We want to make sure no unneccessary copy is made. 1104 ( 1105 FileCheck() 1106 .check("buf0 = empty_strided") 1107 .check("buf6 = empty_strided") 1108 .check(".run(arg0_1, buf0, buf6, 16") 1109 .check( 1110 "buf1 = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced.default([buf0, arg0_1]" 1111 ) 1112 .check("buf2 = buf1[0]") 1113 .check("buf3 = buf1[1]") 1114 .check("torch.ops._c10d_functional.wait_tensor.default(buf2") 1115 .check("buf7 = buf0; del buf0 # reuse") 1116 .check(".run(buf7, 16") 1117 .check("torch.ops._c10d_functional.wait_tensor.default(buf3") 1118 .check("return (buf2, buf6, buf7, buf3") 1119 .run(code) 1120 ) 1121 out = compiled(inputs, **self.get_world_trs()) 1122 correct = func(inputs, **self.get_world_trs()) 1123 assert same(out, correct), f"{out} va {correct}" 1124 1125 1126if __name__ == "__main__": 1127 from torch._dynamo.test_case import run_tests 1128 1129 run_tests() 1130