1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/ExpandUtils.h>
4 #include <ATen/native/Resize.h>
5 #include <ATen/quantized/Quantizer.h>
6 #include <c10/core/QScheme.h>
7
8 #ifndef AT_PER_OPERATOR_HEADERS
9 #include <ATen/Functions.h>
10 #include <ATen/NativeFunctions.h>
11 #else
12 #include <ATen/ops/eq.h>
13 #include <ATen/ops/eq_native.h>
14 #include <ATen/ops/ge.h>
15 #include <ATen/ops/ge_native.h>
16 #include <ATen/ops/gt.h>
17 #include <ATen/ops/gt_native.h>
18 #include <ATen/ops/le.h>
19 #include <ATen/ops/le_native.h>
20 #include <ATen/ops/lt.h>
21 #include <ATen/ops/lt_native.h>
22 #include <ATen/ops/ne.h>
23 #include <ATen/ops/ne_native.h>
24 #include <ATen/ops/resize_native.h>
25 #endif
26
27 namespace at {
28 namespace native {
29
30 /*
31 All comparator operators will be named "<aten op name>_quantized_cpu".
32 '_out' will be appended for the 'out' variant of the op.
33
34 TODO: This is an inefficient implementation that uses `.dequantize`.
35 Need a more efficient implementation.
36 */
37
38 #define DEFINE_COMPARATOR(at_op) \
39 Tensor& at_op##_out_quantized_cpu(const Tensor& self, \
40 const Scalar& other, Tensor& out) { \
41 TORCH_CHECK(out.dtype() == at::ScalarType::Bool, \
42 "The 'out' tensor must have dtype 'torch.bool'"); \
43 auto self_dq = self.dequantize(); \
44 return at:: at_op##_out(out, self_dq, other); \
45 } \
46 Tensor at_op##_quantized_cpu(const Tensor& self, const Scalar& other) { \
47 auto self_dq = self.dequantize(); \
48 return at:: at_op(self_dq, other); \
49 } \
50 Tensor& at_op##_out_quantized_cpu(const Tensor& self, \
51 const Tensor& other, Tensor& out) { \
52 /* We infer size to make sure the tensors are compatible. */\
53 infer_size_dimvector(self.sizes(), other.sizes()); \
54 TORCH_CHECK(out.dtype() == at::ScalarType::Bool, \
55 "The 'out' tensor must have dtype 'torch.bool'"); \
56 auto self_dq = self.dequantize(); \
57 auto other_dq = other.dequantize(); \
58 return at:: at_op##_out(out, self_dq, other_dq); \
59 } \
60 Tensor at_op##_quantized_cpu(const Tensor& self, const Tensor& other) { \
61 /* We infer size to make sure the tensors are compatible. */\
62 infer_size_dimvector(self.sizes(), other.sizes()); \
63 auto self_dq = self.dequantize(); \
64 auto other_dq = other.dequantize(); \
65 return at:: at_op(self_dq, other_dq); \
66 }
67
68 #define AT_FORALL_OPERATORS(_) \
69 _(ne) \
70 _(eq) \
71 _(ge) \
72 _(le) \
73 _(gt) \
74 _(lt) \
75
AT_FORALL_OPERATORS(DEFINE_COMPARATOR) const76 AT_FORALL_OPERATORS(DEFINE_COMPARATOR)
77
78 #undef AT_FORALL_OPERATORS
79 #undef DEFINE_COMPARATOR
80
81 const Tensor& quantized_resize_cpu_(
82 const Tensor& self,
83 IntArrayRef size,
84 std::optional<MemoryFormat> optional_memory_format) {
85 // See Note [Writing Nondeterministic Operations]
86 // Nondeterministic because if storage is resized, new elements are uninitialized
87 globalContext().alertNotDeterministic("quantized_resize_cpu_");
88 TORCH_CHECK(
89 !optional_memory_format.has_value(),
90 "Unsupported memory format for quantized tensor resize ",
91 optional_memory_format.value());
92 auto qscheme = self.quantizer()->qscheme();
93 TORCH_CHECK(
94 qscheme == QScheme::PER_TENSOR_AFFINE ||
95 qscheme == QScheme::PER_TENSOR_SYMMETRIC,
96 "Can only resize quantized tensors with per-tensor schemes!");
97 auto* self_ = self.unsafeGetTensorImpl();
98 // NOLINTNEXTLINE(bugprone-argument-comment)
99 resize_impl_cpu_(self_, size, /*strides=*/std::nullopt);
100 return self;
101 }
102
103 }} // at::native
104