xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/dispatch/backend_fallback_test.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <ATen/ATen.h>
4 #include <ATen/NativeFunctions.h>
5 #include <ATen/Functions.h>
6 #include <ATen/core/dispatch/Dispatcher.h>
7 #include <ATen/core/op_registration/op_registration.h>
8 #include <c10/util/irange.h>
9 #include <torch/library.h>
10 
11 using namespace at;
12 
13 namespace {
14 
15 // This test file gives an example of a simple use case for "wrapper"
16 // and "mode" style tensor type ids.  In both cases, the implementation
17 // of the wrapper/mode simply passes through the call to underlying JIT
18 // implementation (so the wrapper/mode doesn't actually do anything),
19 // but this could be used as a starting point to do more interesting things.
20 
21 // Global counter for ease of testing
22 static int64_t override_call_count = 0;
23 
24 // Mode implementation
25 
generic_mode_fallback(const c10::OperatorHandle & op,torch::jit::Stack * stack)26 void generic_mode_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
27   override_call_count++;
28   c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::TESTING_ONLY_GenericMode);
29   op.callBoxed(stack);
30 }
31 
32 // Wrapper implementation
33 
34 struct GenericWrapperTensorImpl : public c10::TensorImpl {
GenericWrapperTensorImpl__anon902be9020111::GenericWrapperTensorImpl35   explicit GenericWrapperTensorImpl(at::Tensor rep)
36     : TensorImpl(
37         c10::DispatchKeySet(c10::DispatchKey::TESTING_ONLY_GenericWrapper),
38         rep.dtype(),
39         rep.device()
40         // TODO: propagate size!
41       )
42     , rep_(std::move(rep)) {}
43 
44   at::Tensor rep_;
45 };
46 
generic_wrapper_fallback(const c10::OperatorHandle & op,torch::jit::Stack * stack)47 void generic_wrapper_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
48   override_call_count++;
49 
50   auto num_arguments = op.schema().arguments().size();
51   auto num_returns = op.schema().returns().size();
52 
53   // Unwrap all arguments
54   auto args = torch::jit::pop(*stack, num_arguments);
55   for (const auto i : c10::irange(num_arguments)) {
56     // TODO: Handle tensor list
57     if (args[i].isTensor()) {
58       auto* impl = args[i].unsafeToTensorImpl();
59       if (impl->key_set().has(DispatchKey::TESTING_ONLY_GenericWrapper)) {
60         auto* wrapper = static_cast<GenericWrapperTensorImpl*>(impl);
61         torch::jit::push(*stack, wrapper->rep_);  // no move!
62       } else {
63         torch::jit::push(*stack, std::move(args[i]));
64       }
65     } else {
66       torch::jit::push(*stack, std::move(args[i]));
67     }
68   }
69 
70   op.callBoxed(stack);
71 
72   // Rewrap outputs
73   auto rets = torch::jit::pop(*stack, num_returns);
74   for (const auto i : c10::irange(num_returns)) {
75     // TODO: Handle tensor list
76     if (rets[i].isTensor()) {
77       torch::jit::push(*stack, at::detail::make_tensor<GenericWrapperTensorImpl>(std::move(rets[i]).toTensor()));  // yes move!
78     } else {
79       torch::jit::push(*stack, std::move(rets[i]));
80     }
81   }
82 }
83 
84 #ifndef ATEN_CPU_STATIC_DISPATCH
TEST(BackendFallbackTest,TestBackendFallbackWithMode)85 TEST(BackendFallbackTest, TestBackendFallbackWithMode) {
86   auto m = MAKE_TORCH_LIBRARY_IMPL(_, TESTING_ONLY_GenericMode);
87   m.fallback(torch::CppFunction::makeFromBoxedFunction<&generic_mode_fallback>());
88 
89   c10::impl::IncludeDispatchKeyGuard guard(DispatchKey::TESTING_ONLY_GenericMode);
90 
91   override_call_count = 0;
92   Tensor a = ones({5, 5}, kDouble);
93   Tensor b = batch_norm(a, {}, {}, {}, {}, true, 0.1, 1e-05, false);
94   ASSERT_EQ(override_call_count, 2);
95 }
96 
TEST(BackendFallbackTest,TestBackendFallbackWithWrapper)97 TEST(BackendFallbackTest, TestBackendFallbackWithWrapper) {
98   auto m = MAKE_TORCH_LIBRARY_IMPL(_, TESTING_ONLY_GenericWrapper);
99   m.fallback(torch::CppFunction::makeFromBoxedFunction<&generic_wrapper_fallback>());
100 
101   override_call_count = 0;
102   Tensor a = at::detail::make_tensor<GenericWrapperTensorImpl>(ones({5, 5}, kDouble));
103   Tensor b = batch_norm(a, {}, {}, {}, {}, true, 0.1, 1e-05, false);
104   ASSERT_EQ(override_call_count, 1);
105 }
106 
TEST(BackendFallbackTest,TestFallthroughBackendFallback)107 TEST(BackendFallbackTest, TestFallthroughBackendFallback) {
108   auto m = MAKE_TORCH_LIBRARY_IMPL(aten, TESTING_ONLY_GenericMode);
109   m.impl("mul.Tensor", torch::CppFunction::makeFromBoxedFunction<&generic_mode_fallback>());
110 
111   auto gm = MAKE_TORCH_LIBRARY_IMPL(_, TESTING_ONLY_GenericMode);
112   gm.fallback(torch::CppFunction::makeFallthrough());
113 
114   c10::impl::IncludeDispatchKeyGuard guard(DispatchKey::TESTING_ONLY_GenericMode);
115 
116   override_call_count = 0;
117   // Doesn't trigger, as we fallthrough
118   Tensor a = zeros({5, 5}, kDouble);
119   ASSERT_EQ(override_call_count, 0);
120   // Does trigger, because we explicitly set it
121   Tensor b = mul(a, a);
122   ASSERT_EQ(override_call_count, 1);
123 }
124 #endif // ATEN_CPU_STATIC_DISPATCH
125 
126 }
127