xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/TensorOperators.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/ExpandUtils.h>
4 #include <ATen/native/Resize.h>
5 #include <ATen/quantized/Quantizer.h>
6 #include <c10/core/QScheme.h>
7 
8 #ifndef AT_PER_OPERATOR_HEADERS
9 #include <ATen/Functions.h>
10 #include <ATen/NativeFunctions.h>
11 #else
12 #include <ATen/ops/eq.h>
13 #include <ATen/ops/eq_native.h>
14 #include <ATen/ops/ge.h>
15 #include <ATen/ops/ge_native.h>
16 #include <ATen/ops/gt.h>
17 #include <ATen/ops/gt_native.h>
18 #include <ATen/ops/le.h>
19 #include <ATen/ops/le_native.h>
20 #include <ATen/ops/lt.h>
21 #include <ATen/ops/lt_native.h>
22 #include <ATen/ops/ne.h>
23 #include <ATen/ops/ne_native.h>
24 #include <ATen/ops/resize_native.h>
25 #endif
26 
27 namespace at {
28 namespace native {
29 
30 /*
31 All comparator operators will be named "<aten op name>_quantized_cpu".
32 '_out' will be appended for the 'out' variant of the op.
33 
34 TODO: This is an inefficient implementation that uses `.dequantize`.
35       Need a more efficient implementation.
36 */
37 
38 #define DEFINE_COMPARATOR(at_op) \
39 Tensor& at_op##_out_quantized_cpu(const Tensor& self, \
40                                 const Scalar& other, Tensor& out) { \
41   TORCH_CHECK(out.dtype() == at::ScalarType::Bool, \
42               "The 'out' tensor must have dtype 'torch.bool'"); \
43   auto self_dq = self.dequantize(); \
44   return at:: at_op##_out(out, self_dq, other); \
45 } \
46 Tensor at_op##_quantized_cpu(const Tensor& self, const Scalar& other) { \
47   auto self_dq = self.dequantize(); \
48   return at:: at_op(self_dq, other); \
49 } \
50 Tensor& at_op##_out_quantized_cpu(const Tensor& self, \
51                                 const Tensor& other, Tensor& out) { \
52   /* We infer size to make sure the tensors are compatible. */\
53   infer_size_dimvector(self.sizes(), other.sizes()); \
54   TORCH_CHECK(out.dtype() == at::ScalarType::Bool, \
55               "The 'out' tensor must have dtype 'torch.bool'"); \
56   auto self_dq = self.dequantize(); \
57   auto other_dq = other.dequantize(); \
58   return at:: at_op##_out(out, self_dq, other_dq); \
59 } \
60 Tensor at_op##_quantized_cpu(const Tensor& self, const Tensor& other) { \
61   /* We infer size to make sure the tensors are compatible. */\
62   infer_size_dimvector(self.sizes(), other.sizes()); \
63   auto self_dq = self.dequantize(); \
64   auto other_dq = other.dequantize(); \
65   return at:: at_op(self_dq, other_dq); \
66 }
67 
68 #define AT_FORALL_OPERATORS(_) \
69 _(ne)                          \
70 _(eq)                          \
71 _(ge)                          \
72 _(le)                          \
73 _(gt)                          \
74 _(lt)                          \
75 
AT_FORALL_OPERATORS(DEFINE_COMPARATOR) const76 AT_FORALL_OPERATORS(DEFINE_COMPARATOR)
77 
78 #undef AT_FORALL_OPERATORS
79 #undef DEFINE_COMPARATOR
80 
81 const Tensor& quantized_resize_cpu_(
82     const Tensor& self,
83     IntArrayRef size,
84     std::optional<MemoryFormat> optional_memory_format) {
85   // See Note [Writing Nondeterministic Operations]
86   // Nondeterministic because if storage is resized, new elements are uninitialized
87   globalContext().alertNotDeterministic("quantized_resize_cpu_");
88   TORCH_CHECK(
89       !optional_memory_format.has_value(),
90       "Unsupported memory format for quantized tensor resize ",
91       optional_memory_format.value());
92   auto qscheme = self.quantizer()->qscheme();
93   TORCH_CHECK(
94       qscheme == QScheme::PER_TENSOR_AFFINE ||
95           qscheme == QScheme::PER_TENSOR_SYMMETRIC,
96       "Can only resize quantized tensors with per-tensor schemes!");
97   auto* self_ = self.unsafeGetTensorImpl();
98   // NOLINTNEXTLINE(bugprone-argument-comment)
99   resize_impl_cpu_(self_, size, /*strides=*/std::nullopt);
100   return self;
101 }
102 
103 }}  // at::native
104