1# Owner(s): ["oncall: distributed"] 2 3import os 4import weakref 5 6import test_c10d_common 7 8import torch 9import torch.distributed as dist 10import torch.nn as nn 11from torch._C._distributed_c10d import _create_work_from_future 12from torch.futures import Future 13from torch.nn.parallel import DistributedDataParallel as DDP 14from torch.testing._internal.common_distributed import MultiProcessTestCase 15from torch.testing._internal.common_utils import run_tests 16 17 18def create_work(result): 19 future = Future() 20 future.set_result(result) 21 return _create_work_from_future(future) 22 23 24class MyWork(dist._Work): 25 def __init__(self, result, pg): 26 super().__init__() 27 self.result_ = result 28 self.future_ = torch.futures.Future() 29 self.future_.set_result(result) 30 self.pg_ = weakref.ref(pg) 31 32 def wait(self, timeout): 33 self.pg_().wait_count += 1 34 return True 35 36 def get_future(self): 37 self.pg_().get_future_count += 1 38 return self.future_ 39 40 41class LonelyRankProcessGroup(dist.ProcessGroup): 42 """ 43 This PG only supports world_size of 1 44 """ 45 46 def __init__(self, rank, world, use_wrapper): 47 super().__init__(rank, world) 48 assert rank == 0 49 assert world == 1 50 51 self._rank = rank 52 self._world = world 53 self.wait_count = 0 54 self.get_future_count = 0 55 self.use_wrapper = use_wrapper 56 self._work = [] 57 58 def broadcast(self, tensor_list, opts): 59 if self.use_wrapper: 60 return create_work(tensor_list) 61 res = MyWork(tensor_list, self) 62 self._work.append(res) 63 return res 64 65 def allgather(self, output_tensors, input_tensor, opts): 66 for o, i in zip(output_tensors[0], input_tensor): 67 o.copy_(i) 68 if self.use_wrapper: 69 return create_work(output_tensors) 70 71 res = MyWork(output_tensors, self) 72 self._work.append(res) 73 74 return res 75 76 def allreduce(self, tensors, opts): 77 if self.use_wrapper: 78 return create_work(tensors) 79 res = MyWork(tensors, self) 80 self._work.append(res) 81 return res 82 83 def size(self): 84 return self._world 85 86 def getBackendName(self): 87 return "lonely-pg" 88 89 def __repr__(self): 90 return f"PLG w:{self._world} r:{self._rank}" 91 92 93# We cannot use parametrize as some tests are defined on the base class and use _get_process_group 94class AbstractDDPSingleRank(test_c10d_common.CommonDistributedDataParallelTest): 95 def setUp(self): 96 super().setUp() 97 self._spawn_processes() 98 99 @property 100 def world_size(self): 101 return 1 102 103 def tearDown(self): 104 super().tearDown() 105 try: 106 os.remove(self.file_name) 107 except OSError: 108 pass 109 110 def _get_process_group(self): 111 return LonelyRankProcessGroup(self.rank, self.world_size, self.use_wrapper) 112 113 def test_ddp_invoke_work_object(self): 114 pg = self._get_process_group() 115 116 torch.manual_seed(123) 117 model = nn.Sequential(nn.Linear(2, 2), nn.ReLU()) 118 wrapped_model = model 119 input_tensor = torch.rand(2) 120 model = DDP(model, process_group=pg) 121 model(input_tensor).sum().backward() 122 123 ddp_grad = wrapped_model[0].bias.grad.clone() 124 125 wrapped_model.zero_grad() 126 wrapped_model(input_tensor).sum().backward() 127 self.assertEqual(wrapped_model[0].bias.grad, ddp_grad) 128 if not self.use_wrapper: 129 self.assertTrue(pg.wait_count > 0) 130 self.assertTrue(pg.get_future_count > 0) 131 132 def test_ddp_with_pypg(self): 133 pg = self._get_process_group() 134 135 self._test_ddp_with_process_group(pg, [torch.device("cpu")], device_ids=None) 136 137 def test_ddp_with_pypg_with_grad_views(self): 138 pg = self._get_process_group() 139 140 self._test_ddp_with_process_group( 141 pg, [torch.device("cpu")], device_ids=None, gradient_as_bucket_view=True 142 ) 143 144 145class TestDDPWithWorkSubclass(AbstractDDPSingleRank, MultiProcessTestCase): 146 @property 147 def use_wrapper(self): 148 return False 149 150 151class TestDDPWithWorkWrapper(AbstractDDPSingleRank, MultiProcessTestCase): 152 @property 153 def use_wrapper(self): 154 return True 155 156 157if __name__ == "__main__": 158 run_tests() 159