xref: /aosp_15_r20/external/pytorch/test/distributed/pipelining/test_backward.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Copyright (c) Meta Platforms, Inc. and affiliates
2# Owner(s): ["oncall: distributed"]
3import copy
4
5from model_registry import MLPModule
6
7import torch
8from torch.distributed.pipelining._backward import (
9    stage_backward,
10    stage_backward_input,
11    stage_backward_weight,
12)
13from torch.testing._internal.common_utils import run_tests, TestCase
14
15
16d_hid = 512
17batch_size = 256
18
19
20class StageBackwardTests(TestCase):
21    def test_stage_backward(self):
22        # MLP as a stage module
23        mod = MLPModule(d_hid)
24        x = torch.randn(batch_size, d_hid)
25        # As in a pipeline stage, the inputs to this stage requires gradients
26        x.requires_grad_(True)
27        target = torch.randn(batch_size, d_hid)
28        loss_fn = torch.nn.MSELoss(reduction="sum")
29
30        # Make a copy
31        ref_mod = copy.deepcopy(mod)
32        ref_x = x.detach().requires_grad_(x.requires_grad)
33        ref_target = target.detach()
34
35        # Forward and backward in stage manner
36        out = mod(x)
37        loss = loss_fn(out, target)
38        grad_inputs = stage_backward(
39            stage_output=loss,
40            output_grads=None,
41            input_values=(x,),
42        )
43
44        # Run reference
45        ref_out = ref_mod(ref_x)
46        ref_loss = loss_fn(ref_out, ref_target)
47        ref_loss.backward()
48
49        torch.testing.assert_close(grad_inputs[0], ref_x.grad)
50
51        # Every rank checks gradients
52        for name, p in mod.named_parameters():
53            ref_p = ref_mod.get_parameter(name)
54            try:
55                torch.testing.assert_close(p.grad, ref_p.grad)
56            except AssertionError:
57                print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}")
58                raise
59
60    def test_stage_backward_input(self):
61        # MLP as a stage module
62        mod = MLPModule(d_hid)
63        x = torch.randn(batch_size, d_hid)
64        # As in a pipeline stage, the inputs to this stage requires gradients
65        x.requires_grad_(True)
66        target = torch.randn(batch_size, d_hid)
67        loss_fn = torch.nn.MSELoss(reduction="sum")
68
69        # Make a copy
70        ref_mod = copy.deepcopy(mod)
71        ref_x = x.detach().requires_grad_(x.requires_grad)
72        ref_target = target.detach()
73
74        # Forward, then backward of loss with respect to inputs
75        out = mod(x)
76        loss = loss_fn(out, target)
77        dinputs, param_groups = stage_backward_input(
78            stage_outputs=(loss,),
79            output_grads=None,
80            input_values=[x],
81            weights=mod.parameters(),
82        )
83
84        # Run reference
85        ref_out = ref_mod(ref_x)
86        ref_loss = loss_fn(ref_out, ref_target)
87        ref_loss.backward()
88
89        torch.testing.assert_close(x.grad, ref_x.grad)
90        torch.testing.assert_close(dinputs[0], ref_x.grad)
91        for name, p in mod.named_parameters():
92            # Check that the weight gradients were not updated
93            self.assertEqual(p.grad, None)
94
95    def test_stage_backward_weight(self):
96        # MLP as a stage module
97        mod = MLPModule(d_hid)
98        x = torch.randn(batch_size, d_hid)
99        # As in a pipeline stage, the inputs to this stage requires gradients
100        x.requires_grad_(True)
101        target = torch.randn(batch_size, d_hid)
102        loss_fn = torch.nn.MSELoss(reduction="sum")
103
104        # Make a copy
105        ref_mod = copy.deepcopy(mod)
106        ref_x = x.detach().requires_grad_(x.requires_grad)
107        ref_target = target.detach()
108
109        # Forward, then backward of loss with respect to inputs
110        out = mod(x)
111        loss = loss_fn(out, target)
112        dinputs, param_groups = stage_backward_input(
113            stage_outputs=(loss,),
114            output_grads=None,
115            input_values=[x],
116            weights=mod.parameters(),
117        )
118
119        # backward of loss with respect to weights
120        dweights = stage_backward_weight(mod.parameters(), param_groups)
121
122        # Run reference
123        ref_out = ref_mod(ref_x)
124        ref_loss = loss_fn(ref_out, ref_target)
125        ref_loss.backward()
126
127        # Every rank checks gradients
128        for name, p in mod.named_parameters():
129            ref_p = ref_mod.get_parameter(name)
130            try:
131                torch.testing.assert_close(p.grad, ref_p.grad)
132            except AssertionError:
133                print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}")
134                raise
135
136    def test_stage_backward_weight_multiple_iters(self):
137        # MLP as a stage module
138        mod = MLPModule(d_hid)
139        inputs = []
140        for _ in range(10):
141            x = torch.randn(batch_size, d_hid)
142            inputs.append(x)
143            # As in a pipeline stage, the inputs to this stage requires gradients
144            x.requires_grad_(True)
145
146        target = torch.randn(batch_size, d_hid)
147        loss_fn = torch.nn.MSELoss(reduction="sum")
148
149        # Make a copy
150        ref_mod = copy.deepcopy(mod)
151        ref_inputs = []
152        for x in inputs:
153            ref_inputs.append(x.detach().requires_grad_(x.requires_grad))
154        ref_target = target.detach()
155
156        # Forward, then backward of loss with respect to inputs
157        for x in inputs:
158            out = mod(x)
159            loss = loss_fn(out, target)
160            dinputs, param_groups = stage_backward_input(
161                stage_outputs=(loss,),
162                output_grads=None,
163                input_values=[x],
164                weights=mod.parameters(),
165            )
166
167            # backward of loss with respect to weights
168            stage_backward_weight(mod.parameters(), param_groups)
169
170        # Run reference
171        for ref_x in ref_inputs:
172            ref_out = ref_mod(ref_x)
173            ref_loss = loss_fn(ref_out, ref_target)
174            ref_loss.backward()
175
176        # Every rank checks gradients
177        for name, p in mod.named_parameters():
178            ref_p = ref_mod.get_parameter(name)
179            try:
180                torch.testing.assert_close(p.grad, ref_p.grad)
181            except AssertionError:
182                print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}")
183                raise
184
185
186if __name__ == "__main__":
187    run_tests()
188