1 #include <ATen/native/nested/NestedTensorMath.h> 2 #include <ATen/native/nested/NestedTensorUtils.h> 3 4 #include <ATen/AccumulateType.h> 5 #include <ATen/Dispatch.h> 6 #include <ATen/Functions.h> 7 #include <ATen/NativeFunctions.h> 8 #include <ATen/NestedTensorImpl.h> 9 #include <ATen/ScalarOps.h> 10 #include <ATen/TensorIndexing.h> 11 #include <ATen/TensorOperators.h> 12 #include <ATen/TensorUtils.h> 13 #include <ATen/core/Tensor.h> 14 #include <ATen/core/grad_mode.h> 15 #include <ATen/native/layer_norm.h> 16 #include <ATen/native/nested/NestedTensorUtils.h> 17 18 namespace at::native { 19 bmm_nested(const Tensor & self,const Tensor & mat2)20 Tensor bmm_nested(const Tensor& self, const Tensor& mat2) { 21 TORCH_CHECK(self.dim() == 3, "batch1 must be a 3D tensor"); 22 TORCH_CHECK(mat2.dim() == 3, "batch2 must be a 3D tensor"); 23 24 int64_t ntensors = self.is_nested() ? get_nested_tensor_impl(self)->size(0) : self.size(0); 25 int64_t ntensors2 = mat2.is_nested() ? get_nested_tensor_impl(mat2)->size(0) : mat2.size(0); 26 27 TORCH_CHECK(ntensors == ntensors2, 28 "Expected size for the 1st dimension of batch2 tensor to be: ", ntensors, 29 " but got: ", ntensors2, "."); 30 31 const Tensor& self_buffer = self.is_nested() ? get_nested_tensor_impl(self)->get_unsafe_storage_as_tensor() : self; 32 const Tensor& mat2_buffer = mat2.is_nested() ? get_nested_tensor_impl(mat2)->get_unsafe_storage_as_tensor() : mat2; 33 34 35 // create a contiguous output 36 int64_t out_numel = 0; 37 const Tensor& self_sizemat = self.is_nested() ? 38 get_nested_tensor_impl(self)->get_nested_sizes() : get_nested_tensor_impl(mat2)->get_nested_sizes(); 39 40 Tensor out_sizemat = self_sizemat.new_empty(self_sizemat.sizes()); 41 int64_t* out_sizemat_ptr = out_sizemat.data_ptr<int64_t>(); 42 for (int64_t i = 0; i < ntensors; i++) { 43 const IntArrayRef& self_shape = get_size_for_index(self, i); 44 const IntArrayRef& mat2_shape = get_size_for_index(mat2, i); 45 const int64_t& self_size0 = self_shape[0], & self_size1 = self_shape[1], 46 & mat2_size0 = mat2_shape[0], & mat2_size1 = mat2_shape[1]; 47 TORCH_CHECK(self_size1 == mat2_size0, 48 i, "-th nested matrices in batch cannot be multiplied (", 49 self_size0, "x", self_size1, " and ", 50 mat2_size0, "x", mat2_size1, ")"); 51 out_sizemat_ptr[0] = self_size0; 52 out_sizemat_ptr[1] = mat2_size1; 53 out_sizemat_ptr += 2; 54 out_numel += self_size0 * mat2_size1; 55 } 56 Tensor out_buffer = self.is_nested() ? self_buffer.new_empty(out_numel) : mat2_buffer.new_empty(out_numel); 57 Tensor output = wrap_buffer(out_buffer, out_sizemat); 58 // call tensor mm 59 // TODO: `padding nested tensor -> bmm -> remove padding` may be more efficient 60 // until we have specialized nested tensor bmm kernel 61 // useful resource: `aten/src/ATen/native/cpu/LinearAlgebra.cpp/bmm_out_or_baddbmm_` 62 // `aten/src/ATen/native/cuda/Blas.cpp/baddbmm_out_cuda_impl` 63 std::vector<Tensor> output_unbind = output.unbind(); 64 for (int64_t i = 0; i < ntensors; i++) { 65 at::mm_out(output_unbind[i], 66 self_buffer.as_strided(get_size_for_index(self, i), get_stride_for_index(self, i), get_offset_for_index(self, i)), 67 mat2_buffer.as_strided(get_size_for_index(mat2, i), get_stride_for_index(mat2, i), get_offset_for_index(mat2, i))); 68 } 69 return output; 70 } 71 72 73 matmul_with_bmm_nested(const Tensor & self,const Tensor & mat2)74 static Tensor matmul_with_bmm_nested(const Tensor& self, const Tensor& mat2) { 75 // Tensor self = self_.contiguous(); 76 // Tensor mat2 = mat2_.contiguous(); 77 // self [N, n_heads, *, head_dim] 78 // mat2 [N, n_heads, head_dim, *] 79 const auto self_ptr = get_nested_tensor_impl(self); 80 const auto mat2_ptr = get_nested_tensor_impl(mat2); 81 // metadata for self 82 std::vector<IntArrayRef> self_sizes = NestedTensor_get_sizes(self_ptr); 83 std::vector<IntArrayRef> self_strides = NestedTensor_get_strides(self_ptr); 84 int64_t* self_offsets_ptr = 85 self_ptr->get_storage_offsets().data_ptr<int64_t>(); 86 auto opt = self_ptr->get_nested_sizes().options(); 87 88 // metadata for mat2 89 std::vector<IntArrayRef> mat2_sizes = NestedTensor_get_sizes(mat2_ptr); 90 std::vector<IntArrayRef> mat2_strides = NestedTensor_get_strides(mat2_ptr); 91 int64_t* mat2_offsets_ptr = 92 mat2_ptr->get_storage_offsets().data_ptr<int64_t>(); 93 auto opt2 = mat2_ptr->get_nested_sizes().options(); 94 95 int64_t N = static_cast<int64_t>(self_sizes.size()); 96 int64_t n_heads = self_sizes[0][0]; 97 98 // viewed metadata for self 99 auto self_new_sizes = at::empty({N * n_heads, 2}, opt); 100 int64_t* self_new_sizes_ptr = self_new_sizes.mutable_data_ptr<int64_t>(); 101 102 auto self_new_strides = at::empty({N * n_heads, 2}, opt); 103 int64_t* self_new_strides_ptr = self_new_strides.mutable_data_ptr<int64_t>(); 104 auto self_new_offsets = at::empty({N * n_heads}, opt); 105 int64_t* self_new_offsets_ptr = self_new_offsets.mutable_data_ptr<int64_t>(); 106 107 // viewed metadata for mat2 108 auto mat2_new_sizes = at::empty({N * n_heads, 2}, opt2); 109 int64_t* mat2_new_sizes_ptr = mat2_new_sizes.mutable_data_ptr<int64_t>(); 110 111 auto mat2_new_strides = at::empty({N * n_heads, 2}, opt2); 112 int64_t* mat2_new_strides_ptr = mat2_new_strides.mutable_data_ptr<int64_t>(); 113 auto mat2_new_offsets = at::empty({N * n_heads}, opt); 114 int64_t* mat2_new_offsets_ptr = mat2_new_offsets.mutable_data_ptr<int64_t>(); 115 116 for (int64_t i = 0; i < N; i++) { 117 const IntArrayRef& self_size_i = self_sizes[i]; 118 const IntArrayRef& self_stride_i = self_strides[i]; 119 int64_t self_offset = self_offsets_ptr[i]; 120 121 const IntArrayRef& mat2_size_i = mat2_sizes[i]; 122 const IntArrayRef& mat2_stride_i = mat2_strides[i]; 123 int64_t mat2_offset = mat2_offsets_ptr[i]; 124 for (int64_t j = 0; j < n_heads; j++) { 125 auto idx = (i * n_heads + j) * 2; 126 self_new_sizes_ptr[idx] = self_size_i[1]; 127 self_new_sizes_ptr[idx + 1] = self_size_i[2]; 128 self_new_strides_ptr[idx] = self_stride_i[1]; 129 self_new_strides_ptr[idx + 1] = self_stride_i[2]; 130 auto offset_idx = i * n_heads + j; 131 self_new_offsets_ptr[offset_idx] = self_offset; 132 self_offset += self_stride_i[0]; 133 134 mat2_new_sizes_ptr[idx] = mat2_size_i[1]; 135 mat2_new_sizes_ptr[idx + 1] = mat2_size_i[2]; 136 mat2_new_strides_ptr[idx] = mat2_stride_i[1]; 137 mat2_new_strides_ptr[idx + 1] = mat2_stride_i[2]; 138 mat2_new_offsets_ptr[offset_idx] = mat2_offset; 139 mat2_offset += mat2_stride_i[0]; 140 } 141 } 142 143 // view self as [N * n_heads, *, head_dim] (collapse first 2 dims) 144 auto viewed_self = create_nested_view_tensor( 145 self, self_new_sizes, self_new_strides, self_new_offsets); 146 147 // view mat2 as [N * n_heads, head_dim, *] (collapse first 2_dims) 148 auto viewed_mat2 = create_nested_view_tensor( 149 mat2, mat2_new_sizes, mat2_new_strides, mat2_new_offsets); 150 151 // output [N * n_heads, *, *] 152 auto bmm_output = at::bmm(viewed_self, viewed_mat2); 153 154 // generate metadata for viewing output as [N, n_heads, *, *] 155 // output of bmm should be contiguous so stride calculations should hold 156 auto out_new_sizes = at::empty({N, 3}, opt); 157 auto out_new_strides = at::empty({N, 3}, opt); 158 auto out_new_offsets = at::empty({N}, opt); 159 int64_t* out_new_offsets_ptr = out_new_offsets.mutable_data_ptr<int64_t>(); 160 161 int64_t* out_new_sizes_ptr = out_new_sizes.data_ptr<int64_t>(); 162 int64_t* out_new_strides_ptr = out_new_strides.data_ptr<int64_t>(); 163 164 int64_t out_offset = 0; 165 for (int64_t i = 0; i < N; i++) { 166 out_new_offsets_ptr[i] = out_offset; 167 const IntArrayRef& self_size_i = self_sizes[i]; 168 const IntArrayRef& mat2_size_i = mat2_sizes[i]; 169 auto idx = i * 3; 170 out_new_sizes_ptr[idx] = n_heads; 171 out_new_sizes_ptr[idx + 1] = self_size_i[1]; 172 out_new_sizes_ptr[idx + 2] = mat2_size_i[2]; 173 out_new_strides_ptr[idx] = self_size_i[1] * mat2_size_i[2]; 174 out_new_strides_ptr[idx + 1] = mat2_size_i[2]; 175 out_new_strides_ptr[idx + 2] = 1; 176 out_offset += n_heads * (self_size_i[1] * mat2_size_i[2]); 177 } 178 179 auto viewed_out = create_nested_view_tensor( 180 bmm_output, out_new_sizes, out_new_strides, out_new_offsets); 181 182 return viewed_out; 183 } 184 185 // nt: NT of shape (B, *, C, D) 186 // other: dense tensor of shape (D, E) 187 // output: NT of shape (B, *, C, E) matmul_nested_with_broadcasted_dense(const Tensor & nt,const Tensor & other)188 static Tensor matmul_nested_with_broadcasted_dense( 189 const Tensor& nt, 190 const Tensor& other) { 191 // View nt buffer as 3D jagged for matmul 192 auto* nt_impl = get_nested_tensor_impl(nt); 193 auto jagged = nt_impl->get_buffer().view({-1, nt.size(2), nt.size(3)}); 194 auto new_buffer = at::matmul(jagged, other); 195 196 // Wrap result into nested tensor 197 const auto E = other.size(-1); 198 const auto component_dim = nt.dim() - 1; 199 auto new_sizes = nt_impl->get_nested_sizes().clone(); 200 auto new_sizes_ptr = new_sizes.data_ptr<int64_t>(); 201 for (const auto i : c10::irange(nt.size(0))) { 202 new_sizes_ptr[i * component_dim + 2] = E; 203 } 204 return at::detail::make_tensor<NestedTensorImpl>( 205 new_buffer.view(-1), new_sizes); 206 } 207 208 // Note [nested tensor matmul] 209 // This is really a generalized batched matmul dedicated to nested tensors, 210 // where `self` and `mat2` have same number (>= 3) of dimensions. 211 // The last 2 dimensions will be considered as matrix dimensions, 212 // so they should be matrix-multiplicable. 213 // The leading dimensions are considered as batch dimensions, 214 // and since nested tensor does not support broadcasting for now, 215 // for each batch dimension `self` and `mat2` must have same size. 216 // TODO: Should make full matmul semantics support some day matmul_nested(const Tensor & self,const Tensor & mat2)217 Tensor matmul_nested(const Tensor& self, const Tensor& mat2) { 218 // special case of NT (B, *, C, D) with broadcasted dense (D, E) 219 if (self.is_nested() && self.is_contiguous() && !mat2.is_nested() && 220 self.dim() == 4 && mat2.dim() == 2 && 221 get_nested_tensor_impl(self)->opt_size(2).has_value() && 222 get_nested_tensor_impl(self)->opt_size(3).has_value() && 223 self.size(3) == mat2.size(0)) { 224 return matmul_nested_with_broadcasted_dense(self, mat2); 225 } 226 if (self.is_nested() && !mat2.is_nested()) { 227 AT_ERROR( 228 "Expected both to be nested, but got a nested self and non-nested other"); 229 } else if (!self.is_nested() && mat2.is_nested()) { 230 AT_ERROR( 231 "Expected both to be nested, but got a non-nested self and nested other"); 232 } 233 // to_padded_tensor only supports contiguous inputs 234 auto self_contig = self.contiguous(); 235 auto mat2_contig = mat2.contiguous(); 236 // dispatcher should have guaranteed that at least one is nested 237 const auto self_ptr = get_nested_tensor_impl(self_contig); 238 const auto mat2_ptr = get_nested_tensor_impl(mat2_contig); 239 int64_t self_dim = self_ptr->dim(), mat2_dim = mat2_ptr->dim(); 240 TORCH_CHECK( 241 self_dim >= 3, 242 "matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: ", 243 self_dim); 244 TORCH_CHECK( 245 mat2_dim >= 3, 246 "matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 2nd input has rank: ", 247 mat2_dim); 248 TORCH_CHECK( 249 self_dim == mat2_dim, "matmul: both inputs must have the same rank"); 250 int64_t ntensors = self_ptr->size(0), ntensors2 = mat2_ptr->size(0); 251 TORCH_CHECK( 252 ntensors == ntensors2, 253 "matmul: Expected size for the 1st dimension of 2nd input tensor to be: ", 254 ntensors, 255 " but got: ", 256 ntensors2, 257 "."); 258 // Ensure batch dimensions have the same sizes (no broadcasting). 259 const auto& self_sizes = self_ptr->get_nested_sizes(); 260 const auto& mat2_sizes = mat2_ptr->get_nested_sizes(); 261 const auto& self_batch_sizes = self_sizes.narrow(1, 0, self_dim - 3); 262 const auto& mat2_batch_sizes = mat2_sizes.narrow(1, 0, mat2_dim - 3); 263 TORCH_CHECK( 264 at::equal(self_batch_sizes, mat2_batch_sizes), 265 "matmul: For nested tensors, batch dimensions must have the same sizes, ", 266 "no broadcasting is currently performed. Got batch shapes for self ", 267 self_batch_sizes, 268 " and batch shapes for mat2 ", 269 mat2_batch_sizes); 270 // Ensure last dim of self and second last dim of mat2 have the same size 271 const auto& self_dim_size = self_sizes.select(1, -1); 272 const auto& mat2_dim_size = mat2_sizes.select(1, -2); 273 TORCH_CHECK( 274 at::equal(self_dim_size, mat2_dim_size), 275 "matmul: Nested tensors cannot be matrix multiplied, last dimension of self has sizes", 276 self_dim_size, 277 "second last dimension of mat2 has sizes", 278 mat2_dim_size); 279 280 // use bmm inference-only fast path for [N, n_heads, *, head_dim] [N, n_heads, 281 // head_dim, *] 282 if (self.is_cuda() && self_dim == 4 && self.is_contiguous() && 283 mat2_dim == 4 && mat2.is_contiguous() && 284 !(GradMode::is_enabled() && 285 (self.requires_grad() || mat2.requires_grad()))) { 286 const auto& self_opt_head_dim = self_ptr->opt_size(1); 287 const auto& mat2_opt_head_dim = mat2_ptr->opt_size(1); 288 if (self_opt_head_dim.has_value() && mat2_opt_head_dim.has_value() && 289 self_opt_head_dim.value() == mat2_opt_head_dim.value()) { 290 return matmul_with_bmm_nested(self, mat2); 291 } 292 } 293 294 // Construct output size from input sizes 295 Tensor output_sizes = self_sizes.clone(); 296 // The last entry in every row of output_sizes should be last column of 297 // mat2_sizes 298 output_sizes.index_put_( 299 {at::indexing::Slice(), -1}, mat2_sizes.select(1, -1).clone()); 300 301 auto self_padded = self_contig.to_padded_tensor(0.); 302 auto mat2_padded = mat2_contig.to_padded_tensor(0.); 303 auto output_padded = at::matmul(self_padded, mat2_padded); 304 auto output_nested = nested_from_padded_generic(output_padded, output_sizes); 305 return output_nested; 306 } 307 matmul_out_nested(const Tensor & tensor1,const Tensor & tensor2,Tensor & result)308 Tensor& matmul_out_nested( 309 const Tensor& tensor1, 310 const Tensor& tensor2, 311 Tensor& result) { 312 // TODO: this is a very quick and dirty implementation 313 // should improve it to avoid the intermediate memory usage 314 Tensor function_result = at::matmul(tensor1, tensor2); 315 auto function_result_ptr = get_nested_tensor_impl(function_result); 316 // TODO: this is to reproduce function_result_ptr->opt_sizes_ 317 // if an accessor is provided in the future, can replace this 318 std::vector<int64_t> sizes; 319 for (int64_t i = 0; i < function_result_ptr->dim(); i++) { 320 std::optional<int64_t> opt_size = function_result_ptr->opt_size(i); 321 if (opt_size.has_value()) { 322 sizes.push_back(*opt_size); 323 } else { 324 sizes.push_back(-1); 325 } 326 } 327 result.reshape(sizes); 328 result.copy_(function_result); 329 return result; 330 } 331 332 } // namespace at::native 333