1# Copyright (c) Meta Platforms, Inc. and affiliates 2# Owner(s): ["oncall: distributed"] 3import copy 4 5from model_registry import MLPModule 6 7import torch 8from torch.distributed.pipelining._backward import ( 9 stage_backward, 10 stage_backward_input, 11 stage_backward_weight, 12) 13from torch.testing._internal.common_utils import run_tests, TestCase 14 15 16d_hid = 512 17batch_size = 256 18 19 20class StageBackwardTests(TestCase): 21 def test_stage_backward(self): 22 # MLP as a stage module 23 mod = MLPModule(d_hid) 24 x = torch.randn(batch_size, d_hid) 25 # As in a pipeline stage, the inputs to this stage requires gradients 26 x.requires_grad_(True) 27 target = torch.randn(batch_size, d_hid) 28 loss_fn = torch.nn.MSELoss(reduction="sum") 29 30 # Make a copy 31 ref_mod = copy.deepcopy(mod) 32 ref_x = x.detach().requires_grad_(x.requires_grad) 33 ref_target = target.detach() 34 35 # Forward and backward in stage manner 36 out = mod(x) 37 loss = loss_fn(out, target) 38 grad_inputs = stage_backward( 39 stage_output=loss, 40 output_grads=None, 41 input_values=(x,), 42 ) 43 44 # Run reference 45 ref_out = ref_mod(ref_x) 46 ref_loss = loss_fn(ref_out, ref_target) 47 ref_loss.backward() 48 49 torch.testing.assert_close(grad_inputs[0], ref_x.grad) 50 51 # Every rank checks gradients 52 for name, p in mod.named_parameters(): 53 ref_p = ref_mod.get_parameter(name) 54 try: 55 torch.testing.assert_close(p.grad, ref_p.grad) 56 except AssertionError: 57 print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}") 58 raise 59 60 def test_stage_backward_input(self): 61 # MLP as a stage module 62 mod = MLPModule(d_hid) 63 x = torch.randn(batch_size, d_hid) 64 # As in a pipeline stage, the inputs to this stage requires gradients 65 x.requires_grad_(True) 66 target = torch.randn(batch_size, d_hid) 67 loss_fn = torch.nn.MSELoss(reduction="sum") 68 69 # Make a copy 70 ref_mod = copy.deepcopy(mod) 71 ref_x = x.detach().requires_grad_(x.requires_grad) 72 ref_target = target.detach() 73 74 # Forward, then backward of loss with respect to inputs 75 out = mod(x) 76 loss = loss_fn(out, target) 77 dinputs, param_groups = stage_backward_input( 78 stage_outputs=(loss,), 79 output_grads=None, 80 input_values=[x], 81 weights=mod.parameters(), 82 ) 83 84 # Run reference 85 ref_out = ref_mod(ref_x) 86 ref_loss = loss_fn(ref_out, ref_target) 87 ref_loss.backward() 88 89 torch.testing.assert_close(x.grad, ref_x.grad) 90 torch.testing.assert_close(dinputs[0], ref_x.grad) 91 for name, p in mod.named_parameters(): 92 # Check that the weight gradients were not updated 93 self.assertEqual(p.grad, None) 94 95 def test_stage_backward_weight(self): 96 # MLP as a stage module 97 mod = MLPModule(d_hid) 98 x = torch.randn(batch_size, d_hid) 99 # As in a pipeline stage, the inputs to this stage requires gradients 100 x.requires_grad_(True) 101 target = torch.randn(batch_size, d_hid) 102 loss_fn = torch.nn.MSELoss(reduction="sum") 103 104 # Make a copy 105 ref_mod = copy.deepcopy(mod) 106 ref_x = x.detach().requires_grad_(x.requires_grad) 107 ref_target = target.detach() 108 109 # Forward, then backward of loss with respect to inputs 110 out = mod(x) 111 loss = loss_fn(out, target) 112 dinputs, param_groups = stage_backward_input( 113 stage_outputs=(loss,), 114 output_grads=None, 115 input_values=[x], 116 weights=mod.parameters(), 117 ) 118 119 # backward of loss with respect to weights 120 dweights = stage_backward_weight(mod.parameters(), param_groups) 121 122 # Run reference 123 ref_out = ref_mod(ref_x) 124 ref_loss = loss_fn(ref_out, ref_target) 125 ref_loss.backward() 126 127 # Every rank checks gradients 128 for name, p in mod.named_parameters(): 129 ref_p = ref_mod.get_parameter(name) 130 try: 131 torch.testing.assert_close(p.grad, ref_p.grad) 132 except AssertionError: 133 print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}") 134 raise 135 136 def test_stage_backward_weight_multiple_iters(self): 137 # MLP as a stage module 138 mod = MLPModule(d_hid) 139 inputs = [] 140 for _ in range(10): 141 x = torch.randn(batch_size, d_hid) 142 inputs.append(x) 143 # As in a pipeline stage, the inputs to this stage requires gradients 144 x.requires_grad_(True) 145 146 target = torch.randn(batch_size, d_hid) 147 loss_fn = torch.nn.MSELoss(reduction="sum") 148 149 # Make a copy 150 ref_mod = copy.deepcopy(mod) 151 ref_inputs = [] 152 for x in inputs: 153 ref_inputs.append(x.detach().requires_grad_(x.requires_grad)) 154 ref_target = target.detach() 155 156 # Forward, then backward of loss with respect to inputs 157 for x in inputs: 158 out = mod(x) 159 loss = loss_fn(out, target) 160 dinputs, param_groups = stage_backward_input( 161 stage_outputs=(loss,), 162 output_grads=None, 163 input_values=[x], 164 weights=mod.parameters(), 165 ) 166 167 # backward of loss with respect to weights 168 stage_backward_weight(mod.parameters(), param_groups) 169 170 # Run reference 171 for ref_x in ref_inputs: 172 ref_out = ref_mod(ref_x) 173 ref_loss = loss_fn(ref_out, ref_target) 174 ref_loss.backward() 175 176 # Every rank checks gradients 177 for name, p in mod.named_parameters(): 178 ref_p = ref_mod.get_parameter(name) 179 try: 180 torch.testing.assert_close(p.grad, ref_p.grad) 181 except AssertionError: 182 print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}") 183 raise 184 185 186if __name__ == "__main__": 187 run_tests() 188