xref: /aosp_15_r20/external/pytorch/aten/src/ATen/FunctionalStorageImpl.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/FunctionalStorageImpl.h>
2 
3 #include <ATen/EmptyTensor.h>
4 #include <ATen/FunctionalTensorWrapper.h>
5 #include <ATen/SparseCsrTensorUtils.h>
6 #include <ATen/core/LegacyTypeDispatch.h>
7 #include <c10/util/Exception.h>
8 #include <vector>
9 
10 namespace at::functionalization {
11 
to_out_idx(int64_t out_idx)12 ViewMeta ViewMeta::to_out_idx(int64_t out_idx) {
13   if (out_idx == this->out_index) return *this;
14   return ViewMeta(forward_fn, reverse_fn, has_symbolic_inputs, is_multi_output, is_as_strided, out_idx);
15 }
16 
17 // Note [Functionalization: Alias Removal Part 2]
18 // See Note [Functionalization: Alias Removal] for more details.
19 // This function applies a single update from one of the views to the StorageImpl.
20 // We start out with <original_base> and <mutated_view>, and our goal is to end up with <mutated_base>.
21 // Consider this program:
22 //
23 // base = ...
24 // a = base.view1()
25 // b = a.view2()
26 // c = b.view3()
27 // c.add_(3)
28 //
29 // Then the functionalization pass will queue an update as follows:
30 //
31 // update.new_val = c  # the updated value of c
32 // update.view_metas = [view1_meta, view2_meta, view3_meta]
33 //
34 // Syncing any of a, b or c will eventually call apply_update() on the storage, and the following will run:
35 //
36 // tmp_values = [base, a, b]  # NB: c is not necessary
37 // t = update.new_val
38 // t = view3_inverse(b, t, 0)  # 0 is output index, these are all single output views so it's 0
39 // t = view2_inverse(a, t, 0)
40 // t = view1_inverse(base, t, 0)  # t now represents the updated storage.
41 // storage.base_ = t
apply_update(const FunctionalStorageImpl::Update & update,const Tensor & base)42 static const Tensor apply_update(const FunctionalStorageImpl::Update& update, const Tensor& base) {
43   at::Tensor t = update.new_val;
44   TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
45   if (update.view_metas.empty()) return t;
46 
47   std::vector<at::Tensor> tmp_values({base});
48   tmp_values.reserve(update.view_metas.size());
49   for (size_t i = 0; i < update.view_metas.size() - 1; ++i) {
50     at::Tensor next_view = update.view_metas[i].forward_fn(tmp_values.back(), update.view_metas[i].out_index);
51     // NB: We only actually need tmp_values for ops like select/slice/diagonal/squeeze/as_strided
52     // All of these ops require additional information to recover the sizes of the original tensor.
53     // If need to, we could probably apply this optimization and only bother computing tmp_values
54     // for those necessary view ops.
55     tmp_values.push_back(std::move(next_view));
56   }
57   for(int64_t i = static_cast<int64_t>(update.view_metas.size()) - 1; i >= 0; --i) {
58     int64_t out_idx = update.view_metas[i].out_index;
59     // Each view inverse is implemented in ViewInverses.cpp.
60     t = update.view_metas[i].reverse_fn(tmp_values[i], t, out_idx);
61   }
62   TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
63   return t;
64 }
65 
66 
get_nbytes(const Tensor & value)67 static c10::SymInt get_nbytes(const Tensor& value) {
68   // The functionalization story when wrapping tensors that don't have storage
69   // is a bit wonky, but fortunately for some models (e.g., dlrm) we never
70   // actually perform mutations on these tensors, so you never really get
71   // called out on it.  For now, functionalization still creates "storages"
72   // for these tensors (which is wrong), but we don't give them any space.
73   // A more proper fix would be to have a SparseFunctionalTensorWrapper that
74   // models sparse correctly.
75   if (value.is_sparse() || at::sparse_csr::is_sparse_compressed(value)) {
76     return 0;
77   }
78   if (value.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {
79     // Today, the two implementations of SymInt are in Python (proxy tensor),
80     // and lazy tensor (LTC/XLA).
81     // LTC hasn't implemented SymInt support yet though
82     // Once it does, we should remove this check.
83     if (value.key_set().has(c10::DispatchKey::Python)) {
84       return value.storage().sym_nbytes();
85     }
86     return at::detail::computeStorageNbytes(value.sym_sizes(), value.sym_strides(), value.dtype().itemsize(), value.sym_storage_offset());
87   }
88   // XLA storage objects also do not properly track nbytes.
89   return at::detail::computeStorageNbytes(value.sizes(), value.strides(), value.dtype().itemsize(), value.storage_offset());
90 }
91 
FunctionalStorageImpl(const Tensor & base)92 FunctionalStorageImpl::FunctionalStorageImpl(const Tensor& base)
93   : c10::StorageImpl(
94       c10::StorageImpl::use_byte_size_t(),
95       get_nbytes(base),
96       DataPtr{nullptr, base.device()},
97       GetAllocator(kMeta),
98       /*resizable=*/true
99     ),
100     base_(base)
101 {
102   // SparseTensorImpl has no storage, so we cannot query its nbytes.
103   // (original_storage_size is only used for storage resizing in fsdp anyway, which does not apply to sparse)
104   // Same for XLA
105   if (base.unsafeGetTensorImpl()->has_storage() && base.device().type() != c10::DeviceType::XLA) {
106     original_storage_size_ = base.unsafeGetTensorImpl()->unsafe_storage().unsafeGetStorageImpl()->sym_nbytes();
107   } else {
108     original_storage_size_ = -1;
109   }
110   curr_storage_size_ = original_storage_size_;
111   TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(base_));
112 }
113 
add_update(const Tensor & updated_val,const std::vector<ViewMeta> & metas)114 void FunctionalStorageImpl::add_update(const Tensor& updated_val, const std::vector<ViewMeta>& metas) {
115   TORCH_CHECK(!frozen_, "cannot mutate tensors with frozen storage");
116 
117   if (metas.size() > 1) {
118     for (size_t i = 1; i < metas.size(); ++i) {
119       // Skipping this check for XLA. Would be good to add it back, but it is failing XLA CI
120       TORCH_CHECK(updated_val.device().type() == c10::DeviceType::XLA || !metas[i].is_as_strided,
121 "During torch.compile, encountered a mutation on a view chain of length ", metas.size(), ", where view ", i,
122 " was an as_strided() call. as_strided() is non-compositional, and therefore is not possible to functionalize properly today,"
123 "so this behavior is banned in compile. As a workaround, you can either remove the mutation from the model code, or you "
124 "can insert a graph break right before the mutation with torch._dynamo.graph_break(). If you would like this behavior to "
125 "work properly, please comment on https://github.com/pytorch/pytorch/issues/104505.");
126     }
127   }
128   updates_.push_back({updated_val, metas});
129   generation_++;
130 }
131 
apply_updates()132 bool FunctionalStorageImpl::apply_updates() {
133   // N.B:none of the tensors used in this function should be FunctionalTensorWrappers at this point.
134   // The only reason we currently need the TLS exclude guard here is because of functorch's DynamicLayer stack.
135   // It adds the Functionalize key into TLS before redispatching to the functionalization kernels,
136   // which means that we need to explicitly exclude it here before doing any other work underneath the pass.
137   at::AutoDispatchSkipFunctionalize guard;
138   bool any_updates = !updates_.empty();
139   for (auto& update_data: updates_) {
140     base_ = apply_update(update_data, base_);
141   }
142   updates_.clear();
143   return any_updates;
144 }
145 
146 } // namespace at::functionalization
147