xref: /aosp_15_r20/external/pytorch/test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import os
4import sys
5
6import torch
7import torch.distributed as dist
8from torch import nn
9
10
11if not dist.is_available():
12    print("Distributed not available, skipping tests", file=sys.stderr)
13    sys.exit(0)
14
15from torch.distributed.algorithms.ddp_comm_hooks import (
16    DDPCommHookType,
17    register_ddp_comm_hook,
18)
19from torch.nn.parallel import DistributedDataParallel
20from torch.testing._internal.common_distributed import (
21    MultiProcessTestCase,
22    requires_nccl,
23    skip_if_lt_x_gpu,
24)
25from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
26
27
28if TEST_WITH_DEV_DBG_ASAN:
29    print("Multiprocessing spawn is not compatible with dev/dbg asan", file=sys.stderr)
30    sys.exit(0)
31
32
33def gpus_for_rank(world_size):
34    visible_devices = list(range(torch.cuda.device_count()))
35    gpus_per_process = torch.cuda.device_count() // world_size
36    gpus_for_rank = []
37    for rank in range(world_size):
38        gpus_for_rank.append(
39            visible_devices[rank * gpus_per_process : (rank + 1) * gpus_per_process]
40        )
41    return gpus_for_rank
42
43
44class Task(nn.Module):
45    def __init__(self) -> None:
46        super().__init__()
47        torch.manual_seed(0)
48        self.p = nn.Parameter(torch.randn(40, 20))
49
50    def forward(self, x):
51        return self.p * x
52
53
54class TestDdpCommHook(nn.Module):
55    def __init__(self) -> None:
56        super().__init__()
57        self.t0 = Task()
58
59    def forward(self, x, rank):
60        return self.t0(x ** (1 + rank))
61
62
63class DistributedDataParallelCommHookTest(MultiProcessTestCase):
64    def setUp(self):
65        super().setUp()
66        self._spawn_processes()
67
68    def tearDown(self):
69        try:
70            os.remove(self.file_name)
71        except OSError:
72            pass
73
74    def _get_process_group_nccl(self):
75        store = dist.FileStore(self.file_name, self.world_size)
76        dist.init_process_group(
77            backend="nccl",
78            world_size=self.world_size,
79            rank=self.rank,
80            store=store,
81        )
82        return dist.distributed_c10d._get_default_group()
83
84    @property
85    def world_size(self):
86        return 2
87
88    def _local_model(self):
89        local_model = TestDdpCommHook().cpu()
90
91        return local_model
92
93    def _get_grads(self, process_group, hook_type=None):
94        device_id = gpus_for_rank(self.world_size)[self.rank][0]
95        gpu_model = DistributedDataParallel(
96            TestDdpCommHook().to(device_id),
97            device_ids=[device_id],
98            process_group=process_group,
99        )
100
101        # Register DDP Communication Hook if defined
102        if hook_type is not None:
103            register_ddp_comm_hook(
104                comm_hook_type=hook_type, model=gpu_model, state=process_group
105            )
106
107        return self._run_and_get_grads(gpu_model)
108
109    def _run_and_get_grads(self, model):
110        torch.manual_seed(2020)
111        input = torch.randn(40, 20)
112        # Run forward
113        output = model(input, self.rank)
114
115        # Run backward
116        output.mean().backward()
117
118        # The only layer
119        param = next(model.parameters())
120        return param.grad
121
122    @requires_nccl()
123    @skip_if_lt_x_gpu(2)
124    def test_ddp_comm_hook_allreduce_hook(self):
125        """
126        This unit test verifies the ``allreduce`` hook registered case gives same result
127        with no hook registered case.
128        """
129        process_group = self._get_process_group_nccl()
130
131        # No hook registered case, get the reference grads.
132        reference_grads = self._get_grads(process_group, None)
133        # Register hook case, get the hook grads.
134        hook_grads = self._get_grads(process_group, DDPCommHookType.ALLREDUCE)
135
136        torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=0)
137
138    @requires_nccl()
139    @skip_if_lt_x_gpu(2)
140    def test_ddp_comm_hook_fp16compress_hook(self):
141        """
142        This unit test verifies the ``fp16 compress`` hook registered case
143        gives close result with no hook registered case.
144        """
145        process_group = self._get_process_group_nccl()
146
147        # No hook registered case, get the reference grads.
148        reference_grads = self._get_grads(process_group, None)
149        # Register hook case, get the hook grads.
150        hook_grads = self._get_grads(process_group, DDPCommHookType.FP16_COMPRESS)
151
152        torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=1e-4)
153
154    @requires_nccl()
155    @skip_if_lt_x_gpu(2)
156    def test_ddp_comm_hook_quantize_per_tensor_hook(self):
157        """
158        This unit test verifies the ``quantize per tensor`` hook registered case
159        gives close result with no hook registered case.
160        """
161        process_group = self._get_process_group_nccl()
162
163        # No hook registered case, get the reference grads.
164        reference_grads = self._get_grads(process_group, None)
165        # Register hook case, get the hook grads.
166        hook_grads = self._get_grads(process_group, DDPCommHookType.QUANTIZE_PER_TENSOR)
167
168        torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=1e-4)
169
170    @requires_nccl()
171    @skip_if_lt_x_gpu(2)
172    def test_ddp_comm_hook_quantize_per_channel_hook(self):
173        """
174        This unit test verifies the ``quantize per channel`` hook registered case
175        gives close result with no hook registered case.
176        """
177        process_group = self._get_process_group_nccl()
178
179        # No hook registered case, get the reference grads.
180        reference_grads = self._get_grads(process_group, None)
181        # Register hook case, get the hook grads.
182        hook_grads = self._get_grads(
183            process_group, DDPCommHookType.QUANTIZE_PER_CHANNEL
184        )
185
186        torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=1e-4)
187
188    @requires_nccl()
189    @skip_if_lt_x_gpu(2)
190    def test_ddp_comm_hook_noop_hook(self):
191        """
192        This unit test verifies the ``noop`` hook registered case and a subsequent allreduce
193        gives same result with no hook registered case.
194        """
195        process_group = self._get_process_group_nccl()
196
197        # No hook registered case, get the reference grads.
198        reference_grads = self._get_grads(process_group, None)
199        # Register hook case, get the hook grads.
200        hook_grads = self._get_grads(process_group, DDPCommHookType.NOOP)
201        # Apply a subsequent allreduce to average grads.
202        hook_grads.div_(self.world_size)
203        dist.all_reduce(hook_grads, group=process_group)
204
205        torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=0)
206
207    @requires_nccl()
208    @skip_if_lt_x_gpu(2)
209    def test_is_last_hook(self):
210        process_group = self._get_process_group_nccl()
211
212        def hook(flags, bucket):
213            flags.append(bucket.is_last())
214            fut = torch.futures.Future()
215            fut.set_result(bucket.buffer())
216            return fut
217
218        flags = []
219        device_id = gpus_for_rank(self.world_size)[self.rank][0]
220        model = nn.Sequential(
221            nn.Linear(2, 4000, bias=False),
222            *[nn.Linear(4000, 4000, bias=False) for _ in range(10)],
223        )
224        gpu_model = DistributedDataParallel(
225            model.to(device_id),
226            device_ids=[device_id],
227            process_group=process_group,
228        )
229        gpu_model.register_comm_hook(state=flags, hook=hook)
230        input = torch.randn(10, 2)
231        gpu_model(input).sum().backward()
232        self.assertTrue(flags[-1])
233        self.assertFalse(any(flags[:-1]))
234
235
236if __name__ == "__main__":
237    assert (
238        not torch.cuda._initialized
239    ), "test_distributed must not have initialized CUDA context on main process"
240
241    run_tests()
242