1 #include <ATen/ATen.h>
2 #include <ATen/NestedTensorImpl.h>
3 #include <ATen/native/nested/NestedTensorUtils.h>
4
5
6 namespace at::native {
7
verify_empty_parameters(const at::Tensor & self,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory,std::optional<c10::MemoryFormat> optional_memory_format)8 static TensorOptions verify_empty_parameters(
9 const at::Tensor& self,
10 std::optional<ScalarType> dtype,
11 std::optional<Layout> layout,
12 std::optional<Device> device,
13 std::optional<bool> pin_memory,
14 std::optional<c10::MemoryFormat> optional_memory_format) {
15 TensorOptions options_ = TensorOptions()
16 .dtype(dtype)
17 .layout(layout)
18 .device(device)
19 .pinned_memory(pin_memory)
20 .memory_format(optional_memory_format);
21
22 TensorOptions options = self.options().merge_in(options_);
23 auto memory_format =
24 options_.memory_format_opt().value_or(MemoryFormat::Preserve);
25 TORCH_CHECK(
26 memory_format == MemoryFormat::Preserve || memory_format == MemoryFormat::Contiguous,
27 "empty_like_nested only supports memory format Preserve or Contiguous, but got ",
28 memory_format,
29 " instead.");
30
31 TORCH_CHECK(
32 !(options.layout() != kStrided && optional_memory_format.has_value()),
33 "memory format option is only supported by strided tensors");
34 return options;
35 }
36
empty_like_nested(const Tensor & self,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory,std::optional<c10::MemoryFormat> optional_memory_format)37 Tensor empty_like_nested(
38 const Tensor& self,
39 std::optional<ScalarType> dtype,
40 std::optional<Layout> layout,
41 std::optional<Device> device,
42 std::optional<bool> pin_memory,
43 std::optional<c10::MemoryFormat> optional_memory_format) {
44 auto options = verify_empty_parameters(
45 self, dtype, layout, device, pin_memory, optional_memory_format);
46 auto self_nt = get_nested_tensor_impl(self);
47 auto memory_format = options.memory_format_opt().value_or(MemoryFormat::Preserve);
48 if (memory_format == MemoryFormat::Contiguous) {
49 auto nested_size = self_nt->get_nested_sizes().clone();
50 int64_t buffer_size = get_numel_from_nested_size_tensor(nested_size);
51 Tensor new_buffer = at::empty({buffer_size}, options);
52 auto tensor = wrap_buffer(new_buffer, nested_size);
53 return tensor;
54 }
55 // The fall through path must be Preserve
56 TORCH_CHECK(
57 memory_format == MemoryFormat::Preserve,
58 "memory format option is only supported by strided tensors");
59 // Since we clone sizes, strides, and offsets it should be safe to use
60 // get_unsafe_storage_as_tensor for the call to empty_like.
61 Tensor new_buffer =
62 at::empty_like(self_nt->get_unsafe_storage_as_tensor(), options);
63 auto nested_size = self_nt->get_nested_sizes().clone();
64 auto nested_strides = self_nt->get_nested_strides().clone();
65 auto offsets = self_nt->get_storage_offsets().clone();
66 auto tensor = wrap_buffer(new_buffer, nested_size, nested_strides, offsets);
67 return tensor;
68 }
69
70 // Take a Device that may not have device_index set (i.e., having it as -1
71 // representing the current device) and return the corresponding Device
72 // according to the actual device at the time of this function call. No-op
73 // if the device_index is set.
ensure_has_index(Device device)74 static inline Device ensure_has_index(Device device) {
75 if (device.is_cpu() || device.has_index()) {
76 return device;
77 }
78 const c10::impl::DeviceGuardImplInterface* impl =
79 c10::impl::getDeviceGuardImpl(device.type());
80 return impl->getDevice();
81 }
82
_to_copy_nested(const Tensor & self,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory,bool non_blocking,std::optional<c10::MemoryFormat> optional_memory_format)83 Tensor _to_copy_nested(
84 const Tensor& self,
85 std::optional<ScalarType> dtype,
86 std::optional<Layout> layout,
87 std::optional<Device> device,
88 std::optional<bool> pin_memory,
89 bool non_blocking,
90 std::optional<c10::MemoryFormat> optional_memory_format) {
91 TORCH_CHECK(
92 !layout.has_value() || self.layout() == layout.value(),
93 "to(options) doesn't support converting to a different layout, "
94 "but got self.layout being ",
95 self.layout(),
96 " and options.layout set as ",
97 layout.value());
98 auto options =
99 TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(
100 pin_memory);
101
102 if (options.has_device()) {
103 options = options.device(ensure_has_index(options.device()));
104 }
105 // memory_format is handled separately due to MemoryFormat::Preserve logic
106 options = self.options().merge_in(options).memory_format(std::nullopt);
107 auto memory_format = optional_memory_format.value_or(MemoryFormat::Preserve);
108
109 bool pin_out =
110 (non_blocking && self.is_cuda() && options.device().is_cpu() &&
111 (options.layout() == c10::kStrided));
112
113 Tensor r;
114 r = at::empty_like(self, dtype, layout, device, pin_out, memory_format);
115 get_nested_tensor_impl(r)->get_buffer().copy_(
116 get_nested_tensor_impl(self)->get_buffer(), non_blocking);
117 return r;
118 }
119
copy_nested_(Tensor & self,const Tensor & src,bool non_blocking)120 Tensor& copy_nested_(Tensor& self, const Tensor& src, bool non_blocking) {
121 const auto* nt_self = get_nested_tensor_impl(self);
122 const auto* nt_src = get_nested_tensor_impl(src);
123 TORCH_CHECK(
124 at::equal(
125 nt_self->get_nested_sizes(), nt_src->get_nested_sizes()),
126 "copy_ only supports tensors that are the same size for Nested implementations");
127 nt_self->get_buffer().copy_(nt_src->get_buffer(), non_blocking);
128 return self;
129 }
130
131
clone_nested(const Tensor & self,std::optional<c10::MemoryFormat> optional_memory_format)132 Tensor clone_nested(
133 const Tensor& self,
134 std::optional<c10::MemoryFormat> optional_memory_format) {
135 auto memory_format = optional_memory_format.value_or(c10::MemoryFormat::Preserve);
136 auto self_ptr = get_nested_tensor_impl(self);
137 if (memory_format == c10::MemoryFormat::Preserve ||
138 (memory_format == c10::MemoryFormat::Contiguous && self.is_contiguous())) {
139 const Tensor& buffer = self_ptr->get_unsafe_storage_as_tensor(),
140 sizemat = self_ptr->get_nested_sizes(),
141 stridemat = self_ptr->get_nested_strides();
142 const auto& offsets = self_ptr->get_storage_offsets();
143 // TODO: The size and the stride do not necessarily need to be cloned,
144 // but it is more conservative.
145 // This is something we could revisit once we land a more
146 // efficient implementation of nested_sizes_ and nested_strides.
147 return wrap_buffer(buffer.clone(), sizemat.clone(), stridemat.clone(), offsets.clone());
148 }
149 // actually, memory format is contiguous and self is noncontiguous
150 else if (memory_format == c10::MemoryFormat::Contiguous) {
151 const Tensor& self_buffer = self_ptr->get_unsafe_storage_as_tensor(),
152 sizemat = self_ptr->get_nested_sizes();
153 Tensor output_buffer = at::empty(self.numel(), self_buffer.options());
154 Tensor output = wrap_buffer(output_buffer, sizemat);
155 std::vector<Tensor> self_unbind = self.unbind(),
156 output_unbind = output.unbind();
157 for (const int64_t i: c10::irange(self_ptr->size(0))) {
158 output_unbind[i].copy_(self_unbind[i]);
159 }
160 return output;
161 } else {
162 TORCH_CHECK(
163 false,
164 "Nested tensor clone supports Preserve and Contiguous memory formats, called clone with memory format: ",
165 memory_format);
166 }
167 }
168
NestedTensor_unbind(const at::Tensor & self,int64_t dim)169 std::vector<at::Tensor> NestedTensor_unbind(
170 const at::Tensor& self,
171 int64_t dim) {
172 TORCH_CHECK(
173 dim == 0,
174 "NestedTensor can only be unbound along dimension 0 ",
175 "got dimension ",
176 dim,
177 " instead.");
178 auto self_ptr = get_nested_tensor_impl(self);
179 int64_t ntensors = self_ptr->size(0);
180 std::vector<at::Tensor> result_tensors(ntensors);
181 if (ntensors == 0) {
182 return result_tensors;
183 }
184 // This returns a differentiable view of self as a regular tensor
185 auto buffer = self.values();
186 std::vector<IntArrayRef> sizes = NestedTensor_get_sizes(self_ptr),
187 strides = NestedTensor_get_strides(self_ptr);
188 int64_t *offsets_ptr = self_ptr->get_storage_offsets().data_ptr<int64_t>();
189 for (const int64_t i: c10::irange(ntensors)){
190 result_tensors[i] = buffer.as_strided(sizes[i], strides[i], offsets_ptr[i]);
191 }
192 return result_tensors;
193 }
194
195 // NOLINTNEXTLINE(performance-unnecessary-value-param)
narrow_nested_symint(const at::Tensor & self,int64_t dim,SymInt start,SymInt length)196 Tensor narrow_nested_symint(const at::Tensor& self, int64_t dim, SymInt start, SymInt length) {
197 TORCH_CHECK(dim == 0, "narrow(): only dim=0 supported for nested tensors, but got: ", dim);
198 TORCH_SYM_CHECK(length.sym_ge(0), "narrow(): length must be non-negative");
199 auto cur_size = self.sym_size(dim);
200 TORCH_CHECK_INDEX(
201 ((-cur_size).sym_le(start).sym_and(start.sym_le(cur_size))).expect_true(__FILE__, __LINE__),
202 "start out of range (expected to be in range of [", -cur_size, ", ", cur_size, "], but got ",
203 start, ")");
204 if (start < 0) {
205 start = start + cur_size;
206 }
207 TORCH_SYM_CHECK(start.sym_le(cur_size - length),
208 "start (", start, ") + length (", length, ") exceeds dimension size (", cur_size, ").");
209 auto *nt_impl = get_nested_tensor_impl(self);
210 TORCH_CHECK(self.is_contiguous(), "narrow(): only contiguous nested tensors supported");
211 auto buffer = nt_impl->get_unsafe_storage_as_tensor();
212 auto nested_sizes = nt_impl->get_nested_sizes();
213 auto nested_strides = nt_impl->get_nested_strides();
214 auto storage_offsets = nt_impl->get_storage_offsets();
215 auto storage_offsets_ptr = storage_offsets.data_ptr<int64_t>();
216
217 auto start_int = start.guard_int(__FILE__, __LINE__);
218 auto length_int = length.guard_int(__FILE__, __LINE__);
219 auto buffer_offset = storage_offsets_ptr[start_int];
220
221 nested_sizes = nested_sizes.narrow(0, start_int, length_int);
222 nested_strides = nested_strides.narrow(0, start_int, length_int);
223 storage_offsets = storage_offsets.narrow(0, start_int, length_int);
224
225 return at::detail::make_tensor<NestedTensorImpl>(
226 c10::TensorImpl::VIEW,
227 buffer.narrow(0, buffer_offset, buffer.numel() - buffer_offset),
228 nested_sizes,
229 nested_strides,
230 storage_offsets);
231 }
232
alias_nested(const Tensor & self)233 Tensor alias_nested(const Tensor& self) {
234 auto* nt_impl = get_nested_tensor_impl(self);
235 auto buffer = nt_impl->get_unsafe_storage_as_tensor();
236 const auto& nested_sizes = nt_impl->get_nested_sizes();
237 const auto& nested_strides = nt_impl->get_nested_strides();
238 const auto& storage_offsets = nt_impl->get_storage_offsets();
239 return at::detail::make_tensor<NestedTensorImpl>(
240 c10::TensorImpl::VIEW,
241 std::move(buffer),
242 nested_sizes,
243 nested_strides,
244 storage_offsets);
245 }
246
247 } // namespace at::native
248