xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/VariableTypeManual.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/RedispatchFunctions.h>
2 #include <ATen/TracerMode.h>
3 #include <ATen/core/op_registration/op_registration.h>
4 #include <c10/core/ScalarType.h>
5 #include <c10/util/irange.h>
6 #include <torch/csrc/autograd/FunctionsManual.h>
7 #include <torch/csrc/autograd/VariableTypeUtils.h>
8 #include <torch/csrc/autograd/autograd.h>
9 #include <torch/csrc/autograd/functions/utils.h>
10 #include <torch/csrc/autograd/generated/VariableType.h>
11 #include <torch/csrc/autograd/generated/ViewFuncs.h>
12 #include <torch/library.h>
13 #include <optional>
14 
15 #include <utility>
16 
17 using namespace at;
18 using namespace torch::autograd::generated;
19 using torch::autograd::as_view;
20 using torch::autograd::CreationMeta;
21 
22 namespace torch {
23 
24 namespace autograd::VariableType {
25 
allTypesForBackends(at::ArrayRef<at::Backend> backends)26 static std::vector<at::DeprecatedTypeProperties*> allTypesForBackends(
27     at::ArrayRef<at::Backend> backends) {
28   std::vector<DeprecatedTypeProperties*> res;
29   res.reserve(backends.size());
30   for (auto p : backends) {
31     for (const auto s :
32          c10::irange(static_cast<int64_t>(ScalarType::NumOptions))) {
33       auto& type = getDeprecatedTypeProperties(
34           static_cast<Backend>(p), static_cast<ScalarType>(s));
35       res.emplace_back(&type);
36     }
37   }
38   return res;
39 }
40 
allCPUTypes()41 std::vector<at::DeprecatedTypeProperties*> allCPUTypes() {
42   return allTypesForBackends({Backend::CPU, Backend::SparseCPU});
43 }
44 
allCUDATypes()45 std::vector<at::DeprecatedTypeProperties*> allCUDATypes() {
46   at::globalContext().lazyInitCUDA();
47   return allTypesForBackends({Backend::CUDA, Backend::SparseCUDA});
48 }
49 
allXPUTypes()50 std::vector<at::DeprecatedTypeProperties*> allXPUTypes() {
51   return allTypesForBackends({Backend::XPU, Backend::SparseXPU});
52 }
53 
allPrivateUser1Types()54 std::vector<at::DeprecatedTypeProperties*> allPrivateUser1Types() {
55   at::globalContext().lazyInitPrivateUse1();
56   return allTypesForBackends(
57       {Backend::PrivateUse1, Backend::SparsePrivateUse1});
58 }
59 
60 namespace {
checked_cast_variable(const Tensor & t,const char * name,int pos)61 const Variable& checked_cast_variable(
62     const Tensor& t,
63     const char* name,
64     int pos) {
65   if (!t.defined()) {
66     AT_ERROR(
67         "Expected a proper Tensor but got None (or an undefined Tensor in C++) ",
68         "for argument #",
69         pos,
70         " '",
71         name,
72         "'");
73   }
74   return t;
75 }
76 
checked_cast_variable(Tensor & t,const char * name,int pos)77 Variable& checked_cast_variable(Tensor& t, const char* name, int pos) {
78   if (!t.defined()) {
79     AT_ERROR(
80         "Expected a proper Tensor but got None (or an undefined Tensor in C++) ",
81         "for argument #",
82         pos,
83         " '",
84         name,
85         "'");
86   }
87   return t;
88 }
89 } // namespace
90 
unpack(const Tensor & t,const char * name,int pos)91 const Tensor& unpack(const Tensor& t, const char* name, int pos) {
92   return checked_cast_variable(t, name, pos);
93 }
94 
unpack(Tensor & t,const char * name,int pos)95 Tensor& unpack(Tensor& t, const char* name, int pos) {
96   return checked_cast_variable(t, name, pos);
97 }
98 
unpack_opt(const Tensor & t,const char * name,int pos)99 Tensor unpack_opt(const Tensor& t, const char* name, int pos) {
100   if (!t.defined()) {
101     return Tensor();
102   }
103   return unpack(t, name, pos);
104 }
105 
unpack(const at::ITensorListRef & tl,const char * name,int pos)106 std::vector<at::Tensor> unpack(
107     const at::ITensorListRef& tl,
108     const char* name,
109     int pos) {
110   std::vector<at::Tensor> ret;
111   ret.reserve(tl.size());
112   for (const auto& t : tl) {
113     ret.push_back(t);
114   }
115   return ret;
116 }
117 
118 namespace {
119 
120 // Taken from codegened version
_fw_primal(c10::DispatchKeySet ks,const Tensor & self,int64_t level)121 Tensor _fw_primal(c10::DispatchKeySet ks, const Tensor& self, int64_t level) {
122   auto& self_ = unpack(self, "self", 0);
123   std::shared_ptr<Identity> grad_fn;
124   if (compute_requires_grad(self)) {
125     grad_fn = std::make_shared<Identity>();
126     grad_fn->set_next_edges(collect_next_edges(self));
127   }
128 
129   auto result = ([&]() {
130     at::AutoDispatchBelowAutograd guard;
131     return at::redispatch::_fw_primal(
132         ks & c10::after_autograd_keyset, self_, level);
133   })();
134 
135   if (grad_fn) {
136     set_history(flatten_tensor_args(result), grad_fn);
137   }
138   if (isFwGradDefined(self)) {
139     // Modified from original codegen
140     // We explicitly want to ignore the forward grad at the given level
141     TORCH_CHECK(level == 0, "Invalid level given to _fw_primal");
142     // End modified from original codegen
143   }
144   return result;
145 }
146 
147 // NB: We need a manual variable type kernel so that set_fw_grad properly
148 // detects that _make_dual is not a forward-differentiable view
149 //
150 // This function can be used to create a dual Tensor that holds a tangent to
151 // compute forward mode gradients. Note that the dual Tensor's primal is a view
152 // of the given primal and the given tangent is used as-is. This function is
153 // backward differentiable.
_make_dual(c10::DispatchKeySet ks,const Tensor & primal,const Tensor & tangent,int64_t level)154 Tensor _make_dual(
155     c10::DispatchKeySet ks,
156     const Tensor& primal,
157     const Tensor& tangent,
158     int64_t level) {
159   TORCH_CHECK(
160       !primal._fw_grad(level).defined(),
161       "Making a dual Tensor based on a Tensor that "
162       "already has a forward gradient at the same level ",
163       level,
164       " is not supported.");
165   auto& primal_ = unpack(primal, "primal", 0);
166   auto& tangent_ = unpack(tangent, "tangent", 0);
167   std::shared_ptr<ViewBackward0> grad_fn;
168   if (compute_requires_grad(primal_)) {
169     grad_fn = std::make_shared<ViewBackward0>();
170     grad_fn->self_sym_sizes = primal_.sym_sizes().vec();
171     grad_fn->set_next_edges(collect_next_edges(primal_));
172   }
173 
174   auto result = ([&]() {
175     at::AutoDispatchBelowAutograd guard;
176     return at::redispatch::_make_dual(
177         ks & c10::after_autograd_keyset, primal_, tangent_, level);
178   })();
179 
180   if (grad_fn) {
181     set_history(flatten_tensor_args(result), grad_fn);
182   }
183 
184   TORCH_CHECK(level == 0, "Invalid level given to _make_dual");
185   result._set_fw_grad(tangent_, level, /* is_inplace_op */ false);
186   return result;
187 }
188 
189 // We don't have an outplace copy, so this can't be generated automatically
copy_(c10::DispatchKeySet ks,Tensor & self,const Tensor & src,bool non_blocking)190 Tensor& copy_(
191     c10::DispatchKeySet ks,
192     Tensor& self,
193     const Tensor& src,
194     bool non_blocking) {
195   // TODO: once copy is exposed in Declarations.yaml we may be able to bind
196   // it automatically
197   auto& self_ = unpack(self, "self", 0);
198   auto& src_ = unpack(src, "src", 1);
199   std::shared_ptr<CopyBackwards> grad_fn;
200   auto requires_grad = compute_requires_grad(self, src);
201   requires_grad &= isDifferentiableType(self.scalar_type());
202   check_inplace(self, requires_grad);
203   if (requires_grad) {
204     grad_fn = std::make_shared<CopyBackwards>();
205     grad_fn->set_next_edges(collect_next_edges(self, src));
206     grad_fn->src_options = src.options();
207   }
208   {
209     at::AutoDispatchBelowAutograd mode;
210     at::redispatch::copy_(
211         ks & c10::after_autograd_keyset, self_, src_, non_blocking);
212   }
213   rebase_history(self, std::move(grad_fn));
214 
215   if (isDifferentiableType(self.scalar_type()) &&
216       (isFwGradDefined(self) || isFwGradDefined(src))) {
217     auto self_fw_grad = generated::details::toNonOptFwGrad(self);
218     auto src_fw_grad = generated::details::toNonOptFwGrad(src);
219     Tensor new_fw_grad;
220     if (self_fw_grad.defined()) {
221       if (src_fw_grad.defined()) {
222         new_fw_grad = self_fw_grad.copy_(src_fw_grad);
223       } else {
224         new_fw_grad = self_fw_grad.fill_(0);
225       }
226     } else {
227       if (!self.is_same_size(src_fw_grad)) {
228         new_fw_grad = src_fw_grad.broadcast_to(self.sizes());
229       } else {
230         new_fw_grad = src_fw_grad.clone();
231       }
232     }
233     self._set_fw_grad(new_fw_grad, /* level */ 0, /* is_inplace_op */ true);
234   }
235 
236   return self;
237 }
238 
resize_(c10::DispatchKeySet ks,const Tensor & self,SymIntArrayRef size,std::optional<MemoryFormat> optional_memory_format)239 const Tensor& resize_(
240     c10::DispatchKeySet ks,
241     const Tensor& self,
242     SymIntArrayRef size,
243     std::optional<MemoryFormat> optional_memory_format) {
244   auto& self_ = unpack(self, "self", 0);
245   if (self.requires_grad()) {
246     AT_ERROR("cannot resize variables that require grad");
247   }
248   {
249     at::AutoDispatchBelowAutograd mode;
250     at::redispatch::resize__symint(
251         ks & c10::after_autograd_keyset, self_, size, optional_memory_format);
252   }
253 
254   if (self._fw_grad(/* level */ 0).defined()) {
255     AT_ERROR("cannot resize variables that has a forward grad");
256   }
257 
258   return self;
259 }
260 
resize_as_(c10::DispatchKeySet ks,const Tensor & self,const Tensor & the_template,std::optional<MemoryFormat> optional_memory_format)261 const Tensor& resize_as_(
262     c10::DispatchKeySet ks,
263     const Tensor& self,
264     const Tensor& the_template,
265     std::optional<MemoryFormat> optional_memory_format) {
266   auto& self_ = unpack(self, "self", 0);
267   auto& the_template_ = unpack(the_template, "the_template", 1);
268   if (self.requires_grad()) {
269     AT_ERROR("cannot resize variables that require grad");
270   }
271   {
272     at::AutoDispatchBelowAutograd mode;
273     at::redispatch::resize_as_(
274         ks & c10::after_autograd_keyset,
275         self_,
276         the_template_,
277         optional_memory_format);
278   }
279 
280   // Handle fw grad
281   if (self._fw_grad(/* level */ 0).defined()) {
282     AT_ERROR("cannot resize variables that has a forward grad");
283   }
284 
285   return self;
286 }
287 
detach(c10::DispatchKeySet ks,const Tensor & self)288 Tensor detach(c10::DispatchKeySet ks, const Tensor& self) {
289   auto& self_ = unpack(self, "self", 0);
290   RECORD_FUNCTION("detach", std::vector<c10::IValue>({self}));
291   auto result = ([&]() {
292     at::AutoDispatchBelowAutograd guard;
293     return at::redispatch::detach(ks & c10::after_autograd_keyset, self_);
294   })();
295   namedinference::propagate_names(result, self);
296 
297   // Detach the forward grads by not setting anything on the result
298 
299   return result;
300 }
301 
detach_(c10::DispatchKeySet ks,Tensor & self)302 Tensor& detach_(c10::DispatchKeySet ks, Tensor& self) {
303   RECORD_FUNCTION("detach_", std::vector<c10::IValue>({self}));
304   if (self.is_view()) {
305     // See NOTE [ View + Inplace detection ]
306     AT_ERROR(
307         "Can't detach views in-place. Use detach() instead. "
308         "If you are using DistributedDataParallel (DDP) for training, "
309         "and gradient_as_bucket_view is set as True, gradients are "
310         "views of DDP buckets, and hence detach_() cannot be called "
311         "on these gradients. To fix this error, please refer to the "
312         "Optimizer.zero_grad() function in torch/optim/optimizer.py "
313         "as the solution.");
314   }
315   // I think the choice here is conservative.  In principle, doing
316   // an in-place detach should give us the ability to just clear
317   // the autograd meta.  But this function ONLY resets requires_grad,
318   // grad_fn and output_nr; there's other metadata like debug name
319   // and hooks which aren't cleared.  Is this function supposed to
320   // clear those too? I'm not too sure, so I'm leaving it be for now.
321   auto autograd_meta = impl::materialize_autograd_meta(self);
322   autograd_meta->set_requires_grad(false, self.unsafeGetTensorImpl());
323   autograd_meta->grad_fn_.reset();
324   autograd_meta->output_nr_ = 0;
325   autograd_meta->fw_grad_.reset();
326 
327   return self;
328 }
329 
330 // Ops in the following registration list are registered as
331 //   (1) CompositeImplicitAutograd kernels
332 //   (2) Autograd kernels
333 //   (3) CompositeExplicitAutograd kernels and additionally Autograd kernels
334 // The reason for (3) is that ops that also use dispatch (e.g. register
335 // CPU/CUDA/QuantizedCPU kernels) will skip picking up CompositeImplicitAutograd
336 // kernels for Autograd, so we register them to both CompositeExplicitAutograd
337 // and Autograd instead. See
338 // https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native#choosing-the-right-dispatch-keyword
339 // for more details.
340 // Invariant:
341 // - Ops registered to CompositeImplicitAutograd or CompositeExplicitAutograd
342 // below must match `MANUAL_BACKEND` set in tools/autograd/gen_variable_type.py.
343 //   and they have manual_kernel_registration=True in native_functions.yaml.
344 // - Ops registered to DispatchKey::Autograd below must be included in
345 // `MANUAL_AUTOGRAD` in tools/autograd/gen_variable_type.py
346 
TORCH_LIBRARY_IMPL(aten,Autograd,m)347 TORCH_LIBRARY_IMPL(aten, Autograd, m) {
348   m.impl(
349       "resize_",
350       torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::resize_)));
351   m.impl(
352       "resize_as_",
353       torch::dispatch(
354           DispatchKey::Autograd, TORCH_FN(VariableType::resize_as_)));
355   m.impl(
356       "detach",
357       torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::detach)));
358   m.impl(
359       "detach_",
360       torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::detach_)));
361   m.impl(
362       "copy_",
363       torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::copy_)));
364   m.impl(
365       "_fw_primal",
366       torch::dispatch(
367           DispatchKey::Autograd, TORCH_FN(VariableType::_fw_primal)));
368   m.impl(
369       "_make_dual",
370       torch::dispatch(
371           DispatchKey::Autograd, TORCH_FN(VariableType::_make_dual)));
372 }
373 
374 } // namespace
375 } // namespace autograd::VariableType
376 
377 namespace ADInplaceOrView {
378 #define CREATION_META_DEFINITION                            \
379   InferenceMode::is_enabled()                               \
380       ? CreationMeta::INFERENCE_MODE                        \
381       : (at::GradMode::is_enabled() ? CreationMeta::DEFAULT \
382                                     : CreationMeta::NO_GRAD_MODE)
383 
copy_(c10::DispatchKeySet ks,Tensor & self,const Tensor & src,bool non_blocking)384 static Tensor& copy_(
385     c10::DispatchKeySet ks,
386     Tensor& self,
387     const Tensor& src,
388     bool non_blocking) {
389   {
390     at::AutoDispatchBelowADInplaceOrView guard;
391     at::redispatch::copy_(
392         ks & c10::after_ADInplaceOrView_keyset, self, src, non_blocking);
393   }
394   torch::autograd::increment_version(self);
395   return self;
396 }
397 
resize_(c10::DispatchKeySet ks,const Tensor & self,SymIntArrayRef size,std::optional<MemoryFormat> optional_memory_format)398 static const Tensor& resize_(
399     c10::DispatchKeySet ks,
400     const Tensor& self,
401     SymIntArrayRef size,
402     std::optional<MemoryFormat> optional_memory_format) {
403   // Hold sizes to verify if we actually resize `self`.
404   // Explicitly copy data, since resizing can move original data
405   // and make references invalid.
406   auto org_size = self.sym_sizes().vec();
407   {
408     at::AutoDispatchBelowADInplaceOrView guard;
409     at::redispatch::resize__symint(
410         ks & c10::after_ADInplaceOrView_keyset,
411         self,
412         size,
413         optional_memory_format);
414   }
415   // If `self` was resized, increment the version.
416   if (org_size != size) {
417     torch::autograd::increment_version(self);
418   }
419   return self;
420 }
421 
resize_as_(c10::DispatchKeySet ks,const Tensor & self,const Tensor & the_template,std::optional<MemoryFormat> optional_memory_format)422 static const Tensor& resize_as_(
423     c10::DispatchKeySet ks,
424     const Tensor& self,
425     const Tensor& the_template,
426     std::optional<MemoryFormat> optional_memory_format) {
427   // Hold sizes to verify if we actually resize `self`.
428   // Explicitly copy data, since resizing can move original data
429   // and make references invalid.
430   auto org_size = self.sym_sizes().vec();
431   {
432     at::AutoDispatchBelowADInplaceOrView guard;
433     at::redispatch::resize_as_(
434         ks & c10::after_ADInplaceOrView_keyset,
435         self,
436         the_template,
437         optional_memory_format);
438   }
439 
440   // If `self` was resized, increment the version.
441   if (org_size != the_template.sym_sizes()) {
442     torch::autograd::increment_version(self);
443   }
444   return self;
445 }
446 
detach(c10::DispatchKeySet ks,const Tensor & self)447 static Tensor detach(c10::DispatchKeySet ks, const Tensor& self) {
448   auto out = ([&]() {
449     at::AutoDispatchBelowADInplaceOrView guard;
450     return at::_ops::detach::redispatch(
451         ks & c10::after_ADInplaceOrView_keyset, self);
452   })();
453   // NB: we can't make detach() a normal view operator because the codegen
454   // generates allow_tensor_metadata_change = True for them. In the future we
455   // should have an option for this in the codegen.
456   auto result = as_view(
457       /* base */ self,
458       /* output */ out,
459       /* is_bw_differentiable */ false,
460       /* is_fw_differentiable */ false,
461       /* view_func */ nullptr,
462       /* rev_view_func */ nullptr,
463       /* creation_meta */ CreationMeta::DEFAULT,
464       /*allow_tensor_metadata_change=*/false);
465 
466   return result;
467 }
468 
_fw_primal(c10::DispatchKeySet ks,const Tensor & self,int64_t level)469 static Tensor _fw_primal(
470     c10::DispatchKeySet ks,
471     const Tensor& self,
472     int64_t level) {
473   auto tmp = ([&]() {
474     at::AutoDispatchBelowADInplaceOrView guard;
475     return at::alias(self);
476   })();
477   std::unique_ptr<torch::autograd::ViewFunc> func(nullptr);
478   std::function<at::Tensor(const at::Tensor&)> rev_func = nullptr;
479   if (!self.unsafeGetTensorImpl()->support_as_strided()) {
480     func = std::make_unique<ViewViewFunc>(self.sym_sizes());
481     rev_func = [=](const at::Tensor& input_view) {
482       TORCH_INTERNAL_ASSERT(
483           false,
484           "Reverse view_func for _fw_primal() is not currently supported");
485       return Tensor();
486     };
487   }
488   auto result = as_view(
489       /* base */ self,
490       /* output */ tmp,
491       /* is_bw_differentiable */ true,
492       /* is_fw_differentiable */ false,
493       /* view_func */ std::move(func),
494       /* rev_view_func */ std::move(rev_func),
495       /* creation_meta */ CREATION_META_DEFINITION);
496 
497   return result;
498 }
499 
500 // NB: This does not redispatch any further
_make_dual(c10::DispatchKeySet ks,const Tensor & primal,const Tensor & tangent,int64_t level)501 static Tensor _make_dual(
502     c10::DispatchKeySet ks,
503     const Tensor& primal,
504     const Tensor& tangent,
505     int64_t level) {
506   auto tmp = ([&]() {
507     at::AutoDispatchBelowADInplaceOrView guard;
508     return at::alias(primal);
509   })();
510   std::unique_ptr<torch::autograd::ViewFunc> func(nullptr);
511   std::function<at::Tensor(const at::Tensor&)> rev_func = nullptr;
512   if (!primal.unsafeGetTensorImpl()->support_as_strided()) {
513     func = std::make_unique<ViewViewFunc>(primal.sym_sizes());
514     rev_func = [=](const at::Tensor& input_view) {
515       TORCH_INTERNAL_ASSERT(
516           false,
517           "Reverse view_func for _make_dual() is not currently supported");
518       return Tensor();
519     };
520   }
521   auto result = as_view(
522       /* base */ primal,
523       /* output */ tmp,
524       /* is_bw_differentiable */ true,
525       /* is_fw_differentiable */ false,
526       /* view_func */ std::move(func),
527       /* rev_view_func */ std::move(rev_func),
528       /* creation_meta */ CREATION_META_DEFINITION);
529 
530   return result;
531 }
532 
533 namespace {
TORCH_LIBRARY_IMPL(aten,ADInplaceOrView,m)534 TORCH_LIBRARY_IMPL(aten, ADInplaceOrView, m) {
535   m.impl(
536       "copy_",
537       torch::dispatch(
538           DispatchKey::ADInplaceOrView, TORCH_FN(ADInplaceOrView::copy_)));
539   m.impl(
540       "detach",
541       torch::dispatch(
542           DispatchKey::ADInplaceOrView, TORCH_FN(ADInplaceOrView::detach)));
543   m.impl(
544       "resize_",
545       torch::dispatch(
546           DispatchKey::ADInplaceOrView, TORCH_FN(ADInplaceOrView::resize_)));
547   m.impl(
548       "resize_as_",
549       torch::dispatch(
550           DispatchKey::ADInplaceOrView, TORCH_FN(ADInplaceOrView::resize_as_)));
551   m.impl(
552       "_fw_primal",
553       torch::dispatch(
554           DispatchKey::ADInplaceOrView, TORCH_FN(ADInplaceOrView::_fw_primal)));
555   m.impl(
556       "_make_dual",
557       torch::dispatch(
558           DispatchKey::ADInplaceOrView, TORCH_FN(ADInplaceOrView::_make_dual)));
559 }
560 } // namespace
561 } // namespace ADInplaceOrView
562 } // namespace torch
563