1 #include <c10/util/irange.h>
2 #include <torch/csrc/autograd/cpp_hook.h>
3 #include <torch/csrc/autograd/custom_function.h>
4 #include <torch/csrc/autograd/variable.h>
5
6 #include <utility>
7
8 namespace {
9 using torch::autograd::Variable;
check_single_result(const at::TensorBase & value,const at::TensorBase & result,const std::string & hook_name)10 void check_single_result(
11 const at::TensorBase& value,
12 const at::TensorBase& result,
13 const std::string& hook_name) {
14 if (!value.defined()) {
15 throw std::runtime_error(
16 "can't replace a empty gradient with a non-empty value");
17 }
18 torch::autograd::check_variable_result(value, result, hook_name);
19 }
20 } // namespace
21
22 namespace torch::autograd {
23
CppFunctionTensorPreHook(std::shared_ptr<hooks_list> hooks,size_t value_idx)24 CppFunctionTensorPreHook::CppFunctionTensorPreHook(
25 std::shared_ptr<hooks_list> hooks,
26 size_t value_idx)
27 : hooks_(std::move(hooks)), value_idx_(value_idx) {}
28
operator ()(const variable_list & values)29 variable_list CppFunctionTensorPreHook::operator()(
30 const variable_list& values) {
31 auto value = values[value_idx_];
32 for (const auto i : c10::irange(hooks_->size())) {
33 auto& hook = (*hooks_)[i];
34 if (!hook) {
35 // hook was removed
36 continue;
37 }
38 auto res = hook(value);
39 if (!res.defined()) {
40 // Don't change gradient
41 continue;
42 }
43 check_single_result(value, res, std::to_string(i));
44 value = std::move(res);
45 }
46 variable_list results(values);
47 results[value_idx_] = value;
48 return results;
49 }
50
CppFunctionSingleTensorPreHook(std::function<at::TensorBase (const at::TensorBase &)> hook,size_t value_idx)51 CppFunctionSingleTensorPreHook::CppFunctionSingleTensorPreHook(
52 std::function<at::TensorBase(const at::TensorBase&)> hook,
53 size_t value_idx)
54 : hook_(std::move(hook)), value_idx_(value_idx) {}
55
operator ()(const variable_list & values)56 variable_list CppFunctionSingleTensorPreHook::operator()(
57 const variable_list& values) {
58 const auto& value = values[value_idx_];
59 auto res = hook_(value);
60 TORCH_INTERNAL_ASSERT(
61 !res.defined(),
62 "CppFunctionSingleTensorPreHook currently only supports hooks that don't return");
63 variable_list results(values);
64 return results;
65 }
66
67 } // namespace torch::autograd
68