xref: /aosp_15_r20/external/pytorch/aten/src/ATen/FunctionalizeFallbackKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/dispatch/Dispatcher.h>
2*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/LegacyTypeDispatch.h>
3*da0073e9SAndroid Build Coastguard Worker #include <ATen/EmptyTensor.h>
4*da0073e9SAndroid Build Coastguard Worker #include <ATen/FunctionalTensorWrapper.h>
5*da0073e9SAndroid Build Coastguard Worker #include <ATen/InferSize.h>
6*da0073e9SAndroid Build Coastguard Worker #include <ATen/TensorUtils.h>
7*da0073e9SAndroid Build Coastguard Worker #include <torch/library.h>
8*da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h>
9*da0073e9SAndroid Build Coastguard Worker #include <c10/util/strides.h>
10*da0073e9SAndroid Build Coastguard Worker 
11*da0073e9SAndroid Build Coastguard Worker #ifndef AT_PER_OPERATOR_HEADERS
12*da0073e9SAndroid Build Coastguard Worker #include <ATen/ATen.h>
13*da0073e9SAndroid Build Coastguard Worker #include <ATen/Functions.h>
14*da0073e9SAndroid Build Coastguard Worker #include <ATen/NativeFunctions.h>
15*da0073e9SAndroid Build Coastguard Worker #else
16*da0073e9SAndroid Build Coastguard Worker #include <ATen/ops/_to_copy.h>
17*da0073e9SAndroid Build Coastguard Worker #include <ATen/ops/to_native.h>
18*da0073e9SAndroid Build Coastguard Worker #include <ATen/ops/lift.h>
19*da0073e9SAndroid Build Coastguard Worker #include <ATen/ops/lift_fresh.h>
20*da0073e9SAndroid Build Coastguard Worker #include <ATen/ops/lift_fresh_copy.h>
21*da0073e9SAndroid Build Coastguard Worker #include <ATen/ops/resize.h>
22*da0073e9SAndroid Build Coastguard Worker #include <ATen/ops/as_strided.h>
23*da0073e9SAndroid Build Coastguard Worker #include <ATen/ops/as_strided_copy.h>
24*da0073e9SAndroid Build Coastguard Worker #include <ATen/ops/empty_strided_native.h>
25*da0073e9SAndroid Build Coastguard Worker #include <ATen/ops/_unsafe_view.h>
26*da0073e9SAndroid Build Coastguard Worker 
27*da0073e9SAndroid Build Coastguard Worker #include <utility>
28*da0073e9SAndroid Build Coastguard Worker #endif
29*da0073e9SAndroid Build Coastguard Worker 
30*da0073e9SAndroid Build Coastguard Worker namespace {
functionalizeFallback(const c10::OperatorHandle & op,c10::DispatchKeySet dispatchKeySet,torch::jit::Stack * stack)31*da0073e9SAndroid Build Coastguard Worker   void functionalizeFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatchKeySet [[maybe_unused]], torch::jit::Stack* stack) {
32*da0073e9SAndroid Build Coastguard Worker     const auto& schema = op.schema();
33*da0073e9SAndroid Build Coastguard Worker     // NB: auto_functionalize handles the case where outputs do not have alias info.
34*da0073e9SAndroid Build Coastguard Worker     // This error message therefore suggests users to modify their custom op to the
35*da0073e9SAndroid Build Coastguard Worker     // point where auto_functionalize works instead of asking them to try the raw
36*da0073e9SAndroid Build Coastguard Worker     // functionalization API (because that is a bit difficult to use).
37*da0073e9SAndroid Build Coastguard Worker     // If you're here and want to try the raw functionalizaton kernel approach,
38*da0073e9SAndroid Build Coastguard Worker     // see https://gist.github.com/bdhirsh/7dadbf6296f8f7d1abcf4c482f438aaa
39*da0073e9SAndroid Build Coastguard Worker     TORCH_CHECK(
40*da0073e9SAndroid Build Coastguard Worker       !schema.hasAnyAliasInfo(),
41*da0073e9SAndroid Build Coastguard Worker       "Found a custom (non-ATen) operator whose output has alias annotations: ",
42*da0073e9SAndroid Build Coastguard Worker       op.schema(),
43*da0073e9SAndroid Build Coastguard Worker       ". We only support functionalizing operators whose outputs do not have alias ",
44*da0073e9SAndroid Build Coastguard Worker       "annotations (e.g. 'Tensor(a)' is a Tensor with an alias annotation whereas ",
45*da0073e9SAndroid Build Coastguard Worker       "'Tensor' is a Tensor without. The '(a)' is the alias annotation). "
46*da0073e9SAndroid Build Coastguard Worker       "The alias annotation specifies that the output ",
47*da0073e9SAndroid Build Coastguard Worker       "Tensor shares storage with an input that has the same annotation. ",
48*da0073e9SAndroid Build Coastguard Worker       "Please check if ",
49*da0073e9SAndroid Build Coastguard Worker       "(1) the output needs to be an output (if not, don't return it), ",
50*da0073e9SAndroid Build Coastguard Worker       "(2) if the output doesn't share storage with any inputs, then ",
51*da0073e9SAndroid Build Coastguard Worker       "delete the alias annotation. ",
52*da0073e9SAndroid Build Coastguard Worker       "(3) if the output indeed shares storage with an input, then add a ",
53*da0073e9SAndroid Build Coastguard Worker       ".clone() before returning it to prevent storage sharing and then "
54*da0073e9SAndroid Build Coastguard Worker       "delete the alias annotation. ",
55*da0073e9SAndroid Build Coastguard Worker       "Otherwise, please file an issue on GitHub.");
56*da0073e9SAndroid Build Coastguard Worker     const auto num_arguments = schema.arguments().size();
57*da0073e9SAndroid Build Coastguard Worker     const auto arguments_begin = stack->size() - num_arguments;
58*da0073e9SAndroid Build Coastguard Worker     auto arguments = torch::jit::last(stack, num_arguments);
59*da0073e9SAndroid Build Coastguard Worker 
60*da0073e9SAndroid Build Coastguard Worker     auto any_functional_inputs = false;
61*da0073e9SAndroid Build Coastguard Worker     auto any_tensor_inputs = false;
62*da0073e9SAndroid Build Coastguard Worker     for (uint64_t idx = 0; idx < num_arguments; ++idx) {
63*da0073e9SAndroid Build Coastguard Worker       const auto& ivalue = arguments[idx];
64*da0073e9SAndroid Build Coastguard Worker       if (ivalue.isTensor()) {
65*da0073e9SAndroid Build Coastguard Worker         any_tensor_inputs = true;
66*da0073e9SAndroid Build Coastguard Worker         const auto& t = ivalue.toTensor();
67*da0073e9SAndroid Build Coastguard Worker         if (t.defined() && at::functionalization::impl::isFunctionalTensor(t)) {
68*da0073e9SAndroid Build Coastguard Worker           any_functional_inputs = true;
69*da0073e9SAndroid Build Coastguard Worker           at::functionalization::impl::sync(t);
70*da0073e9SAndroid Build Coastguard Worker           auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(t));
71*da0073e9SAndroid Build Coastguard Worker           (*stack)[arguments_begin + idx] = t_new;
72*da0073e9SAndroid Build Coastguard Worker         }
73*da0073e9SAndroid Build Coastguard Worker       } else if (ivalue.isTensorList()) {
74*da0073e9SAndroid Build Coastguard Worker         any_tensor_inputs = true;
75*da0073e9SAndroid Build Coastguard Worker         auto tensors = ivalue.toTensorList();
76*da0073e9SAndroid Build Coastguard Worker         if (at::functionalization::impl::isFunctionalTensor(tensors)) {
77*da0073e9SAndroid Build Coastguard Worker           any_functional_inputs = true;
78*da0073e9SAndroid Build Coastguard Worker           at::functionalization::impl::sync(tensors);
79*da0073e9SAndroid Build Coastguard Worker           auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(tensors));
80*da0073e9SAndroid Build Coastguard Worker           (*stack)[arguments_begin + idx] = t_new;
81*da0073e9SAndroid Build Coastguard Worker         }
82*da0073e9SAndroid Build Coastguard Worker       } else if (ivalue.isOptionalTensorList()) {
83*da0073e9SAndroid Build Coastguard Worker         any_tensor_inputs = true;
84*da0073e9SAndroid Build Coastguard Worker         auto opt_tensors = ivalue.toOptionalTensorList();
85*da0073e9SAndroid Build Coastguard Worker         if (at::functionalization::impl::isFunctionalTensor(opt_tensors)) {
86*da0073e9SAndroid Build Coastguard Worker           any_functional_inputs = true;
87*da0073e9SAndroid Build Coastguard Worker           at::functionalization::impl::sync(opt_tensors);
88*da0073e9SAndroid Build Coastguard Worker           auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(opt_tensors));
89*da0073e9SAndroid Build Coastguard Worker           (*stack)[arguments_begin + idx] = t_new;
90*da0073e9SAndroid Build Coastguard Worker         }
91*da0073e9SAndroid Build Coastguard Worker       }
92*da0073e9SAndroid Build Coastguard Worker     }
93*da0073e9SAndroid Build Coastguard Worker     // we should wrap the output if any inputs were wrapped,
94*da0073e9SAndroid Build Coastguard Worker     // OR if we're hitting a factory function (with no tensor inputs)
95*da0073e9SAndroid Build Coastguard Worker     auto should_wrap_outputs = !any_tensor_inputs || any_functional_inputs;
96*da0073e9SAndroid Build Coastguard Worker     {
97*da0073e9SAndroid Build Coastguard Worker       at::AutoDispatchSkipFunctionalize guard;
98*da0073e9SAndroid Build Coastguard Worker       op.callBoxed(stack);
99*da0073e9SAndroid Build Coastguard Worker     }
100*da0073e9SAndroid Build Coastguard Worker     const auto num_returns = schema.returns().size();
101*da0073e9SAndroid Build Coastguard Worker     const auto returns_begin = stack->size() - num_returns;
102*da0073e9SAndroid Build Coastguard Worker     auto returns = torch::jit::last(stack, num_returns);
103*da0073e9SAndroid Build Coastguard Worker 
104*da0073e9SAndroid Build Coastguard Worker     for (const auto idx : c10::irange(num_returns)) {
105*da0073e9SAndroid Build Coastguard Worker       const auto& ivalue = returns[idx];
106*da0073e9SAndroid Build Coastguard Worker       if (ivalue.isTensor() && should_wrap_outputs) {
107*da0073e9SAndroid Build Coastguard Worker         const auto& t = ivalue.toTensor();
108*da0073e9SAndroid Build Coastguard Worker         if (!t.defined()) continue;
109*da0073e9SAndroid Build Coastguard Worker         auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(t));
110*da0073e9SAndroid Build Coastguard Worker         (*stack)[returns_begin + idx] = t_new;
111*da0073e9SAndroid Build Coastguard Worker       } else if (ivalue.isTensorList() && should_wrap_outputs) {
112*da0073e9SAndroid Build Coastguard Worker         auto tensors = ivalue.toTensorList();
113*da0073e9SAndroid Build Coastguard Worker         auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(tensors));
114*da0073e9SAndroid Build Coastguard Worker         (*stack)[returns_begin + idx] = t_new;
115*da0073e9SAndroid Build Coastguard Worker       } else if (ivalue.isOptionalTensorList() && should_wrap_outputs) {
116*da0073e9SAndroid Build Coastguard Worker         auto opt_tensors = ivalue.toOptionalTensorList();
117*da0073e9SAndroid Build Coastguard Worker         auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(opt_tensors));
118*da0073e9SAndroid Build Coastguard Worker         (*stack)[returns_begin + idx] = t_new;
119*da0073e9SAndroid Build Coastguard Worker       }
120*da0073e9SAndroid Build Coastguard Worker     }
121*da0073e9SAndroid Build Coastguard Worker   }
122*da0073e9SAndroid Build Coastguard Worker }
123*da0073e9SAndroid Build Coastguard Worker 
124*da0073e9SAndroid Build Coastguard Worker // resize_() is special because:
125*da0073e9SAndroid Build Coastguard Worker // - when we resize to a larger size, it acts as a mutation
126*da0073e9SAndroid Build Coastguard Worker // - when we resize to a smaller size, it acts as a view
127*da0073e9SAndroid Build Coastguard Worker // See Note [resize_ in Functionalization] for more dtails
resize__functionalization(c10::DispatchKeySet dispatchKeySet,const at::Tensor & self,at::IntArrayRef size,std::optional<at::MemoryFormat> memory_format)128*da0073e9SAndroid Build Coastguard Worker static const at::Tensor & resize__functionalization(c10::DispatchKeySet dispatchKeySet [[maybe_unused]], const at::Tensor & self, at::IntArrayRef size, std::optional<at::MemoryFormat> memory_format) {
129*da0073e9SAndroid Build Coastguard Worker   // First unwrap the tensor arguments
130*da0073e9SAndroid Build Coastguard Worker   at::Tensor self_;
131*da0073e9SAndroid Build Coastguard Worker   if (at::functionalization::impl::isFunctionalTensor(self)) {
132*da0073e9SAndroid Build Coastguard Worker     at::functionalization::impl::sync(self);
133*da0073e9SAndroid Build Coastguard Worker     self_ = at::functionalization::impl::from_functional_tensor(self);
134*da0073e9SAndroid Build Coastguard Worker   } else {
135*da0073e9SAndroid Build Coastguard Worker     self_ = self;
136*da0073e9SAndroid Build Coastguard Worker   }
137*da0073e9SAndroid Build Coastguard Worker   // Case 1: arguments are not functional tensors, so we no-op and redispatch.
138*da0073e9SAndroid Build Coastguard Worker   if (!at::functionalization::impl::isFunctionalTensor(self)) {
139*da0073e9SAndroid Build Coastguard Worker      at::AutoDispatchSkipFunctionalize guard;
140*da0073e9SAndroid Build Coastguard Worker      self_.resize_(size, memory_format);
141*da0073e9SAndroid Build Coastguard Worker      return self;
142*da0073e9SAndroid Build Coastguard Worker   }
143*da0073e9SAndroid Build Coastguard Worker 
144*da0073e9SAndroid Build Coastguard Worker   // Case 2: actually functionalize resize_()
145*da0073e9SAndroid Build Coastguard Worker   at::Tensor tmp_output;
146*da0073e9SAndroid Build Coastguard Worker   {
147*da0073e9SAndroid Build Coastguard Worker     at::AutoDispatchSkipFunctionalize guard;
148*da0073e9SAndroid Build Coastguard Worker     tmp_output = at::resize(self_, size, memory_format);
149*da0073e9SAndroid Build Coastguard Worker   }
150*da0073e9SAndroid Build Coastguard Worker 
151*da0073e9SAndroid Build Coastguard Worker   auto itemsize = self.dtype().itemsize();
152*da0073e9SAndroid Build Coastguard Worker   auto storage_offset = self.storage_offset();
153*da0073e9SAndroid Build Coastguard Worker   auto new_size_bytes = at::detail::computeStorageNbytesContiguous(size, itemsize, storage_offset);
154*da0073e9SAndroid Build Coastguard Worker   auto needs_resize_storage = new_size_bytes > self.storage().nbytes();
155*da0073e9SAndroid Build Coastguard Worker 
156*da0073e9SAndroid Build Coastguard Worker   if (needs_resize_storage) {
157*da0073e9SAndroid Build Coastguard Worker     // If resize_() actually increases the size of the storage, then we need to tell FunctionalTensorWrapper about it.
158*da0073e9SAndroid Build Coastguard Worker     // See Note[resize_() in functionalization pass]
159*da0073e9SAndroid Build Coastguard Worker     auto func_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self);
160*da0073e9SAndroid Build Coastguard Worker     func_impl->maybe_replace_storage(tmp_output);
161*da0073e9SAndroid Build Coastguard Worker     // See the note - we're guaranteed at this point that "self" is *not* a view (and has no outstanding views)
162*da0073e9SAndroid Build Coastguard Worker     // So we don't need to treat the output of resize as view tensor.
163*da0073e9SAndroid Build Coastguard Worker     return self;
164*da0073e9SAndroid Build Coastguard Worker   }
165*da0073e9SAndroid Build Coastguard Worker 
166*da0073e9SAndroid Build Coastguard Worker   // Otherwise, we know that we're resizing to a smaller size.
167*da0073e9SAndroid Build Coastguard Worker   // resize_() is effectively a view operator.
168*da0073e9SAndroid Build Coastguard Worker   // The output of resizing is equivalent to taking a slice of a larger tensor.
169*da0073e9SAndroid Build Coastguard Worker   // We have to emulate this "slicing" with an as_strided call.
170*da0073e9SAndroid Build Coastguard Worker   auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
171*da0073e9SAndroid Build Coastguard Worker   at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
172*da0073e9SAndroid Build Coastguard Worker     [reapply_views = reapply_views, size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor {
173*da0073e9SAndroid Build Coastguard Worker       if (reapply_views) {
174*da0073e9SAndroid Build Coastguard Worker         return base.as_strided(size, c10::contiguous_strides(size));
175*da0073e9SAndroid Build Coastguard Worker       } else {
176*da0073e9SAndroid Build Coastguard Worker         return at::as_strided_copy(base, size, c10::contiguous_strides(size));
177*da0073e9SAndroid Build Coastguard Worker       }
178*da0073e9SAndroid Build Coastguard Worker     },
179*da0073e9SAndroid Build Coastguard Worker     [size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor {
180*da0073e9SAndroid Build Coastguard Worker       return base.as_strided_scatter(mutated_view, size, c10::contiguous_strides(size));
181*da0073e9SAndroid Build Coastguard Worker     },
182*da0073e9SAndroid Build Coastguard Worker     /*has_symbolic_inputs=*/false
183*da0073e9SAndroid Build Coastguard Worker   );
184*da0073e9SAndroid Build Coastguard Worker   at::functionalization::impl::mutate_view_meta(self, view_meta);
185*da0073e9SAndroid Build Coastguard Worker   return self;
186*da0073e9SAndroid Build Coastguard Worker }
187*da0073e9SAndroid Build Coastguard Worker 
188*da0073e9SAndroid Build Coastguard Worker 
lift_functionalize(const at::Tensor & self)189*da0073e9SAndroid Build Coastguard Worker static at::Tensor lift_functionalize(const at::Tensor & self) {
190*da0073e9SAndroid Build Coastguard Worker   TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(self));
191*da0073e9SAndroid Build Coastguard Worker   at::AutoDispatchSkipFunctionalize guard;
192*da0073e9SAndroid Build Coastguard Worker   auto out = at::lift(self);
193*da0073e9SAndroid Build Coastguard Worker   return at::functionalization::impl::to_functional_tensor(out);
194*da0073e9SAndroid Build Coastguard Worker }
195*da0073e9SAndroid Build Coastguard Worker 
lift_fresh_functionalize(const at::Tensor & self)196*da0073e9SAndroid Build Coastguard Worker static at::Tensor lift_fresh_functionalize(const at::Tensor & self) {
197*da0073e9SAndroid Build Coastguard Worker   // See Note [Exporting and compiling a graph with lift_fresh_copy]
198*da0073e9SAndroid Build Coastguard Worker   if (at::functionalization::impl::isFunctionalTensor(self)) {
199*da0073e9SAndroid Build Coastguard Worker     return self.view_as(self);
200*da0073e9SAndroid Build Coastguard Worker   }
201*da0073e9SAndroid Build Coastguard Worker 
202*da0073e9SAndroid Build Coastguard Worker   at::AutoDispatchSkipFunctionalize guard;
203*da0073e9SAndroid Build Coastguard Worker   auto out = at::lift_fresh(self);
204*da0073e9SAndroid Build Coastguard Worker   return at::functionalization::impl::to_functional_tensor(out);
205*da0073e9SAndroid Build Coastguard Worker }
206*da0073e9SAndroid Build Coastguard Worker 
lift_fresh_functionalize_copy(const at::Tensor & self)207*da0073e9SAndroid Build Coastguard Worker static at::Tensor lift_fresh_functionalize_copy(const at::Tensor & self) {
208*da0073e9SAndroid Build Coastguard Worker   // Note [Exporting and compiling a graph with lift_fresh_copy]
209*da0073e9SAndroid Build Coastguard Worker   // If out is already a functional tensor, don't wrap it twice.
210*da0073e9SAndroid Build Coastguard Worker   // In theory this could be useful if we want to nest functionalization with itself,
211*da0073e9SAndroid Build Coastguard Worker   // but that isn't really a use case today.
212*da0073e9SAndroid Build Coastguard Worker   // Needed for https://github.com/pytorch/pytorch/issues/105327
213*da0073e9SAndroid Build Coastguard Worker   if (at::functionalization::impl::isFunctionalTensor(self)) {
214*da0073e9SAndroid Build Coastguard Worker     // Note [Composite Functionalization under PreDispatch mode]
215*da0073e9SAndroid Build Coastguard Worker     // When we are tracing under PreDispatch, PreDispatch key will be
216*da0073e9SAndroid Build Coastguard Worker     // in the local include TLS. As a result, when we redispatch here,
217*da0073e9SAndroid Build Coastguard Worker     // we will end up hitting PreDispatch stack first. So, we should
218*da0073e9SAndroid Build Coastguard Worker     // directly redispatch to the functionalize key manually.
219*da0073e9SAndroid Build Coastguard Worker     static auto op = c10::Dispatcher::singleton().findSchemaOrThrow("aten::clone", "").typed<at::Tensor(const at::Tensor &, std::optional<at::MemoryFormat>)>();
220*da0073e9SAndroid Build Coastguard Worker     return op.redispatch(c10::DispatchKeySet({c10::DispatchKey::Functionalize}), self, std::nullopt);
221*da0073e9SAndroid Build Coastguard Worker   }
222*da0073e9SAndroid Build Coastguard Worker 
223*da0073e9SAndroid Build Coastguard Worker   at::AutoDispatchSkipFunctionalize guard;
224*da0073e9SAndroid Build Coastguard Worker   auto out = at::lift_fresh_copy(self);
225*da0073e9SAndroid Build Coastguard Worker   return at::functionalization::impl::to_functional_tensor(out);
226*da0073e9SAndroid Build Coastguard Worker }
227*da0073e9SAndroid Build Coastguard Worker 
device_opted_into_functionalization(c10::Device self_device,std::optional<c10::Device> tgt_device)228*da0073e9SAndroid Build Coastguard Worker static bool device_opted_into_functionalization(c10::Device self_device, std::optional<c10::Device> tgt_device) {
229*da0073e9SAndroid Build Coastguard Worker     // If the target device is empty, then the output tensor should be on the same device as the input
230*da0073e9SAndroid Build Coastguard Worker     auto real_tgt_device = tgt_device.has_value() ? tgt_device.value() : self_device;
231*da0073e9SAndroid Build Coastguard Worker     return real_tgt_device.type() == c10::DeviceType::XLA || real_tgt_device.type() == c10::DeviceType::Lazy;
232*da0073e9SAndroid Build Coastguard Worker }
233*da0073e9SAndroid Build Coastguard Worker 
234*da0073e9SAndroid Build Coastguard Worker // note I only need this because the to.dtype/to.dtype_layout overload calls this, so we skip the op above.
235*da0073e9SAndroid Build Coastguard Worker // We should probably get rid of this though.
_to_copy_functionalize(const at::Tensor & self,std::optional<at::ScalarType> dtype,std::optional<at::Layout> layout,std::optional<at::Device> device,std::optional<bool> pin_memory,bool non_blocking,std::optional<at::MemoryFormat> memory_format)236*da0073e9SAndroid Build Coastguard Worker static at::Tensor _to_copy_functionalize(
237*da0073e9SAndroid Build Coastguard Worker         const at::Tensor & self,
238*da0073e9SAndroid Build Coastguard Worker         std::optional<at::ScalarType> dtype,
239*da0073e9SAndroid Build Coastguard Worker         std::optional<at::Layout> layout,
240*da0073e9SAndroid Build Coastguard Worker         std::optional<at::Device> device,
241*da0073e9SAndroid Build Coastguard Worker         std::optional<bool> pin_memory,
242*da0073e9SAndroid Build Coastguard Worker         bool non_blocking,
243*da0073e9SAndroid Build Coastguard Worker         std::optional<at::MemoryFormat> memory_format) {
244*da0073e9SAndroid Build Coastguard Worker   at::Tensor self_;
245*da0073e9SAndroid Build Coastguard Worker   if (at::functionalization::impl::isFunctionalTensor(self)) {
246*da0073e9SAndroid Build Coastguard Worker     // sync any pending updates
247*da0073e9SAndroid Build Coastguard Worker     at::functionalization::impl::sync(self);
248*da0073e9SAndroid Build Coastguard Worker     // pass the unwrapped tensor to the backend
249*da0073e9SAndroid Build Coastguard Worker     self_ = at::functionalization::impl::from_functional_tensor(self);
250*da0073e9SAndroid Build Coastguard Worker   } else {
251*da0073e9SAndroid Build Coastguard Worker     self_ = self;
252*da0073e9SAndroid Build Coastguard Worker   }
253*da0073e9SAndroid Build Coastguard Worker 
254*da0073e9SAndroid Build Coastguard Worker   at::AutoDispatchSkipFunctionalize guard;
255*da0073e9SAndroid Build Coastguard Worker   auto out = at::_to_copy(self_, dtype, layout, device, pin_memory, non_blocking, memory_format);
256*da0073e9SAndroid Build Coastguard Worker 
257*da0073e9SAndroid Build Coastguard Worker   // Special case: if the Functionalize key is not in TLS, we assume that we're running
258*da0073e9SAndroid Build Coastguard Worker   // on a lazy backend (LTC).
259*da0073e9SAndroid Build Coastguard Worker   // In that case, if we're copying to a non-functionalize-enabled device,
260*da0073e9SAndroid Build Coastguard Worker   // then the functionalization pass should "end". We need to sync any updates on the input
261*da0073e9SAndroid Build Coastguard Worker   // tensor, but we shouldn't wrap the output.
262*da0073e9SAndroid Build Coastguard Worker   if (!c10::impl::tls_local_dispatch_key_set().included_.has(c10::DispatchKey::Functionalize)) {
263*da0073e9SAndroid Build Coastguard Worker     if (!device_opted_into_functionalization(self.device(), device)) {
264*da0073e9SAndroid Build Coastguard Worker       return out;
265*da0073e9SAndroid Build Coastguard Worker     }
266*da0073e9SAndroid Build Coastguard Worker   }
267*da0073e9SAndroid Build Coastguard Worker   return at::functionalization::impl::to_functional_tensor(out);
268*da0073e9SAndroid Build Coastguard Worker }
269*da0073e9SAndroid Build Coastguard Worker 
270*da0073e9SAndroid Build Coastguard Worker 
271*da0073e9SAndroid Build Coastguard Worker // Why is _unsafe_view special-cased here?
272*da0073e9SAndroid Build Coastguard Worker // Basically just to satisfy autograd's debug asserts.
273*da0073e9SAndroid Build Coastguard Worker // The situation:
274*da0073e9SAndroid Build Coastguard Worker // - _unsafe_view's autograd kernel has debug asserts to confirm
275*da0073e9SAndroid Build Coastguard Worker //   that the input and output alias storage.
276*da0073e9SAndroid Build Coastguard Worker // - _unsafe_view's schema in native_functions.yaml
277*da0073e9SAndroid Build Coastguard Worker //   does not contain alias annotations, so it advertises as non-aliasing.
278*da0073e9SAndroid Build Coastguard Worker // - functionalization will then treat _unsafe_view like a non-aliasing op.
279*da0073e9SAndroid Build Coastguard Worker //   Specifically, autograd will redispatch to functionalization's
280*da0073e9SAndroid Build Coastguard Worker //   boxed fallback kernel, which creates a new FunctionalTensorWrapper output
281*da0073e9SAndroid Build Coastguard Worker //   that does **not** alias storage with the input, tripping the assert.
282*da0073e9SAndroid Build Coastguard Worker // The kernel written here just manually re-ifies the aliasing relationship.
283*da0073e9SAndroid Build Coastguard Worker //
284*da0073e9SAndroid Build Coastguard Worker // Another way to handle this would be to fix unsafe_view's alias annotations
285*da0073e9SAndroid Build Coastguard Worker // in native_functions.yaml, but I think this would be a pessimization.
286*da0073e9SAndroid Build Coastguard Worker // The idea with _unsafe_view is that you're guaranteed that the input
287*da0073e9SAndroid Build Coastguard Worker // is a temporary, and don't actually have to worry about propagating
288*da0073e9SAndroid Build Coastguard Worker // mutations between the input and output.
_unsafe_view_functionalize(const at::Tensor & self,at::SymIntArrayRef size)289*da0073e9SAndroid Build Coastguard Worker static at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::SymIntArrayRef size) {
290*da0073e9SAndroid Build Coastguard Worker   if (!at::functionalization::impl::isFunctionalTensor(self)) {
291*da0073e9SAndroid Build Coastguard Worker     at::AutoDispatchSkipFunctionalize guard;
292*da0073e9SAndroid Build Coastguard Worker     return at::_unsafe_view_symint(self, size);
293*da0073e9SAndroid Build Coastguard Worker   }
294*da0073e9SAndroid Build Coastguard Worker 
295*da0073e9SAndroid Build Coastguard Worker   auto self_ = at::functionalization::impl::from_functional_tensor(self);
296*da0073e9SAndroid Build Coastguard Worker   at::Tensor tmp_output;
297*da0073e9SAndroid Build Coastguard Worker   {
298*da0073e9SAndroid Build Coastguard Worker     at::AutoDispatchSkipFunctionalize guard;
299*da0073e9SAndroid Build Coastguard Worker     tmp_output = at::_unsafe_view_symint(self_, size);
300*da0073e9SAndroid Build Coastguard Worker   }
301*da0073e9SAndroid Build Coastguard Worker 
302*da0073e9SAndroid Build Coastguard Worker   bool has_symbolic_inputs = std::any_of(size.begin(), size.end(), [=](auto& s) { return s.is_symbolic(); });
303*da0073e9SAndroid Build Coastguard Worker 
304*da0073e9SAndroid Build Coastguard Worker   at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
305*da0073e9SAndroid Build Coastguard Worker     [size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor {
306*da0073e9SAndroid Build Coastguard Worker       return at::_unsafe_view_symint(base, size);
307*da0073e9SAndroid Build Coastguard Worker     },
308*da0073e9SAndroid Build Coastguard Worker     [size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor {
309*da0073e9SAndroid Build Coastguard Worker       return at::_unsafe_view_symint(mutated_view, base.sym_sizes());
310*da0073e9SAndroid Build Coastguard Worker     },
311*da0073e9SAndroid Build Coastguard Worker     /*has_symbolic_inputs=*/has_symbolic_inputs
312*da0073e9SAndroid Build Coastguard Worker   );
313*da0073e9SAndroid Build Coastguard Worker 
314*da0073e9SAndroid Build Coastguard Worker   auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, self, std::move(view_meta));
315*da0073e9SAndroid Build Coastguard Worker   // See  Note [Propagating strides in the functionalization pass]
316*da0073e9SAndroid Build Coastguard Worker   // (for _unsafe_view, I'm just manually doing the shape inference rule here instead of calling the meta function for unsafe_view)
317*da0073e9SAndroid Build Coastguard Worker   auto inferred_size = at::infer_size_dv(size, self.sym_numel());
318*da0073e9SAndroid Build Coastguard Worker   auto stride = at::detail::computeStride(self.sym_sizes(), self.sym_strides(), inferred_size);
319*da0073e9SAndroid Build Coastguard Worker   TORCH_INTERNAL_ASSERT(stride.has_value());
320*da0073e9SAndroid Build Coastguard Worker   out.unsafeGetTensorImpl()->set_sizes_and_strides(inferred_size, stride.value());
321*da0073e9SAndroid Build Coastguard Worker   return out;
322*da0073e9SAndroid Build Coastguard Worker }
323*da0073e9SAndroid Build Coastguard Worker 
set__functionalize(at::Tensor & self,const at::Tensor & src)324*da0073e9SAndroid Build Coastguard Worker static at::Tensor& set__functionalize(at::Tensor& self, const at::Tensor& src) {
325*da0073e9SAndroid Build Coastguard Worker   // error case
326*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(at::functionalization::impl::isFunctionalTensor(self) || !at::functionalization::impl::isFunctionalTensor(src),
327*da0073e9SAndroid Build Coastguard Worker     "set__functionalize: Tried to mutate a non-functional tensor with a functional tensor, which is not allowed");
328*da0073e9SAndroid Build Coastguard Worker 
329*da0073e9SAndroid Build Coastguard Worker   // nop case
330*da0073e9SAndroid Build Coastguard Worker   if (!at::functionalization::impl::isFunctionalTensor(self) && !at::functionalization::impl::isFunctionalTensor(src)) {
331*da0073e9SAndroid Build Coastguard Worker     at::AutoDispatchSkipFunctionalize guard;
332*da0073e9SAndroid Build Coastguard Worker     return self.set_(src);
333*da0073e9SAndroid Build Coastguard Worker   }
334*da0073e9SAndroid Build Coastguard Worker 
335*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(at::functionalization::impl::isFunctionalTensor(src),
336*da0073e9SAndroid Build Coastguard Worker     "set__functionalize: We do not currently support x.set_(y) where y is not a FunctionalTensor. Please file an issue");
337*da0073e9SAndroid Build Coastguard Worker 
338*da0073e9SAndroid Build Coastguard Worker   TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self));
339*da0073e9SAndroid Build Coastguard Worker   TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(src));
340*da0073e9SAndroid Build Coastguard Worker   auto self_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self);
341*da0073e9SAndroid Build Coastguard Worker   auto src_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(src);
342*da0073e9SAndroid Build Coastguard Worker   // See Note [Ordering of resize_() and set_()]
343*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(!self_impl->was_inductor_storage_resized(),
344*da0073e9SAndroid Build Coastguard Worker     "storage_resize_() followed by set_() in torch.compile is not supported today");
345*da0073e9SAndroid Build Coastguard Worker   self_impl->set__impl(src_impl);
346*da0073e9SAndroid Build Coastguard Worker   return self;
347*da0073e9SAndroid Build Coastguard Worker }
348*da0073e9SAndroid Build Coastguard Worker 
TORCH_LIBRARY_IMPL(_,Functionalize,m)349*da0073e9SAndroid Build Coastguard Worker TORCH_LIBRARY_IMPL(_, Functionalize, m) {
350*da0073e9SAndroid Build Coastguard Worker   m.fallback(torch::CppFunction::makeFromBoxedFunction<&functionalizeFallback>());
351*da0073e9SAndroid Build Coastguard Worker }
352*da0073e9SAndroid Build Coastguard Worker 
TORCH_LIBRARY_IMPL(aten,Functionalize,m)353*da0073e9SAndroid Build Coastguard Worker TORCH_LIBRARY_IMPL(aten, Functionalize, m) {
354*da0073e9SAndroid Build Coastguard Worker   m.impl("resize_", TORCH_FN(resize__functionalization));
355*da0073e9SAndroid Build Coastguard Worker   m.impl("lift", TORCH_FN(lift_functionalize));
356*da0073e9SAndroid Build Coastguard Worker   m.impl("lift_fresh", TORCH_FN(lift_fresh_functionalize));
357*da0073e9SAndroid Build Coastguard Worker   m.impl("lift_fresh_copy", TORCH_FN(lift_fresh_functionalize_copy));
358*da0073e9SAndroid Build Coastguard Worker   m.impl("_to_copy", TORCH_FN(_to_copy_functionalize));
359*da0073e9SAndroid Build Coastguard Worker   m.impl("_unsafe_view", TORCH_FN(_unsafe_view_functionalize));
360*da0073e9SAndroid Build Coastguard Worker   // The overloads of set_() that take in a storage should never
361*da0073e9SAndroid Build Coastguard Worker   // appear with torch.compile, because dynamo graph breaks
362*da0073e9SAndroid Build Coastguard Worker   m.impl("set_.source_Tensor", TORCH_FN(set__functionalize));
363*da0073e9SAndroid Build Coastguard Worker }
364