xref: /aosp_15_r20/external/pytorch/test/distributed/test_c10d_pypg.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import os
4import weakref
5
6import test_c10d_common
7
8import torch
9import torch.distributed as dist
10import torch.nn as nn
11from torch._C._distributed_c10d import _create_work_from_future
12from torch.futures import Future
13from torch.nn.parallel import DistributedDataParallel as DDP
14from torch.testing._internal.common_distributed import MultiProcessTestCase
15from torch.testing._internal.common_utils import run_tests
16
17
18def create_work(result):
19    future = Future()
20    future.set_result(result)
21    return _create_work_from_future(future)
22
23
24class MyWork(dist._Work):
25    def __init__(self, result, pg):
26        super().__init__()
27        self.result_ = result
28        self.future_ = torch.futures.Future()
29        self.future_.set_result(result)
30        self.pg_ = weakref.ref(pg)
31
32    def wait(self, timeout):
33        self.pg_().wait_count += 1
34        return True
35
36    def get_future(self):
37        self.pg_().get_future_count += 1
38        return self.future_
39
40
41class LonelyRankProcessGroup(dist.ProcessGroup):
42    """
43    This PG only supports world_size of 1
44    """
45
46    def __init__(self, rank, world, use_wrapper):
47        super().__init__(rank, world)
48        assert rank == 0
49        assert world == 1
50
51        self._rank = rank
52        self._world = world
53        self.wait_count = 0
54        self.get_future_count = 0
55        self.use_wrapper = use_wrapper
56        self._work = []
57
58    def broadcast(self, tensor_list, opts):
59        if self.use_wrapper:
60            return create_work(tensor_list)
61        res = MyWork(tensor_list, self)
62        self._work.append(res)
63        return res
64
65    def allgather(self, output_tensors, input_tensor, opts):
66        for o, i in zip(output_tensors[0], input_tensor):
67            o.copy_(i)
68        if self.use_wrapper:
69            return create_work(output_tensors)
70
71        res = MyWork(output_tensors, self)
72        self._work.append(res)
73
74        return res
75
76    def allreduce(self, tensors, opts):
77        if self.use_wrapper:
78            return create_work(tensors)
79        res = MyWork(tensors, self)
80        self._work.append(res)
81        return res
82
83    def size(self):
84        return self._world
85
86    def getBackendName(self):
87        return "lonely-pg"
88
89    def __repr__(self):
90        return f"PLG w:{self._world} r:{self._rank}"
91
92
93# We cannot use parametrize as some tests are defined on the base class and use _get_process_group
94class AbstractDDPSingleRank(test_c10d_common.CommonDistributedDataParallelTest):
95    def setUp(self):
96        super().setUp()
97        self._spawn_processes()
98
99    @property
100    def world_size(self):
101        return 1
102
103    def tearDown(self):
104        super().tearDown()
105        try:
106            os.remove(self.file_name)
107        except OSError:
108            pass
109
110    def _get_process_group(self):
111        return LonelyRankProcessGroup(self.rank, self.world_size, self.use_wrapper)
112
113    def test_ddp_invoke_work_object(self):
114        pg = self._get_process_group()
115
116        torch.manual_seed(123)
117        model = nn.Sequential(nn.Linear(2, 2), nn.ReLU())
118        wrapped_model = model
119        input_tensor = torch.rand(2)
120        model = DDP(model, process_group=pg)
121        model(input_tensor).sum().backward()
122
123        ddp_grad = wrapped_model[0].bias.grad.clone()
124
125        wrapped_model.zero_grad()
126        wrapped_model(input_tensor).sum().backward()
127        self.assertEqual(wrapped_model[0].bias.grad, ddp_grad)
128        if not self.use_wrapper:
129            self.assertTrue(pg.wait_count > 0)
130            self.assertTrue(pg.get_future_count > 0)
131
132    def test_ddp_with_pypg(self):
133        pg = self._get_process_group()
134
135        self._test_ddp_with_process_group(pg, [torch.device("cpu")], device_ids=None)
136
137    def test_ddp_with_pypg_with_grad_views(self):
138        pg = self._get_process_group()
139
140        self._test_ddp_with_process_group(
141            pg, [torch.device("cpu")], device_ids=None, gradient_as_bucket_view=True
142        )
143
144
145class TestDDPWithWorkSubclass(AbstractDDPSingleRank, MultiProcessTestCase):
146    @property
147    def use_wrapper(self):
148        return False
149
150
151class TestDDPWithWorkWrapper(AbstractDDPSingleRank, MultiProcessTestCase):
152    @property
153    def use_wrapper(self):
154        return True
155
156
157if __name__ == "__main__":
158    run_tests()
159