1 #include <gtest/gtest.h>
2
3 #include <ATen/ATen.h>
4 #include <ATen/NativeFunctions.h>
5 #include <ATen/Functions.h>
6 #include <ATen/core/dispatch/Dispatcher.h>
7 #include <ATen/core/op_registration/op_registration.h>
8 #include <c10/util/irange.h>
9 #include <torch/library.h>
10
11 using namespace at;
12
13 namespace {
14
15 // This test file gives an example of a simple use case for "wrapper"
16 // and "mode" style tensor type ids. In both cases, the implementation
17 // of the wrapper/mode simply passes through the call to underlying JIT
18 // implementation (so the wrapper/mode doesn't actually do anything),
19 // but this could be used as a starting point to do more interesting things.
20
21 // Global counter for ease of testing
22 static int64_t override_call_count = 0;
23
24 // Mode implementation
25
generic_mode_fallback(const c10::OperatorHandle & op,torch::jit::Stack * stack)26 void generic_mode_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
27 override_call_count++;
28 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::TESTING_ONLY_GenericMode);
29 op.callBoxed(stack);
30 }
31
32 // Wrapper implementation
33
34 struct GenericWrapperTensorImpl : public c10::TensorImpl {
GenericWrapperTensorImpl__anon902be9020111::GenericWrapperTensorImpl35 explicit GenericWrapperTensorImpl(at::Tensor rep)
36 : TensorImpl(
37 c10::DispatchKeySet(c10::DispatchKey::TESTING_ONLY_GenericWrapper),
38 rep.dtype(),
39 rep.device()
40 // TODO: propagate size!
41 )
42 , rep_(std::move(rep)) {}
43
44 at::Tensor rep_;
45 };
46
generic_wrapper_fallback(const c10::OperatorHandle & op,torch::jit::Stack * stack)47 void generic_wrapper_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
48 override_call_count++;
49
50 auto num_arguments = op.schema().arguments().size();
51 auto num_returns = op.schema().returns().size();
52
53 // Unwrap all arguments
54 auto args = torch::jit::pop(*stack, num_arguments);
55 for (const auto i : c10::irange(num_arguments)) {
56 // TODO: Handle tensor list
57 if (args[i].isTensor()) {
58 auto* impl = args[i].unsafeToTensorImpl();
59 if (impl->key_set().has(DispatchKey::TESTING_ONLY_GenericWrapper)) {
60 auto* wrapper = static_cast<GenericWrapperTensorImpl*>(impl);
61 torch::jit::push(*stack, wrapper->rep_); // no move!
62 } else {
63 torch::jit::push(*stack, std::move(args[i]));
64 }
65 } else {
66 torch::jit::push(*stack, std::move(args[i]));
67 }
68 }
69
70 op.callBoxed(stack);
71
72 // Rewrap outputs
73 auto rets = torch::jit::pop(*stack, num_returns);
74 for (const auto i : c10::irange(num_returns)) {
75 // TODO: Handle tensor list
76 if (rets[i].isTensor()) {
77 torch::jit::push(*stack, at::detail::make_tensor<GenericWrapperTensorImpl>(std::move(rets[i]).toTensor())); // yes move!
78 } else {
79 torch::jit::push(*stack, std::move(rets[i]));
80 }
81 }
82 }
83
84 #ifndef ATEN_CPU_STATIC_DISPATCH
TEST(BackendFallbackTest,TestBackendFallbackWithMode)85 TEST(BackendFallbackTest, TestBackendFallbackWithMode) {
86 auto m = MAKE_TORCH_LIBRARY_IMPL(_, TESTING_ONLY_GenericMode);
87 m.fallback(torch::CppFunction::makeFromBoxedFunction<&generic_mode_fallback>());
88
89 c10::impl::IncludeDispatchKeyGuard guard(DispatchKey::TESTING_ONLY_GenericMode);
90
91 override_call_count = 0;
92 Tensor a = ones({5, 5}, kDouble);
93 Tensor b = batch_norm(a, {}, {}, {}, {}, true, 0.1, 1e-05, false);
94 ASSERT_EQ(override_call_count, 2);
95 }
96
TEST(BackendFallbackTest,TestBackendFallbackWithWrapper)97 TEST(BackendFallbackTest, TestBackendFallbackWithWrapper) {
98 auto m = MAKE_TORCH_LIBRARY_IMPL(_, TESTING_ONLY_GenericWrapper);
99 m.fallback(torch::CppFunction::makeFromBoxedFunction<&generic_wrapper_fallback>());
100
101 override_call_count = 0;
102 Tensor a = at::detail::make_tensor<GenericWrapperTensorImpl>(ones({5, 5}, kDouble));
103 Tensor b = batch_norm(a, {}, {}, {}, {}, true, 0.1, 1e-05, false);
104 ASSERT_EQ(override_call_count, 1);
105 }
106
TEST(BackendFallbackTest,TestFallthroughBackendFallback)107 TEST(BackendFallbackTest, TestFallthroughBackendFallback) {
108 auto m = MAKE_TORCH_LIBRARY_IMPL(aten, TESTING_ONLY_GenericMode);
109 m.impl("mul.Tensor", torch::CppFunction::makeFromBoxedFunction<&generic_mode_fallback>());
110
111 auto gm = MAKE_TORCH_LIBRARY_IMPL(_, TESTING_ONLY_GenericMode);
112 gm.fallback(torch::CppFunction::makeFallthrough());
113
114 c10::impl::IncludeDispatchKeyGuard guard(DispatchKey::TESTING_ONLY_GenericMode);
115
116 override_call_count = 0;
117 // Doesn't trigger, as we fallthrough
118 Tensor a = zeros({5, 5}, kDouble);
119 ASSERT_EQ(override_call_count, 0);
120 // Does trigger, because we explicitly set it
121 Tensor b = mul(a, a);
122 ASSERT_EQ(override_call_count, 1);
123 }
124 #endif // ATEN_CPU_STATIC_DISPATCH
125
126 }
127