xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/cpp_hook.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/irange.h>
2 #include <torch/csrc/autograd/cpp_hook.h>
3 #include <torch/csrc/autograd/custom_function.h>
4 #include <torch/csrc/autograd/variable.h>
5 
6 #include <utility>
7 
8 namespace {
9 using torch::autograd::Variable;
check_single_result(const at::TensorBase & value,const at::TensorBase & result,const std::string & hook_name)10 void check_single_result(
11     const at::TensorBase& value,
12     const at::TensorBase& result,
13     const std::string& hook_name) {
14   if (!value.defined()) {
15     throw std::runtime_error(
16         "can't replace a empty gradient with a non-empty value");
17   }
18   torch::autograd::check_variable_result(value, result, hook_name);
19 }
20 } // namespace
21 
22 namespace torch::autograd {
23 
CppFunctionTensorPreHook(std::shared_ptr<hooks_list> hooks,size_t value_idx)24 CppFunctionTensorPreHook::CppFunctionTensorPreHook(
25     std::shared_ptr<hooks_list> hooks,
26     size_t value_idx)
27     : hooks_(std::move(hooks)), value_idx_(value_idx) {}
28 
operator ()(const variable_list & values)29 variable_list CppFunctionTensorPreHook::operator()(
30     const variable_list& values) {
31   auto value = values[value_idx_];
32   for (const auto i : c10::irange(hooks_->size())) {
33     auto& hook = (*hooks_)[i];
34     if (!hook) {
35       // hook was removed
36       continue;
37     }
38     auto res = hook(value);
39     if (!res.defined()) {
40       // Don't change gradient
41       continue;
42     }
43     check_single_result(value, res, std::to_string(i));
44     value = std::move(res);
45   }
46   variable_list results(values);
47   results[value_idx_] = value;
48   return results;
49 }
50 
CppFunctionSingleTensorPreHook(std::function<at::TensorBase (const at::TensorBase &)> hook,size_t value_idx)51 CppFunctionSingleTensorPreHook::CppFunctionSingleTensorPreHook(
52     std::function<at::TensorBase(const at::TensorBase&)> hook,
53     size_t value_idx)
54     : hook_(std::move(hook)), value_idx_(value_idx) {}
55 
operator ()(const variable_list & values)56 variable_list CppFunctionSingleTensorPreHook::operator()(
57     const variable_list& values) {
58   const auto& value = values[value_idx_];
59   auto res = hook_(value);
60   TORCH_INTERNAL_ASSERT(
61       !res.defined(),
62       "CppFunctionSingleTensorPreHook currently only supports hooks that don't return");
63   variable_list results(values);
64   return results;
65 }
66 
67 } // namespace torch::autograd
68