1 #include <ATen/FunctionalStorageImpl.h>
2
3 #include <ATen/EmptyTensor.h>
4 #include <ATen/FunctionalTensorWrapper.h>
5 #include <ATen/SparseCsrTensorUtils.h>
6 #include <ATen/core/LegacyTypeDispatch.h>
7 #include <c10/util/Exception.h>
8 #include <vector>
9
10 namespace at::functionalization {
11
to_out_idx(int64_t out_idx)12 ViewMeta ViewMeta::to_out_idx(int64_t out_idx) {
13 if (out_idx == this->out_index) return *this;
14 return ViewMeta(forward_fn, reverse_fn, has_symbolic_inputs, is_multi_output, is_as_strided, out_idx);
15 }
16
17 // Note [Functionalization: Alias Removal Part 2]
18 // See Note [Functionalization: Alias Removal] for more details.
19 // This function applies a single update from one of the views to the StorageImpl.
20 // We start out with <original_base> and <mutated_view>, and our goal is to end up with <mutated_base>.
21 // Consider this program:
22 //
23 // base = ...
24 // a = base.view1()
25 // b = a.view2()
26 // c = b.view3()
27 // c.add_(3)
28 //
29 // Then the functionalization pass will queue an update as follows:
30 //
31 // update.new_val = c # the updated value of c
32 // update.view_metas = [view1_meta, view2_meta, view3_meta]
33 //
34 // Syncing any of a, b or c will eventually call apply_update() on the storage, and the following will run:
35 //
36 // tmp_values = [base, a, b] # NB: c is not necessary
37 // t = update.new_val
38 // t = view3_inverse(b, t, 0) # 0 is output index, these are all single output views so it's 0
39 // t = view2_inverse(a, t, 0)
40 // t = view1_inverse(base, t, 0) # t now represents the updated storage.
41 // storage.base_ = t
apply_update(const FunctionalStorageImpl::Update & update,const Tensor & base)42 static const Tensor apply_update(const FunctionalStorageImpl::Update& update, const Tensor& base) {
43 at::Tensor t = update.new_val;
44 TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
45 if (update.view_metas.empty()) return t;
46
47 std::vector<at::Tensor> tmp_values({base});
48 tmp_values.reserve(update.view_metas.size());
49 for (size_t i = 0; i < update.view_metas.size() - 1; ++i) {
50 at::Tensor next_view = update.view_metas[i].forward_fn(tmp_values.back(), update.view_metas[i].out_index);
51 // NB: We only actually need tmp_values for ops like select/slice/diagonal/squeeze/as_strided
52 // All of these ops require additional information to recover the sizes of the original tensor.
53 // If need to, we could probably apply this optimization and only bother computing tmp_values
54 // for those necessary view ops.
55 tmp_values.push_back(std::move(next_view));
56 }
57 for(int64_t i = static_cast<int64_t>(update.view_metas.size()) - 1; i >= 0; --i) {
58 int64_t out_idx = update.view_metas[i].out_index;
59 // Each view inverse is implemented in ViewInverses.cpp.
60 t = update.view_metas[i].reverse_fn(tmp_values[i], t, out_idx);
61 }
62 TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
63 return t;
64 }
65
66
get_nbytes(const Tensor & value)67 static c10::SymInt get_nbytes(const Tensor& value) {
68 // The functionalization story when wrapping tensors that don't have storage
69 // is a bit wonky, but fortunately for some models (e.g., dlrm) we never
70 // actually perform mutations on these tensors, so you never really get
71 // called out on it. For now, functionalization still creates "storages"
72 // for these tensors (which is wrong), but we don't give them any space.
73 // A more proper fix would be to have a SparseFunctionalTensorWrapper that
74 // models sparse correctly.
75 if (value.is_sparse() || at::sparse_csr::is_sparse_compressed(value)) {
76 return 0;
77 }
78 if (value.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {
79 // Today, the two implementations of SymInt are in Python (proxy tensor),
80 // and lazy tensor (LTC/XLA).
81 // LTC hasn't implemented SymInt support yet though
82 // Once it does, we should remove this check.
83 if (value.key_set().has(c10::DispatchKey::Python)) {
84 return value.storage().sym_nbytes();
85 }
86 return at::detail::computeStorageNbytes(value.sym_sizes(), value.sym_strides(), value.dtype().itemsize(), value.sym_storage_offset());
87 }
88 // XLA storage objects also do not properly track nbytes.
89 return at::detail::computeStorageNbytes(value.sizes(), value.strides(), value.dtype().itemsize(), value.storage_offset());
90 }
91
FunctionalStorageImpl(const Tensor & base)92 FunctionalStorageImpl::FunctionalStorageImpl(const Tensor& base)
93 : c10::StorageImpl(
94 c10::StorageImpl::use_byte_size_t(),
95 get_nbytes(base),
96 DataPtr{nullptr, base.device()},
97 GetAllocator(kMeta),
98 /*resizable=*/true
99 ),
100 base_(base)
101 {
102 // SparseTensorImpl has no storage, so we cannot query its nbytes.
103 // (original_storage_size is only used for storage resizing in fsdp anyway, which does not apply to sparse)
104 // Same for XLA
105 if (base.unsafeGetTensorImpl()->has_storage() && base.device().type() != c10::DeviceType::XLA) {
106 original_storage_size_ = base.unsafeGetTensorImpl()->unsafe_storage().unsafeGetStorageImpl()->sym_nbytes();
107 } else {
108 original_storage_size_ = -1;
109 }
110 curr_storage_size_ = original_storage_size_;
111 TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(base_));
112 }
113
add_update(const Tensor & updated_val,const std::vector<ViewMeta> & metas)114 void FunctionalStorageImpl::add_update(const Tensor& updated_val, const std::vector<ViewMeta>& metas) {
115 TORCH_CHECK(!frozen_, "cannot mutate tensors with frozen storage");
116
117 if (metas.size() > 1) {
118 for (size_t i = 1; i < metas.size(); ++i) {
119 // Skipping this check for XLA. Would be good to add it back, but it is failing XLA CI
120 TORCH_CHECK(updated_val.device().type() == c10::DeviceType::XLA || !metas[i].is_as_strided,
121 "During torch.compile, encountered a mutation on a view chain of length ", metas.size(), ", where view ", i,
122 " was an as_strided() call. as_strided() is non-compositional, and therefore is not possible to functionalize properly today,"
123 "so this behavior is banned in compile. As a workaround, you can either remove the mutation from the model code, or you "
124 "can insert a graph break right before the mutation with torch._dynamo.graph_break(). If you would like this behavior to "
125 "work properly, please comment on https://github.com/pytorch/pytorch/issues/104505.");
126 }
127 }
128 updates_.push_back({updated_val, metas});
129 generation_++;
130 }
131
apply_updates()132 bool FunctionalStorageImpl::apply_updates() {
133 // N.B:none of the tensors used in this function should be FunctionalTensorWrappers at this point.
134 // The only reason we currently need the TLS exclude guard here is because of functorch's DynamicLayer stack.
135 // It adds the Functionalize key into TLS before redispatching to the functionalization kernels,
136 // which means that we need to explicitly exclude it here before doing any other work underneath the pass.
137 at::AutoDispatchSkipFunctionalize guard;
138 bool any_updates = !updates_.empty();
139 for (auto& update_data: updates_) {
140 base_ = apply_update(update_data, base_);
141 }
142 updates_.clear();
143 return any_updates;
144 }
145
146 } // namespace at::functionalization
147