1 #include <ATen/core/NestedIntSymNodeImpl.h>
2 #include <c10/core/SymNodeImpl.h>
3 #include <c10/util/Exception.h>
4
5 namespace c10 {
6
7 namespace {
_eq(const char * op,c10::SymNodeImpl * lhs,c10::SymNodeImpl * rhs)8 bool _eq(const char* op, c10::SymNodeImpl* lhs, c10::SymNodeImpl* rhs) {
9 TORCH_INTERNAL_ASSERT(lhs->is_nested_int());
10 std::optional<int64_t> c = rhs->nested_int();
11 return (
12 c.has_value() && lhs->nested_int() == *c &&
13 lhs->nested_int_coeff() == rhs->nested_int_coeff());
14 }
_ge(const char * op,c10::SymNodeImpl * lhs,c10::SymNodeImpl * rhs)15 bool _ge(const char* op, c10::SymNodeImpl* lhs, c10::SymNodeImpl* rhs) {
16 if (auto mb_si = lhs->nested_int()) {
17 if (auto mb_si2 = rhs->nested_int()) {
18 if (*mb_si == *mb_si2) {
19 return lhs->nested_int_coeff() >= rhs->nested_int_coeff();
20 }
21 TORCH_CHECK(false, "nested int ", op, ": Relation is indeterminate");
22 }
23 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
24 if (rhs->constant_int() && *rhs->constant_int() <= 2) {
25 return true;
26 }
27 TORCH_CHECK(false, "nested int ", op, ": Relation is indeterminate");
28 } else if (rhs->nested_int()) {
29 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
30 if (lhs->constant_int() && *lhs->constant_int() < 2) {
31 return false;
32 }
33 TORCH_CHECK(false, "nested int ", op, ": Relation is indeterminate");
34 }
35 TORCH_INTERNAL_ASSERT(false, "expect at least one nested int");
36 }
37 } // namespace
38
eq(const c10::SymNode & other)39 c10::SymNode NestedIntSymNodeImpl::eq(const c10::SymNode& other) {
40 return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
41 _eq("eq", this, other.get())));
42 }
43
ne(const c10::SymNode & other)44 c10::SymNode NestedIntSymNodeImpl::ne(const c10::SymNode& other) {
45 return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
46 !_eq("ne", this, other.get())));
47 }
48
ge(const c10::SymNode & other)49 c10::SymNode NestedIntSymNodeImpl::ge(const c10::SymNode& other) {
50 return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
51 _ge("ge", this, other.get())));
52 }
53
gt(const c10::SymNode & other)54 c10::SymNode NestedIntSymNodeImpl::gt(const c10::SymNode& other) {
55 return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
56 !_ge("gt", other.get(), this)));
57 }
58
lt(const c10::SymNode & other)59 c10::SymNode NestedIntSymNodeImpl::lt(const c10::SymNode& other) {
60 return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
61 !_ge("lt", this, other.get())));
62 }
63
le(const c10::SymNode & other)64 c10::SymNode NestedIntSymNodeImpl::le(const c10::SymNode& other) {
65 return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
66 _ge("le", other.get(), this)));
67 }
68
mul(const c10::SymNode & other)69 c10::SymNode NestedIntSymNodeImpl::mul(const c10::SymNode& other) {
70 TORCH_CHECK(!other->nested_int(), "nested int cannot be multiplied by nested int");
71 std::optional<int64_t> c = other->constant_int();
72 TORCH_CHECK(c.has_value());
73 return SymNode(c10::make_intrusive<NestedIntSymNodeImpl>(val_, coeff_ * *c));
74 }
75
clone()76 c10::SymNode NestedIntSymNodeImpl::clone() {
77 return SymNode(c10::make_intrusive<NestedIntSymNodeImpl>(val_, coeff_));
78 }
79
80 } // namespace c10
81