xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/nested/NestedTensorFactories.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ATen.h>
2 #include <ATen/NestedTensorImpl.h>
3 #include <ATen/native/nested/NestedTensorUtils.h>
4 
5 
6 namespace at::native {
7 
verify_empty_parameters(const at::Tensor & self,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory,std::optional<c10::MemoryFormat> optional_memory_format)8 static TensorOptions verify_empty_parameters(
9     const at::Tensor& self,
10     std::optional<ScalarType> dtype,
11     std::optional<Layout> layout,
12     std::optional<Device> device,
13     std::optional<bool> pin_memory,
14     std::optional<c10::MemoryFormat> optional_memory_format) {
15   TensorOptions options_ = TensorOptions()
16                                .dtype(dtype)
17                                .layout(layout)
18                                .device(device)
19                                .pinned_memory(pin_memory)
20                                .memory_format(optional_memory_format);
21 
22   TensorOptions options = self.options().merge_in(options_);
23   auto memory_format =
24       options_.memory_format_opt().value_or(MemoryFormat::Preserve);
25   TORCH_CHECK(
26       memory_format == MemoryFormat::Preserve || memory_format == MemoryFormat::Contiguous,
27       "empty_like_nested only supports memory format Preserve or Contiguous, but got ",
28       memory_format,
29       " instead.");
30 
31   TORCH_CHECK(
32       !(options.layout() != kStrided && optional_memory_format.has_value()),
33       "memory format option is only supported by strided tensors");
34   return options;
35 }
36 
empty_like_nested(const Tensor & self,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory,std::optional<c10::MemoryFormat> optional_memory_format)37 Tensor empty_like_nested(
38     const Tensor& self,
39     std::optional<ScalarType> dtype,
40     std::optional<Layout> layout,
41     std::optional<Device> device,
42     std::optional<bool> pin_memory,
43     std::optional<c10::MemoryFormat> optional_memory_format) {
44   auto options = verify_empty_parameters(
45       self, dtype, layout, device, pin_memory, optional_memory_format);
46   auto self_nt = get_nested_tensor_impl(self);
47   auto memory_format = options.memory_format_opt().value_or(MemoryFormat::Preserve);
48   if (memory_format == MemoryFormat::Contiguous) {
49     auto nested_size = self_nt->get_nested_sizes().clone();
50     int64_t buffer_size = get_numel_from_nested_size_tensor(nested_size);
51     Tensor new_buffer = at::empty({buffer_size}, options);
52     auto tensor = wrap_buffer(new_buffer, nested_size);
53     return tensor;
54   }
55   // The fall through path must be Preserve
56   TORCH_CHECK(
57       memory_format == MemoryFormat::Preserve,
58       "memory format option is only supported by strided tensors");
59   // Since we clone sizes, strides, and offsets it should be safe to use
60   // get_unsafe_storage_as_tensor for the call to empty_like.
61   Tensor new_buffer =
62       at::empty_like(self_nt->get_unsafe_storage_as_tensor(), options);
63   auto nested_size = self_nt->get_nested_sizes().clone();
64   auto nested_strides = self_nt->get_nested_strides().clone();
65   auto offsets = self_nt->get_storage_offsets().clone();
66   auto tensor = wrap_buffer(new_buffer, nested_size, nested_strides, offsets);
67   return tensor;
68 }
69 
70 // Take a Device that may not have device_index set (i.e., having it as -1
71 // representing the current device) and return the corresponding Device
72 // according to the actual device at the time of this function call.  No-op
73 // if the device_index is set.
ensure_has_index(Device device)74 static inline Device ensure_has_index(Device device) {
75   if (device.is_cpu() || device.has_index()) {
76     return device;
77   }
78   const c10::impl::DeviceGuardImplInterface* impl =
79       c10::impl::getDeviceGuardImpl(device.type());
80   return impl->getDevice();
81 }
82 
_to_copy_nested(const Tensor & self,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory,bool non_blocking,std::optional<c10::MemoryFormat> optional_memory_format)83 Tensor _to_copy_nested(
84     const Tensor& self,
85     std::optional<ScalarType> dtype,
86     std::optional<Layout> layout,
87     std::optional<Device> device,
88     std::optional<bool> pin_memory,
89     bool non_blocking,
90     std::optional<c10::MemoryFormat> optional_memory_format) {
91   TORCH_CHECK(
92       !layout.has_value() || self.layout() == layout.value(),
93       "to(options) doesn't support converting to a different layout, "
94       "but got self.layout being ",
95       self.layout(),
96       " and options.layout set as ",
97       layout.value());
98   auto options =
99       TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(
100           pin_memory);
101 
102   if (options.has_device()) {
103     options = options.device(ensure_has_index(options.device()));
104   }
105   // memory_format is handled separately due to MemoryFormat::Preserve logic
106   options = self.options().merge_in(options).memory_format(std::nullopt);
107   auto memory_format = optional_memory_format.value_or(MemoryFormat::Preserve);
108 
109   bool pin_out =
110       (non_blocking && self.is_cuda() && options.device().is_cpu() &&
111        (options.layout() == c10::kStrided));
112 
113   Tensor r;
114   r = at::empty_like(self, dtype, layout, device, pin_out, memory_format);
115   get_nested_tensor_impl(r)->get_buffer().copy_(
116       get_nested_tensor_impl(self)->get_buffer(), non_blocking);
117   return r;
118 }
119 
copy_nested_(Tensor & self,const Tensor & src,bool non_blocking)120 Tensor& copy_nested_(Tensor& self, const Tensor& src, bool non_blocking) {
121   const auto* nt_self = get_nested_tensor_impl(self);
122   const auto* nt_src = get_nested_tensor_impl(src);
123   TORCH_CHECK(
124       at::equal(
125           nt_self->get_nested_sizes(), nt_src->get_nested_sizes()),
126       "copy_ only supports tensors that are the same size for Nested implementations");
127   nt_self->get_buffer().copy_(nt_src->get_buffer(), non_blocking);
128   return self;
129 }
130 
131 
clone_nested(const Tensor & self,std::optional<c10::MemoryFormat> optional_memory_format)132 Tensor clone_nested(
133     const Tensor& self,
134     std::optional<c10::MemoryFormat> optional_memory_format) {
135   auto memory_format = optional_memory_format.value_or(c10::MemoryFormat::Preserve);
136   auto self_ptr = get_nested_tensor_impl(self);
137   if (memory_format == c10::MemoryFormat::Preserve ||
138   (memory_format == c10::MemoryFormat::Contiguous && self.is_contiguous())) {
139     const Tensor& buffer = self_ptr->get_unsafe_storage_as_tensor(),
140         sizemat = self_ptr->get_nested_sizes(),
141         stridemat = self_ptr->get_nested_strides();
142     const auto& offsets = self_ptr->get_storage_offsets();
143     // TODO: The size and the stride do not necessarily need to be cloned,
144     //       but it is more conservative.
145     //       This is something we could revisit once we land a more
146     //       efficient implementation of nested_sizes_ and nested_strides.
147     return wrap_buffer(buffer.clone(), sizemat.clone(), stridemat.clone(), offsets.clone());
148   }
149   // actually, memory format is contiguous and self is noncontiguous
150   else if (memory_format == c10::MemoryFormat::Contiguous) {
151     const Tensor& self_buffer = self_ptr->get_unsafe_storage_as_tensor(),
152         sizemat = self_ptr->get_nested_sizes();
153     Tensor output_buffer = at::empty(self.numel(), self_buffer.options());
154     Tensor output = wrap_buffer(output_buffer, sizemat);
155     std::vector<Tensor> self_unbind = self.unbind(),
156         output_unbind = output.unbind();
157     for (const int64_t i: c10::irange(self_ptr->size(0))) {
158       output_unbind[i].copy_(self_unbind[i]);
159     }
160     return output;
161   } else {
162     TORCH_CHECK(
163         false,
164         "Nested tensor clone supports Preserve and Contiguous memory formats, called clone with memory format: ",
165         memory_format);
166   }
167 }
168 
NestedTensor_unbind(const at::Tensor & self,int64_t dim)169 std::vector<at::Tensor> NestedTensor_unbind(
170     const at::Tensor& self,
171     int64_t dim) {
172   TORCH_CHECK(
173       dim == 0,
174       "NestedTensor can only be unbound along dimension 0 ",
175       "got dimension ",
176       dim,
177       " instead.");
178   auto self_ptr = get_nested_tensor_impl(self);
179   int64_t ntensors = self_ptr->size(0);
180   std::vector<at::Tensor> result_tensors(ntensors);
181   if (ntensors == 0) {
182     return result_tensors;
183   }
184   // This returns a differentiable view of self as a regular tensor
185   auto buffer = self.values();
186   std::vector<IntArrayRef> sizes = NestedTensor_get_sizes(self_ptr),
187       strides = NestedTensor_get_strides(self_ptr);
188   int64_t *offsets_ptr = self_ptr->get_storage_offsets().data_ptr<int64_t>();
189   for (const int64_t i: c10::irange(ntensors)){
190     result_tensors[i] = buffer.as_strided(sizes[i], strides[i], offsets_ptr[i]);
191   }
192   return result_tensors;
193 }
194 
195 // NOLINTNEXTLINE(performance-unnecessary-value-param)
narrow_nested_symint(const at::Tensor & self,int64_t dim,SymInt start,SymInt length)196 Tensor narrow_nested_symint(const at::Tensor& self, int64_t dim, SymInt start, SymInt length) {
197   TORCH_CHECK(dim == 0, "narrow(): only dim=0 supported for nested tensors, but got: ", dim);
198   TORCH_SYM_CHECK(length.sym_ge(0), "narrow(): length must be non-negative");
199   auto cur_size = self.sym_size(dim);
200   TORCH_CHECK_INDEX(
201       ((-cur_size).sym_le(start).sym_and(start.sym_le(cur_size))).expect_true(__FILE__, __LINE__),
202       "start out of range (expected to be in range of [", -cur_size, ", ", cur_size, "], but got ",
203       start, ")");
204   if (start < 0) {
205     start = start + cur_size;
206   }
207   TORCH_SYM_CHECK(start.sym_le(cur_size - length),
208       "start (", start, ") + length (", length, ") exceeds dimension size (", cur_size, ").");
209   auto *nt_impl = get_nested_tensor_impl(self);
210   TORCH_CHECK(self.is_contiguous(), "narrow(): only contiguous nested tensors supported");
211   auto buffer = nt_impl->get_unsafe_storage_as_tensor();
212   auto nested_sizes = nt_impl->get_nested_sizes();
213   auto nested_strides = nt_impl->get_nested_strides();
214   auto storage_offsets = nt_impl->get_storage_offsets();
215   auto storage_offsets_ptr = storage_offsets.data_ptr<int64_t>();
216 
217   auto start_int = start.guard_int(__FILE__, __LINE__);
218   auto length_int = length.guard_int(__FILE__, __LINE__);
219   auto buffer_offset = storage_offsets_ptr[start_int];
220 
221   nested_sizes = nested_sizes.narrow(0, start_int, length_int);
222   nested_strides = nested_strides.narrow(0, start_int, length_int);
223   storage_offsets = storage_offsets.narrow(0, start_int, length_int);
224 
225   return at::detail::make_tensor<NestedTensorImpl>(
226       c10::TensorImpl::VIEW,
227       buffer.narrow(0, buffer_offset, buffer.numel() - buffer_offset),
228       nested_sizes,
229       nested_strides,
230       storage_offsets);
231 }
232 
alias_nested(const Tensor & self)233 Tensor alias_nested(const Tensor& self) {
234   auto* nt_impl = get_nested_tensor_impl(self);
235   auto buffer = nt_impl->get_unsafe_storage_as_tensor();
236   const auto& nested_sizes = nt_impl->get_nested_sizes();
237   const auto& nested_strides = nt_impl->get_nested_strides();
238   const auto& storage_offsets = nt_impl->get_storage_offsets();
239   return at::detail::make_tensor<NestedTensorImpl>(
240       c10::TensorImpl::VIEW,
241       std::move(buffer),
242       nested_sizes,
243       nested_strides,
244       storage_offsets);
245 }
246 
247 } // namespace at::native
248