1*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/dispatch/Dispatcher.h>
2*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/LegacyTypeDispatch.h>
3*da0073e9SAndroid Build Coastguard Worker #include <ATen/EmptyTensor.h>
4*da0073e9SAndroid Build Coastguard Worker #include <ATen/FunctionalTensorWrapper.h>
5*da0073e9SAndroid Build Coastguard Worker #include <ATen/InferSize.h>
6*da0073e9SAndroid Build Coastguard Worker #include <ATen/TensorUtils.h>
7*da0073e9SAndroid Build Coastguard Worker #include <torch/library.h>
8*da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h>
9*da0073e9SAndroid Build Coastguard Worker #include <c10/util/strides.h>
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker #ifndef AT_PER_OPERATOR_HEADERS
12*da0073e9SAndroid Build Coastguard Worker #include <ATen/ATen.h>
13*da0073e9SAndroid Build Coastguard Worker #include <ATen/Functions.h>
14*da0073e9SAndroid Build Coastguard Worker #include <ATen/NativeFunctions.h>
15*da0073e9SAndroid Build Coastguard Worker #else
16*da0073e9SAndroid Build Coastguard Worker #include <ATen/ops/_to_copy.h>
17*da0073e9SAndroid Build Coastguard Worker #include <ATen/ops/to_native.h>
18*da0073e9SAndroid Build Coastguard Worker #include <ATen/ops/lift.h>
19*da0073e9SAndroid Build Coastguard Worker #include <ATen/ops/lift_fresh.h>
20*da0073e9SAndroid Build Coastguard Worker #include <ATen/ops/lift_fresh_copy.h>
21*da0073e9SAndroid Build Coastguard Worker #include <ATen/ops/resize.h>
22*da0073e9SAndroid Build Coastguard Worker #include <ATen/ops/as_strided.h>
23*da0073e9SAndroid Build Coastguard Worker #include <ATen/ops/as_strided_copy.h>
24*da0073e9SAndroid Build Coastguard Worker #include <ATen/ops/empty_strided_native.h>
25*da0073e9SAndroid Build Coastguard Worker #include <ATen/ops/_unsafe_view.h>
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker #include <utility>
28*da0073e9SAndroid Build Coastguard Worker #endif
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker namespace {
functionalizeFallback(const c10::OperatorHandle & op,c10::DispatchKeySet dispatchKeySet,torch::jit::Stack * stack)31*da0073e9SAndroid Build Coastguard Worker void functionalizeFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatchKeySet [[maybe_unused]], torch::jit::Stack* stack) {
32*da0073e9SAndroid Build Coastguard Worker const auto& schema = op.schema();
33*da0073e9SAndroid Build Coastguard Worker // NB: auto_functionalize handles the case where outputs do not have alias info.
34*da0073e9SAndroid Build Coastguard Worker // This error message therefore suggests users to modify their custom op to the
35*da0073e9SAndroid Build Coastguard Worker // point where auto_functionalize works instead of asking them to try the raw
36*da0073e9SAndroid Build Coastguard Worker // functionalization API (because that is a bit difficult to use).
37*da0073e9SAndroid Build Coastguard Worker // If you're here and want to try the raw functionalizaton kernel approach,
38*da0073e9SAndroid Build Coastguard Worker // see https://gist.github.com/bdhirsh/7dadbf6296f8f7d1abcf4c482f438aaa
39*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
40*da0073e9SAndroid Build Coastguard Worker !schema.hasAnyAliasInfo(),
41*da0073e9SAndroid Build Coastguard Worker "Found a custom (non-ATen) operator whose output has alias annotations: ",
42*da0073e9SAndroid Build Coastguard Worker op.schema(),
43*da0073e9SAndroid Build Coastguard Worker ". We only support functionalizing operators whose outputs do not have alias ",
44*da0073e9SAndroid Build Coastguard Worker "annotations (e.g. 'Tensor(a)' is a Tensor with an alias annotation whereas ",
45*da0073e9SAndroid Build Coastguard Worker "'Tensor' is a Tensor without. The '(a)' is the alias annotation). "
46*da0073e9SAndroid Build Coastguard Worker "The alias annotation specifies that the output ",
47*da0073e9SAndroid Build Coastguard Worker "Tensor shares storage with an input that has the same annotation. ",
48*da0073e9SAndroid Build Coastguard Worker "Please check if ",
49*da0073e9SAndroid Build Coastguard Worker "(1) the output needs to be an output (if not, don't return it), ",
50*da0073e9SAndroid Build Coastguard Worker "(2) if the output doesn't share storage with any inputs, then ",
51*da0073e9SAndroid Build Coastguard Worker "delete the alias annotation. ",
52*da0073e9SAndroid Build Coastguard Worker "(3) if the output indeed shares storage with an input, then add a ",
53*da0073e9SAndroid Build Coastguard Worker ".clone() before returning it to prevent storage sharing and then "
54*da0073e9SAndroid Build Coastguard Worker "delete the alias annotation. ",
55*da0073e9SAndroid Build Coastguard Worker "Otherwise, please file an issue on GitHub.");
56*da0073e9SAndroid Build Coastguard Worker const auto num_arguments = schema.arguments().size();
57*da0073e9SAndroid Build Coastguard Worker const auto arguments_begin = stack->size() - num_arguments;
58*da0073e9SAndroid Build Coastguard Worker auto arguments = torch::jit::last(stack, num_arguments);
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker auto any_functional_inputs = false;
61*da0073e9SAndroid Build Coastguard Worker auto any_tensor_inputs = false;
62*da0073e9SAndroid Build Coastguard Worker for (uint64_t idx = 0; idx < num_arguments; ++idx) {
63*da0073e9SAndroid Build Coastguard Worker const auto& ivalue = arguments[idx];
64*da0073e9SAndroid Build Coastguard Worker if (ivalue.isTensor()) {
65*da0073e9SAndroid Build Coastguard Worker any_tensor_inputs = true;
66*da0073e9SAndroid Build Coastguard Worker const auto& t = ivalue.toTensor();
67*da0073e9SAndroid Build Coastguard Worker if (t.defined() && at::functionalization::impl::isFunctionalTensor(t)) {
68*da0073e9SAndroid Build Coastguard Worker any_functional_inputs = true;
69*da0073e9SAndroid Build Coastguard Worker at::functionalization::impl::sync(t);
70*da0073e9SAndroid Build Coastguard Worker auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(t));
71*da0073e9SAndroid Build Coastguard Worker (*stack)[arguments_begin + idx] = t_new;
72*da0073e9SAndroid Build Coastguard Worker }
73*da0073e9SAndroid Build Coastguard Worker } else if (ivalue.isTensorList()) {
74*da0073e9SAndroid Build Coastguard Worker any_tensor_inputs = true;
75*da0073e9SAndroid Build Coastguard Worker auto tensors = ivalue.toTensorList();
76*da0073e9SAndroid Build Coastguard Worker if (at::functionalization::impl::isFunctionalTensor(tensors)) {
77*da0073e9SAndroid Build Coastguard Worker any_functional_inputs = true;
78*da0073e9SAndroid Build Coastguard Worker at::functionalization::impl::sync(tensors);
79*da0073e9SAndroid Build Coastguard Worker auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(tensors));
80*da0073e9SAndroid Build Coastguard Worker (*stack)[arguments_begin + idx] = t_new;
81*da0073e9SAndroid Build Coastguard Worker }
82*da0073e9SAndroid Build Coastguard Worker } else if (ivalue.isOptionalTensorList()) {
83*da0073e9SAndroid Build Coastguard Worker any_tensor_inputs = true;
84*da0073e9SAndroid Build Coastguard Worker auto opt_tensors = ivalue.toOptionalTensorList();
85*da0073e9SAndroid Build Coastguard Worker if (at::functionalization::impl::isFunctionalTensor(opt_tensors)) {
86*da0073e9SAndroid Build Coastguard Worker any_functional_inputs = true;
87*da0073e9SAndroid Build Coastguard Worker at::functionalization::impl::sync(opt_tensors);
88*da0073e9SAndroid Build Coastguard Worker auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(opt_tensors));
89*da0073e9SAndroid Build Coastguard Worker (*stack)[arguments_begin + idx] = t_new;
90*da0073e9SAndroid Build Coastguard Worker }
91*da0073e9SAndroid Build Coastguard Worker }
92*da0073e9SAndroid Build Coastguard Worker }
93*da0073e9SAndroid Build Coastguard Worker // we should wrap the output if any inputs were wrapped,
94*da0073e9SAndroid Build Coastguard Worker // OR if we're hitting a factory function (with no tensor inputs)
95*da0073e9SAndroid Build Coastguard Worker auto should_wrap_outputs = !any_tensor_inputs || any_functional_inputs;
96*da0073e9SAndroid Build Coastguard Worker {
97*da0073e9SAndroid Build Coastguard Worker at::AutoDispatchSkipFunctionalize guard;
98*da0073e9SAndroid Build Coastguard Worker op.callBoxed(stack);
99*da0073e9SAndroid Build Coastguard Worker }
100*da0073e9SAndroid Build Coastguard Worker const auto num_returns = schema.returns().size();
101*da0073e9SAndroid Build Coastguard Worker const auto returns_begin = stack->size() - num_returns;
102*da0073e9SAndroid Build Coastguard Worker auto returns = torch::jit::last(stack, num_returns);
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker for (const auto idx : c10::irange(num_returns)) {
105*da0073e9SAndroid Build Coastguard Worker const auto& ivalue = returns[idx];
106*da0073e9SAndroid Build Coastguard Worker if (ivalue.isTensor() && should_wrap_outputs) {
107*da0073e9SAndroid Build Coastguard Worker const auto& t = ivalue.toTensor();
108*da0073e9SAndroid Build Coastguard Worker if (!t.defined()) continue;
109*da0073e9SAndroid Build Coastguard Worker auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(t));
110*da0073e9SAndroid Build Coastguard Worker (*stack)[returns_begin + idx] = t_new;
111*da0073e9SAndroid Build Coastguard Worker } else if (ivalue.isTensorList() && should_wrap_outputs) {
112*da0073e9SAndroid Build Coastguard Worker auto tensors = ivalue.toTensorList();
113*da0073e9SAndroid Build Coastguard Worker auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(tensors));
114*da0073e9SAndroid Build Coastguard Worker (*stack)[returns_begin + idx] = t_new;
115*da0073e9SAndroid Build Coastguard Worker } else if (ivalue.isOptionalTensorList() && should_wrap_outputs) {
116*da0073e9SAndroid Build Coastguard Worker auto opt_tensors = ivalue.toOptionalTensorList();
117*da0073e9SAndroid Build Coastguard Worker auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(opt_tensors));
118*da0073e9SAndroid Build Coastguard Worker (*stack)[returns_begin + idx] = t_new;
119*da0073e9SAndroid Build Coastguard Worker }
120*da0073e9SAndroid Build Coastguard Worker }
121*da0073e9SAndroid Build Coastguard Worker }
122*da0073e9SAndroid Build Coastguard Worker }
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker // resize_() is special because:
125*da0073e9SAndroid Build Coastguard Worker // - when we resize to a larger size, it acts as a mutation
126*da0073e9SAndroid Build Coastguard Worker // - when we resize to a smaller size, it acts as a view
127*da0073e9SAndroid Build Coastguard Worker // See Note [resize_ in Functionalization] for more dtails
resize__functionalization(c10::DispatchKeySet dispatchKeySet,const at::Tensor & self,at::IntArrayRef size,std::optional<at::MemoryFormat> memory_format)128*da0073e9SAndroid Build Coastguard Worker static const at::Tensor & resize__functionalization(c10::DispatchKeySet dispatchKeySet [[maybe_unused]], const at::Tensor & self, at::IntArrayRef size, std::optional<at::MemoryFormat> memory_format) {
129*da0073e9SAndroid Build Coastguard Worker // First unwrap the tensor arguments
130*da0073e9SAndroid Build Coastguard Worker at::Tensor self_;
131*da0073e9SAndroid Build Coastguard Worker if (at::functionalization::impl::isFunctionalTensor(self)) {
132*da0073e9SAndroid Build Coastguard Worker at::functionalization::impl::sync(self);
133*da0073e9SAndroid Build Coastguard Worker self_ = at::functionalization::impl::from_functional_tensor(self);
134*da0073e9SAndroid Build Coastguard Worker } else {
135*da0073e9SAndroid Build Coastguard Worker self_ = self;
136*da0073e9SAndroid Build Coastguard Worker }
137*da0073e9SAndroid Build Coastguard Worker // Case 1: arguments are not functional tensors, so we no-op and redispatch.
138*da0073e9SAndroid Build Coastguard Worker if (!at::functionalization::impl::isFunctionalTensor(self)) {
139*da0073e9SAndroid Build Coastguard Worker at::AutoDispatchSkipFunctionalize guard;
140*da0073e9SAndroid Build Coastguard Worker self_.resize_(size, memory_format);
141*da0073e9SAndroid Build Coastguard Worker return self;
142*da0073e9SAndroid Build Coastguard Worker }
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker // Case 2: actually functionalize resize_()
145*da0073e9SAndroid Build Coastguard Worker at::Tensor tmp_output;
146*da0073e9SAndroid Build Coastguard Worker {
147*da0073e9SAndroid Build Coastguard Worker at::AutoDispatchSkipFunctionalize guard;
148*da0073e9SAndroid Build Coastguard Worker tmp_output = at::resize(self_, size, memory_format);
149*da0073e9SAndroid Build Coastguard Worker }
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard Worker auto itemsize = self.dtype().itemsize();
152*da0073e9SAndroid Build Coastguard Worker auto storage_offset = self.storage_offset();
153*da0073e9SAndroid Build Coastguard Worker auto new_size_bytes = at::detail::computeStorageNbytesContiguous(size, itemsize, storage_offset);
154*da0073e9SAndroid Build Coastguard Worker auto needs_resize_storage = new_size_bytes > self.storage().nbytes();
155*da0073e9SAndroid Build Coastguard Worker
156*da0073e9SAndroid Build Coastguard Worker if (needs_resize_storage) {
157*da0073e9SAndroid Build Coastguard Worker // If resize_() actually increases the size of the storage, then we need to tell FunctionalTensorWrapper about it.
158*da0073e9SAndroid Build Coastguard Worker // See Note[resize_() in functionalization pass]
159*da0073e9SAndroid Build Coastguard Worker auto func_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self);
160*da0073e9SAndroid Build Coastguard Worker func_impl->maybe_replace_storage(tmp_output);
161*da0073e9SAndroid Build Coastguard Worker // See the note - we're guaranteed at this point that "self" is *not* a view (and has no outstanding views)
162*da0073e9SAndroid Build Coastguard Worker // So we don't need to treat the output of resize as view tensor.
163*da0073e9SAndroid Build Coastguard Worker return self;
164*da0073e9SAndroid Build Coastguard Worker }
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Worker // Otherwise, we know that we're resizing to a smaller size.
167*da0073e9SAndroid Build Coastguard Worker // resize_() is effectively a view operator.
168*da0073e9SAndroid Build Coastguard Worker // The output of resizing is equivalent to taking a slice of a larger tensor.
169*da0073e9SAndroid Build Coastguard Worker // We have to emulate this "slicing" with an as_strided call.
170*da0073e9SAndroid Build Coastguard Worker auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
171*da0073e9SAndroid Build Coastguard Worker at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
172*da0073e9SAndroid Build Coastguard Worker [reapply_views = reapply_views, size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor {
173*da0073e9SAndroid Build Coastguard Worker if (reapply_views) {
174*da0073e9SAndroid Build Coastguard Worker return base.as_strided(size, c10::contiguous_strides(size));
175*da0073e9SAndroid Build Coastguard Worker } else {
176*da0073e9SAndroid Build Coastguard Worker return at::as_strided_copy(base, size, c10::contiguous_strides(size));
177*da0073e9SAndroid Build Coastguard Worker }
178*da0073e9SAndroid Build Coastguard Worker },
179*da0073e9SAndroid Build Coastguard Worker [size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor {
180*da0073e9SAndroid Build Coastguard Worker return base.as_strided_scatter(mutated_view, size, c10::contiguous_strides(size));
181*da0073e9SAndroid Build Coastguard Worker },
182*da0073e9SAndroid Build Coastguard Worker /*has_symbolic_inputs=*/false
183*da0073e9SAndroid Build Coastguard Worker );
184*da0073e9SAndroid Build Coastguard Worker at::functionalization::impl::mutate_view_meta(self, view_meta);
185*da0073e9SAndroid Build Coastguard Worker return self;
186*da0073e9SAndroid Build Coastguard Worker }
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard Worker
lift_functionalize(const at::Tensor & self)189*da0073e9SAndroid Build Coastguard Worker static at::Tensor lift_functionalize(const at::Tensor & self) {
190*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(self));
191*da0073e9SAndroid Build Coastguard Worker at::AutoDispatchSkipFunctionalize guard;
192*da0073e9SAndroid Build Coastguard Worker auto out = at::lift(self);
193*da0073e9SAndroid Build Coastguard Worker return at::functionalization::impl::to_functional_tensor(out);
194*da0073e9SAndroid Build Coastguard Worker }
195*da0073e9SAndroid Build Coastguard Worker
lift_fresh_functionalize(const at::Tensor & self)196*da0073e9SAndroid Build Coastguard Worker static at::Tensor lift_fresh_functionalize(const at::Tensor & self) {
197*da0073e9SAndroid Build Coastguard Worker // See Note [Exporting and compiling a graph with lift_fresh_copy]
198*da0073e9SAndroid Build Coastguard Worker if (at::functionalization::impl::isFunctionalTensor(self)) {
199*da0073e9SAndroid Build Coastguard Worker return self.view_as(self);
200*da0073e9SAndroid Build Coastguard Worker }
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker at::AutoDispatchSkipFunctionalize guard;
203*da0073e9SAndroid Build Coastguard Worker auto out = at::lift_fresh(self);
204*da0073e9SAndroid Build Coastguard Worker return at::functionalization::impl::to_functional_tensor(out);
205*da0073e9SAndroid Build Coastguard Worker }
206*da0073e9SAndroid Build Coastguard Worker
lift_fresh_functionalize_copy(const at::Tensor & self)207*da0073e9SAndroid Build Coastguard Worker static at::Tensor lift_fresh_functionalize_copy(const at::Tensor & self) {
208*da0073e9SAndroid Build Coastguard Worker // Note [Exporting and compiling a graph with lift_fresh_copy]
209*da0073e9SAndroid Build Coastguard Worker // If out is already a functional tensor, don't wrap it twice.
210*da0073e9SAndroid Build Coastguard Worker // In theory this could be useful if we want to nest functionalization with itself,
211*da0073e9SAndroid Build Coastguard Worker // but that isn't really a use case today.
212*da0073e9SAndroid Build Coastguard Worker // Needed for https://github.com/pytorch/pytorch/issues/105327
213*da0073e9SAndroid Build Coastguard Worker if (at::functionalization::impl::isFunctionalTensor(self)) {
214*da0073e9SAndroid Build Coastguard Worker // Note [Composite Functionalization under PreDispatch mode]
215*da0073e9SAndroid Build Coastguard Worker // When we are tracing under PreDispatch, PreDispatch key will be
216*da0073e9SAndroid Build Coastguard Worker // in the local include TLS. As a result, when we redispatch here,
217*da0073e9SAndroid Build Coastguard Worker // we will end up hitting PreDispatch stack first. So, we should
218*da0073e9SAndroid Build Coastguard Worker // directly redispatch to the functionalize key manually.
219*da0073e9SAndroid Build Coastguard Worker static auto op = c10::Dispatcher::singleton().findSchemaOrThrow("aten::clone", "").typed<at::Tensor(const at::Tensor &, std::optional<at::MemoryFormat>)>();
220*da0073e9SAndroid Build Coastguard Worker return op.redispatch(c10::DispatchKeySet({c10::DispatchKey::Functionalize}), self, std::nullopt);
221*da0073e9SAndroid Build Coastguard Worker }
222*da0073e9SAndroid Build Coastguard Worker
223*da0073e9SAndroid Build Coastguard Worker at::AutoDispatchSkipFunctionalize guard;
224*da0073e9SAndroid Build Coastguard Worker auto out = at::lift_fresh_copy(self);
225*da0073e9SAndroid Build Coastguard Worker return at::functionalization::impl::to_functional_tensor(out);
226*da0073e9SAndroid Build Coastguard Worker }
227*da0073e9SAndroid Build Coastguard Worker
device_opted_into_functionalization(c10::Device self_device,std::optional<c10::Device> tgt_device)228*da0073e9SAndroid Build Coastguard Worker static bool device_opted_into_functionalization(c10::Device self_device, std::optional<c10::Device> tgt_device) {
229*da0073e9SAndroid Build Coastguard Worker // If the target device is empty, then the output tensor should be on the same device as the input
230*da0073e9SAndroid Build Coastguard Worker auto real_tgt_device = tgt_device.has_value() ? tgt_device.value() : self_device;
231*da0073e9SAndroid Build Coastguard Worker return real_tgt_device.type() == c10::DeviceType::XLA || real_tgt_device.type() == c10::DeviceType::Lazy;
232*da0073e9SAndroid Build Coastguard Worker }
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker // note I only need this because the to.dtype/to.dtype_layout overload calls this, so we skip the op above.
235*da0073e9SAndroid Build Coastguard Worker // We should probably get rid of this though.
_to_copy_functionalize(const at::Tensor & self,std::optional<at::ScalarType> dtype,std::optional<at::Layout> layout,std::optional<at::Device> device,std::optional<bool> pin_memory,bool non_blocking,std::optional<at::MemoryFormat> memory_format)236*da0073e9SAndroid Build Coastguard Worker static at::Tensor _to_copy_functionalize(
237*da0073e9SAndroid Build Coastguard Worker const at::Tensor & self,
238*da0073e9SAndroid Build Coastguard Worker std::optional<at::ScalarType> dtype,
239*da0073e9SAndroid Build Coastguard Worker std::optional<at::Layout> layout,
240*da0073e9SAndroid Build Coastguard Worker std::optional<at::Device> device,
241*da0073e9SAndroid Build Coastguard Worker std::optional<bool> pin_memory,
242*da0073e9SAndroid Build Coastguard Worker bool non_blocking,
243*da0073e9SAndroid Build Coastguard Worker std::optional<at::MemoryFormat> memory_format) {
244*da0073e9SAndroid Build Coastguard Worker at::Tensor self_;
245*da0073e9SAndroid Build Coastguard Worker if (at::functionalization::impl::isFunctionalTensor(self)) {
246*da0073e9SAndroid Build Coastguard Worker // sync any pending updates
247*da0073e9SAndroid Build Coastguard Worker at::functionalization::impl::sync(self);
248*da0073e9SAndroid Build Coastguard Worker // pass the unwrapped tensor to the backend
249*da0073e9SAndroid Build Coastguard Worker self_ = at::functionalization::impl::from_functional_tensor(self);
250*da0073e9SAndroid Build Coastguard Worker } else {
251*da0073e9SAndroid Build Coastguard Worker self_ = self;
252*da0073e9SAndroid Build Coastguard Worker }
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker at::AutoDispatchSkipFunctionalize guard;
255*da0073e9SAndroid Build Coastguard Worker auto out = at::_to_copy(self_, dtype, layout, device, pin_memory, non_blocking, memory_format);
256*da0073e9SAndroid Build Coastguard Worker
257*da0073e9SAndroid Build Coastguard Worker // Special case: if the Functionalize key is not in TLS, we assume that we're running
258*da0073e9SAndroid Build Coastguard Worker // on a lazy backend (LTC).
259*da0073e9SAndroid Build Coastguard Worker // In that case, if we're copying to a non-functionalize-enabled device,
260*da0073e9SAndroid Build Coastguard Worker // then the functionalization pass should "end". We need to sync any updates on the input
261*da0073e9SAndroid Build Coastguard Worker // tensor, but we shouldn't wrap the output.
262*da0073e9SAndroid Build Coastguard Worker if (!c10::impl::tls_local_dispatch_key_set().included_.has(c10::DispatchKey::Functionalize)) {
263*da0073e9SAndroid Build Coastguard Worker if (!device_opted_into_functionalization(self.device(), device)) {
264*da0073e9SAndroid Build Coastguard Worker return out;
265*da0073e9SAndroid Build Coastguard Worker }
266*da0073e9SAndroid Build Coastguard Worker }
267*da0073e9SAndroid Build Coastguard Worker return at::functionalization::impl::to_functional_tensor(out);
268*da0073e9SAndroid Build Coastguard Worker }
269*da0073e9SAndroid Build Coastguard Worker
270*da0073e9SAndroid Build Coastguard Worker
271*da0073e9SAndroid Build Coastguard Worker // Why is _unsafe_view special-cased here?
272*da0073e9SAndroid Build Coastguard Worker // Basically just to satisfy autograd's debug asserts.
273*da0073e9SAndroid Build Coastguard Worker // The situation:
274*da0073e9SAndroid Build Coastguard Worker // - _unsafe_view's autograd kernel has debug asserts to confirm
275*da0073e9SAndroid Build Coastguard Worker // that the input and output alias storage.
276*da0073e9SAndroid Build Coastguard Worker // - _unsafe_view's schema in native_functions.yaml
277*da0073e9SAndroid Build Coastguard Worker // does not contain alias annotations, so it advertises as non-aliasing.
278*da0073e9SAndroid Build Coastguard Worker // - functionalization will then treat _unsafe_view like a non-aliasing op.
279*da0073e9SAndroid Build Coastguard Worker // Specifically, autograd will redispatch to functionalization's
280*da0073e9SAndroid Build Coastguard Worker // boxed fallback kernel, which creates a new FunctionalTensorWrapper output
281*da0073e9SAndroid Build Coastguard Worker // that does **not** alias storage with the input, tripping the assert.
282*da0073e9SAndroid Build Coastguard Worker // The kernel written here just manually re-ifies the aliasing relationship.
283*da0073e9SAndroid Build Coastguard Worker //
284*da0073e9SAndroid Build Coastguard Worker // Another way to handle this would be to fix unsafe_view's alias annotations
285*da0073e9SAndroid Build Coastguard Worker // in native_functions.yaml, but I think this would be a pessimization.
286*da0073e9SAndroid Build Coastguard Worker // The idea with _unsafe_view is that you're guaranteed that the input
287*da0073e9SAndroid Build Coastguard Worker // is a temporary, and don't actually have to worry about propagating
288*da0073e9SAndroid Build Coastguard Worker // mutations between the input and output.
_unsafe_view_functionalize(const at::Tensor & self,at::SymIntArrayRef size)289*da0073e9SAndroid Build Coastguard Worker static at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::SymIntArrayRef size) {
290*da0073e9SAndroid Build Coastguard Worker if (!at::functionalization::impl::isFunctionalTensor(self)) {
291*da0073e9SAndroid Build Coastguard Worker at::AutoDispatchSkipFunctionalize guard;
292*da0073e9SAndroid Build Coastguard Worker return at::_unsafe_view_symint(self, size);
293*da0073e9SAndroid Build Coastguard Worker }
294*da0073e9SAndroid Build Coastguard Worker
295*da0073e9SAndroid Build Coastguard Worker auto self_ = at::functionalization::impl::from_functional_tensor(self);
296*da0073e9SAndroid Build Coastguard Worker at::Tensor tmp_output;
297*da0073e9SAndroid Build Coastguard Worker {
298*da0073e9SAndroid Build Coastguard Worker at::AutoDispatchSkipFunctionalize guard;
299*da0073e9SAndroid Build Coastguard Worker tmp_output = at::_unsafe_view_symint(self_, size);
300*da0073e9SAndroid Build Coastguard Worker }
301*da0073e9SAndroid Build Coastguard Worker
302*da0073e9SAndroid Build Coastguard Worker bool has_symbolic_inputs = std::any_of(size.begin(), size.end(), [=](auto& s) { return s.is_symbolic(); });
303*da0073e9SAndroid Build Coastguard Worker
304*da0073e9SAndroid Build Coastguard Worker at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
305*da0073e9SAndroid Build Coastguard Worker [size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor {
306*da0073e9SAndroid Build Coastguard Worker return at::_unsafe_view_symint(base, size);
307*da0073e9SAndroid Build Coastguard Worker },
308*da0073e9SAndroid Build Coastguard Worker [size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor {
309*da0073e9SAndroid Build Coastguard Worker return at::_unsafe_view_symint(mutated_view, base.sym_sizes());
310*da0073e9SAndroid Build Coastguard Worker },
311*da0073e9SAndroid Build Coastguard Worker /*has_symbolic_inputs=*/has_symbolic_inputs
312*da0073e9SAndroid Build Coastguard Worker );
313*da0073e9SAndroid Build Coastguard Worker
314*da0073e9SAndroid Build Coastguard Worker auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, self, std::move(view_meta));
315*da0073e9SAndroid Build Coastguard Worker // See Note [Propagating strides in the functionalization pass]
316*da0073e9SAndroid Build Coastguard Worker // (for _unsafe_view, I'm just manually doing the shape inference rule here instead of calling the meta function for unsafe_view)
317*da0073e9SAndroid Build Coastguard Worker auto inferred_size = at::infer_size_dv(size, self.sym_numel());
318*da0073e9SAndroid Build Coastguard Worker auto stride = at::detail::computeStride(self.sym_sizes(), self.sym_strides(), inferred_size);
319*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(stride.has_value());
320*da0073e9SAndroid Build Coastguard Worker out.unsafeGetTensorImpl()->set_sizes_and_strides(inferred_size, stride.value());
321*da0073e9SAndroid Build Coastguard Worker return out;
322*da0073e9SAndroid Build Coastguard Worker }
323*da0073e9SAndroid Build Coastguard Worker
set__functionalize(at::Tensor & self,const at::Tensor & src)324*da0073e9SAndroid Build Coastguard Worker static at::Tensor& set__functionalize(at::Tensor& self, const at::Tensor& src) {
325*da0073e9SAndroid Build Coastguard Worker // error case
326*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(at::functionalization::impl::isFunctionalTensor(self) || !at::functionalization::impl::isFunctionalTensor(src),
327*da0073e9SAndroid Build Coastguard Worker "set__functionalize: Tried to mutate a non-functional tensor with a functional tensor, which is not allowed");
328*da0073e9SAndroid Build Coastguard Worker
329*da0073e9SAndroid Build Coastguard Worker // nop case
330*da0073e9SAndroid Build Coastguard Worker if (!at::functionalization::impl::isFunctionalTensor(self) && !at::functionalization::impl::isFunctionalTensor(src)) {
331*da0073e9SAndroid Build Coastguard Worker at::AutoDispatchSkipFunctionalize guard;
332*da0073e9SAndroid Build Coastguard Worker return self.set_(src);
333*da0073e9SAndroid Build Coastguard Worker }
334*da0073e9SAndroid Build Coastguard Worker
335*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(at::functionalization::impl::isFunctionalTensor(src),
336*da0073e9SAndroid Build Coastguard Worker "set__functionalize: We do not currently support x.set_(y) where y is not a FunctionalTensor. Please file an issue");
337*da0073e9SAndroid Build Coastguard Worker
338*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self));
339*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(src));
340*da0073e9SAndroid Build Coastguard Worker auto self_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self);
341*da0073e9SAndroid Build Coastguard Worker auto src_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(src);
342*da0073e9SAndroid Build Coastguard Worker // See Note [Ordering of resize_() and set_()]
343*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(!self_impl->was_inductor_storage_resized(),
344*da0073e9SAndroid Build Coastguard Worker "storage_resize_() followed by set_() in torch.compile is not supported today");
345*da0073e9SAndroid Build Coastguard Worker self_impl->set__impl(src_impl);
346*da0073e9SAndroid Build Coastguard Worker return self;
347*da0073e9SAndroid Build Coastguard Worker }
348*da0073e9SAndroid Build Coastguard Worker
TORCH_LIBRARY_IMPL(_,Functionalize,m)349*da0073e9SAndroid Build Coastguard Worker TORCH_LIBRARY_IMPL(_, Functionalize, m) {
350*da0073e9SAndroid Build Coastguard Worker m.fallback(torch::CppFunction::makeFromBoxedFunction<&functionalizeFallback>());
351*da0073e9SAndroid Build Coastguard Worker }
352*da0073e9SAndroid Build Coastguard Worker
TORCH_LIBRARY_IMPL(aten,Functionalize,m)353*da0073e9SAndroid Build Coastguard Worker TORCH_LIBRARY_IMPL(aten, Functionalize, m) {
354*da0073e9SAndroid Build Coastguard Worker m.impl("resize_", TORCH_FN(resize__functionalization));
355*da0073e9SAndroid Build Coastguard Worker m.impl("lift", TORCH_FN(lift_functionalize));
356*da0073e9SAndroid Build Coastguard Worker m.impl("lift_fresh", TORCH_FN(lift_fresh_functionalize));
357*da0073e9SAndroid Build Coastguard Worker m.impl("lift_fresh_copy", TORCH_FN(lift_fresh_functionalize_copy));
358*da0073e9SAndroid Build Coastguard Worker m.impl("_to_copy", TORCH_FN(_to_copy_functionalize));
359*da0073e9SAndroid Build Coastguard Worker m.impl("_unsafe_view", TORCH_FN(_unsafe_view_functionalize));
360*da0073e9SAndroid Build Coastguard Worker // The overloads of set_() that take in a storage should never
361*da0073e9SAndroid Build Coastguard Worker // appear with torch.compile, because dynamo graph breaks
362*da0073e9SAndroid Build Coastguard Worker m.impl("set_.source_Tensor", TORCH_FN(set__functionalize));
363*da0073e9SAndroid Build Coastguard Worker }
364