1 // Copyright (c) Meta Platforms, Inc. and affiliates. 2 // 3 // This source code is licensed under the BSD-style license found in the 4 // LICENSE file in the root directory of this source tree. 5 6 #pragma once 7 8 #include <ATen/ATen.h> 9 10 #include <typeinfo> 11 torch_tensor_device_name(const at::Tensor & ten)12inline std::string torch_tensor_device_name(const at::Tensor& ten) { 13 return c10::DeviceTypeName(ten.device().type()); 14 } 15 16 #define TENSOR_NDIM_EQUALS(ten, dims) \ 17 TORCH_CHECK( \ 18 (ten).ndimension() == (dims), \ 19 "Tensor '" #ten "' must have " #dims \ 20 " dimension(s). " \ 21 "Found ", \ 22 (ten).ndimension()) 23 24 #define TENSOR_ON_CPU(x) \ 25 TORCH_CHECK( \ 26 !x.is_cuda(), \ 27 #x " must be a CPU tensor; it is currently on device ", \ 28 torch_tensor_device_name(x)) 29 30 #define TENSOR_ON_CUDA_GPU(x) \ 31 TORCH_CHECK( \ 32 x.is_cuda(), \ 33 #x " must be a CUDA tensor; it is currently on device ", \ 34 torch_tensor_device_name(x)) 35