1# Owner(s): ["oncall: distributed"] 2 3import os 4import sys 5 6import torch 7import torch.distributed as dist 8from torch import nn 9 10 11if not dist.is_available(): 12 print("Distributed not available, skipping tests", file=sys.stderr) 13 sys.exit(0) 14 15from torch.distributed.algorithms.ddp_comm_hooks import ( 16 DDPCommHookType, 17 register_ddp_comm_hook, 18) 19from torch.nn.parallel import DistributedDataParallel 20from torch.testing._internal.common_distributed import ( 21 MultiProcessTestCase, 22 requires_nccl, 23 skip_if_lt_x_gpu, 24) 25from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN 26 27 28if TEST_WITH_DEV_DBG_ASAN: 29 print("Multiprocessing spawn is not compatible with dev/dbg asan", file=sys.stderr) 30 sys.exit(0) 31 32 33def gpus_for_rank(world_size): 34 visible_devices = list(range(torch.cuda.device_count())) 35 gpus_per_process = torch.cuda.device_count() // world_size 36 gpus_for_rank = [] 37 for rank in range(world_size): 38 gpus_for_rank.append( 39 visible_devices[rank * gpus_per_process : (rank + 1) * gpus_per_process] 40 ) 41 return gpus_for_rank 42 43 44class Task(nn.Module): 45 def __init__(self) -> None: 46 super().__init__() 47 torch.manual_seed(0) 48 self.p = nn.Parameter(torch.randn(40, 20)) 49 50 def forward(self, x): 51 return self.p * x 52 53 54class TestDdpCommHook(nn.Module): 55 def __init__(self) -> None: 56 super().__init__() 57 self.t0 = Task() 58 59 def forward(self, x, rank): 60 return self.t0(x ** (1 + rank)) 61 62 63class DistributedDataParallelCommHookTest(MultiProcessTestCase): 64 def setUp(self): 65 super().setUp() 66 self._spawn_processes() 67 68 def tearDown(self): 69 try: 70 os.remove(self.file_name) 71 except OSError: 72 pass 73 74 def _get_process_group_nccl(self): 75 store = dist.FileStore(self.file_name, self.world_size) 76 dist.init_process_group( 77 backend="nccl", 78 world_size=self.world_size, 79 rank=self.rank, 80 store=store, 81 ) 82 return dist.distributed_c10d._get_default_group() 83 84 @property 85 def world_size(self): 86 return 2 87 88 def _local_model(self): 89 local_model = TestDdpCommHook().cpu() 90 91 return local_model 92 93 def _get_grads(self, process_group, hook_type=None): 94 device_id = gpus_for_rank(self.world_size)[self.rank][0] 95 gpu_model = DistributedDataParallel( 96 TestDdpCommHook().to(device_id), 97 device_ids=[device_id], 98 process_group=process_group, 99 ) 100 101 # Register DDP Communication Hook if defined 102 if hook_type is not None: 103 register_ddp_comm_hook( 104 comm_hook_type=hook_type, model=gpu_model, state=process_group 105 ) 106 107 return self._run_and_get_grads(gpu_model) 108 109 def _run_and_get_grads(self, model): 110 torch.manual_seed(2020) 111 input = torch.randn(40, 20) 112 # Run forward 113 output = model(input, self.rank) 114 115 # Run backward 116 output.mean().backward() 117 118 # The only layer 119 param = next(model.parameters()) 120 return param.grad 121 122 @requires_nccl() 123 @skip_if_lt_x_gpu(2) 124 def test_ddp_comm_hook_allreduce_hook(self): 125 """ 126 This unit test verifies the ``allreduce`` hook registered case gives same result 127 with no hook registered case. 128 """ 129 process_group = self._get_process_group_nccl() 130 131 # No hook registered case, get the reference grads. 132 reference_grads = self._get_grads(process_group, None) 133 # Register hook case, get the hook grads. 134 hook_grads = self._get_grads(process_group, DDPCommHookType.ALLREDUCE) 135 136 torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=0) 137 138 @requires_nccl() 139 @skip_if_lt_x_gpu(2) 140 def test_ddp_comm_hook_fp16compress_hook(self): 141 """ 142 This unit test verifies the ``fp16 compress`` hook registered case 143 gives close result with no hook registered case. 144 """ 145 process_group = self._get_process_group_nccl() 146 147 # No hook registered case, get the reference grads. 148 reference_grads = self._get_grads(process_group, None) 149 # Register hook case, get the hook grads. 150 hook_grads = self._get_grads(process_group, DDPCommHookType.FP16_COMPRESS) 151 152 torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=1e-4) 153 154 @requires_nccl() 155 @skip_if_lt_x_gpu(2) 156 def test_ddp_comm_hook_quantize_per_tensor_hook(self): 157 """ 158 This unit test verifies the ``quantize per tensor`` hook registered case 159 gives close result with no hook registered case. 160 """ 161 process_group = self._get_process_group_nccl() 162 163 # No hook registered case, get the reference grads. 164 reference_grads = self._get_grads(process_group, None) 165 # Register hook case, get the hook grads. 166 hook_grads = self._get_grads(process_group, DDPCommHookType.QUANTIZE_PER_TENSOR) 167 168 torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=1e-4) 169 170 @requires_nccl() 171 @skip_if_lt_x_gpu(2) 172 def test_ddp_comm_hook_quantize_per_channel_hook(self): 173 """ 174 This unit test verifies the ``quantize per channel`` hook registered case 175 gives close result with no hook registered case. 176 """ 177 process_group = self._get_process_group_nccl() 178 179 # No hook registered case, get the reference grads. 180 reference_grads = self._get_grads(process_group, None) 181 # Register hook case, get the hook grads. 182 hook_grads = self._get_grads( 183 process_group, DDPCommHookType.QUANTIZE_PER_CHANNEL 184 ) 185 186 torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=1e-4) 187 188 @requires_nccl() 189 @skip_if_lt_x_gpu(2) 190 def test_ddp_comm_hook_noop_hook(self): 191 """ 192 This unit test verifies the ``noop`` hook registered case and a subsequent allreduce 193 gives same result with no hook registered case. 194 """ 195 process_group = self._get_process_group_nccl() 196 197 # No hook registered case, get the reference grads. 198 reference_grads = self._get_grads(process_group, None) 199 # Register hook case, get the hook grads. 200 hook_grads = self._get_grads(process_group, DDPCommHookType.NOOP) 201 # Apply a subsequent allreduce to average grads. 202 hook_grads.div_(self.world_size) 203 dist.all_reduce(hook_grads, group=process_group) 204 205 torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=0) 206 207 @requires_nccl() 208 @skip_if_lt_x_gpu(2) 209 def test_is_last_hook(self): 210 process_group = self._get_process_group_nccl() 211 212 def hook(flags, bucket): 213 flags.append(bucket.is_last()) 214 fut = torch.futures.Future() 215 fut.set_result(bucket.buffer()) 216 return fut 217 218 flags = [] 219 device_id = gpus_for_rank(self.world_size)[self.rank][0] 220 model = nn.Sequential( 221 nn.Linear(2, 4000, bias=False), 222 *[nn.Linear(4000, 4000, bias=False) for _ in range(10)], 223 ) 224 gpu_model = DistributedDataParallel( 225 model.to(device_id), 226 device_ids=[device_id], 227 process_group=process_group, 228 ) 229 gpu_model.register_comm_hook(state=flags, hook=hook) 230 input = torch.randn(10, 2) 231 gpu_model(input).sum().backward() 232 self.assertTrue(flags[-1]) 233 self.assertFalse(any(flags[:-1])) 234 235 236if __name__ == "__main__": 237 assert ( 238 not torch.cuda._initialized 239 ), "test_distributed must not have initialized CUDA context on main process" 240 241 run_tests() 242