xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/NestedIntSymNodeImpl.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/NestedIntSymNodeImpl.h>
2 #include <c10/core/SymNodeImpl.h>
3 #include <c10/util/Exception.h>
4 
5 namespace c10 {
6 
7 namespace {
_eq(const char * op,c10::SymNodeImpl * lhs,c10::SymNodeImpl * rhs)8 bool _eq(const char* op, c10::SymNodeImpl* lhs, c10::SymNodeImpl* rhs) {
9   TORCH_INTERNAL_ASSERT(lhs->is_nested_int());
10   std::optional<int64_t> c = rhs->nested_int();
11   return (
12       c.has_value() && lhs->nested_int() == *c &&
13       lhs->nested_int_coeff() == rhs->nested_int_coeff());
14 }
_ge(const char * op,c10::SymNodeImpl * lhs,c10::SymNodeImpl * rhs)15 bool _ge(const char* op, c10::SymNodeImpl* lhs, c10::SymNodeImpl* rhs) {
16   if (auto mb_si = lhs->nested_int()) {
17     if (auto mb_si2 = rhs->nested_int()) {
18       if (*mb_si == *mb_si2) {
19         return lhs->nested_int_coeff() >= rhs->nested_int_coeff();
20       }
21       TORCH_CHECK(false, "nested int ", op, ": Relation is indeterminate");
22     }
23     // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
24     if (rhs->constant_int() && *rhs->constant_int() <= 2) {
25       return true;
26     }
27     TORCH_CHECK(false, "nested int ", op, ": Relation is indeterminate");
28   } else if (rhs->nested_int()) {
29     // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
30     if (lhs->constant_int() && *lhs->constant_int() < 2) {
31       return false;
32     }
33     TORCH_CHECK(false, "nested int ", op, ": Relation is indeterminate");
34   }
35   TORCH_INTERNAL_ASSERT(false, "expect at least one nested int");
36 }
37 } // namespace
38 
eq(const c10::SymNode & other)39 c10::SymNode NestedIntSymNodeImpl::eq(const c10::SymNode& other) {
40   return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
41       _eq("eq", this, other.get())));
42 }
43 
ne(const c10::SymNode & other)44 c10::SymNode NestedIntSymNodeImpl::ne(const c10::SymNode& other) {
45   return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
46       !_eq("ne", this, other.get())));
47 }
48 
ge(const c10::SymNode & other)49 c10::SymNode NestedIntSymNodeImpl::ge(const c10::SymNode& other) {
50   return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
51       _ge("ge", this, other.get())));
52 }
53 
gt(const c10::SymNode & other)54 c10::SymNode NestedIntSymNodeImpl::gt(const c10::SymNode& other) {
55   return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
56       !_ge("gt", other.get(), this)));
57 }
58 
lt(const c10::SymNode & other)59 c10::SymNode NestedIntSymNodeImpl::lt(const c10::SymNode& other) {
60   return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
61       !_ge("lt", this, other.get())));
62 }
63 
le(const c10::SymNode & other)64 c10::SymNode NestedIntSymNodeImpl::le(const c10::SymNode& other) {
65   return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
66       _ge("le", other.get(), this)));
67 }
68 
mul(const c10::SymNode & other)69 c10::SymNode NestedIntSymNodeImpl::mul(const c10::SymNode& other) {
70   TORCH_CHECK(!other->nested_int(), "nested int cannot be multiplied by nested int");
71   std::optional<int64_t> c = other->constant_int();
72   TORCH_CHECK(c.has_value());
73   return SymNode(c10::make_intrusive<NestedIntSymNodeImpl>(val_, coeff_ * *c));
74 }
75 
clone()76 c10::SymNode NestedIntSymNodeImpl::clone() {
77   return SymNode(c10::make_intrusive<NestedIntSymNodeImpl>(val_, coeff_));
78 }
79 
80 } // namespace c10
81