1 #include <torch/csrc/autograd/autograd.h>
2 #include <torch/csrc/autograd/variable.h>
3
4 #ifndef AT_PER_OPERATOR_HEADERS
5 #include <ATen/Functions.h>
6 #else
7 #include <ATen/ops/ones_like.h>
8 #endif
9
10 #include <torch/csrc/autograd/edge.h>
11 #include <torch/csrc/autograd/engine.h>
12 #include <torch/csrc/autograd/function.h>
13 #include <torch/csrc/autograd/functions/basic_ops.h>
14
15 #include <c10/util/irange.h>
16
17 namespace torch::autograd {
18
19 // NB: This code duplicates existing logic at torch/autograd/__init__.py and
20 // torch._C._EngineBase.run_backward in torch/csrc/autograd/python_engine.cpp
21 // This is a purely C++ API for Autograd without any dependencies on python
22 // it can be exposed in PyTorch C++ API and TorchScript. We will need to
23 // maintain the logic equality of this file and the python file together if one
24 // changes.
25 // TODO: Make the Python API above to just call this C++ API.
_make_grads(const variable_list & outputs,const variable_list & grad_outputs)26 static variable_list _make_grads(
27 const variable_list& outputs,
28 const variable_list& grad_outputs) {
29 size_t num_tensors = outputs.size();
30 size_t num_gradients = grad_outputs.size();
31 variable_list new_grads;
32 new_grads.reserve(num_tensors);
33 if (grad_outputs.empty()) {
34 for (const Variable& output : outputs) {
35 if (output.requires_grad()) {
36 TORCH_CHECK(
37 output.numel() == 1,
38 "grad can be implicitly created only for scalar outputs");
39 TORCH_CHECK(
40 c10::isFloatingType(output.scalar_type()),
41 "grad can be computed only for real scalar outputs but got ",
42 output.scalar_type());
43 new_grads.emplace_back(
44 at::ones_like(output, LEGACY_CONTIGUOUS_MEMORY_FORMAT));
45 }
46 }
47 } else {
48 TORCH_CHECK(
49 num_tensors == num_gradients,
50 "got ",
51 num_tensors,
52 " tensors and ",
53 num_gradients,
54 " gradients");
55 for (const auto i : c10::irange(outputs.size())) {
56 const Variable& output = outputs[i];
57 const Variable& grad_output = grad_outputs[i];
58 if (!grad_output.defined()) {
59 if (output.requires_grad()) {
60 TORCH_CHECK(
61 output.numel() == 1,
62 "grad can be implicitly created only for scalar outputs");
63 TORCH_CHECK(
64 c10::isFloatingType(output.scalar_type()),
65 "grad can be computed only for real scalar outputs but got ",
66 output.scalar_type());
67 new_grads.emplace_back(
68 at::ones_like(output, LEGACY_CONTIGUOUS_MEMORY_FORMAT));
69 }
70 } else {
71 TORCH_CHECK(
72 grad_output.is_complex() == output.is_complex(),
73 "For complex Tensors, both grad_output and output are required ",
74 "to have the same dtype. Mismatch in dtype: grad_output[",
75 grad_output,
76 "] has a dtype of ",
77 grad_output.scalar_type(),
78 " and output[",
79 output,
80 "] has a dtype of ",
81 output.scalar_type(),
82 ".");
83 // grad output is defined, just append to the new_grads
84 new_grads.emplace_back(grad_output);
85 }
86 }
87 }
88 return new_grads;
89 }
run_backward(const variable_list & outputs,const variable_list & grad_outputs,bool keep_graph,bool create_graph,const variable_list & inputs,bool allow_unused,bool accumulate_grad)90 static variable_list run_backward(
91 const variable_list& outputs,
92 const variable_list& grad_outputs,
93 bool keep_graph,
94 bool create_graph,
95 const variable_list& inputs,
96 bool allow_unused,
97 bool accumulate_grad) {
98 size_t num_tensors = outputs.size();
99 edge_list roots;
100 roots.reserve(num_tensors);
101 for (const auto i : c10::irange(num_tensors)) {
102 const Variable& output = outputs[i];
103 auto gradient_edge = impl::gradient_edge(output);
104 TORCH_CHECK(
105 gradient_edge.function,
106 "element ",
107 i,
108 " of tensors does not require grad and does not have a grad_fn");
109 roots.push_back(std::move(gradient_edge));
110 }
111
112 edge_list output_edges;
113 if (!inputs.empty()) {
114 size_t num_inputs = inputs.size();
115 output_edges.reserve(num_inputs);
116 for (const auto i : c10::irange(num_inputs)) {
117 const Variable& input = inputs[i];
118 const auto output_nr = input.output_nr();
119 auto grad_fn = input.grad_fn();
120 if (!grad_fn) {
121 grad_fn = impl::try_get_grad_accumulator(input);
122 }
123 if (accumulate_grad) {
124 input.retain_grad();
125 }
126 TORCH_CHECK(
127 input.requires_grad(),
128 "element ",
129 i,
130 " of the input tensors does not require grad");
131 if (!grad_fn) {
132 // See NOTE [ Autograd Unreachable Input ] for details
133 output_edges.emplace_back(std::make_shared<Identity>(), 0);
134 } else {
135 output_edges.emplace_back(grad_fn, output_nr);
136 }
137 }
138 }
139
140 variable_list grad_inputs = Engine::get_default_engine().execute(
141 roots,
142 grad_outputs,
143 keep_graph,
144 create_graph,
145 accumulate_grad,
146 output_edges);
147 // check if grad_inputs contains None or not base on the allow_unused flag
148 if (!inputs.empty() && !allow_unused) {
149 size_t num_inputs = inputs.size();
150 for (const auto i : c10::irange(num_inputs)) {
151 TORCH_CHECK(
152 grad_inputs[i].defined(),
153 "element ",
154 i,
155 "of the "
156 "differentiated Tensors appears to not have been used "
157 "in the graph. Set allow_unused=True if this is the "
158 "desired behavior.");
159 }
160 }
161 return grad_inputs;
162 }
163
backward(const variable_list & tensors,const variable_list & grad_tensors,std::optional<bool> retain_graph,bool create_graph,const variable_list & inputs)164 void backward(
165 const variable_list& tensors,
166 const variable_list& grad_tensors,
167 std::optional<bool> retain_graph,
168 bool create_graph,
169 const variable_list& inputs) {
170 variable_list gradients = _make_grads(tensors, grad_tensors);
171 if (!retain_graph) {
172 retain_graph = create_graph;
173 }
174 run_backward(
175 tensors,
176 gradients,
177 retain_graph.value(),
178 create_graph,
179 inputs,
180 /*allow_unused=*/true,
181 /*accumulate_grad=*/true);
182 }
183
grad(const variable_list & outputs,const variable_list & inputs,const variable_list & grad_outputs,std::optional<bool> retain_graph,bool create_graph,bool allow_unused)184 variable_list grad(
185 const variable_list& outputs,
186 const variable_list& inputs,
187 const variable_list& grad_outputs,
188 std::optional<bool> retain_graph,
189 bool create_graph,
190 bool allow_unused) {
191 variable_list gradients = _make_grads(outputs, grad_outputs);
192 if (!retain_graph) {
193 retain_graph = create_graph;
194 }
195 return run_backward(
196 outputs,
197 gradients,
198 retain_graph.value(),
199 create_graph,
200 inputs,
201 allow_unused,
202 /*accumulate_grad=*/false);
203 }
204
205 namespace forward_ad {
206
enter_dual_level()207 uint64_t enter_dual_level() {
208 return ForwardADLevel::get_next_idx();
209 }
210
exit_dual_level(uint64_t level)211 void exit_dual_level(uint64_t level) {
212 ForwardADLevel::release_idx(level);
213 }
214
215 } // namespace forward_ad
216
217 } // namespace torch::autograd
218