xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/quantization/quantization_utils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Meta Platforms, Inc. and affiliates.
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #pragma once
7 
8 #include <ATen/ATen.h>
9 
10 #include <typeinfo>
11 
torch_tensor_device_name(const at::Tensor & ten)12 inline std::string torch_tensor_device_name(const at::Tensor& ten) {
13   return c10::DeviceTypeName(ten.device().type());
14 }
15 
16 #define TENSOR_NDIM_EQUALS(ten, dims)      \
17   TORCH_CHECK(                             \
18       (ten).ndimension() == (dims),        \
19       "Tensor '" #ten "' must have " #dims \
20       " dimension(s). "                    \
21       "Found ",                            \
22       (ten).ndimension())
23 
24 #define TENSOR_ON_CPU(x)                                      \
25   TORCH_CHECK(                                                \
26       !x.is_cuda(),                                           \
27       #x " must be a CPU tensor; it is currently on device ", \
28       torch_tensor_device_name(x))
29 
30 #define TENSOR_ON_CUDA_GPU(x)                                  \
31   TORCH_CHECK(                                                 \
32       x.is_cuda(),                                             \
33       #x " must be a CUDA tensor; it is currently on device ", \
34       torch_tensor_device_name(x))
35