# Owner(s): ["module: inductor"] from sympy import Symbol from torch._inductor.test_case import run_tests, TestCase from torch._inductor.utils import sympy_subs class TestUtils(TestCase): def testSympySubs(self): # integer and nonnegetaive attributes are preserved. expr = Symbol("x") result = sympy_subs(expr, {expr: "y"}) self.assertEqual(result.name, "y") self.assertEqual(result.is_integer, None) self.assertEqual(result.is_nonnegative, None) expr = Symbol("x", integer=True, nonnegative=False) result = sympy_subs(expr, {expr: "y"}) self.assertEqual(result.name, "y") self.assertEqual(result.is_integer, True) self.assertEqual(result.is_nonnegative, False) # invalid replacement. expr = Symbol("x", integer=True) result = sympy_subs(expr, {Symbol("x"): Symbol("y")}) self.assertEqual(result.name, "x") # valid replacement since properties match. expr = Symbol("x", integer=True) result = sympy_subs(expr, {Symbol("x", integer=True): Symbol("y")}) self.assertEqual(result.name, "y") # invalid replacement. expr = Symbol("x", integer=None) result = sympy_subs(expr, {Symbol("x", integer=False): Symbol("y")}) self.assertEqual(result.name, "x") # replaced cant be string self.assertRaises(AssertionError, sympy_subs, expr, {"x": "y"}) # replaced can be an expression expr = Symbol("x") expr = abs(expr) self.assertEqual(expr.is_integer, None) self.assertEqual(expr.is_nonnegative, None) # replace abs(x) with y # propagte abs(x) sympy properties. result = sympy_subs(expr, {expr: Symbol("y")}) self.assertEqual(result.name, "y") self.assertEqual(result.is_integer, None) self.assertEqual(result.is_nonnegative, None) if __name__ == "__main__": run_tests()