1# Owner(s): ["oncall: jit"] 2 3import unittest 4 5import numpy as np 6 7import torch 8from torch.testing import FileCheck 9from torch.testing._internal.common_utils import IS_MACOS 10from torch.testing._internal.jit_utils import JitTestCase 11 12 13if __name__ == "__main__": 14 raise RuntimeError( 15 "This test file is not meant to be run directly, use:\n\n" 16 "\tpython test/test_jit.py TESTNAME\n\n" 17 "instead." 18 ) 19 20 21class TestPythonIr(JitTestCase): 22 def test_param_strides(self): 23 def trace_me(arg): 24 return arg 25 26 t = torch.zeros(1, 3, 16, 16) 27 traced = torch.jit.trace(trace_me, t) 28 value = list(traced.graph.param_node().outputs())[0] 29 real_strides = list(t.stride()) 30 type_strides = value.type().strides() 31 self.assertEqual(real_strides, type_strides) 32 33 def test_permute_inputs_binding(self): 34 @torch.jit.script 35 def foo(i, j, k): 36 pass 37 38 g = foo.graph 39 40 idxs = [] 41 for i, inp in enumerate(g.inputs()): 42 inp.setDebugName(f"inp{i}") 43 idxs.append(i) 44 45 permuted_idxs = list(np.random.permutation(idxs)) 46 g.permuteInputs(permuted_idxs) 47 for i, inp in enumerate(g.inputs()): 48 self.assertEqual(f"inp{permuted_idxs[i]}", inp.debugName()) 49 50 @unittest.skipIf(IS_MACOS, "Failing on MacOS only") 51 def test_python_ir_utils(self): 52 @torch.jit.script 53 def foo(inp): 54 x = inp + 1 55 y = x / 2 56 z = y * y 57 return z 58 59 add_node = foo.graph.findNode("aten::add") 60 div_node = foo.graph.findNode("aten::div") 61 62 with foo.graph.insert_point_guard(add_node): 63 with foo.graph.insert_point_guard(div_node): 64 foo.graph.insertConstant("goodbye") 65 foo.graph.insertConstant("hello") 66 with foo.graph.insert_point_guard(foo.graph.findNode("aten::mul")): 67 foo.graph.insertConstant("hello") 68 FileCheck().check("hello").check("goodbye").check("hello").run(foo.graph) 69 70 self.assertTrue(add_node.matches(add_node.schema())) 71 self.assertFalse(add_node.matches(div_node.schema())) 72 73 def test_python_ir_utils_graph(self): 74 @torch.jit.script 75 def unrolled_mul(x: torch.Tensor, y: int): 76 out = x 77 for _ in range(y - 1): 78 out = out + x 79 return out 80 81 @torch.jit.script 82 def foo(x): 83 return x * 4 84 85 g = foo.graph 86 muls = g.findAllNodes("aten::mul") 87 scalar_muls = filter( 88 lambda x: x.matches("aten::mul(Tensor self, Scalar other) -> Tensor"), muls 89 ) 90 mul_constant_int = filter( 91 lambda x: isinstance(list(x.inputs())[1].toIValue(), int), scalar_muls 92 ) 93 for mul in mul_constant_int: 94 with g.insert_point_guard(mul): 95 outputs = g.insertGraph(unrolled_mul.graph, list(mul.inputs())) 96 assert len(outputs) == len(list(mul.outputs())) 97 for new_out, old_out in zip(outputs, g.outputs()): 98 old_out.replaceAllUsesWith(new_out) 99 mul.destroy() 100 101 FileCheck().check_not("aten::mul").check("aten::add").run(foo.graph) 102 self.assertEqual(foo(torch.ones([2, 2])), torch.ones([2, 2]) * 4) 103