1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport copy 4*da0073e9SAndroid Build Coastguard Workerimport io 5*da0073e9SAndroid Build Coastguard Workerimport os 6*da0073e9SAndroid Build Coastguard Workerimport sys 7*da0073e9SAndroid Build Coastguard Workerimport unittest 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerimport torch 10*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn 11*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F 12*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd import Function, Variable 13*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import FileCheck 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable 17*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 18*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir) 19*da0073e9SAndroid Build Coastguard Workerimport warnings 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker# Standard library 22*da0073e9SAndroid Build Coastguard Workerfrom collections import namedtuple 23*da0073e9SAndroid Build Coastguard Workerfrom itertools import chain 24*da0073e9SAndroid Build Coastguard Workerfrom typing import Dict, List, Optional, Tuple 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor 27*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import with_tf32_off 28*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 29*da0073e9SAndroid Build Coastguard Worker enable_profiling_mode_for_profiling_tests, 30*da0073e9SAndroid Build Coastguard Worker IS_SANDCASTLE, 31*da0073e9SAndroid Build Coastguard Worker skipIfCompiledWithoutNumpy, 32*da0073e9SAndroid Build Coastguard Worker skipIfCrossRef, 33*da0073e9SAndroid Build Coastguard Worker skipIfTorchDynamo, 34*da0073e9SAndroid Build Coastguard Worker suppress_warnings, 35*da0073e9SAndroid Build Coastguard Worker TemporaryFileName, 36*da0073e9SAndroid Build Coastguard Worker) 37*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import ( 38*da0073e9SAndroid Build Coastguard Worker _tmp_donotuse_dont_inline_everything, 39*da0073e9SAndroid Build Coastguard Worker _trace, 40*da0073e9SAndroid Build Coastguard Worker enable_cpu_fuser, 41*da0073e9SAndroid Build Coastguard Worker JitTestCase, 42*da0073e9SAndroid Build Coastguard Worker make_global, 43*da0073e9SAndroid Build Coastguard Worker RUN_CUDA, 44*da0073e9SAndroid Build Coastguard Worker RUN_CUDA_MULTI_GPU, 45*da0073e9SAndroid Build Coastguard Worker) 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 49*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 50*da0073e9SAndroid Build Coastguard Worker "This test file is not meant to be run directly, use:\n\n" 51*da0073e9SAndroid Build Coastguard Worker "\tpython test/test_jit.py TESTNAME\n\n" 52*da0073e9SAndroid Build Coastguard Worker "instead." 53*da0073e9SAndroid Build Coastguard Worker ) 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker@skipIfTorchDynamo("Not a suitable test for TorchDynamo") 57*da0073e9SAndroid Build Coastguard Workerclass TestTracer(JitTestCase): 58*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "requires CUDA") 59*da0073e9SAndroid Build Coastguard Worker def test_large_nbr_kernel_args(self): 60*da0073e9SAndroid Build Coastguard Worker class Recurrence(nn.Module): 61*da0073e9SAndroid Build Coastguard Worker def __init__(self, seq_len): 62*da0073e9SAndroid Build Coastguard Worker super().__init__() 63*da0073e9SAndroid Build Coastguard Worker self.seq_len = seq_len 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 66*da0073e9SAndroid Build Coastguard Worker input = input.transpose(0, 1) 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker # Main loop 69*da0073e9SAndroid Build Coastguard Worker output = [] 70*da0073e9SAndroid Build Coastguard Worker for i in range(self.seq_len): 71*da0073e9SAndroid Build Coastguard Worker b = input[i] * 2 72*da0073e9SAndroid Build Coastguard Worker output.append(b) 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker output = torch.cat(output, 0).view(input.size(0), *output[0].size()) 75*da0073e9SAndroid Build Coastguard Worker output = output.transpose(0, 1) 76*da0073e9SAndroid Build Coastguard Worker return output 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker input_size = 8 79*da0073e9SAndroid Build Coastguard Worker batch_size = 2 80*da0073e9SAndroid Build Coastguard Worker seq_len = 130 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker rec = Recurrence(seq_len) 83*da0073e9SAndroid Build Coastguard Worker input = torch.rand(batch_size, seq_len, input_size) 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker torch.cuda.set_device(0) 86*da0073e9SAndroid Build Coastguard Worker rec = rec.cuda() 87*da0073e9SAndroid Build Coastguard Worker input = input.cuda() 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker traced_rec = torch.jit.trace(rec, (input)) 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker def test_trace_legacy_ctor(self): 92*da0073e9SAndroid Build Coastguard Worker class MyModule(nn.Module): 93*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 94*da0073e9SAndroid Build Coastguard Worker return (x + 1, torch.FloatTensor([0])) 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker traced_rec = torch.jit.trace(MyModule(), torch.randn(2, 2)) 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker def test_simple(self): 99*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([0.4], requires_grad=True) 100*da0073e9SAndroid Build Coastguard Worker y = torch.tensor([0.7], requires_grad=True) 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker def f(x, y): 103*da0073e9SAndroid Build Coastguard Worker return torch.sigmoid(torch.tanh(x * (x + y))) 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Worker self.checkTrace(f, (x, y)) 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker def test_trace_checking_with_global_name(self): 108*da0073e9SAndroid Build Coastguard Worker class MyClass(torch.nn.Module): 109*da0073e9SAndroid Build Coastguard Worker def forward(self, xs: List[Tensor]): 110*da0073e9SAndroid Build Coastguard Worker y = torch.cat(xs, dim=0) 111*da0073e9SAndroid Build Coastguard Worker return y 112*da0073e9SAndroid Build Coastguard Worker 113*da0073e9SAndroid Build Coastguard Worker model = MyClass() 114*da0073e9SAndroid Build Coastguard Worker # Simulate these inputs being in the globals, like they would be if, 115*da0073e9SAndroid Build Coastguard Worker # e.g. they were defined outermost scope of a script 116*da0073e9SAndroid Build Coastguard Worker global input1, input2 117*da0073e9SAndroid Build Coastguard Worker input1 = torch.ones(2, 2) 118*da0073e9SAndroid Build Coastguard Worker input2 = torch.ones(2, 2) 119*da0073e9SAndroid Build Coastguard Worker m2 = torch.jit.trace(model, ((input1, input2),)) 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker def test_trace_aliased_parameter(self): 122*da0073e9SAndroid Build Coastguard Worker class M(nn.Module): 123*da0073e9SAndroid Build Coastguard Worker def __init__(self, x): 124*da0073e9SAndroid Build Coastguard Worker super().__init__() 125*da0073e9SAndroid Build Coastguard Worker self.x = nn.Parameter(x) 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker def forward(self, y): 128*da0073e9SAndroid Build Coastguard Worker return self.x + y 129*da0073e9SAndroid Build Coastguard Worker 130*da0073e9SAndroid Build Coastguard Worker m = M(torch.rand(3, 4)) 131*da0073e9SAndroid Build Coastguard Worker r = torch.jit.trace(m, m.x) 132*da0073e9SAndroid Build Coastguard Worker t2 = torch.rand(3, 4) 133*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r(t2), m.x + t2) 134*da0073e9SAndroid Build Coastguard Worker 135*da0073e9SAndroid Build Coastguard Worker def test_trace_nested_fn(self): 136*da0073e9SAndroid Build Coastguard Worker class TracedInlineDecision(torch.nn.Module): 137*da0073e9SAndroid Build Coastguard Worker def forward(self, x, flag): 138*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 139*da0073e9SAndroid Build Coastguard Worker def make_decision(flag, x): 140*da0073e9SAndroid Build Coastguard Worker if flag: 141*da0073e9SAndroid Build Coastguard Worker return x 142*da0073e9SAndroid Build Coastguard Worker else: 143*da0073e9SAndroid Build Coastguard Worker return torch.zeros_like(x) 144*da0073e9SAndroid Build Coastguard Worker 145*da0073e9SAndroid Build Coastguard Worker x = torch.neg(x) 146*da0073e9SAndroid Build Coastguard Worker return make_decision(flag, x) 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker decision = TracedInlineDecision() 149*da0073e9SAndroid Build Coastguard Worker torch.jit.trace( 150*da0073e9SAndroid Build Coastguard Worker decision, 151*da0073e9SAndroid Build Coastguard Worker (torch.rand(3, 4), torch.tensor([True], dtype=torch.bool)), 152*da0073e9SAndroid Build Coastguard Worker check_trace=True, 153*da0073e9SAndroid Build Coastguard Worker ) 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Worker def test_trace_single_tuple(self): 156*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(2.0) 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Worker def f2(x): 159*da0073e9SAndroid Build Coastguard Worker return (x,) 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker jit_f2 = torch.jit.trace(f2, x) 162*da0073e9SAndroid Build Coastguard Worker assert f2(x) == jit_f2(x) # fails 163*da0073e9SAndroid Build Coastguard Worker 164*da0073e9SAndroid Build Coastguard Worker def test_trace_out_operator_with_two_output(self): 165*da0073e9SAndroid Build Coastguard Worker example_input = torch.rand(2, 8) 166*da0073e9SAndroid Build Coastguard Worker out_1, out_2 = torch.cummax(example_input, 1) 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker def run_cummax(example_input, out_1, out_2): 169*da0073e9SAndroid Build Coastguard Worker output_1, output_2 = torch.cummax(example_input, 1, out=(out_1, out_2)) 170*da0073e9SAndroid Build Coastguard Worker return output_1, output_2 171*da0073e9SAndroid Build Coastguard Worker 172*da0073e9SAndroid Build Coastguard Worker trace_model = torch.jit.trace(run_cummax, (example_input, out_1, out_2)) 173*da0073e9SAndroid Build Coastguard Worker 174*da0073e9SAndroid Build Coastguard Worker def test_trace_namedtuple(self): 175*da0073e9SAndroid Build Coastguard Worker Point = namedtuple("point", ["x", "y"]) 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard Worker def f(p): 178*da0073e9SAndroid Build Coastguard Worker if type(p) is tuple: 179*da0073e9SAndroid Build Coastguard Worker p = Point(*p) 180*da0073e9SAndroid Build Coastguard Worker return p.x + p.y 181*da0073e9SAndroid Build Coastguard Worker 182*da0073e9SAndroid Build Coastguard Worker p = Point(torch.randn(1), torch.randn(1)) 183*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(f, (p,)) 184*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(p), traced(p)) 185*da0073e9SAndroid Build Coastguard Worker 186*da0073e9SAndroid Build Coastguard Worker def test_trace_topk(self): 187*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 188*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y): 189*da0073e9SAndroid Build Coastguard Worker return x.topk(y, dim=1)[1] 190*da0073e9SAndroid Build Coastguard Worker 191*da0073e9SAndroid Build Coastguard Worker mod = M() 192*da0073e9SAndroid Build Coastguard Worker inputs = (torch.randint(0, 10, (20, 20)), torch.tensor(17)) 193*da0073e9SAndroid Build Coastguard Worker traced_func = torch.jit.trace(mod, inputs) 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Worker test_inputs = (torch.randint(0, 9, (9, 9)), torch.tensor(8)) 196*da0073e9SAndroid Build Coastguard Worker eager_out = mod(*test_inputs) 197*da0073e9SAndroid Build Coastguard Worker traced_out = traced_func(*test_inputs) 198*da0073e9SAndroid Build Coastguard Worker self.assertNotWarn( 199*da0073e9SAndroid Build Coastguard Worker lambda: traced_func(*test_inputs), 200*da0073e9SAndroid Build Coastguard Worker "Shouldn't throw slicing related warn here", 201*da0073e9SAndroid Build Coastguard Worker ) 202*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager_out, traced_out) 203*da0073e9SAndroid Build Coastguard Worker 204*da0073e9SAndroid Build Coastguard Worker test_inputs = (torch.randint(0, 50, (50, 50)), torch.tensor(12)) 205*da0073e9SAndroid Build Coastguard Worker eager_out = mod(*test_inputs) 206*da0073e9SAndroid Build Coastguard Worker traced_out = traced_func(*test_inputs) 207*da0073e9SAndroid Build Coastguard Worker self.assertNotWarn( 208*da0073e9SAndroid Build Coastguard Worker lambda: traced_func(*test_inputs), 209*da0073e9SAndroid Build Coastguard Worker "Shouldn't throw slicing related warn here", 210*da0073e9SAndroid Build Coastguard Worker ) 211*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager_out, traced_out) 212*da0073e9SAndroid Build Coastguard Worker 213*da0073e9SAndroid Build Coastguard Worker def test_typeas_trace_check(self): 214*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([0.4], requires_grad=True) 215*da0073e9SAndroid Build Coastguard Worker b = torch.tensor([0.7], requires_grad=True) 216*da0073e9SAndroid Build Coastguard Worker 217*da0073e9SAndroid Build Coastguard Worker def f(x, y): 218*da0073e9SAndroid Build Coastguard Worker return x.type_as(y) 219*da0073e9SAndroid Build Coastguard Worker 220*da0073e9SAndroid Build Coastguard Worker trace = torch.jit.trace(f, (a, b)) 221*da0073e9SAndroid Build Coastguard Worker 222*da0073e9SAndroid Build Coastguard Worker def test_trace_index(self): 223*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([0.4], requires_grad=True) 224*da0073e9SAndroid Build Coastguard Worker y = torch.tensor([0], dtype=torch.int64) 225*da0073e9SAndroid Build Coastguard Worker 226*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 227*da0073e9SAndroid Build Coastguard Worker return x[y] 228*da0073e9SAndroid Build Coastguard Worker 229*da0073e9SAndroid Build Coastguard Worker fn_traced = torch.jit.trace( 230*da0073e9SAndroid Build Coastguard Worker fn, 231*da0073e9SAndroid Build Coastguard Worker ( 232*da0073e9SAndroid Build Coastguard Worker x, 233*da0073e9SAndroid Build Coastguard Worker y, 234*da0073e9SAndroid Build Coastguard Worker ), 235*da0073e9SAndroid Build Coastguard Worker ) 236*da0073e9SAndroid Build Coastguard Worker 237*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x, y), fn_traced(x, y)) 238*da0073e9SAndroid Build Coastguard Worker 239*da0073e9SAndroid Build Coastguard Worker # Backwards tracing was broken for indexing by a constant, 240*da0073e9SAndroid Build Coastguard Worker # because it's internally implemented using as_strided, 241*da0073e9SAndroid Build Coastguard Worker # and we attempted to trace its derivative (which is not 242*da0073e9SAndroid Build Coastguard Worker # currently supported.) It currently works because 243*da0073e9SAndroid Build Coastguard Worker # slice() is now not marked as traceable. 244*da0073e9SAndroid Build Coastguard Worker def test_trace_index_constant(self): 245*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([0.4], requires_grad=True) 246*da0073e9SAndroid Build Coastguard Worker 247*da0073e9SAndroid Build Coastguard Worker def fn(x): 248*da0073e9SAndroid Build Coastguard Worker return x[0] 249*da0073e9SAndroid Build Coastguard Worker 250*da0073e9SAndroid Build Coastguard Worker def run(f): 251*da0073e9SAndroid Build Coastguard Worker y = f(x) 252*da0073e9SAndroid Build Coastguard Worker grad = torch.autograd.grad(y, x)[0].clone() 253*da0073e9SAndroid Build Coastguard Worker return y, grad 254*da0073e9SAndroid Build Coastguard Worker 255*da0073e9SAndroid Build Coastguard Worker traced_fn = torch.jit.trace(fn, torch.ones(1)) 256*da0073e9SAndroid Build Coastguard Worker self.assertEqual(run(fn), run(traced_fn)) 257*da0073e9SAndroid Build Coastguard Worker 258*da0073e9SAndroid Build Coastguard Worker def test_index_put(self): 259*da0073e9SAndroid Build Coastguard Worker ten = torch.zeros(3, 3) 260*da0073e9SAndroid Build Coastguard Worker mask = torch.tensor( 261*da0073e9SAndroid Build Coastguard Worker [[True, True, True], [True, False, False], [True, True, False]] 262*da0073e9SAndroid Build Coastguard Worker ) 263*da0073e9SAndroid Build Coastguard Worker 264*da0073e9SAndroid Build Coastguard Worker def test_fn(ten, mask): 265*da0073e9SAndroid Build Coastguard Worker ten[mask] = torch.ones(6) 266*da0073e9SAndroid Build Coastguard Worker return ten 267*da0073e9SAndroid Build Coastguard Worker 268*da0073e9SAndroid Build Coastguard Worker traced_test_fn = torch.jit.trace(test_fn, (ten, mask)) 269*da0073e9SAndroid Build Coastguard Worker 270*da0073e9SAndroid Build Coastguard Worker ten = torch.rand(3, 3) 271*da0073e9SAndroid Build Coastguard Worker self.assertEqual(test_fn(ten, mask), traced_test_fn(ten, mask)) 272*da0073e9SAndroid Build Coastguard Worker 273*da0073e9SAndroid Build Coastguard Worker def test_canonicalize_tensor_iterator(self): 274*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 4) 275*da0073e9SAndroid Build Coastguard Worker 276*da0073e9SAndroid Build Coastguard Worker def f(x): 277*da0073e9SAndroid Build Coastguard Worker x = x + 2 278*da0073e9SAndroid Build Coastguard Worker x = x - 4 279*da0073e9SAndroid Build Coastguard Worker x = x * 6 280*da0073e9SAndroid Build Coastguard Worker x = x / 8 281*da0073e9SAndroid Build Coastguard Worker return x 282*da0073e9SAndroid Build Coastguard Worker 283*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(f, (x,)) 284*da0073e9SAndroid Build Coastguard Worker f(x) 285*da0073e9SAndroid Build Coastguard Worker graph = traced.graph_for(x) 286*da0073e9SAndroid Build Coastguard Worker # There should be 4 int constants for the right sides of operators, plus one 287*da0073e9SAndroid Build Coastguard Worker # for the alpha argument for add and sub 288*da0073e9SAndroid Build Coastguard Worker self.assertTrue(str(traced.graph_for(x)).count(": int = prim::Constant") == 5) 289*da0073e9SAndroid Build Coastguard Worker 290*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 291*da0073e9SAndroid Build Coastguard Worker def test_constant(self): 292*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2, requires_grad=True) 293*da0073e9SAndroid Build Coastguard Worker 294*da0073e9SAndroid Build Coastguard Worker def f(x): 295*da0073e9SAndroid Build Coastguard Worker return x.matmul(torch.diag(torch.tensor([2.0, 2.0]))) 296*da0073e9SAndroid Build Coastguard Worker 297*da0073e9SAndroid Build Coastguard Worker self.checkTrace(f, (x,), (torch.ones(2, 2, requires_grad=True),)) 298*da0073e9SAndroid Build Coastguard Worker 299*da0073e9SAndroid Build Coastguard Worker def test_wrapped_number(self): 300*da0073e9SAndroid Build Coastguard Worker # Scalar's get converted to 'wrapped' tensors of default tensor type. 301*da0073e9SAndroid Build Coastguard Worker # Wrapped tensors behave differently in certain promotion operations: 302*da0073e9SAndroid Build Coastguard Worker # float_tensor * double -> float but wrapped_float * double -> double. 303*da0073e9SAndroid Build Coastguard Worker # This can cause issues in check-trace if not handled correctly in 304*da0073e9SAndroid Build Coastguard Worker # `aten::isclose()`. 305*da0073e9SAndroid Build Coastguard Worker 306*da0073e9SAndroid Build Coastguard Worker def foobar(): 307*da0073e9SAndroid Build Coastguard Worker x = -10000.0 308*da0073e9SAndroid Build Coastguard Worker result = x * torch.ones(1, dtype=torch.float) 309*da0073e9SAndroid Build Coastguard Worker return result 310*da0073e9SAndroid Build Coastguard Worker 311*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.trace(foobar, (), check_trace=True) 312*da0073e9SAndroid Build Coastguard Worker 313*da0073e9SAndroid Build Coastguard Worker def test_inplace_transplant(self): 314*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([0.0], requires_grad=True) 315*da0073e9SAndroid Build Coastguard Worker 316*da0073e9SAndroid Build Coastguard Worker def fn(x): 317*da0073e9SAndroid Build Coastguard Worker y = x.clone() 318*da0073e9SAndroid Build Coastguard Worker y.add_(2) 319*da0073e9SAndroid Build Coastguard Worker y.add_(3) 320*da0073e9SAndroid Build Coastguard Worker return y 321*da0073e9SAndroid Build Coastguard Worker 322*da0073e9SAndroid Build Coastguard Worker g, _ = torch.jit._get_trace_graph(fn, (x,)) 323*da0073e9SAndroid Build Coastguard Worker self.run_pass("dce", g) 324*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("aten::clone", 1, exactly=True).check_count( 325*da0073e9SAndroid Build Coastguard Worker "aten::add_", 2, exactly=True 326*da0073e9SAndroid Build Coastguard Worker ).check_next("return").run(str(g)) 327*da0073e9SAndroid Build Coastguard Worker self.assertExportImport(g, (x,)) 328*da0073e9SAndroid Build Coastguard Worker 329*da0073e9SAndroid Build Coastguard Worker def test_inplace_flags(self): 330*da0073e9SAndroid Build Coastguard Worker class InplaceFn(Function): 331*da0073e9SAndroid Build Coastguard Worker @staticmethod 332*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 333*da0073e9SAndroid Build Coastguard Worker ctx.mark_dirty(x) 334*da0073e9SAndroid Build Coastguard Worker return x.add_(1) 335*da0073e9SAndroid Build Coastguard Worker 336*da0073e9SAndroid Build Coastguard Worker @staticmethod 337*da0073e9SAndroid Build Coastguard Worker def backward(ctx, go): 338*da0073e9SAndroid Build Coastguard Worker return go 339*da0073e9SAndroid Build Coastguard Worker 340*da0073e9SAndroid Build Coastguard Worker class RegularFn(Function): 341*da0073e9SAndroid Build Coastguard Worker @staticmethod 342*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 343*da0073e9SAndroid Build Coastguard Worker return x.add(1) 344*da0073e9SAndroid Build Coastguard Worker 345*da0073e9SAndroid Build Coastguard Worker @staticmethod 346*da0073e9SAndroid Build Coastguard Worker def backward(ctx, go): 347*da0073e9SAndroid Build Coastguard Worker return go 348*da0073e9SAndroid Build Coastguard Worker 349*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([0.0], requires_grad=True) 350*da0073e9SAndroid Build Coastguard Worker 351*da0073e9SAndroid Build Coastguard Worker def fn(x): 352*da0073e9SAndroid Build Coastguard Worker y = RegularFn.apply(x) 353*da0073e9SAndroid Build Coastguard Worker y = InplaceFn.apply(y) 354*da0073e9SAndroid Build Coastguard Worker y = InplaceFn.apply(y) 355*da0073e9SAndroid Build Coastguard Worker y = RegularFn.apply(y) 356*da0073e9SAndroid Build Coastguard Worker return y 357*da0073e9SAndroid Build Coastguard Worker 358*da0073e9SAndroid Build Coastguard Worker trace_graph, _ = torch.jit._get_trace_graph(fn, (x,), _force_outplace=True) 359*da0073e9SAndroid Build Coastguard Worker self.run_pass("dce", trace_graph) 360*da0073e9SAndroid Build Coastguard Worker ops = list(trace_graph.nodes()) 361*da0073e9SAndroid Build Coastguard Worker for op in ops: 362*da0073e9SAndroid Build Coastguard Worker self.assertTrue(op.hasAttribute("inplace")) 363*da0073e9SAndroid Build Coastguard Worker inplace_flags = [False, True, True, False] 364*da0073e9SAndroid Build Coastguard Worker for op, is_inplace in zip(ops, inplace_flags): 365*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op.i("inplace"), is_inplace) 366*da0073e9SAndroid Build Coastguard Worker 367*da0073e9SAndroid Build Coastguard Worker def test_inplace_check(self): 368*da0073e9SAndroid Build Coastguard Worker class MyInplaceFn(Function): 369*da0073e9SAndroid Build Coastguard Worker @staticmethod 370*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 371*da0073e9SAndroid Build Coastguard Worker x.add_(1) 372*da0073e9SAndroid Build Coastguard Worker self.mark_dirty(x) 373*da0073e9SAndroid Build Coastguard Worker return x 374*da0073e9SAndroid Build Coastguard Worker 375*da0073e9SAndroid Build Coastguard Worker @staticmethod 376*da0073e9SAndroid Build Coastguard Worker def backward(self, grad): 377*da0073e9SAndroid Build Coastguard Worker return grad 378*da0073e9SAndroid Build Coastguard Worker 379*da0073e9SAndroid Build Coastguard Worker def fn(x): 380*da0073e9SAndroid Build Coastguard Worker return MyInplaceFn.apply(x) 381*da0073e9SAndroid Build Coastguard Worker 382*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5) 383*da0073e9SAndroid Build Coastguard Worker ge = torch.jit.trace(fn, (x,), _force_outplace=True, check_trace=False) 384*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "inplace MyInplaceFn"): 385*da0073e9SAndroid Build Coastguard Worker ge(x) 386*da0073e9SAndroid Build Coastguard Worker 387*da0073e9SAndroid Build Coastguard Worker def test_force_outplace_check_fill(self): 388*da0073e9SAndroid Build Coastguard Worker def f(x): 389*da0073e9SAndroid Build Coastguard Worker return torch.empty(x.shape).fill_(7) 390*da0073e9SAndroid Build Coastguard Worker 391*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 15) 392*da0073e9SAndroid Build Coastguard Worker ft = torch.jit.trace(f, x, _force_outplace=True) 393*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(x), ft(x)) 394*da0073e9SAndroid Build Coastguard Worker 395*da0073e9SAndroid Build Coastguard Worker def test_force_outplace_check_zero(self): 396*da0073e9SAndroid Build Coastguard Worker def f(x): 397*da0073e9SAndroid Build Coastguard Worker return torch.empty(x.shape).zero_() 398*da0073e9SAndroid Build Coastguard Worker 399*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 15) 400*da0073e9SAndroid Build Coastguard Worker ft = torch.jit.trace(f, x, _force_outplace=True) 401*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(x), ft(x)) 402*da0073e9SAndroid Build Coastguard Worker 403*da0073e9SAndroid Build Coastguard Worker def do_trace_size(self, requires_grad): 404*da0073e9SAndroid Build Coastguard Worker def fn(x): 405*da0073e9SAndroid Build Coastguard Worker return x.view(x.shape[1] * 2, x.size(0), 2) 406*da0073e9SAndroid Build Coastguard Worker 407*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 2, 4, requires_grad=requires_grad) 408*da0073e9SAndroid Build Coastguard Worker y = torch.randn(4, 8, 4, requires_grad=requires_grad) 409*da0073e9SAndroid Build Coastguard Worker 410*da0073e9SAndroid Build Coastguard Worker # Check that it behaves as expected 411*da0073e9SAndroid Build Coastguard Worker traced_fn = torch.jit.trace(fn, x) 412*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced_fn(y), fn(y)) 413*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced_fn(x), fn(x)) 414*da0073e9SAndroid Build Coastguard Worker 415*da0073e9SAndroid Build Coastguard Worker def test_trace_size(self): 416*da0073e9SAndroid Build Coastguard Worker self.do_trace_size(False) 417*da0073e9SAndroid Build Coastguard Worker 418*da0073e9SAndroid Build Coastguard Worker # test the different graph_executor path that happens when 419*da0073e9SAndroid Build Coastguard Worker # gradients are required and sizes are involved 420*da0073e9SAndroid Build Coastguard Worker def test_trace_size_with_grad(self): 421*da0073e9SAndroid Build Coastguard Worker self.do_trace_size(True) 422*da0073e9SAndroid Build Coastguard Worker 423*da0073e9SAndroid Build Coastguard Worker def test_trace_numel(self): 424*da0073e9SAndroid Build Coastguard Worker def fn(x): 425*da0073e9SAndroid Build Coastguard Worker return x.numel() 426*da0073e9SAndroid Build Coastguard Worker 427*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, 4) 428*da0073e9SAndroid Build Coastguard Worker y = torch.randn(4, 5, 6) 429*da0073e9SAndroid Build Coastguard Worker 430*da0073e9SAndroid Build Coastguard Worker traced_fn = torch.jit.trace(fn, x) 431*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced_fn(y), fn(y)) 432*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced_fn(x), fn(x)) 433*da0073e9SAndroid Build Coastguard Worker 434*da0073e9SAndroid Build Coastguard Worker def do_trace_arange(self, requires_grad): 435*da0073e9SAndroid Build Coastguard Worker def arange(x): 436*da0073e9SAndroid Build Coastguard Worker return torch.arange(x.shape[0]) 437*da0073e9SAndroid Build Coastguard Worker 438*da0073e9SAndroid Build Coastguard Worker def arange_scalar(x): 439*da0073e9SAndroid Build Coastguard Worker return torch.arange(12) 440*da0073e9SAndroid Build Coastguard Worker 441*da0073e9SAndroid Build Coastguard Worker def arange_start_end(x): 442*da0073e9SAndroid Build Coastguard Worker return torch.arange(start=x.shape[0], end=x.shape[0] + 5) 443*da0073e9SAndroid Build Coastguard Worker 444*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 3, 2, requires_grad=requires_grad) 445*da0073e9SAndroid Build Coastguard Worker y = torch.randn(8, 2, 4, requires_grad=requires_grad) 446*da0073e9SAndroid Build Coastguard Worker 447*da0073e9SAndroid Build Coastguard Worker # Check that it behaves as expected 448*da0073e9SAndroid Build Coastguard Worker traced_arange = torch.jit.trace(arange, x) 449*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced_arange(y), arange(y)) 450*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced_arange(x), arange(x)) 451*da0073e9SAndroid Build Coastguard Worker 452*da0073e9SAndroid Build Coastguard Worker traced_arange_scalar = torch.jit.trace(arange_scalar, x) 453*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced_arange_scalar(y), arange_scalar(y)) 454*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced_arange_scalar(x), arange_scalar(x)) 455*da0073e9SAndroid Build Coastguard Worker 456*da0073e9SAndroid Build Coastguard Worker traced_arange_start_end = torch.jit.trace(arange_start_end, x) 457*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced_arange_start_end(y), arange_start_end(y)) 458*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced_arange_start_end(x), arange_start_end(x)) 459*da0073e9SAndroid Build Coastguard Worker 460*da0073e9SAndroid Build Coastguard Worker def test_trace_arange(self): 461*da0073e9SAndroid Build Coastguard Worker self.do_trace_arange(False) 462*da0073e9SAndroid Build Coastguard Worker 463*da0073e9SAndroid Build Coastguard Worker # test the different graph_executor path that happens when 464*da0073e9SAndroid Build Coastguard Worker # gradients are required and sizes are involved 465*da0073e9SAndroid Build Coastguard Worker def test_trace_arange_with_grad(self): 466*da0073e9SAndroid Build Coastguard Worker self.do_trace_arange(True) 467*da0073e9SAndroid Build Coastguard Worker 468*da0073e9SAndroid Build Coastguard Worker # Test that a trace of torch.full(x.shape) doesn't store the shape as a constant 469*da0073e9SAndroid Build Coastguard Worker def test_trace_full_dynamic_shape(self): 470*da0073e9SAndroid Build Coastguard Worker def full_with_shape_like(x): 471*da0073e9SAndroid Build Coastguard Worker return torch.full(x.shape, 2.0) 472*da0073e9SAndroid Build Coastguard Worker 473*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 4) 474*da0073e9SAndroid Build Coastguard Worker ge = torch.jit.trace(full_with_shape_like, example_inputs=x) 475*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2, 7) 476*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ge(y).shape, y.shape) 477*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ge(x).shape, x.shape) 478*da0073e9SAndroid Build Coastguard Worker 479*da0073e9SAndroid Build Coastguard Worker # Test that the trace of setitem doesn't store shapes as constants 480*da0073e9SAndroid Build Coastguard Worker # Fix https://github.com/pytorch/pytorch/issues/43548 481*da0073e9SAndroid Build Coastguard Worker def test_trace_slice_setitem_dynamic_shape(self): 482*da0073e9SAndroid Build Coastguard Worker def slice_setitem(x, y): 483*da0073e9SAndroid Build Coastguard Worker x[:, 2] = y + 1 484*da0073e9SAndroid Build Coastguard Worker return x 485*da0073e9SAndroid Build Coastguard Worker 486*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 4) 487*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(slice_setitem, (x, x[:, 0])) 488*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 5) 489*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced(x.clone(), x[:, 0]), slice_setitem(x.clone(), x[:, 0])) 490*da0073e9SAndroid Build Coastguard Worker 491*da0073e9SAndroid Build Coastguard Worker # Suppression: we are intentionally slicing a tensor, we don't care that it 492*da0073e9SAndroid Build Coastguard Worker # will be constantified 493*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 494*da0073e9SAndroid Build Coastguard Worker def do_trace_slice(self, requires_grad): 495*da0073e9SAndroid Build Coastguard Worker def slice(x): 496*da0073e9SAndroid Build Coastguard Worker results = [] 497*da0073e9SAndroid Build Coastguard Worker for i in range(4): 498*da0073e9SAndroid Build Coastguard Worker results.append(x[: x.size(0) - i, i : x.size(2), i:3]) 499*da0073e9SAndroid Build Coastguard Worker return tuple(results) 500*da0073e9SAndroid Build Coastguard Worker 501*da0073e9SAndroid Build Coastguard Worker def slice_select(x): 502*da0073e9SAndroid Build Coastguard Worker results = [] 503*da0073e9SAndroid Build Coastguard Worker for i in range(4): 504*da0073e9SAndroid Build Coastguard Worker results.append(x[:, i:, x.size(2) - 5]) 505*da0073e9SAndroid Build Coastguard Worker return tuple(results) 506*da0073e9SAndroid Build Coastguard Worker 507*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 6, 7, requires_grad=requires_grad) 508*da0073e9SAndroid Build Coastguard Worker y = torch.randn(7, 8, 9, requires_grad=requires_grad) 509*da0073e9SAndroid Build Coastguard Worker 510*da0073e9SAndroid Build Coastguard Worker # Check that it behaves as expected 511*da0073e9SAndroid Build Coastguard Worker traced_slice = torch.jit.trace(slice, x) 512*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced_slice(y), slice(y)) 513*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced_slice(x), slice(x)) 514*da0073e9SAndroid Build Coastguard Worker 515*da0073e9SAndroid Build Coastguard Worker traced_slice_select = torch.jit.trace(slice_select, x) 516*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced_slice_select(y), slice_select(y)) 517*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced_slice_select(x), slice_select(x)) 518*da0073e9SAndroid Build Coastguard Worker 519*da0073e9SAndroid Build Coastguard Worker def test_trace_slice(self): 520*da0073e9SAndroid Build Coastguard Worker self.do_trace_slice(False) 521*da0073e9SAndroid Build Coastguard Worker 522*da0073e9SAndroid Build Coastguard Worker # test the different graph_executor path that happens when 523*da0073e9SAndroid Build Coastguard Worker # gradients are required and sizes are involved 524*da0073e9SAndroid Build Coastguard Worker def test_trace_slice_with_grad(self): 525*da0073e9SAndroid Build Coastguard Worker self.do_trace_slice(True) 526*da0073e9SAndroid Build Coastguard Worker 527*da0073e9SAndroid Build Coastguard Worker def test_trace_casts(self): 528*da0073e9SAndroid Build Coastguard Worker casts = [ 529*da0073e9SAndroid Build Coastguard Worker lambda x: x.byte(), 530*da0073e9SAndroid Build Coastguard Worker lambda x: x.float(), 531*da0073e9SAndroid Build Coastguard Worker lambda x: x.cpu(), 532*da0073e9SAndroid Build Coastguard Worker lambda x: x.to(device="cpu"), 533*da0073e9SAndroid Build Coastguard Worker lambda x: x.to(dtype=torch.int64), 534*da0073e9SAndroid Build Coastguard Worker lambda x: x.to(device="cpu", dtype=torch.float), 535*da0073e9SAndroid Build Coastguard Worker lambda x: x.to(x), 536*da0073e9SAndroid Build Coastguard Worker ] 537*da0073e9SAndroid Build Coastguard Worker 538*da0073e9SAndroid Build Coastguard Worker def assertContainsCast(trace): 539*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 540*da0073e9SAndroid Build Coastguard Worker sum(n.kind() == "aten::to" for n in trace.graph.nodes()), 1 541*da0073e9SAndroid Build Coastguard Worker ) 542*da0073e9SAndroid Build Coastguard Worker 543*da0073e9SAndroid Build Coastguard Worker for cast in casts: 544*da0073e9SAndroid Build Coastguard Worker trace = torch.jit.trace(cast, torch.randn(2, 2)) 545*da0073e9SAndroid Build Coastguard Worker assertContainsCast(trace) 546*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2) 547*da0073e9SAndroid Build Coastguard Worker self.assertEqual(trace(x), cast(x)) 548*da0073e9SAndroid Build Coastguard Worker 549*da0073e9SAndroid Build Coastguard Worker def to_tensor(x, y): 550*da0073e9SAndroid Build Coastguard Worker return x.to(y) 551*da0073e9SAndroid Build Coastguard Worker 552*da0073e9SAndroid Build Coastguard Worker to_tensor_trace = torch.jit.trace( 553*da0073e9SAndroid Build Coastguard Worker to_tensor, (torch.randn(2, 2), torch.randn(1, 8)) 554*da0073e9SAndroid Build Coastguard Worker ) 555*da0073e9SAndroid Build Coastguard Worker assertContainsCast(to_tensor_trace) 556*da0073e9SAndroid Build Coastguard Worker x, y = torch.randn(2, 2), torch.randn(1, 10) 557*da0073e9SAndroid Build Coastguard Worker self.assertEqual(to_tensor_trace(x, y), to_tensor(x, y)) 558*da0073e9SAndroid Build Coastguard Worker 559*da0073e9SAndroid Build Coastguard Worker @skipIfCompiledWithoutNumpy 560*da0073e9SAndroid Build Coastguard Worker @skipIfCrossRef 561*da0073e9SAndroid Build Coastguard Worker def test_trace_warn(self): 562*da0073e9SAndroid Build Coastguard Worker def fn(x): 563*da0073e9SAndroid Build Coastguard Worker int(x) # Warning 1. 564*da0073e9SAndroid Build Coastguard Worker y = x * 1 565*da0073e9SAndroid Build Coastguard Worker if y: # Warning 2. 566*da0073e9SAndroid Build Coastguard Worker pass 567*da0073e9SAndroid Build Coastguard Worker q = [x, x * 4] 568*da0073e9SAndroid Build Coastguard Worker z = q[y] 569*da0073e9SAndroid Build Coastguard Worker float(z) # Warning 3. 570*da0073e9SAndroid Build Coastguard Worker z.tolist() # Warning 4. 571*da0073e9SAndroid Build Coastguard Worker z.numpy() # Warning 5. 572*da0073e9SAndroid Build Coastguard Worker for _ in torch.ones(4, 4): # Warning 6. 573*da0073e9SAndroid Build Coastguard Worker pass 574*da0073e9SAndroid Build Coastguard Worker return z + 4 575*da0073e9SAndroid Build Coastguard Worker 576*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as warns: 577*da0073e9SAndroid Build Coastguard Worker traced_fn = torch.jit.trace(fn, torch.tensor([1])) 578*da0073e9SAndroid Build Coastguard Worker for warn in warns: 579*da0073e9SAndroid Build Coastguard Worker self.assertIs(warn.category, torch.jit.TracerWarning) 580*da0073e9SAndroid Build Coastguard Worker warns = [str(w.message) for w in warns] 581*da0073e9SAndroid Build Coastguard Worker self.assertIn("a Python integer", warns[0]) 582*da0073e9SAndroid Build Coastguard Worker self.assertIn("a Python boolean", warns[1]) 583*da0073e9SAndroid Build Coastguard Worker self.assertIn("a Python float", warns[2]) 584*da0073e9SAndroid Build Coastguard Worker self.assertIn("a Python list", warns[3]) 585*da0073e9SAndroid Build Coastguard Worker self.assertIn("a NumPy array", warns[4]) 586*da0073e9SAndroid Build Coastguard Worker self.assertIn("Iterating over", warns[5]) 587*da0073e9SAndroid Build Coastguard Worker 588*da0073e9SAndroid Build Coastguard Worker def test_trace_tuple(self): 589*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 590*da0073e9SAndroid Build Coastguard Worker return x, (x * y[1], x * y[0]) 591*da0073e9SAndroid Build Coastguard Worker 592*da0073e9SAndroid Build Coastguard Worker x, y = torch.randn(2, 2), (torch.ones(2, 2), torch.randn(2, 2)) 593*da0073e9SAndroid Build Coastguard Worker traced_fn = torch.jit.trace(fn, (x, y)) 594*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced_fn(x, y), fn(x, y)) 595*da0073e9SAndroid Build Coastguard Worker # should be a tuple nested within another tuple 596*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("prim::TupleConstruct", 2, exactly=True).check_next( 597*da0073e9SAndroid Build Coastguard Worker "return" 598*da0073e9SAndroid Build Coastguard Worker ).run(str(traced_fn.graph)) 599*da0073e9SAndroid Build Coastguard Worker self.assertExportImport(traced_fn.graph, (x, y)) 600*da0073e9SAndroid Build Coastguard Worker 601*da0073e9SAndroid Build Coastguard Worker def test_trace_random(self): 602*da0073e9SAndroid Build Coastguard Worker def f(mean, std): 603*da0073e9SAndroid Build Coastguard Worker return torch.normal(mean, std) 604*da0073e9SAndroid Build Coastguard Worker 605*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace( 606*da0073e9SAndroid Build Coastguard Worker f, (torch.zeros(2, 3), torch.ones(2, 3)), check_trace=False 607*da0073e9SAndroid Build Coastguard Worker ) 608*da0073e9SAndroid Build Coastguard Worker mean, std = torch.zeros(5, 5), torch.ones(5, 5) 609*da0073e9SAndroid Build Coastguard Worker with torch.random.fork_rng(devices=[]): 610*da0073e9SAndroid Build Coastguard Worker output = f(mean, std) 611*da0073e9SAndroid Build Coastguard Worker traced_output = traced(mean, std) 612*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, traced_output) 613*da0073e9SAndroid Build Coastguard Worker 614*da0073e9SAndroid Build Coastguard Worker def test_trace_tensor_factory(self): 615*da0073e9SAndroid Build Coastguard Worker def run(**kwargs): 616*da0073e9SAndroid Build Coastguard Worker inputs_require_grads = kwargs.pop("inputs_require_grads", True) 617*da0073e9SAndroid Build Coastguard Worker 618*da0073e9SAndroid Build Coastguard Worker def fn(x): 619*da0073e9SAndroid Build Coastguard Worker return x + torch.ones(2, 3, **kwargs) 620*da0073e9SAndroid Build Coastguard Worker 621*da0073e9SAndroid Build Coastguard Worker input_kwargs = kwargs.copy() 622*da0073e9SAndroid Build Coastguard Worker if "out" in input_kwargs: 623*da0073e9SAndroid Build Coastguard Worker del input_kwargs["out"] 624*da0073e9SAndroid Build Coastguard Worker input = torch.ones(2, 3, **input_kwargs) 625*da0073e9SAndroid Build Coastguard Worker self.checkTrace(fn, (input,), inputs_require_grads=inputs_require_grads) 626*da0073e9SAndroid Build Coastguard Worker # check we recorded 'ones' and did not just record a constant 627*da0073e9SAndroid Build Coastguard Worker tfn = torch.jit.trace(fn, input) 628*da0073e9SAndroid Build Coastguard Worker self.assertTrue("ones" in str(tfn.graph)) 629*da0073e9SAndroid Build Coastguard Worker 630*da0073e9SAndroid Build Coastguard Worker run() 631*da0073e9SAndroid Build Coastguard Worker run(dtype=torch.int, inputs_require_grads=False) 632*da0073e9SAndroid Build Coastguard Worker run(out=torch.tensor([])) 633*da0073e9SAndroid Build Coastguard Worker if RUN_CUDA: 634*da0073e9SAndroid Build Coastguard Worker run(device="cuda:0") 635*da0073e9SAndroid Build Coastguard Worker if RUN_CUDA_MULTI_GPU: 636*da0073e9SAndroid Build Coastguard Worker run(device="cuda:1") 637*da0073e9SAndroid Build Coastguard Worker 638*da0073e9SAndroid Build Coastguard Worker def test_trace_indexed_assignment(self): 639*da0073e9SAndroid Build Coastguard Worker def stuff(x, y): 640*da0073e9SAndroid Build Coastguard Worker x = x.clone() 641*da0073e9SAndroid Build Coastguard Worker x[0] = y 642*da0073e9SAndroid Build Coastguard Worker return x 643*da0073e9SAndroid Build Coastguard Worker 644*da0073e9SAndroid Build Coastguard Worker example = torch.rand(3, 4) 645*da0073e9SAndroid Build Coastguard Worker self.checkTrace(stuff, (example, example[0] + 1)) 646*da0073e9SAndroid Build Coastguard Worker 647*da0073e9SAndroid Build Coastguard Worker # TODO: implement 648*da0073e9SAndroid Build Coastguard Worker @unittest.expectedFailure 649*da0073e9SAndroid Build Coastguard Worker def test_output_unflatten(self): 650*da0073e9SAndroid Build Coastguard Worker """Check that outputs of traced functions retain the original structure and nesting""" 651*da0073e9SAndroid Build Coastguard Worker 652*da0073e9SAndroid Build Coastguard Worker def fn(x): 653*da0073e9SAndroid Build Coastguard Worker return ( 654*da0073e9SAndroid Build Coastguard Worker x * 2, 655*da0073e9SAndroid Build Coastguard Worker ( 656*da0073e9SAndroid Build Coastguard Worker x**2, 657*da0073e9SAndroid Build Coastguard Worker x + 4, 658*da0073e9SAndroid Build Coastguard Worker (x + 2,), 659*da0073e9SAndroid Build Coastguard Worker ), 660*da0073e9SAndroid Build Coastguard Worker x * 4, 661*da0073e9SAndroid Build Coastguard Worker ) 662*da0073e9SAndroid Build Coastguard Worker 663*da0073e9SAndroid Build Coastguard Worker self.checkTrace(fn, (torch.randn(2, 2),)) 664*da0073e9SAndroid Build Coastguard Worker 665*da0073e9SAndroid Build Coastguard Worker def test_input_flatten(self): 666*da0073e9SAndroid Build Coastguard Worker """Check that inputs to traced functions are flattened""" 667*da0073e9SAndroid Build Coastguard Worker 668*da0073e9SAndroid Build Coastguard Worker def fn(x, t): 669*da0073e9SAndroid Build Coastguard Worker y, z = t 670*da0073e9SAndroid Build Coastguard Worker return x * y * z 671*da0073e9SAndroid Build Coastguard Worker 672*da0073e9SAndroid Build Coastguard Worker inputs = (torch.randn(1), (torch.randn(1), torch.randn(1))) 673*da0073e9SAndroid Build Coastguard Worker self.checkTrace(fn, inputs) 674*da0073e9SAndroid Build Coastguard Worker 675*da0073e9SAndroid Build Coastguard Worker def test_input_dict_empty(self): 676*da0073e9SAndroid Build Coastguard Worker def test(d): 677*da0073e9SAndroid Build Coastguard Worker pass 678*da0073e9SAndroid Build Coastguard Worker 679*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 680*da0073e9SAndroid Build Coastguard Worker self.checkTrace(test, {}) 681*da0073e9SAndroid Build Coastguard Worker 682*da0073e9SAndroid Build Coastguard Worker def test_input_dict_remembers_keys(self): 683*da0073e9SAndroid Build Coastguard Worker """Check that the trace remembers which keys were in a dict input""" 684*da0073e9SAndroid Build Coastguard Worker 685*da0073e9SAndroid Build Coastguard Worker class TestModule(torch.nn.Module): 686*da0073e9SAndroid Build Coastguard Worker def forward(self, dict_input): 687*da0073e9SAndroid Build Coastguard Worker return dict_input["x"] 688*da0073e9SAndroid Build Coastguard Worker 689*da0073e9SAndroid Build Coastguard Worker input_1 = {"x": torch.tensor(1)} 690*da0073e9SAndroid Build Coastguard Worker m = TestModule() 691*da0073e9SAndroid Build Coastguard Worker m_traced = torch.jit.trace(m, (input_1,)) 692*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m_traced(input_1), torch.tensor(1)) 693*da0073e9SAndroid Build Coastguard Worker 694*da0073e9SAndroid Build Coastguard Worker # should work to change the values and not the keys 695*da0073e9SAndroid Build Coastguard Worker input_same_key_different_value = {"x": torch.tensor(2)} 696*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m_traced(input_same_key_different_value), torch.tensor(2)) 697*da0073e9SAndroid Build Coastguard Worker 698*da0073e9SAndroid Build Coastguard Worker # error to use something that doesn't have `x` 699*da0073e9SAndroid Build Coastguard Worker input_different_key = {"y": torch.tensor(3)} 700*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 701*da0073e9SAndroid Build Coastguard Worker m_traced(input_different_key) 702*da0073e9SAndroid Build Coastguard Worker 703*da0073e9SAndroid Build Coastguard Worker # it's okay to have additional elements in the dictionary, so long as 'x' is there 704*da0073e9SAndroid Build Coastguard Worker input_additional_key = {"x": torch.tensor(4), "y": torch.tensor(3)} 705*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m_traced(input_additional_key), torch.tensor(4)) 706*da0073e9SAndroid Build Coastguard Worker 707*da0073e9SAndroid Build Coastguard Worker def test_input_dict_insertion_order(self): 708*da0073e9SAndroid Build Coastguard Worker """Check that dictionary access doesn't care about insertion order""" 709*da0073e9SAndroid Build Coastguard Worker 710*da0073e9SAndroid Build Coastguard Worker class TestModule(torch.nn.Module): 711*da0073e9SAndroid Build Coastguard Worker def forward(self, dict_input): 712*da0073e9SAndroid Build Coastguard Worker return dict_input["x"], dict_input["y"] 713*da0073e9SAndroid Build Coastguard Worker 714*da0073e9SAndroid Build Coastguard Worker input_x_then_y = {} 715*da0073e9SAndroid Build Coastguard Worker input_x_then_y["x"] = torch.tensor(1) 716*da0073e9SAndroid Build Coastguard Worker input_x_then_y["y"] = torch.tensor(2) 717*da0073e9SAndroid Build Coastguard Worker 718*da0073e9SAndroid Build Coastguard Worker m = TestModule() 719*da0073e9SAndroid Build Coastguard Worker m_traced = torch.jit.trace(m, (input_x_then_y,)) 720*da0073e9SAndroid Build Coastguard Worker 721*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m_traced(input_x_then_y), (torch.tensor(1), torch.tensor(2))) 722*da0073e9SAndroid Build Coastguard Worker 723*da0073e9SAndroid Build Coastguard Worker input_y_then_x = {} 724*da0073e9SAndroid Build Coastguard Worker input_y_then_x["y"] = torch.tensor(4) 725*da0073e9SAndroid Build Coastguard Worker input_y_then_x["x"] = torch.tensor(3) 726*da0073e9SAndroid Build Coastguard Worker 727*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m_traced(input_y_then_x), (torch.tensor(3), torch.tensor(4))) 728*da0073e9SAndroid Build Coastguard Worker 729*da0073e9SAndroid Build Coastguard Worker def test_input_dict_recursive(self): 730*da0073e9SAndroid Build Coastguard Worker class TestModule(torch.nn.Module): 731*da0073e9SAndroid Build Coastguard Worker def forward(self, dict_input): 732*da0073e9SAndroid Build Coastguard Worker return dict_input["x"][1] 733*da0073e9SAndroid Build Coastguard Worker 734*da0073e9SAndroid Build Coastguard Worker input_1 = {"x": {1: torch.tensor(1)}} 735*da0073e9SAndroid Build Coastguard Worker m = TestModule() 736*da0073e9SAndroid Build Coastguard Worker m_traced = torch.jit.trace(m, (input_1,)) 737*da0073e9SAndroid Build Coastguard Worker 738*da0073e9SAndroid Build Coastguard Worker input_2 = {"x": {1: torch.tensor(2)}} 739*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m_traced(input_2), torch.tensor(2)) 740*da0073e9SAndroid Build Coastguard Worker 741*da0073e9SAndroid Build Coastguard Worker def test_input_dict_checkTrace_mut(self): 742*da0073e9SAndroid Build Coastguard Worker def test(d): 743*da0073e9SAndroid Build Coastguard Worker d["x"].tanh_() 744*da0073e9SAndroid Build Coastguard Worker return d["x"] 745*da0073e9SAndroid Build Coastguard Worker 746*da0073e9SAndroid Build Coastguard Worker inputs = {"x": torch.rand(3, 4), "y": torch.rand(3, 4)} 747*da0073e9SAndroid Build Coastguard Worker self.checkTrace(test, (inputs,), inputs_require_grads=False) 748*da0073e9SAndroid Build Coastguard Worker 749*da0073e9SAndroid Build Coastguard Worker def test_input_dict_unify(self): 750*da0073e9SAndroid Build Coastguard Worker def test(d): 751*da0073e9SAndroid Build Coastguard Worker return d["int"], d["float"] 752*da0073e9SAndroid Build Coastguard Worker 753*da0073e9SAndroid Build Coastguard Worker inputs = { 754*da0073e9SAndroid Build Coastguard Worker "int": torch.ones((2, 2), dtype=torch.int32), 755*da0073e9SAndroid Build Coastguard Worker "float": torch.ones((2, 2), dtype=torch.float32), 756*da0073e9SAndroid Build Coastguard Worker } 757*da0073e9SAndroid Build Coastguard Worker self.checkTrace(test, (inputs,), inputs_require_grads=False) 758*da0073e9SAndroid Build Coastguard Worker 759*da0073e9SAndroid Build Coastguard Worker def test_input_tuple_of_dicts(self): 760*da0073e9SAndroid Build Coastguard Worker def test(t): 761*da0073e9SAndroid Build Coastguard Worker d = t[0] 762*da0073e9SAndroid Build Coastguard Worker return d["x"]["y"] 763*da0073e9SAndroid Build Coastguard Worker 764*da0073e9SAndroid Build Coastguard Worker inputs = {"x": {"y": torch.rand(2, 3)}} 765*da0073e9SAndroid Build Coastguard Worker self.checkTrace(test, ((inputs, inputs),), allow_unused=True) 766*da0073e9SAndroid Build Coastguard Worker 767*da0073e9SAndroid Build Coastguard Worker def test_input_dict_of_dicts(self): 768*da0073e9SAndroid Build Coastguard Worker def test(d): 769*da0073e9SAndroid Build Coastguard Worker return d["x"]["y"] 770*da0073e9SAndroid Build Coastguard Worker 771*da0073e9SAndroid Build Coastguard Worker nested_input = {"y": torch.rand(2, 3)} 772*da0073e9SAndroid Build Coastguard Worker unified_nested = {"y": torch.rand(3, 2)} 773*da0073e9SAndroid Build Coastguard Worker inputs = {"x": nested_input, "force_unify": unified_nested} 774*da0073e9SAndroid Build Coastguard Worker self.checkTrace(test, (inputs,), allow_unused=True) 775*da0073e9SAndroid Build Coastguard Worker 776*da0073e9SAndroid Build Coastguard Worker def test_input_dict_of_lists(self): 777*da0073e9SAndroid Build Coastguard Worker def test(d): 778*da0073e9SAndroid Build Coastguard Worker return d["x"][0] 779*da0073e9SAndroid Build Coastguard Worker 780*da0073e9SAndroid Build Coastguard Worker inputs = {"x": [torch.rand(3, 2)]} 781*da0073e9SAndroid Build Coastguard Worker self.checkTrace(test, (inputs,)) 782*da0073e9SAndroid Build Coastguard Worker 783*da0073e9SAndroid Build Coastguard Worker def test_input_list_toplevel_flatten(self): 784*da0073e9SAndroid Build Coastguard Worker def test(t1, t2): 785*da0073e9SAndroid Build Coastguard Worker return torch.add(t1, t2) 786*da0073e9SAndroid Build Coastguard Worker 787*da0073e9SAndroid Build Coastguard Worker inputs = [torch.ones(2, 2), torch.rand(2, 2)] 788*da0073e9SAndroid Build Coastguard Worker self.checkTrace(test, inputs) 789*da0073e9SAndroid Build Coastguard Worker 790*da0073e9SAndroid Build Coastguard Worker def test_input_list_toplevel_flatten_direct(self): 791*da0073e9SAndroid Build Coastguard Worker class Test(torch.nn.Module): 792*da0073e9SAndroid Build Coastguard Worker def forward(self, t1, t2): 793*da0073e9SAndroid Build Coastguard Worker return torch.add(t1, t2) 794*da0073e9SAndroid Build Coastguard Worker 795*da0073e9SAndroid Build Coastguard Worker inputs = [torch.ones(2, 2), torch.rand(2, 2)] 796*da0073e9SAndroid Build Coastguard Worker torch.jit.trace(Test(), inputs) 797*da0073e9SAndroid Build Coastguard Worker 798*da0073e9SAndroid Build Coastguard Worker def test_input_list_of_tuples(self): 799*da0073e9SAndroid Build Coastguard Worker def test(l): 800*da0073e9SAndroid Build Coastguard Worker return l[0][0] 801*da0073e9SAndroid Build Coastguard Worker 802*da0073e9SAndroid Build Coastguard Worker inputs = [(torch.ones(2, 2),)] 803*da0073e9SAndroid Build Coastguard Worker self.checkTrace(test, (inputs,)) 804*da0073e9SAndroid Build Coastguard Worker 805*da0073e9SAndroid Build Coastguard Worker def test_input_dict_empty_list(self): 806*da0073e9SAndroid Build Coastguard Worker def test(d): 807*da0073e9SAndroid Build Coastguard Worker pass 808*da0073e9SAndroid Build Coastguard Worker 809*da0073e9SAndroid Build Coastguard Worker inputs = {1: []} 810*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "List trace"): 811*da0073e9SAndroid Build Coastguard Worker self.checkTrace(test, (inputs,)) 812*da0073e9SAndroid Build Coastguard Worker 813*da0073e9SAndroid Build Coastguard Worker def test_input_list_mixed_type(self): 814*da0073e9SAndroid Build Coastguard Worker def test(d): 815*da0073e9SAndroid Build Coastguard Worker pass 816*da0073e9SAndroid Build Coastguard Worker 817*da0073e9SAndroid Build Coastguard Worker inputs = [torch.rand(2, 3), (torch.ones(2), torch.ones(2))] 818*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "consistent"): 819*da0073e9SAndroid Build Coastguard Worker self.checkTrace(test, (inputs,)) 820*da0073e9SAndroid Build Coastguard Worker 821*da0073e9SAndroid Build Coastguard Worker def test_conv(self): 822*da0073e9SAndroid Build Coastguard Worker x = torch.ones(20, 16, 50, 40) 823*da0073e9SAndroid Build Coastguard Worker g, outputs, inputs = torch.jit._get_trace_graph( 824*da0073e9SAndroid Build Coastguard Worker nn.Conv2d(16, 13, 3, bias=False), x, return_inputs=True 825*da0073e9SAndroid Build Coastguard Worker ) 826*da0073e9SAndroid Build Coastguard Worker m = self.createFunctionFromGraph(g) 827*da0073e9SAndroid Build Coastguard Worker self.assertEqual(outputs, m(*inputs)) 828*da0073e9SAndroid Build Coastguard Worker 829*da0073e9SAndroid Build Coastguard Worker def test_max_pool(self): 830*da0073e9SAndroid Build Coastguard Worker x = torch.rand(20, 16, 10, 10) 831*da0073e9SAndroid Build Coastguard Worker 832*da0073e9SAndroid Build Coastguard Worker def max_pool2d(x): 833*da0073e9SAndroid Build Coastguard Worker return F.max_pool2d(x, 2) + 2 834*da0073e9SAndroid Build Coastguard Worker 835*da0073e9SAndroid Build Coastguard Worker trace = torch.jit.trace(max_pool2d, (x)) 836*da0073e9SAndroid Build Coastguard Worker graph = trace.graph_for(x) 837*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::max_pool2d(").run(graph) 838*da0073e9SAndroid Build Coastguard Worker self.assertEqual(max_pool2d(x), trace(x)) 839*da0073e9SAndroid Build Coastguard Worker 840*da0073e9SAndroid Build Coastguard Worker def test_nested_inplace(self): 841*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2) 842*da0073e9SAndroid Build Coastguard Worker g, outputs, inputs = torch.jit._get_trace_graph( 843*da0073e9SAndroid Build Coastguard Worker lambda x: F.threshold(x, 0, 0, inplace=True), (x,), return_inputs=True 844*da0073e9SAndroid Build Coastguard Worker ) 845*da0073e9SAndroid Build Coastguard Worker m = self.createFunctionFromGraph(g) 846*da0073e9SAndroid Build Coastguard Worker self.assertEqual(outputs, m(*inputs)) 847*da0073e9SAndroid Build Coastguard Worker FileCheck().check("threshold_").run(str(g)) 848*da0073e9SAndroid Build Coastguard Worker self.assertExportImport(g, (x,)) 849*da0073e9SAndroid Build Coastguard Worker 850*da0073e9SAndroid Build Coastguard Worker def test_repeated_input(self): 851*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 852*da0073e9SAndroid Build Coastguard Worker return a + b 853*da0073e9SAndroid Build Coastguard Worker 854*da0073e9SAndroid Build Coastguard Worker ge = self.checkTrace(fn, [torch.randn(2, 2)] * 2) 855*da0073e9SAndroid Build Coastguard Worker inputs = set(ge.graph.inputs()) 856*da0073e9SAndroid Build Coastguard Worker # three instead of 2 because the export/import in checkTrace adds a 857*da0073e9SAndroid Build Coastguard Worker # `self` module argument 858*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(inputs) == 3) 859*da0073e9SAndroid Build Coastguard Worker 860*da0073e9SAndroid Build Coastguard Worker def test_repeated_output(self): 861*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 862*da0073e9SAndroid Build Coastguard Worker z = a + b 863*da0073e9SAndroid Build Coastguard Worker return z, z 864*da0073e9SAndroid Build Coastguard Worker 865*da0073e9SAndroid Build Coastguard Worker ge = self.checkTrace(fn, [torch.randn(2, 2) for _ in range(2)]) 866*da0073e9SAndroid Build Coastguard Worker tuple_output = list(ge.graph.outputs())[0] 867*da0073e9SAndroid Build Coastguard Worker tuple_inputs = list(tuple_output.node().inputs()) 868*da0073e9SAndroid Build Coastguard Worker self.assertTrue(tuple_inputs[0] == tuple_inputs[1]) 869*da0073e9SAndroid Build Coastguard Worker 870*da0073e9SAndroid Build Coastguard Worker def test_inplace_copy(self): 871*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 4, requires_grad=True) 872*da0073e9SAndroid Build Coastguard Worker 873*da0073e9SAndroid Build Coastguard Worker def f(x): 874*da0073e9SAndroid Build Coastguard Worker out = torch.zeros(x.size()) 875*da0073e9SAndroid Build Coastguard Worker out.copy_(x) 876*da0073e9SAndroid Build Coastguard Worker return out 877*da0073e9SAndroid Build Coastguard Worker 878*da0073e9SAndroid Build Coastguard Worker g, outputs, inputs = torch.jit._get_trace_graph(f, (x,), return_inputs=True) 879*da0073e9SAndroid Build Coastguard Worker self.run_pass("dce", g) 880*da0073e9SAndroid Build Coastguard Worker m = self.createFunctionFromGraph(g) 881*da0073e9SAndroid Build Coastguard Worker self.assertEqual(outputs, m(*inputs)) 882*da0073e9SAndroid Build Coastguard Worker self.assertExportImport(g, (x,)) 883*da0073e9SAndroid Build Coastguard Worker 884*da0073e9SAndroid Build Coastguard Worker def test_inplace_copy_force_outplace(self): 885*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 4, requires_grad=True) 886*da0073e9SAndroid Build Coastguard Worker 887*da0073e9SAndroid Build Coastguard Worker def f(x): 888*da0073e9SAndroid Build Coastguard Worker out = torch.zeros(x.size()) 889*da0073e9SAndroid Build Coastguard Worker out.copy_(x) 890*da0073e9SAndroid Build Coastguard Worker return out 891*da0073e9SAndroid Build Coastguard Worker 892*da0073e9SAndroid Build Coastguard Worker g, outputs, inputs = torch.jit._get_trace_graph( 893*da0073e9SAndroid Build Coastguard Worker f, (x,), return_inputs=True, _force_outplace=True 894*da0073e9SAndroid Build Coastguard Worker ) 895*da0073e9SAndroid Build Coastguard Worker self.run_pass("dce", g) 896*da0073e9SAndroid Build Coastguard Worker m = self.createFunctionFromGraph(g) 897*da0073e9SAndroid Build Coastguard Worker self.assertEqual(outputs, m(*inputs)) 898*da0073e9SAndroid Build Coastguard Worker self.assertExportImport(g, (x,)) 899*da0073e9SAndroid Build Coastguard Worker FileCheck().check("expand_as").run(str(g)) 900*da0073e9SAndroid Build Coastguard Worker 901*da0073e9SAndroid Build Coastguard Worker def test_shared_param(self): 902*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 903*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 904*da0073e9SAndroid Build Coastguard Worker super().__init__() 905*da0073e9SAndroid Build Coastguard Worker self.b = self.a = nn.Parameter(torch.randn(2, 2)) 906*da0073e9SAndroid Build Coastguard Worker 907*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 908*da0073e9SAndroid Build Coastguard Worker return x * self.a + self.b 909*da0073e9SAndroid Build Coastguard Worker 910*da0073e9SAndroid Build Coastguard Worker m = MyModule() 911*da0073e9SAndroid Build Coastguard Worker g, _ = torch.jit._get_trace_graph(m, (torch.randn(2, 2),)) 912*da0073e9SAndroid Build Coastguard Worker self.run_pass("dce", g) 913*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(list(g.inputs())), 2) 914*da0073e9SAndroid Build Coastguard Worker FileCheck().check("mul").check("add").run(str(g)) 915*da0073e9SAndroid Build Coastguard Worker 916*da0073e9SAndroid Build Coastguard Worker def run_ge_tests(self, optimize, use_cuda): 917*da0073e9SAndroid Build Coastguard Worker with enable_profiling_mode_for_profiling_tests(): 918*da0073e9SAndroid Build Coastguard Worker with torch.jit.optimized_execution(optimize): 919*da0073e9SAndroid Build Coastguard Worker 920*da0073e9SAndroid Build Coastguard Worker def rand(*args): 921*da0073e9SAndroid Build Coastguard Worker t = torch.rand(*args).float() 922*da0073e9SAndroid Build Coastguard Worker if use_cuda: 923*da0073e9SAndroid Build Coastguard Worker t = t.cuda() 924*da0073e9SAndroid Build Coastguard Worker return t 925*da0073e9SAndroid Build Coastguard Worker 926*da0073e9SAndroid Build Coastguard Worker self.checkTrace( 927*da0073e9SAndroid Build Coastguard Worker lambda a, b: a * b + b, [rand(1), rand(1)], [rand(2, 3), rand(2, 3)] 928*da0073e9SAndroid Build Coastguard Worker ) 929*da0073e9SAndroid Build Coastguard Worker # trivial identity 930*da0073e9SAndroid Build Coastguard Worker self.checkTrace(lambda a, b: (b, a), [rand(1), rand(1)]) 931*da0073e9SAndroid Build Coastguard Worker 932*da0073e9SAndroid Build Coastguard Worker def foo(a): 933*da0073e9SAndroid Build Coastguard Worker t = a * a 934*da0073e9SAndroid Build Coastguard Worker return t * t, 4 * t 935*da0073e9SAndroid Build Coastguard Worker 936*da0073e9SAndroid Build Coastguard Worker self.checkTrace(foo, [rand(1)]) 937*da0073e9SAndroid Build Coastguard Worker # unused input 938*da0073e9SAndroid Build Coastguard Worker self.checkTrace( 939*da0073e9SAndroid Build Coastguard Worker lambda a, b: a * a, [rand(1), rand(1)], allow_unused=True 940*da0073e9SAndroid Build Coastguard Worker ) 941*da0073e9SAndroid Build Coastguard Worker # test outputs that do not get used in grad 942*da0073e9SAndroid Build Coastguard Worker self.checkTrace(foo, [rand(1)], drop=1) 943*da0073e9SAndroid Build Coastguard Worker # test autograd fallback 944*da0073e9SAndroid Build Coastguard Worker self.checkTrace( 945*da0073e9SAndroid Build Coastguard Worker lambda a, b: a * b / (a - 2 * b) + b, [rand(1), rand(1)] 946*da0073e9SAndroid Build Coastguard Worker ) 947*da0073e9SAndroid Build Coastguard Worker 948*da0073e9SAndroid Build Coastguard Worker def test_ge_unoptimized(self): 949*da0073e9SAndroid Build Coastguard Worker self.run_ge_tests(False, False) 950*da0073e9SAndroid Build Coastguard Worker 951*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser support for Sandcastle") 952*da0073e9SAndroid Build Coastguard Worker @enable_cpu_fuser 953*da0073e9SAndroid Build Coastguard Worker def test_ge_optimized(self): 954*da0073e9SAndroid Build Coastguard Worker with enable_profiling_mode_for_profiling_tests(): 955*da0073e9SAndroid Build Coastguard Worker self.run_ge_tests(True, False) 956*da0073e9SAndroid Build Coastguard Worker 957*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "requires CUDA") 958*da0073e9SAndroid Build Coastguard Worker def test_ge_cuda(self): 959*da0073e9SAndroid Build Coastguard Worker self.run_ge_tests(True, True) 960*da0073e9SAndroid Build Coastguard Worker 961*da0073e9SAndroid Build Coastguard Worker # more manual test of graph executor that can be used as a scratchpad 962*da0073e9SAndroid Build Coastguard Worker def test_ge(self): 963*da0073e9SAndroid Build Coastguard Worker def foo(a, b): 964*da0073e9SAndroid Build Coastguard Worker return a * b / (a - b) + b 965*da0073e9SAndroid Build Coastguard Worker 966*da0073e9SAndroid Build Coastguard Worker V = Variable 967*da0073e9SAndroid Build Coastguard Worker a, b = V(torch.rand(1)), V(torch.rand(1)) 968*da0073e9SAndroid Build Coastguard Worker ge = torch.jit.trace(foo, (a, b)) 969*da0073e9SAndroid Build Coastguard Worker a, b = V(torch.rand(1), requires_grad=True), V( 970*da0073e9SAndroid Build Coastguard Worker torch.rand(1), requires_grad=True 971*da0073e9SAndroid Build Coastguard Worker ) 972*da0073e9SAndroid Build Coastguard Worker (r,) = ge(a, b) 973*da0073e9SAndroid Build Coastguard Worker da, db = torch.autograd.grad(r + 3, [a, b], create_graph=True) 974*da0073e9SAndroid Build Coastguard Worker 975*da0073e9SAndroid Build Coastguard Worker l2 = da * db + db * db 976*da0073e9SAndroid Build Coastguard Worker g2result = torch.autograd.grad(l2, [da, db]) 977*da0073e9SAndroid Build Coastguard Worker 978*da0073e9SAndroid Build Coastguard Worker r = foo(a, b) 979*da0073e9SAndroid Build Coastguard Worker da2, db2 = torch.autograd.grad(r + 3, [a, b], create_graph=True) 980*da0073e9SAndroid Build Coastguard Worker self.assertEqual(da, da2) 981*da0073e9SAndroid Build Coastguard Worker self.assertEqual(db, db2) 982*da0073e9SAndroid Build Coastguard Worker l3 = da2 * db2 + db2 * db2 983*da0073e9SAndroid Build Coastguard Worker g2result2 = torch.autograd.grad(l3, [da2, db2]) 984*da0073e9SAndroid Build Coastguard Worker self.assertEqual(g2result, g2result2) 985*da0073e9SAndroid Build Coastguard Worker 986*da0073e9SAndroid Build Coastguard Worker def test_trace_annotation(self): 987*da0073e9SAndroid Build Coastguard Worker @_trace(torch.rand(1)) 988*da0073e9SAndroid Build Coastguard Worker def foo(a): 989*da0073e9SAndroid Build Coastguard Worker return a + a + a 990*da0073e9SAndroid Build Coastguard Worker 991*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5) 992*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(x), x + x + x) 993*da0073e9SAndroid Build Coastguard Worker 994*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "calls .cuda()") 995*da0073e9SAndroid Build Coastguard Worker # By default, on Ampere or later GPUs, nn.Linear computes float tensors at TF32 precision. 996*da0073e9SAndroid Build Coastguard Worker # We want float tensors to be computed at full precision in order to use the default precision 997*da0073e9SAndroid Build Coastguard Worker @with_tf32_off 998*da0073e9SAndroid Build Coastguard Worker def test_traced_module_cuda(self): 999*da0073e9SAndroid Build Coastguard Worker class Model(nn.Module): 1000*da0073e9SAndroid Build Coastguard Worker def __init__(self, num_features, num_layers): 1001*da0073e9SAndroid Build Coastguard Worker super().__init__() 1002*da0073e9SAndroid Build Coastguard Worker self.num_layers = num_layers 1003*da0073e9SAndroid Build Coastguard Worker layers = [ 1004*da0073e9SAndroid Build Coastguard Worker [nn.Linear(num_features, num_features), nn.Sigmoid()] 1005*da0073e9SAndroid Build Coastguard Worker for _ in range(num_layers) 1006*da0073e9SAndroid Build Coastguard Worker ] 1007*da0073e9SAndroid Build Coastguard Worker self.submodule = nn.Sequential(*chain(*layers)) 1008*da0073e9SAndroid Build Coastguard Worker 1009*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1010*da0073e9SAndroid Build Coastguard Worker for i in range(self.num_layers): 1011*da0073e9SAndroid Build Coastguard Worker x = self.submodule[i](x) + x 1012*da0073e9SAndroid Build Coastguard Worker return x 1013*da0073e9SAndroid Build Coastguard Worker 1014*da0073e9SAndroid Build Coastguard Worker model = Model(5, 3) 1015*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 5) 1016*da0073e9SAndroid Build Coastguard Worker traced_model = torch.jit.trace(model, x) 1017*da0073e9SAndroid Build Coastguard Worker 1018*da0073e9SAndroid Build Coastguard Worker # We're missing some attributes these modules had initially. Make sure we can 1019*da0073e9SAndroid Build Coastguard Worker # still get the __repr__() 1020*da0073e9SAndroid Build Coastguard Worker model.__repr__() 1021*da0073e9SAndroid Build Coastguard Worker 1022*da0073e9SAndroid Build Coastguard Worker # XXX: indexing sequentials is broken 1023*da0073e9SAndroid Build Coastguard Worker linear_submodule = next(iter(traced_model.submodule._modules.values())) 1024*da0073e9SAndroid Build Coastguard Worker 1025*da0073e9SAndroid Build Coastguard Worker # All attributes that aren't parameters should raise 1026*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AttributeError): 1027*da0073e9SAndroid Build Coastguard Worker linear_submodule.in_features 1028*da0073e9SAndroid Build Coastguard Worker linear_submodule.weight 1029*da0073e9SAndroid Build Coastguard Worker linear_submodule.weight = nn.Parameter( 1030*da0073e9SAndroid Build Coastguard Worker torch.randn(linear_submodule.weight.shape) 1031*da0073e9SAndroid Build Coastguard Worker ) 1032*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 1033*da0073e9SAndroid Build Coastguard Worker del linear_submodule.weight 1034*da0073e9SAndroid Build Coastguard Worker 1035*da0073e9SAndroid Build Coastguard Worker # Submodules can't be called 1036*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 1037*da0073e9SAndroid Build Coastguard Worker linear_submodule(x) 1038*da0073e9SAndroid Build Coastguard Worker 1039*da0073e9SAndroid Build Coastguard Worker # Type casts 1040*da0073e9SAndroid Build Coastguard Worker linear_submodule.cuda() 1041*da0073e9SAndroid Build Coastguard Worker traced_model.float().cuda() 1042*da0073e9SAndroid Build Coastguard Worker cuda_out = traced_model(x.float().cuda()) 1043*da0073e9SAndroid Build Coastguard Worker traced_model.cpu() 1044*da0073e9SAndroid Build Coastguard Worker cpu_out = traced_model(x.float()) 1045*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_out, cuda_out) 1046*da0073e9SAndroid Build Coastguard Worker traced_model.to("cuda") 1047*da0073e9SAndroid Build Coastguard Worker cuda_out = traced_model(x.float().cuda()) 1048*da0073e9SAndroid Build Coastguard Worker traced_model.to("cpu") 1049*da0073e9SAndroid Build Coastguard Worker cpu_out = traced_model(x.float()) 1050*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_out, cuda_out) 1051*da0073e9SAndroid Build Coastguard Worker traced_model.to(torch.get_default_dtype()) 1052*da0073e9SAndroid Build Coastguard Worker 1053*da0073e9SAndroid Build Coastguard Worker # state_dict + load_state_dict 1054*da0073e9SAndroid Build Coastguard Worker state = {k: v.clone() for k, v in traced_model.state_dict().items()} 1055*da0073e9SAndroid Build Coastguard Worker new_state = {k: v.clone().fill_(1) for k, v in state.items()} 1056*da0073e9SAndroid Build Coastguard Worker out = traced_model(x) 1057*da0073e9SAndroid Build Coastguard Worker traced_model.load_state_dict(new_state) 1058*da0073e9SAndroid Build Coastguard Worker out_ones = traced_model(x) 1059*da0073e9SAndroid Build Coastguard Worker traced_model.load_state_dict(state) 1060*da0073e9SAndroid Build Coastguard Worker out_state = traced_model(x) 1061*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, out_state) 1062*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(out, out_ones) 1063*da0073e9SAndroid Build Coastguard Worker 1064*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "uses cuda") 1065*da0073e9SAndroid Build Coastguard Worker def test_type_same_device(self): 1066*da0073e9SAndroid Build Coastguard Worker class Model(torch.nn.Module): 1067*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1068*da0073e9SAndroid Build Coastguard Worker super().__init__() 1069*da0073e9SAndroid Build Coastguard Worker self.dtype = torch.float16 1070*da0073e9SAndroid Build Coastguard Worker 1071*da0073e9SAndroid Build Coastguard Worker def forward(self, x=None): 1072*da0073e9SAndroid Build Coastguard Worker h = x.type(self.dtype) 1073*da0073e9SAndroid Build Coastguard Worker return h 1074*da0073e9SAndroid Build Coastguard Worker 1075*da0073e9SAndroid Build Coastguard Worker a = Model() 1076*da0073e9SAndroid Build Coastguard Worker b = torch.jit.trace( 1077*da0073e9SAndroid Build Coastguard Worker a, example_inputs=(torch.ones([1], device=torch.device("cuda")),) 1078*da0073e9SAndroid Build Coastguard Worker ) 1079*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("device").run(b.code) 1080*da0073e9SAndroid Build Coastguard Worker 1081*da0073e9SAndroid Build Coastguard Worker def test_export_no_reorder(self): 1082*da0073e9SAndroid Build Coastguard Worker def func(a, b): 1083*da0073e9SAndroid Build Coastguard Worker return a * b / (a - 2 * b) + b 1084*da0073e9SAndroid Build Coastguard Worker 1085*da0073e9SAndroid Build Coastguard Worker recording_inputs = [ 1086*da0073e9SAndroid Build Coastguard Worker torch.tensor( 1087*da0073e9SAndroid Build Coastguard Worker [0.55619788169860839844], dtype=torch.float32, requires_grad=True 1088*da0073e9SAndroid Build Coastguard Worker ), 1089*da0073e9SAndroid Build Coastguard Worker torch.tensor( 1090*da0073e9SAndroid Build Coastguard Worker [0.25947844982147216797], dtype=torch.float32, requires_grad=True 1091*da0073e9SAndroid Build Coastguard Worker ), 1092*da0073e9SAndroid Build Coastguard Worker ] 1093*da0073e9SAndroid Build Coastguard Worker 1094*da0073e9SAndroid Build Coastguard Worker ge1 = torch.jit.trace(func, recording_inputs) 1095*da0073e9SAndroid Build Coastguard Worker ge2 = self.getExportImportCopy(ge1) 1096*da0073e9SAndroid Build Coastguard Worker 1097*da0073e9SAndroid Build Coastguard Worker outputs_ge1 = ge1(*recording_inputs) 1098*da0073e9SAndroid Build Coastguard Worker outputs_ge2 = ge2(*recording_inputs) 1099*da0073e9SAndroid Build Coastguard Worker 1100*da0073e9SAndroid Build Coastguard Worker grad_ge1 = torch.autograd.grad(outputs_ge1, recording_inputs) 1101*da0073e9SAndroid Build Coastguard Worker grad_ge2 = torch.autograd.grad(outputs_ge2, recording_inputs) 1102*da0073e9SAndroid Build Coastguard Worker self.assertTrue(outputs_ge1 == outputs_ge2) 1103*da0073e9SAndroid Build Coastguard Worker self.assertTrue(grad_ge1 == grad_ge2) 1104*da0073e9SAndroid Build Coastguard Worker 1105*da0073e9SAndroid Build Coastguard Worker def test_python_function(self): 1106*da0073e9SAndroid Build Coastguard Worker class MyFn(Function): 1107*da0073e9SAndroid Build Coastguard Worker @staticmethod 1108*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 1109*da0073e9SAndroid Build Coastguard Worker return x + 1 1110*da0073e9SAndroid Build Coastguard Worker 1111*da0073e9SAndroid Build Coastguard Worker @staticmethod 1112*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_output): 1113*da0073e9SAndroid Build Coastguard Worker return grad_output 1114*da0073e9SAndroid Build Coastguard Worker 1115*da0073e9SAndroid Build Coastguard Worker @_trace(torch.zeros(2)) 1116*da0073e9SAndroid Build Coastguard Worker def fn(x): 1117*da0073e9SAndroid Build Coastguard Worker return MyFn.apply(x + 2) + 3 1118*da0073e9SAndroid Build Coastguard Worker 1119*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1.0, 2.0, 3.0]) 1120*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2, 2, requires_grad=True) 1121*da0073e9SAndroid Build Coastguard Worker fn(x) 1122*da0073e9SAndroid Build Coastguard Worker fn(y) 1123*da0073e9SAndroid Build Coastguard Worker 1124*da0073e9SAndroid Build Coastguard Worker def test_python_function_tup(self): 1125*da0073e9SAndroid Build Coastguard Worker class MyFn(Function): 1126*da0073e9SAndroid Build Coastguard Worker @staticmethod 1127*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 1128*da0073e9SAndroid Build Coastguard Worker return x + 1, x - 1 1129*da0073e9SAndroid Build Coastguard Worker 1130*da0073e9SAndroid Build Coastguard Worker @staticmethod 1131*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_output): 1132*da0073e9SAndroid Build Coastguard Worker return grad_output, grad_output 1133*da0073e9SAndroid Build Coastguard Worker 1134*da0073e9SAndroid Build Coastguard Worker @_trace(torch.zeros(2)) 1135*da0073e9SAndroid Build Coastguard Worker def fn(x): 1136*da0073e9SAndroid Build Coastguard Worker a, b = MyFn.apply(x + 2) 1137*da0073e9SAndroid Build Coastguard Worker return a + b + 3 1138*da0073e9SAndroid Build Coastguard Worker 1139*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1.0, 2.0, 3.0]) 1140*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2, 2, requires_grad=True) 1141*da0073e9SAndroid Build Coastguard Worker fn(x) 1142*da0073e9SAndroid Build Coastguard Worker fn(y) 1143*da0073e9SAndroid Build Coastguard Worker 1144*da0073e9SAndroid Build Coastguard Worker def test_trace_detach(self): 1145*da0073e9SAndroid Build Coastguard Worker def foo(x, w): 1146*da0073e9SAndroid Build Coastguard Worker return torch.matmul(x, w).detach() 1147*da0073e9SAndroid Build Coastguard Worker 1148*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(foo, (torch.rand(3, 4), torch.rand(4, 5))) 1149*da0073e9SAndroid Build Coastguard Worker 1150*da0073e9SAndroid Build Coastguard Worker FileCheck().check("matmul").check("detach").run(str(traced.graph)) 1151*da0073e9SAndroid Build Coastguard Worker x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=True) 1152*da0073e9SAndroid Build Coastguard Worker traced_result = traced(x, w) 1153*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(x, w), traced_result) 1154*da0073e9SAndroid Build Coastguard Worker self.assertFalse(traced_result.requires_grad) 1155*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(traced_result.grad_fn) 1156*da0073e9SAndroid Build Coastguard Worker 1157*da0073e9SAndroid Build Coastguard Worker def test_trace_detach_redispatch(self): 1158*da0073e9SAndroid Build Coastguard Worker def foo(x, w): 1159*da0073e9SAndroid Build Coastguard Worker y = torch.matmul(x, w) 1160*da0073e9SAndroid Build Coastguard Worker assert y.requires_grad 1161*da0073e9SAndroid Build Coastguard Worker y = y.detach() 1162*da0073e9SAndroid Build Coastguard Worker # Make sure trace kernel redispatches to the right lower kernel. 1163*da0073e9SAndroid Build Coastguard Worker assert not y.requires_grad 1164*da0073e9SAndroid Build Coastguard Worker return y 1165*da0073e9SAndroid Build Coastguard Worker 1166*da0073e9SAndroid Build Coastguard Worker x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=True) 1167*da0073e9SAndroid Build Coastguard Worker # With `check_trace=True` it will run with `@torch.no_grad()` and break assert. 1168*da0073e9SAndroid Build Coastguard Worker torch.jit.trace(foo, (x, w), check_trace=False) 1169*da0073e9SAndroid Build Coastguard Worker 1170*da0073e9SAndroid Build Coastguard Worker def test_trace_detach_inplace(self): 1171*da0073e9SAndroid Build Coastguard Worker def foo(x, w): 1172*da0073e9SAndroid Build Coastguard Worker y = torch.matmul(x, w) 1173*da0073e9SAndroid Build Coastguard Worker y.detach_() 1174*da0073e9SAndroid Build Coastguard Worker return y 1175*da0073e9SAndroid Build Coastguard Worker 1176*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(foo, (torch.rand(3, 4), torch.rand(4, 5))) 1177*da0073e9SAndroid Build Coastguard Worker 1178*da0073e9SAndroid Build Coastguard Worker FileCheck().check("matmul").check("detach(").run(str(traced.graph)) 1179*da0073e9SAndroid Build Coastguard Worker x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=True) 1180*da0073e9SAndroid Build Coastguard Worker traced_result = traced(x, w) 1181*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(x, w), traced_result) 1182*da0073e9SAndroid Build Coastguard Worker self.assertFalse(traced_result.requires_grad) 1183*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(traced_result.grad_fn) 1184*da0073e9SAndroid Build Coastguard Worker 1185*da0073e9SAndroid Build Coastguard Worker def test_trace_detach_inplace_redispatch(self): 1186*da0073e9SAndroid Build Coastguard Worker def foo(x, w): 1187*da0073e9SAndroid Build Coastguard Worker y = torch.matmul(x, w) 1188*da0073e9SAndroid Build Coastguard Worker assert y.requires_grad 1189*da0073e9SAndroid Build Coastguard Worker y.detach_() 1190*da0073e9SAndroid Build Coastguard Worker # Make sure trace kernel redispatches to the right lower kernel. 1191*da0073e9SAndroid Build Coastguard Worker assert not y.requires_grad 1192*da0073e9SAndroid Build Coastguard Worker return y 1193*da0073e9SAndroid Build Coastguard Worker 1194*da0073e9SAndroid Build Coastguard Worker x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=True) 1195*da0073e9SAndroid Build Coastguard Worker # With `check_trace=True` it will run with `@torch.no_grad()` and break assert. 1196*da0073e9SAndroid Build Coastguard Worker torch.jit.trace(foo, (x, w), check_trace=False) 1197*da0073e9SAndroid Build Coastguard Worker 1198*da0073e9SAndroid Build Coastguard Worker def test_trace_slice_full_dim(self): 1199*da0073e9SAndroid Build Coastguard Worker def foo(x): 1200*da0073e9SAndroid Build Coastguard Worker return x[0:5, 0] + 1.0 1201*da0073e9SAndroid Build Coastguard Worker 1202*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(foo, (torch.rand(5, 4),)) 1203*da0073e9SAndroid Build Coastguard Worker test_x = torch.rand(6, 3) 1204*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(test_x), traced(test_x)) 1205*da0073e9SAndroid Build Coastguard Worker 1206*da0073e9SAndroid Build Coastguard Worker def test_trace_dict_input(self): 1207*da0073e9SAndroid Build Coastguard Worker class Bar(torch.nn.Module): 1208*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1209*da0073e9SAndroid Build Coastguard Worker super().__init__() 1210*da0073e9SAndroid Build Coastguard Worker self.foo = Foo() 1211*da0073e9SAndroid Build Coastguard Worker 1212*da0073e9SAndroid Build Coastguard Worker def forward(self, a, b): 1213*da0073e9SAndroid Build Coastguard Worker return self.foo({"a": a, "b": b})["a"] 1214*da0073e9SAndroid Build Coastguard Worker 1215*da0073e9SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 1216*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1217*da0073e9SAndroid Build Coastguard Worker return {"a": x["a"] * x["b"]} 1218*da0073e9SAndroid Build Coastguard Worker 1219*da0073e9SAndroid Build Coastguard Worker x = (torch.rand(3), torch.rand(3)) 1220*da0073e9SAndroid Build Coastguard Worker model = Bar() 1221*da0073e9SAndroid Build Coastguard Worker self.checkTrace(model, x) 1222*da0073e9SAndroid Build Coastguard Worker 1223*da0073e9SAndroid Build Coastguard Worker def test_trace_dict_output(self): 1224*da0073e9SAndroid Build Coastguard Worker class TraceDictStrTensor(torch.nn.Module): 1225*da0073e9SAndroid Build Coastguard Worker def forward(self, a, b): 1226*da0073e9SAndroid Build Coastguard Worker return {"a": a, "b": b} 1227*da0073e9SAndroid Build Coastguard Worker 1228*da0073e9SAndroid Build Coastguard Worker class TraceDictTensorTensor(torch.nn.Module): 1229*da0073e9SAndroid Build Coastguard Worker def forward(self, a, b): 1230*da0073e9SAndroid Build Coastguard Worker return {a: b, b: a} 1231*da0073e9SAndroid Build Coastguard Worker 1232*da0073e9SAndroid Build Coastguard Worker x = (torch.rand(3), torch.rand(3)) 1233*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"Encountering a dict at the output"): 1234*da0073e9SAndroid Build Coastguard Worker torch.jit.trace(TraceDictStrTensor(), x) 1235*da0073e9SAndroid Build Coastguard Worker 1236*da0073e9SAndroid Build Coastguard Worker traced_dict_str_mod = torch.jit.trace(TraceDictStrTensor(), x, strict=False) 1237*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced_dict_str_mod(*x), {"a": x[0], "b": x[1]}) 1238*da0073e9SAndroid Build Coastguard Worker 1239*da0073e9SAndroid Build Coastguard Worker traced_dict_tensor_mod = torch.jit.trace( 1240*da0073e9SAndroid Build Coastguard Worker TraceDictTensorTensor(), x, strict=False 1241*da0073e9SAndroid Build Coastguard Worker ) 1242*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced_dict_tensor_mod(*x), {x[0]: x[1], x[1]: x[0]}) 1243*da0073e9SAndroid Build Coastguard Worker 1244*da0073e9SAndroid Build Coastguard Worker def test_trace_with_tensor_list_output(self): 1245*da0073e9SAndroid Build Coastguard Worker def f(): 1246*da0073e9SAndroid Build Coastguard Worker return [torch.zeros(1), torch.zeros(5)] 1247*da0073e9SAndroid Build Coastguard Worker 1248*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex( 1249*da0073e9SAndroid Build Coastguard Worker torch.jit.TracerWarning, "cause the trace to be incorrect" 1250*da0073e9SAndroid Build Coastguard Worker ): 1251*da0073e9SAndroid Build Coastguard Worker torch.jit.trace(f, []) 1252*da0073e9SAndroid Build Coastguard Worker traced_non_strict_f = torch.jit.trace(f, [], strict=False) 1253*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced_non_strict_f(), f()) 1254*da0073e9SAndroid Build Coastguard Worker 1255*da0073e9SAndroid Build Coastguard Worker def test_trace_with_number_list_output(self): 1256*da0073e9SAndroid Build Coastguard Worker def f(): 1257*da0073e9SAndroid Build Coastguard Worker return [1, 5] 1258*da0073e9SAndroid Build Coastguard Worker 1259*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1260*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"Only tensors.+can be output from traced functions" 1261*da0073e9SAndroid Build Coastguard Worker ): 1262*da0073e9SAndroid Build Coastguard Worker traced_f = torch.jit.trace(f, []) 1263*da0073e9SAndroid Build Coastguard Worker 1264*da0073e9SAndroid Build Coastguard Worker def test_trace_with_nested_tensor_list_output(self): 1265*da0073e9SAndroid Build Coastguard Worker def f(): 1266*da0073e9SAndroid Build Coastguard Worker return [[torch.zeros(1)], [torch.zeros(5)]] 1267*da0073e9SAndroid Build Coastguard Worker 1268*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1269*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"Only tensors.+can be output from traced functions" 1270*da0073e9SAndroid Build Coastguard Worker ): 1271*da0073e9SAndroid Build Coastguard Worker traced_f = torch.jit.trace(f, []) 1272*da0073e9SAndroid Build Coastguard Worker 1273*da0073e9SAndroid Build Coastguard Worker def test_trace_with_nested_strided_tensor_output(self): 1274*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 1275*da0073e9SAndroid Build Coastguard Worker def nt_construct(values, kv_lengths): 1276*da0073e9SAndroid Build Coastguard Worker kv_lengths_list: List[int] = kv_lengths.tolist() 1277*da0073e9SAndroid Build Coastguard Worker return torch._nested_tensor_from_tensor_list( 1278*da0073e9SAndroid Build Coastguard Worker list(values.split(kv_lengths_list, dim=0)), None, None, None, None 1279*da0073e9SAndroid Build Coastguard Worker ) 1280*da0073e9SAndroid Build Coastguard Worker 1281*da0073e9SAndroid Build Coastguard Worker def f(x, offsets): 1282*da0073e9SAndroid Build Coastguard Worker kv_lengths = offsets[1:] - offsets[:-1] 1283*da0073e9SAndroid Build Coastguard Worker return nt_construct(x, kv_lengths).cos() 1284*da0073e9SAndroid Build Coastguard Worker 1285*da0073e9SAndroid Build Coastguard Worker x = torch.rand(5, 4) 1286*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 2, 5]) 1287*da0073e9SAndroid Build Coastguard Worker ref = f(x, offsets) 1288*da0073e9SAndroid Build Coastguard Worker f_t = torch.jit.trace(f, (x, offsets)) 1289*da0073e9SAndroid Build Coastguard Worker res = f_t(x, offsets) 1290*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 1291*da0073e9SAndroid Build Coastguard Worker x2 = torch.rand((8, 4)) 1292*da0073e9SAndroid Build Coastguard Worker offsets2 = torch.tensor([0, 2, 4, 8]) 1293*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(x2, offsets2), f_t(x2, offsets2)) 1294*da0073e9SAndroid Build Coastguard Worker 1295*da0073e9SAndroid Build Coastguard Worker def test_trace_variable_instantiation(self): 1296*da0073e9SAndroid Build Coastguard Worker def random_foo(x): 1297*da0073e9SAndroid Build Coastguard Worker return Variable(Variable(x) + 1.0) 1298*da0073e9SAndroid Build Coastguard Worker 1299*da0073e9SAndroid Build Coastguard Worker random_foo_traced = torch.jit.trace(random_foo, (torch.rand(3, 4),)) 1300*da0073e9SAndroid Build Coastguard Worker 1301*da0073e9SAndroid Build Coastguard Worker x = torch.rand(5, 6) 1302*da0073e9SAndroid Build Coastguard Worker self.assertEqual(random_foo(x), random_foo_traced(x)) 1303*da0073e9SAndroid Build Coastguard Worker 1304*da0073e9SAndroid Build Coastguard Worker def test_trace_slice_expr_complete_type(self): 1305*da0073e9SAndroid Build Coastguard Worker def random_foo(x): 1306*da0073e9SAndroid Build Coastguard Worker return x + 1.0 1307*da0073e9SAndroid Build Coastguard Worker 1308*da0073e9SAndroid Build Coastguard Worker random_foo_traced = torch.jit.trace(random_foo, (torch.rand(3, 4),)) 1309*da0073e9SAndroid Build Coastguard Worker 1310*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 1311*da0073e9SAndroid Build Coastguard Worker def random_bar(x): 1312*da0073e9SAndroid Build Coastguard Worker return random_foo_traced(x)[0:1] 1313*da0073e9SAndroid Build Coastguard Worker 1314*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4) 1315*da0073e9SAndroid Build Coastguard Worker self.assertEqual(random_bar(x), (x + 1)[0:1]) 1316*da0073e9SAndroid Build Coastguard Worker 1317*da0073e9SAndroid Build Coastguard Worker def test_trace_inline_shape(self): 1318*da0073e9SAndroid Build Coastguard Worker # testing peephole optimization of size is turned into a constant 1319*da0073e9SAndroid Build Coastguard Worker # in script fn 1320*da0073e9SAndroid Build Coastguard Worker 1321*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 1322*da0073e9SAndroid Build Coastguard Worker def tensor_size(x: torch.Tensor) -> torch.Tensor: 1323*da0073e9SAndroid Build Coastguard Worker return torch.tensor([x.size()[0]]) 1324*da0073e9SAndroid Build Coastguard Worker 1325*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1326*da0073e9SAndroid Build Coastguard Worker tensor_size( 1327*da0073e9SAndroid Build Coastguard Worker torch.rand( 1328*da0073e9SAndroid Build Coastguard Worker 15, 1329*da0073e9SAndroid Build Coastguard Worker ) 1330*da0073e9SAndroid Build Coastguard Worker ), 1331*da0073e9SAndroid Build Coastguard Worker torch.tensor([15]), 1332*da0073e9SAndroid Build Coastguard Worker ) 1333*da0073e9SAndroid Build Coastguard Worker 1334*da0073e9SAndroid Build Coastguard Worker traced_tensor_size = torch.jit.trace( 1335*da0073e9SAndroid Build Coastguard Worker tensor_size, 1336*da0073e9SAndroid Build Coastguard Worker torch.rand( 1337*da0073e9SAndroid Build Coastguard Worker 7, 1338*da0073e9SAndroid Build Coastguard Worker ), 1339*da0073e9SAndroid Build Coastguard Worker ) 1340*da0073e9SAndroid Build Coastguard Worker 1341*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1342*da0073e9SAndroid Build Coastguard Worker traced_tensor_size( 1343*da0073e9SAndroid Build Coastguard Worker torch.rand( 1344*da0073e9SAndroid Build Coastguard Worker 15, 1345*da0073e9SAndroid Build Coastguard Worker ) 1346*da0073e9SAndroid Build Coastguard Worker ), 1347*da0073e9SAndroid Build Coastguard Worker torch.tensor([15]), 1348*da0073e9SAndroid Build Coastguard Worker ) 1349*da0073e9SAndroid Build Coastguard Worker 1350*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 1351*da0073e9SAndroid Build Coastguard Worker def use_device(x): 1352*da0073e9SAndroid Build Coastguard Worker return torch.zeros_like(x, device=x.device) 1353*da0073e9SAndroid Build Coastguard Worker 1354*da0073e9SAndroid Build Coastguard Worker def foo(x): 1355*da0073e9SAndroid Build Coastguard Worker return use_device(x) 1356*da0073e9SAndroid Build Coastguard Worker 1357*da0073e9SAndroid Build Coastguard Worker traced_tensor_size = torch.jit.trace( 1358*da0073e9SAndroid Build Coastguard Worker foo, 1359*da0073e9SAndroid Build Coastguard Worker torch.rand( 1360*da0073e9SAndroid Build Coastguard Worker 7, 1361*da0073e9SAndroid Build Coastguard Worker ), 1362*da0073e9SAndroid Build Coastguard Worker ) 1363*da0073e9SAndroid Build Coastguard Worker self.run_pass("inline", traced_tensor_size.graph) 1364*da0073e9SAndroid Build Coastguard Worker FileCheck().check("prim::device").run(traced_tensor_size.graph) 1365*da0073e9SAndroid Build Coastguard Worker 1366*da0073e9SAndroid Build Coastguard Worker def test_trace_save(self): 1367*da0073e9SAndroid Build Coastguard Worker def fn(x): 1368*da0073e9SAndroid Build Coastguard Worker return x + 2 1369*da0073e9SAndroid Build Coastguard Worker 1370*da0073e9SAndroid Build Coastguard Worker def check(func): 1371*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName() as fname: 1372*da0073e9SAndroid Build Coastguard Worker func.save(fname) 1373*da0073e9SAndroid Build Coastguard Worker loaded = torch.jit.load(fname) 1374*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 2) 1375*da0073e9SAndroid Build Coastguard Worker self.assertEqual(func(input), loaded(input)) 1376*da0073e9SAndroid Build Coastguard Worker 1377*da0073e9SAndroid Build Coastguard Worker out = torch.jit.trace(fn, (torch.ones(2, 2),)) 1378*da0073e9SAndroid Build Coastguard Worker check(out) 1379*da0073e9SAndroid Build Coastguard Worker 1380*da0073e9SAndroid Build Coastguard Worker def test_trace_optioanl_dtype(self): 1381*da0073e9SAndroid Build Coastguard Worker class Test(torch.nn.Module): 1382*da0073e9SAndroid Build Coastguard Worker def forward(self): 1383*da0073e9SAndroid Build Coastguard Worker return torch.arange(5) 1384*da0073e9SAndroid Build Coastguard Worker 1385*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(Test(), ()) 1386*da0073e9SAndroid Build Coastguard Worker torch.allclose(traced(), Test()()) 1387*da0073e9SAndroid Build Coastguard Worker 1388*da0073e9SAndroid Build Coastguard Worker def test_trace_save_load_copy(self): 1389*da0073e9SAndroid Build Coastguard Worker class Test(torch.nn.Module): 1390*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1391*da0073e9SAndroid Build Coastguard Worker super().__init__() 1392*da0073e9SAndroid Build Coastguard Worker self.conv = torch.nn.Conv2d(3, 3, 3) 1393*da0073e9SAndroid Build Coastguard Worker 1394*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1395*da0073e9SAndroid Build Coastguard Worker return self.conv(x) 1396*da0073e9SAndroid Build Coastguard Worker 1397*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(Test(), torch.rand(1, 3, 224, 224)) 1398*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 1399*da0073e9SAndroid Build Coastguard Worker torch.jit.save(traced, buffer) 1400*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 1401*da0073e9SAndroid Build Coastguard Worker loaded = torch.jit.load(buffer) 1402*da0073e9SAndroid Build Coastguard Worker # should work 1403*da0073e9SAndroid Build Coastguard Worker copy.copy(loaded) 1404*da0073e9SAndroid Build Coastguard Worker copy.deepcopy(loaded) 1405*da0073e9SAndroid Build Coastguard Worker 1406*da0073e9SAndroid Build Coastguard Worker def test_trace_export_fns(self): 1407*da0073e9SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 1408*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1409*da0073e9SAndroid Build Coastguard Worker super().__init__() 1410*da0073e9SAndroid Build Coastguard Worker self.a = 3 1411*da0073e9SAndroid Build Coastguard Worker 1412*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 1413*da0073e9SAndroid Build Coastguard Worker def __getstate__(self): 1414*da0073e9SAndroid Build Coastguard Worker return (3, self.training) 1415*da0073e9SAndroid Build Coastguard Worker 1416*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 1417*da0073e9SAndroid Build Coastguard Worker def __setstate__(self, state): 1418*da0073e9SAndroid Build Coastguard Worker self.a = state[0] 1419*da0073e9SAndroid Build Coastguard Worker self.training = state[1] 1420*da0073e9SAndroid Build Coastguard Worker 1421*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1422*da0073e9SAndroid Build Coastguard Worker return x + self.a 1423*da0073e9SAndroid Build Coastguard Worker 1424*da0073e9SAndroid Build Coastguard Worker f = Foo() 1425*da0073e9SAndroid Build Coastguard Worker 1426*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(f, (torch.rand(3, 4),)) 1427*da0073e9SAndroid Build Coastguard Worker expected_names = ["__getstate__", "__setstate__"] 1428*da0073e9SAndroid Build Coastguard Worker 1429*da0073e9SAndroid Build Coastguard Worker def check(mod): 1430*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1431*da0073e9SAndroid Build Coastguard Worker all(name in mod._c._method_names() for name in expected_names) 1432*da0073e9SAndroid Build Coastguard Worker ) 1433*da0073e9SAndroid Build Coastguard Worker 1434*da0073e9SAndroid Build Coastguard Worker check(traced) 1435*da0073e9SAndroid Build Coastguard Worker 1436*da0073e9SAndroid Build Coastguard Worker imported = self.getExportImportCopy(traced) 1437*da0073e9SAndroid Build Coastguard Worker check(imported) 1438*da0073e9SAndroid Build Coastguard Worker 1439*da0073e9SAndroid Build Coastguard Worker def test_trace_export_fns_recursive(self): 1440*da0073e9SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 1441*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1442*da0073e9SAndroid Build Coastguard Worker super().__init__() 1443*da0073e9SAndroid Build Coastguard Worker self.a = 3 1444*da0073e9SAndroid Build Coastguard Worker 1445*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 1446*da0073e9SAndroid Build Coastguard Worker def __getstate__(self): 1447*da0073e9SAndroid Build Coastguard Worker return (3, self.training) 1448*da0073e9SAndroid Build Coastguard Worker 1449*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 1450*da0073e9SAndroid Build Coastguard Worker def __setstate__(self, state): 1451*da0073e9SAndroid Build Coastguard Worker self.a = state[0] 1452*da0073e9SAndroid Build Coastguard Worker self.training = state[1] 1453*da0073e9SAndroid Build Coastguard Worker 1454*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1455*da0073e9SAndroid Build Coastguard Worker return x + self.a 1456*da0073e9SAndroid Build Coastguard Worker 1457*da0073e9SAndroid Build Coastguard Worker class Wrapper(torch.nn.Module): 1458*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1459*da0073e9SAndroid Build Coastguard Worker super().__init__() 1460*da0073e9SAndroid Build Coastguard Worker self.foo = Foo() 1461*da0073e9SAndroid Build Coastguard Worker 1462*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1463*da0073e9SAndroid Build Coastguard Worker return self.foo(x) 1464*da0073e9SAndroid Build Coastguard Worker 1465*da0073e9SAndroid Build Coastguard Worker f = Wrapper() 1466*da0073e9SAndroid Build Coastguard Worker 1467*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(f, (torch.rand(3, 4),)) 1468*da0073e9SAndroid Build Coastguard Worker expected_names = ["__getstate__", "__setstate__"] 1469*da0073e9SAndroid Build Coastguard Worker 1470*da0073e9SAndroid Build Coastguard Worker def check(mod): 1471*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1472*da0073e9SAndroid Build Coastguard Worker all(name in mod._c._method_names() for name in expected_names) 1473*da0073e9SAndroid Build Coastguard Worker ) 1474*da0073e9SAndroid Build Coastguard Worker 1475*da0073e9SAndroid Build Coastguard Worker check(traced.foo) 1476*da0073e9SAndroid Build Coastguard Worker 1477*da0073e9SAndroid Build Coastguard Worker imported = self.getExportImportCopy(traced) 1478*da0073e9SAndroid Build Coastguard Worker check(imported.foo) 1479*da0073e9SAndroid Build Coastguard Worker 1480*da0073e9SAndroid Build Coastguard Worker # Note that Bar's forward can only be traced, but not scripted 1481*da0073e9SAndroid Build Coastguard Worker class Bar(nn.Module): 1482*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 1483*da0073e9SAndroid Build Coastguard Worker def addTwo(self, x): 1484*da0073e9SAndroid Build Coastguard Worker return x + 2 1485*da0073e9SAndroid Build Coastguard Worker 1486*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 1487*da0073e9SAndroid Build Coastguard Worker return (lambda a: a + 1)(input) # noqa: PLC3002 1488*da0073e9SAndroid Build Coastguard Worker 1489*da0073e9SAndroid Build Coastguard Worker # When tracing Bar as a submodule, we only want to script the 1490*da0073e9SAndroid Build Coastguard Worker # exported methods, and we want to keep the forwards still 1491*da0073e9SAndroid Build Coastguard Worker # being traced. 1492*da0073e9SAndroid Build Coastguard Worker class WrapperExports(torch.nn.Module): 1493*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1494*da0073e9SAndroid Build Coastguard Worker super().__init__() 1495*da0073e9SAndroid Build Coastguard Worker self.bar = Bar() 1496*da0073e9SAndroid Build Coastguard Worker 1497*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 1498*da0073e9SAndroid Build Coastguard Worker def addOne(self, x): 1499*da0073e9SAndroid Build Coastguard Worker return x + 1 1500*da0073e9SAndroid Build Coastguard Worker 1501*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1502*da0073e9SAndroid Build Coastguard Worker return self.bar(x) 1503*da0073e9SAndroid Build Coastguard Worker 1504*da0073e9SAndroid Build Coastguard Worker f = WrapperExports() 1505*da0073e9SAndroid Build Coastguard Worker 1506*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(f, (torch.rand(3, 4),)) 1507*da0073e9SAndroid Build Coastguard Worker expected_names = ["addOne"] 1508*da0073e9SAndroid Build Coastguard Worker check(traced) 1509*da0073e9SAndroid Build Coastguard Worker 1510*da0073e9SAndroid Build Coastguard Worker def test_trace_autograd_function(self): 1511*da0073e9SAndroid Build Coastguard Worker class TestFunc(torch.autograd.Function): 1512*da0073e9SAndroid Build Coastguard Worker @staticmethod 1513*da0073e9SAndroid Build Coastguard Worker def forward(ctx, input): 1514*da0073e9SAndroid Build Coastguard Worker return torch.neg(input) 1515*da0073e9SAndroid Build Coastguard Worker 1516*da0073e9SAndroid Build Coastguard Worker @staticmethod 1517*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_output): 1518*da0073e9SAndroid Build Coastguard Worker return torch.neg(grad_output) 1519*da0073e9SAndroid Build Coastguard Worker 1520*da0073e9SAndroid Build Coastguard Worker class TracedModule(torch.nn.Module): 1521*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1522*da0073e9SAndroid Build Coastguard Worker return torch.relu(TestFunc.apply(x)) 1523*da0073e9SAndroid Build Coastguard Worker 1524*da0073e9SAndroid Build Coastguard Worker class Wrapper(torch.nn.Module): 1525*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1526*da0073e9SAndroid Build Coastguard Worker super().__init__() 1527*da0073e9SAndroid Build Coastguard Worker self.tm = TracedModule() 1528*da0073e9SAndroid Build Coastguard Worker 1529*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1530*da0073e9SAndroid Build Coastguard Worker return self.tm(x) 1531*da0073e9SAndroid Build Coastguard Worker 1532*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(Wrapper(), (torch.rand(3, 4),)) 1533*da0073e9SAndroid Build Coastguard Worker 1534*da0073e9SAndroid Build Coastguard Worker def test_trace_multi_output_function(self): 1535*da0073e9SAndroid Build Coastguard Worker # An autograd.Function with two outputs. 1536*da0073e9SAndroid Build Coastguard Worker # It swaps inputs so we can check if shape 1537*da0073e9SAndroid Build Coastguard Worker # handling is correct in TorchScript. 1538*da0073e9SAndroid Build Coastguard Worker class Foo(torch.autograd.Function): 1539*da0073e9SAndroid Build Coastguard Worker @staticmethod 1540*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x, y): 1541*da0073e9SAndroid Build Coastguard Worker return y, x 1542*da0073e9SAndroid Build Coastguard Worker 1543*da0073e9SAndroid Build Coastguard Worker @staticmethod 1544*da0073e9SAndroid Build Coastguard Worker def backward(ctx, du, dv): 1545*da0073e9SAndroid Build Coastguard Worker return dv, du 1546*da0073e9SAndroid Build Coastguard Worker 1547*da0073e9SAndroid Build Coastguard Worker class Bar(torch.nn.Module): 1548*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y): 1549*da0073e9SAndroid Build Coastguard Worker x = x.relu() 1550*da0073e9SAndroid Build Coastguard Worker y = y.relu() 1551*da0073e9SAndroid Build Coastguard Worker z = Foo.apply(x, y) 1552*da0073e9SAndroid Build Coastguard Worker return z 1553*da0073e9SAndroid Build Coastguard Worker 1554*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 2, dtype=torch.double) 1555*da0073e9SAndroid Build Coastguard Worker y = torch.rand(1, 2, dtype=torch.double) 1556*da0073e9SAndroid Build Coastguard Worker 1557*da0073e9SAndroid Build Coastguard Worker # Generate JIT IR. 1558*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(Bar(), (x, y)) 1559*da0073e9SAndroid Build Coastguard Worker print(traced.graph) 1560*da0073e9SAndroid Build Coastguard Worker 1561*da0073e9SAndroid Build Coastguard Worker # Expected output schema of the custom autograd.Function. 1562*da0073e9SAndroid Build Coastguard Worker schema = ( 1563*da0073e9SAndroid Build Coastguard Worker "(Double(1, 2, strides=[2, 1], requires_grad=0, device=cpu), " 1564*da0073e9SAndroid Build Coastguard Worker "Double(3, 2, strides=[2, 1], requires_grad=0, device=cpu)) " 1565*da0073e9SAndroid Build Coastguard Worker "= ^Foo" 1566*da0073e9SAndroid Build Coastguard Worker ) 1567*da0073e9SAndroid Build Coastguard Worker 1568*da0073e9SAndroid Build Coastguard Worker # See if expected schema exists. 1569*da0073e9SAndroid Build Coastguard Worker FileCheck().check(schema).run(traced.graph) 1570*da0073e9SAndroid Build Coastguard Worker 1571*da0073e9SAndroid Build Coastguard Worker # Also examine if the graph is runnable and produces 1572*da0073e9SAndroid Build Coastguard Worker # the right result. 1573*da0073e9SAndroid Build Coastguard Worker u, v = traced(x, y) 1574*da0073e9SAndroid Build Coastguard Worker self.assertEqual(u, y) 1575*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v, x) 1576*da0073e9SAndroid Build Coastguard Worker 1577*da0073e9SAndroid Build Coastguard Worker def test_interpolate_trace(self): 1578*da0073e9SAndroid Build Coastguard Worker class test(nn.Module): 1579*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1580*da0073e9SAndroid Build Coastguard Worker super().__init__() 1581*da0073e9SAndroid Build Coastguard Worker self.conv = nn.Conv2d(1, 32, kernel_size=3, padding=1) 1582*da0073e9SAndroid Build Coastguard Worker 1583*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1584*da0073e9SAndroid Build Coastguard Worker y = self.conv(x) 1585*da0073e9SAndroid Build Coastguard Worker w = nn.functional.interpolate( 1586*da0073e9SAndroid Build Coastguard Worker y, mode="bilinear", align_corners=False, scale_factor=3 1587*da0073e9SAndroid Build Coastguard Worker ) 1588*da0073e9SAndroid Build Coastguard Worker return w 1589*da0073e9SAndroid Build Coastguard Worker 1590*da0073e9SAndroid Build Coastguard Worker f = test() 1591*da0073e9SAndroid Build Coastguard Worker # no failure 1592*da0073e9SAndroid Build Coastguard Worker g = torch.jit.trace(f, (torch.zeros(1, 1, 28, 28),)) 1593*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(1, 1, 14, 14) 1594*da0073e9SAndroid Build Coastguard Worker # constants not baked in 1595*da0073e9SAndroid Build Coastguard Worker self.assertEqual(g(x), f(x)) 1596*da0073e9SAndroid Build Coastguard Worker 1597*da0073e9SAndroid Build Coastguard Worker @_tmp_donotuse_dont_inline_everything 1598*da0073e9SAndroid Build Coastguard Worker def test_trace_optional(self): 1599*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 1600*da0073e9SAndroid Build Coastguard Worker def test(x: Optional[Tensor]): 1601*da0073e9SAndroid Build Coastguard Worker if x is None: 1602*da0073e9SAndroid Build Coastguard Worker return torch.zeros(1) 1603*da0073e9SAndroid Build Coastguard Worker else: 1604*da0073e9SAndroid Build Coastguard Worker return x 1605*da0073e9SAndroid Build Coastguard Worker 1606*da0073e9SAndroid Build Coastguard Worker def test_none(): 1607*da0073e9SAndroid Build Coastguard Worker return test(None) 1608*da0073e9SAndroid Build Coastguard Worker 1609*da0073e9SAndroid Build Coastguard Worker def test_tensor(): 1610*da0073e9SAndroid Build Coastguard Worker return test(torch.zeros(2)) 1611*da0073e9SAndroid Build Coastguard Worker 1612*da0073e9SAndroid Build Coastguard Worker f_none = torch.jit.trace(test_none, ()) 1613*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f_none(), torch.zeros(1)) 1614*da0073e9SAndroid Build Coastguard Worker 1615*da0073e9SAndroid Build Coastguard Worker f_tensor = torch.jit.trace(test_tensor, ()) 1616*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f_tensor(), torch.zeros(2)) 1617*da0073e9SAndroid Build Coastguard Worker 1618*da0073e9SAndroid Build Coastguard Worker graph = f_tensor.graph 1619*da0073e9SAndroid Build Coastguard Worker FileCheck().check('name="test"').check_next("prim::CallFunction").run(graph) 1620*da0073e9SAndroid Build Coastguard Worker 1621*da0073e9SAndroid Build Coastguard Worker def test_trace_nested_datatypes(self): 1622*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 1623*da0073e9SAndroid Build Coastguard Worker def foo(x): 1624*da0073e9SAndroid Build Coastguard Worker return [[x + 1, x - 1], [x + 2, x - 2]] 1625*da0073e9SAndroid Build Coastguard Worker 1626*da0073e9SAndroid Build Coastguard Worker def bar(x): 1627*da0073e9SAndroid Build Coastguard Worker list_stuff = foo(x) 1628*da0073e9SAndroid Build Coastguard Worker return list_stuff[0][0], list_stuff[1][1] 1629*da0073e9SAndroid Build Coastguard Worker 1630*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(bar, torch.rand(3, 4)) 1631*da0073e9SAndroid Build Coastguard Worker x = torch.rand(5, 6) 1632*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bar(x), traced(x)) 1633*da0073e9SAndroid Build Coastguard Worker 1634*da0073e9SAndroid Build Coastguard Worker @_tmp_donotuse_dont_inline_everything 1635*da0073e9SAndroid Build Coastguard Worker def test_call_traced_fn_from_traced_module(self): 1636*da0073e9SAndroid Build Coastguard Worker @_trace(torch.rand(3, 4)) 1637*da0073e9SAndroid Build Coastguard Worker def traced_fn(x): 1638*da0073e9SAndroid Build Coastguard Worker return torch.neg(x) 1639*da0073e9SAndroid Build Coastguard Worker 1640*da0073e9SAndroid Build Coastguard Worker class TracedModule(torch.nn.Module): 1641*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1642*da0073e9SAndroid Build Coastguard Worker super().__init__() 1643*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.rand(4, 5)) 1644*da0073e9SAndroid Build Coastguard Worker 1645*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1646*da0073e9SAndroid Build Coastguard Worker return traced_fn(torch.mm(x, self.param)) 1647*da0073e9SAndroid Build Coastguard Worker 1648*da0073e9SAndroid Build Coastguard Worker tm = torch.jit.trace(TracedModule(), torch.rand(3, 4)) 1649*da0073e9SAndroid Build Coastguard Worker 1650*da0073e9SAndroid Build Coastguard Worker # Note: neg op from the traced function should be properly inlined 1651*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::mm").check('name="traced_fn"').check_next( 1652*da0073e9SAndroid Build Coastguard Worker "prim::CallFunction" 1653*da0073e9SAndroid Build Coastguard Worker ).run(str(tm.graph)) 1654*da0073e9SAndroid Build Coastguard Worker 1655*da0073e9SAndroid Build Coastguard Worker @_tmp_donotuse_dont_inline_everything 1656*da0073e9SAndroid Build Coastguard Worker def test_call_traced_module_from_traced_module(self): 1657*da0073e9SAndroid Build Coastguard Worker class TracedModule1(torch.nn.Module): 1658*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1659*da0073e9SAndroid Build Coastguard Worker super().__init__() 1660*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.rand(5, 7)) 1661*da0073e9SAndroid Build Coastguard Worker 1662*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1663*da0073e9SAndroid Build Coastguard Worker return torch.mm(x, self.param) 1664*da0073e9SAndroid Build Coastguard Worker 1665*da0073e9SAndroid Build Coastguard Worker class TracedModule(torch.nn.Module): 1666*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1667*da0073e9SAndroid Build Coastguard Worker super().__init__() 1668*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.rand(4, 5)) 1669*da0073e9SAndroid Build Coastguard Worker self.mod = torch.jit.trace(TracedModule1(), torch.rand(3, 5)) 1670*da0073e9SAndroid Build Coastguard Worker 1671*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1672*da0073e9SAndroid Build Coastguard Worker return self.mod(torch.mm(x, self.param)) + 1.0 1673*da0073e9SAndroid Build Coastguard Worker 1674*da0073e9SAndroid Build Coastguard Worker tm = torch.jit.trace(TracedModule(), torch.rand(3, 4)) 1675*da0073e9SAndroid Build Coastguard Worker 1676*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::mm").check("prim::CallMethod").check_same( 1677*da0073e9SAndroid Build Coastguard Worker "forward" 1678*da0073e9SAndroid Build Coastguard Worker ).check("aten::add").run(str(tm.graph)) 1679*da0073e9SAndroid Build Coastguard Worker 1680*da0073e9SAndroid Build Coastguard Worker def test_index_put_trace_with_view(self): 1681*da0073e9SAndroid Build Coastguard Worker @_trace(torch.rand(100), torch.tensor([1, 2, 3, 4]), torch.rand(1, 1, 1, 4)) 1682*da0073e9SAndroid Build Coastguard Worker def test_index_put(target, indices, rhs): 1683*da0073e9SAndroid Build Coastguard Worker target[indices] = rhs 1684*da0073e9SAndroid Build Coastguard Worker return target 1685*da0073e9SAndroid Build Coastguard Worker 1686*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::view").check("index_put_").run( 1687*da0073e9SAndroid Build Coastguard Worker str(test_index_put.graph) 1688*da0073e9SAndroid Build Coastguard Worker ) 1689*da0073e9SAndroid Build Coastguard Worker 1690*da0073e9SAndroid Build Coastguard Worker def test_index_put_trace_without_view(self): 1691*da0073e9SAndroid Build Coastguard Worker @_trace(torch.rand(100), torch.tensor([1, 2, 3, 4]), torch.rand(4)) 1692*da0073e9SAndroid Build Coastguard Worker def test_index_put(target, indices, rhs): 1693*da0073e9SAndroid Build Coastguard Worker target[indices] = rhs 1694*da0073e9SAndroid Build Coastguard Worker return target 1695*da0073e9SAndroid Build Coastguard Worker 1696*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("aten::view").check("index_put_").run( 1697*da0073e9SAndroid Build Coastguard Worker str(test_index_put.graph) 1698*da0073e9SAndroid Build Coastguard Worker ) 1699*da0073e9SAndroid Build Coastguard Worker 1700*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 1701*da0073e9SAndroid Build Coastguard Worker def test_trace_checker_dot_data(self): 1702*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1703*da0073e9SAndroid Build Coastguard Worker torch.jit.TracingCheckError, 1704*da0073e9SAndroid Build Coastguard Worker r"Tensor-valued Constant nodes differed in value " r"across invocations", 1705*da0073e9SAndroid Build Coastguard Worker ): 1706*da0073e9SAndroid Build Coastguard Worker 1707*da0073e9SAndroid Build Coastguard Worker @_trace(torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)]) 1708*da0073e9SAndroid Build Coastguard Worker def foo(x): 1709*da0073e9SAndroid Build Coastguard Worker y = x.data 1710*da0073e9SAndroid Build Coastguard Worker return x + y 1711*da0073e9SAndroid Build Coastguard Worker 1712*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 1713*da0073e9SAndroid Build Coastguard Worker def test_trace_checker_control_flow(self): 1714*da0073e9SAndroid Build Coastguard Worker def foo(x): 1715*da0073e9SAndroid Build Coastguard Worker for _ in range(x.size(0)): 1716*da0073e9SAndroid Build Coastguard Worker x = torch.neg(x) 1717*da0073e9SAndroid Build Coastguard Worker return x 1718*da0073e9SAndroid Build Coastguard Worker 1719*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1720*da0073e9SAndroid Build Coastguard Worker torch.jit.TracingCheckError, r"Graphs differed across invocations!" 1721*da0073e9SAndroid Build Coastguard Worker ): 1722*da0073e9SAndroid Build Coastguard Worker torch.jit.trace(foo, torch.randn(3, 4), check_inputs=[torch.randn(4, 4)]) 1723*da0073e9SAndroid Build Coastguard Worker 1724*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 1725*da0073e9SAndroid Build Coastguard Worker def test_trace_checker_memoization(self): 1726*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1727*da0073e9SAndroid Build Coastguard Worker torch.jit.TracingCheckError, r"Graphs differed across invocations!" 1728*da0073e9SAndroid Build Coastguard Worker ): 1729*da0073e9SAndroid Build Coastguard Worker 1730*da0073e9SAndroid Build Coastguard Worker def foo(x): 1731*da0073e9SAndroid Build Coastguard Worker if not hasattr(foo, "cache"): 1732*da0073e9SAndroid Build Coastguard Worker foo.cache = torch.neg(x) 1733*da0073e9SAndroid Build Coastguard Worker return x + foo.cache 1734*da0073e9SAndroid Build Coastguard Worker 1735*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace( 1736*da0073e9SAndroid Build Coastguard Worker foo, torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)] 1737*da0073e9SAndroid Build Coastguard Worker ) 1738*da0073e9SAndroid Build Coastguard Worker 1739*da0073e9SAndroid Build Coastguard Worker def test_trace_checker_slice_lhs(self): 1740*da0073e9SAndroid Build Coastguard Worker def foo(x): 1741*da0073e9SAndroid Build Coastguard Worker for i in range(3): 1742*da0073e9SAndroid Build Coastguard Worker x[i, :] = torch.zeros(4) 1743*da0073e9SAndroid Build Coastguard Worker return x 1744*da0073e9SAndroid Build Coastguard Worker 1745*da0073e9SAndroid Build Coastguard Worker self.checkTrace(foo, (torch.rand(3, 4),), inputs_require_grads=False) 1746*da0073e9SAndroid Build Coastguard Worker 1747*da0073e9SAndroid Build Coastguard Worker def test_trace_checker_inplace_on_view(self): 1748*da0073e9SAndroid Build Coastguard Worker def foo(x): 1749*da0073e9SAndroid Build Coastguard Worker x.view(-1).add_(-x.view(-1)) 1750*da0073e9SAndroid Build Coastguard Worker return x 1751*da0073e9SAndroid Build Coastguard Worker 1752*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex( 1753*da0073e9SAndroid Build Coastguard Worker torch.jit.TracerWarning, 1754*da0073e9SAndroid Build Coastguard Worker "Output nr 1. of the traced function does not match the " 1755*da0073e9SAndroid Build Coastguard Worker "corresponding output of the Python function", 1756*da0073e9SAndroid Build Coastguard Worker ): 1757*da0073e9SAndroid Build Coastguard Worker torch.jit.trace( 1758*da0073e9SAndroid Build Coastguard Worker foo, 1759*da0073e9SAndroid Build Coastguard Worker torch.rand(3, 4), 1760*da0073e9SAndroid Build Coastguard Worker check_inputs=[torch.rand(5, 6)], 1761*da0073e9SAndroid Build Coastguard Worker _force_outplace=True, 1762*da0073e9SAndroid Build Coastguard Worker ) 1763*da0073e9SAndroid Build Coastguard Worker 1764*da0073e9SAndroid Build Coastguard Worker def test_lhs_index_fails(self): 1765*da0073e9SAndroid Build Coastguard Worker def foo(x): 1766*da0073e9SAndroid Build Coastguard Worker x[0, 1] = 4 1767*da0073e9SAndroid Build Coastguard Worker return x 1768*da0073e9SAndroid Build Coastguard Worker 1769*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex( 1770*da0073e9SAndroid Build Coastguard Worker torch.jit.TracerWarning, "cause the trace to be incorrect" 1771*da0073e9SAndroid Build Coastguard Worker ): 1772*da0073e9SAndroid Build Coastguard Worker torch.jit.trace(foo, torch.rand(3, 4), _force_outplace=True) 1773*da0073e9SAndroid Build Coastguard Worker 1774*da0073e9SAndroid Build Coastguard Worker def test_lhs_index_trivial(self): 1775*da0073e9SAndroid Build Coastguard Worker def foo(y, x): 1776*da0073e9SAndroid Build Coastguard Worker y[...] = x 1777*da0073e9SAndroid Build Coastguard Worker return y 1778*da0073e9SAndroid Build Coastguard Worker 1779*da0073e9SAndroid Build Coastguard Worker self.checkTrace( 1780*da0073e9SAndroid Build Coastguard Worker foo, (torch.rand(3, 4), torch.rand(4)), inputs_require_grads=False 1781*da0073e9SAndroid Build Coastguard Worker ) 1782*da0073e9SAndroid Build Coastguard Worker 1783*da0073e9SAndroid Build Coastguard Worker def test_inplace_warn(self): 1784*da0073e9SAndroid Build Coastguard Worker def foo(x): 1785*da0073e9SAndroid Build Coastguard Worker x.view(-1).add_(-x.view(-1)) 1786*da0073e9SAndroid Build Coastguard Worker return x 1787*da0073e9SAndroid Build Coastguard Worker 1788*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex( 1789*da0073e9SAndroid Build Coastguard Worker torch.jit.TracerWarning, "cause the trace to be incorrect" 1790*da0073e9SAndroid Build Coastguard Worker ): 1791*da0073e9SAndroid Build Coastguard Worker torch.jit.trace(foo, torch.rand(3, 4), _force_outplace=True) 1792*da0073e9SAndroid Build Coastguard Worker 1793*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 1794*da0073e9SAndroid Build Coastguard Worker def test_trace_checker_dropout_train(self): 1795*da0073e9SAndroid Build Coastguard Worker def foo(x): 1796*da0073e9SAndroid Build Coastguard Worker return torch.dropout(x, p=0.5, train=True) 1797*da0073e9SAndroid Build Coastguard Worker 1798*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex( 1799*da0073e9SAndroid Build Coastguard Worker torch.jit.TracerWarning, 1800*da0073e9SAndroid Build Coastguard Worker "Output nr 1. of the traced function does not match the " 1801*da0073e9SAndroid Build Coastguard Worker "corresponding output of the Python function", 1802*da0073e9SAndroid Build Coastguard Worker ): 1803*da0073e9SAndroid Build Coastguard Worker torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)]) 1804*da0073e9SAndroid Build Coastguard Worker 1805*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex( 1806*da0073e9SAndroid Build Coastguard Worker torch.jit.TracerWarning, "Trace had nondeterministic nodes" 1807*da0073e9SAndroid Build Coastguard Worker ): 1808*da0073e9SAndroid Build Coastguard Worker torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)]) 1809*da0073e9SAndroid Build Coastguard Worker 1810*da0073e9SAndroid Build Coastguard Worker def test_trace_checker_dropout_notrain(self): 1811*da0073e9SAndroid Build Coastguard Worker input = torch.rand(3, 4) 1812*da0073e9SAndroid Build Coastguard Worker 1813*da0073e9SAndroid Build Coastguard Worker @_trace(input) 1814*da0073e9SAndroid Build Coastguard Worker def foo(x): 1815*da0073e9SAndroid Build Coastguard Worker return torch.dropout(x, p=0.5, train=False) 1816*da0073e9SAndroid Build Coastguard Worker 1817*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(input), input) 1818*da0073e9SAndroid Build Coastguard Worker 1819*da0073e9SAndroid Build Coastguard Worker def test_trace_contiguous(self): 1820*da0073e9SAndroid Build Coastguard Worker def foo(x): 1821*da0073e9SAndroid Build Coastguard Worker return x[:, :, ::2].contiguous().view(12) 1822*da0073e9SAndroid Build Coastguard Worker 1823*da0073e9SAndroid Build Coastguard Worker x = torch.rand(2, 3, 4) 1824*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(foo, (x,)) 1825*da0073e9SAndroid Build Coastguard Worker y = traced(x) 1826*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(x.storage().data_ptr(), y.storage().data_ptr()) 1827*da0073e9SAndroid Build Coastguard Worker 1828*da0073e9SAndroid Build Coastguard Worker # This tests the logic in THPVariable_contiguous. There is short-circuiting 1829*da0073e9SAndroid Build Coastguard Worker # code that prevents us from even getting to VariableType::contiguous, since 1830*da0073e9SAndroid Build Coastguard Worker # it is an optimization that prevents us from acquiring the GIL for touching 1831*da0073e9SAndroid Build Coastguard Worker # the device. We needed to add the tracing logic directly into the 1832*da0073e9SAndroid Build Coastguard Worker # THPVariable_contiguous function only for the path where we are skipping 1833*da0073e9SAndroid Build Coastguard Worker # dispatch into contiguous. We should see an aten::contiguous in this trace! 1834*da0073e9SAndroid Build Coastguard Worker def test_trace_contiguous_short_circuit(self): 1835*da0073e9SAndroid Build Coastguard Worker def foo(x): 1836*da0073e9SAndroid Build Coastguard Worker return x.contiguous() 1837*da0073e9SAndroid Build Coastguard Worker 1838*da0073e9SAndroid Build Coastguard Worker x = torch.rand(2, 3, 4) 1839*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(foo, (x,)) 1840*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::contiguous").run(str(traced.graph)) 1841*da0073e9SAndroid Build Coastguard Worker 1842*da0073e9SAndroid Build Coastguard Worker def test_trace_inverse(self): 1843*da0073e9SAndroid Build Coastguard Worker def foo(x): 1844*da0073e9SAndroid Build Coastguard Worker return ~x 1845*da0073e9SAndroid Build Coastguard Worker 1846*da0073e9SAndroid Build Coastguard Worker foo_traced = torch.jit.trace(foo, torch.zeros(3, 4, dtype=torch.uint8)) 1847*da0073e9SAndroid Build Coastguard Worker eg = torch.zeros(3, dtype=torch.uint8) 1848*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo_traced(eg), foo(eg)) 1849*da0073e9SAndroid Build Coastguard Worker 1850*da0073e9SAndroid Build Coastguard Worker def test_trace_modulelist(self): 1851*da0073e9SAndroid Build Coastguard Worker class MySubmod(torch.nn.Module): 1852*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1853*da0073e9SAndroid Build Coastguard Worker super().__init__() 1854*da0073e9SAndroid Build Coastguard Worker self.relu = torch.nn.ReLU() 1855*da0073e9SAndroid Build Coastguard Worker 1856*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1857*da0073e9SAndroid Build Coastguard Worker return self.relu(x) 1858*da0073e9SAndroid Build Coastguard Worker 1859*da0073e9SAndroid Build Coastguard Worker class MyMod(torch.nn.Module): 1860*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1861*da0073e9SAndroid Build Coastguard Worker super().__init__() 1862*da0073e9SAndroid Build Coastguard Worker self.ml = torch.nn.ModuleList([MySubmod(), MySubmod()]) 1863*da0073e9SAndroid Build Coastguard Worker 1864*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1865*da0073e9SAndroid Build Coastguard Worker for mod in self.ml: 1866*da0073e9SAndroid Build Coastguard Worker x = mod(x) 1867*da0073e9SAndroid Build Coastguard Worker return x 1868*da0073e9SAndroid Build Coastguard Worker 1869*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(MyMod(), (torch.rand(3, 4),)) 1870*da0073e9SAndroid Build Coastguard Worker 1871*da0073e9SAndroid Build Coastguard Worker def test_trace_fork_join_and_module(self): 1872*da0073e9SAndroid Build Coastguard Worker class MySubmod(torch.nn.Module): 1873*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1874*da0073e9SAndroid Build Coastguard Worker super().__init__() 1875*da0073e9SAndroid Build Coastguard Worker self.relu = torch.nn.ReLU() 1876*da0073e9SAndroid Build Coastguard Worker 1877*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1878*da0073e9SAndroid Build Coastguard Worker return self.relu(x), torch.neg(x) 1879*da0073e9SAndroid Build Coastguard Worker 1880*da0073e9SAndroid Build Coastguard Worker class Mod(torch.nn.Module): 1881*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1882*da0073e9SAndroid Build Coastguard Worker super().__init__() 1883*da0073e9SAndroid Build Coastguard Worker self.ml = torch.nn.ModuleList([MySubmod() for i in range(2)]) 1884*da0073e9SAndroid Build Coastguard Worker 1885*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1886*da0073e9SAndroid Build Coastguard Worker futs = [] 1887*da0073e9SAndroid Build Coastguard Worker for i in range(2): 1888*da0073e9SAndroid Build Coastguard Worker futs.append(torch.jit._fork(self.ml[i], x)) 1889*da0073e9SAndroid Build Coastguard Worker 1890*da0073e9SAndroid Build Coastguard Worker results = [] 1891*da0073e9SAndroid Build Coastguard Worker for i in range(2): 1892*da0073e9SAndroid Build Coastguard Worker results.append(torch.jit._wait(futs[i])[0]) 1893*da0073e9SAndroid Build Coastguard Worker 1894*da0073e9SAndroid Build Coastguard Worker return torch.stack(results) 1895*da0073e9SAndroid Build Coastguard Worker 1896*da0073e9SAndroid Build Coastguard Worker m = Mod() 1897*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(m, torch.rand(3, 4)) 1898*da0073e9SAndroid Build Coastguard Worker 1899*da0073e9SAndroid Build Coastguard Worker def test_trace_invert_module_hierarchy(self): 1900*da0073e9SAndroid Build Coastguard Worker class MySubmod(torch.nn.Module): 1901*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1902*da0073e9SAndroid Build Coastguard Worker super().__init__() 1903*da0073e9SAndroid Build Coastguard Worker self.relu = torch.nn.ReLU() 1904*da0073e9SAndroid Build Coastguard Worker 1905*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1906*da0073e9SAndroid Build Coastguard Worker return self.relu(x), torch.neg(x) 1907*da0073e9SAndroid Build Coastguard Worker 1908*da0073e9SAndroid Build Coastguard Worker class MyFunctionalMod(torch.nn.Module): 1909*da0073e9SAndroid Build Coastguard Worker def forward(self, x, submod): 1910*da0073e9SAndroid Build Coastguard Worker return submod(x) 1911*da0073e9SAndroid Build Coastguard Worker 1912*da0073e9SAndroid Build Coastguard Worker class Mod(torch.nn.Module): 1913*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1914*da0073e9SAndroid Build Coastguard Worker super().__init__() 1915*da0073e9SAndroid Build Coastguard Worker self.sm = MySubmod() 1916*da0073e9SAndroid Build Coastguard Worker self.fm = MyFunctionalMod() 1917*da0073e9SAndroid Build Coastguard Worker 1918*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1919*da0073e9SAndroid Build Coastguard Worker return self.fm(x, self.sm) 1920*da0073e9SAndroid Build Coastguard Worker 1921*da0073e9SAndroid Build Coastguard Worker torch.jit.trace(Mod(), (torch.rand(3, 4),)) 1922*da0073e9SAndroid Build Coastguard Worker 1923*da0073e9SAndroid Build Coastguard Worker @skipIfCrossRef 1924*da0073e9SAndroid Build Coastguard Worker def test_trace_records_names(self): 1925*da0073e9SAndroid Build Coastguard Worker def foo(bar, baz): 1926*da0073e9SAndroid Build Coastguard Worker baz = bar + 3 1927*da0073e9SAndroid Build Coastguard Worker quick_brown_fox = torch.neg(baz) 1928*da0073e9SAndroid Build Coastguard Worker for _ in range(20): 1929*da0073e9SAndroid Build Coastguard Worker yeet = quick_brown_fox - 3.14 1930*da0073e9SAndroid Build Coastguard Worker return yeet 1931*da0073e9SAndroid Build Coastguard Worker 1932*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(foo, (torch.rand(3, 3), torch.rand(3, 3))) 1933*da0073e9SAndroid Build Coastguard Worker graph_str = str(traced.graph) 1934*da0073e9SAndroid Build Coastguard Worker assert "bar" in graph_str 1935*da0073e9SAndroid Build Coastguard Worker assert "baz" in graph_str 1936*da0073e9SAndroid Build Coastguard Worker assert "quick_brown_fox" in graph_str 1937*da0073e9SAndroid Build Coastguard Worker 1938*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Not a suitable test for TorchDynamo") 1939*da0073e9SAndroid Build Coastguard Worker def test_tracing_hooks(self): 1940*da0073e9SAndroid Build Coastguard Worker class Net(nn.Module): 1941*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1942*da0073e9SAndroid Build Coastguard Worker return x + x 1943*da0073e9SAndroid Build Coastguard Worker 1944*da0073e9SAndroid Build Coastguard Worker def test_hook(is_post_hook, hook, fc): 1945*da0073e9SAndroid Build Coastguard Worker n = Net() 1946*da0073e9SAndroid Build Coastguard Worker if is_post_hook: 1947*da0073e9SAndroid Build Coastguard Worker n.register_forward_hook(hook) 1948*da0073e9SAndroid Build Coastguard Worker else: 1949*da0073e9SAndroid Build Coastguard Worker n.register_forward_pre_hook(hook) 1950*da0073e9SAndroid Build Coastguard Worker 1951*da0073e9SAndroid Build Coastguard Worker module = torch.jit.trace(n, (torch.tensor(1.0),)) 1952*da0073e9SAndroid Build Coastguard Worker 1953*da0073e9SAndroid Build Coastguard Worker eager_input = torch.tensor(1.0) 1954*da0073e9SAndroid Build Coastguard Worker eager_out = n(eager_input) 1955*da0073e9SAndroid Build Coastguard Worker 1956*da0073e9SAndroid Build Coastguard Worker fc.run(module.forward.graph) 1957*da0073e9SAndroid Build Coastguard Worker input = torch.tensor(1.0) 1958*da0073e9SAndroid Build Coastguard Worker output = module(input) 1959*da0073e9SAndroid Build Coastguard Worker 1960*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input, eager_input) 1961*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, eager_out) 1962*da0073e9SAndroid Build Coastguard Worker 1963*da0073e9SAndroid Build Coastguard Worker def hook_no_return(mod, input, output): 1964*da0073e9SAndroid Build Coastguard Worker input[0].add_(1) 1965*da0073e9SAndroid Build Coastguard Worker output.sub_(1) 1966*da0073e9SAndroid Build Coastguard Worker 1967*da0073e9SAndroid Build Coastguard Worker fc = FileCheck().check("add(").check("add_(").check("sub_(") 1968*da0073e9SAndroid Build Coastguard Worker test_hook(True, hook_no_return, fc) 1969*da0073e9SAndroid Build Coastguard Worker 1970*da0073e9SAndroid Build Coastguard Worker def hook_return(mod, input, output): 1971*da0073e9SAndroid Build Coastguard Worker input[0].add_(1) 1972*da0073e9SAndroid Build Coastguard Worker return output - 3 1973*da0073e9SAndroid Build Coastguard Worker 1974*da0073e9SAndroid Build Coastguard Worker fc = FileCheck().check("add(").check("add_(").check("sub(") 1975*da0073e9SAndroid Build Coastguard Worker test_hook(True, hook_return, fc) 1976*da0073e9SAndroid Build Coastguard Worker 1977*da0073e9SAndroid Build Coastguard Worker b = torch.tensor(3.0) 1978*da0073e9SAndroid Build Coastguard Worker 1979*da0073e9SAndroid Build Coastguard Worker def captured_hook(mod, input, output): 1980*da0073e9SAndroid Build Coastguard Worker return output - b 1981*da0073e9SAndroid Build Coastguard Worker 1982*da0073e9SAndroid Build Coastguard Worker fc = FileCheck().check("add(").check("sub(") 1983*da0073e9SAndroid Build Coastguard Worker test_hook(True, captured_hook, fc) 1984*da0073e9SAndroid Build Coastguard Worker 1985*da0073e9SAndroid Build Coastguard Worker def pre_hook_no_ret(mod, input): 1986*da0073e9SAndroid Build Coastguard Worker input[0].add_(3) 1987*da0073e9SAndroid Build Coastguard Worker 1988*da0073e9SAndroid Build Coastguard Worker fc = FileCheck().check("add_(").check("add(") 1989*da0073e9SAndroid Build Coastguard Worker test_hook(False, pre_hook_no_ret, fc) 1990*da0073e9SAndroid Build Coastguard Worker 1991*da0073e9SAndroid Build Coastguard Worker def pre_hook_ret(mod, input): 1992*da0073e9SAndroid Build Coastguard Worker return input[0] - 4 1993*da0073e9SAndroid Build Coastguard Worker 1994*da0073e9SAndroid Build Coastguard Worker fc = FileCheck().check("sub(").check("add(") 1995*da0073e9SAndroid Build Coastguard Worker test_hook(False, pre_hook_ret, fc) 1996*da0073e9SAndroid Build Coastguard Worker 1997*da0073e9SAndroid Build Coastguard Worker def test_tracing_backward_hook_error(self): 1998*da0073e9SAndroid Build Coastguard Worker class Net(nn.Module): 1999*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2000*da0073e9SAndroid Build Coastguard Worker return x + x 2001*da0073e9SAndroid Build Coastguard Worker 2002*da0073e9SAndroid Build Coastguard Worker n = Net() 2003*da0073e9SAndroid Build Coastguard Worker 2004*da0073e9SAndroid Build Coastguard Worker def backward_hook(module, grad_input, grad_output): 2005*da0073e9SAndroid Build Coastguard Worker pass 2006*da0073e9SAndroid Build Coastguard Worker 2007*da0073e9SAndroid Build Coastguard Worker n.register_backward_hook(backward_hook) 2008*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(Exception, "backward hooks assigned"): 2009*da0073e9SAndroid Build Coastguard Worker torch.jit.trace(n, (torch.tensor(1.0),)) 2010*da0073e9SAndroid Build Coastguard Worker 2011*da0073e9SAndroid Build Coastguard Worker def test_tracing_multiple_methods(self): 2012*da0073e9SAndroid Build Coastguard Worker class Net(nn.Module): 2013*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2014*da0073e9SAndroid Build Coastguard Worker super().__init__() 2015*da0073e9SAndroid Build Coastguard Worker self.conv = nn.Conv2d(1, 1, 3) 2016*da0073e9SAndroid Build Coastguard Worker 2017*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2018*da0073e9SAndroid Build Coastguard Worker return self.conv(x) 2019*da0073e9SAndroid Build Coastguard Worker 2020*da0073e9SAndroid Build Coastguard Worker def weighted_kernel_sum(self, weight): 2021*da0073e9SAndroid Build Coastguard Worker return weight * self.conv.weight 2022*da0073e9SAndroid Build Coastguard Worker 2023*da0073e9SAndroid Build Coastguard Worker example_weight = torch.rand(1, 1, 3, 3) 2024*da0073e9SAndroid Build Coastguard Worker example_forward_input = torch.rand(1, 1, 3, 3) 2025*da0073e9SAndroid Build Coastguard Worker inputs = { 2026*da0073e9SAndroid Build Coastguard Worker "forward": example_forward_input, 2027*da0073e9SAndroid Build Coastguard Worker "weighted_kernel_sum": example_weight, 2028*da0073e9SAndroid Build Coastguard Worker } 2029*da0073e9SAndroid Build Coastguard Worker n = Net() 2030*da0073e9SAndroid Build Coastguard Worker module = torch.jit.trace_module(n, inputs) 2031*da0073e9SAndroid Build Coastguard Worker 2032*da0073e9SAndroid Build Coastguard Worker check_inputs = [] 2033*da0073e9SAndroid Build Coastguard Worker for i in range(2): 2034*da0073e9SAndroid Build Coastguard Worker check_weight = torch.rand(1, 1, 3, 3) 2035*da0073e9SAndroid Build Coastguard Worker check_forward_input = torch.rand(1, 1, 3, 3) 2036*da0073e9SAndroid Build Coastguard Worker check_inputs.append( 2037*da0073e9SAndroid Build Coastguard Worker {"forward": check_forward_input, "weighted_kernel_sum": check_weight} 2038*da0073e9SAndroid Build Coastguard Worker ) 2039*da0073e9SAndroid Build Coastguard Worker module = torch.jit.trace_module( 2040*da0073e9SAndroid Build Coastguard Worker n, inputs, check_trace=True, check_inputs=check_inputs 2041*da0073e9SAndroid Build Coastguard Worker ) 2042*da0073e9SAndroid Build Coastguard Worker self.assertTrue(module._c._has_method("forward")) 2043*da0073e9SAndroid Build Coastguard Worker self.assertTrue(module._c._has_method("weighted_kernel_sum")) 2044*da0073e9SAndroid Build Coastguard Worker 2045*da0073e9SAndroid Build Coastguard Worker module = torch.jit.trace(n.forward, example_forward_input) 2046*da0073e9SAndroid Build Coastguard Worker module = torch.jit.trace( 2047*da0073e9SAndroid Build Coastguard Worker n.forward, 2048*da0073e9SAndroid Build Coastguard Worker example_forward_input, 2049*da0073e9SAndroid Build Coastguard Worker check_trace=True, 2050*da0073e9SAndroid Build Coastguard Worker check_inputs=[example_forward_input], 2051*da0073e9SAndroid Build Coastguard Worker ) 2052*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 2053*da0073e9SAndroid Build Coastguard Worker AttributeError, 2054*da0073e9SAndroid Build Coastguard Worker "trace doesn't support compiling individual module's functions", 2055*da0073e9SAndroid Build Coastguard Worker ): 2056*da0073e9SAndroid Build Coastguard Worker module = torch.jit.trace(n.weighted_kernel_sum, inputs) 2057*da0073e9SAndroid Build Coastguard Worker 2058*da0073e9SAndroid Build Coastguard Worker def test_tensor_with_grad_as_constant(self): 2059*da0073e9SAndroid Build Coastguard Worker param = torch.randn(3).requires_grad_() 2060*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 2061*da0073e9SAndroid Build Coastguard Worker 2062*da0073e9SAndroid Build Coastguard Worker def f(x): 2063*da0073e9SAndroid Build Coastguard Worker return x + param 2064*da0073e9SAndroid Build Coastguard Worker 2065*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 2066*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Cannot insert a Tensor that requires grad as a constant" 2067*da0073e9SAndroid Build Coastguard Worker ): 2068*da0073e9SAndroid Build Coastguard Worker torch.jit.trace(f, x) 2069*da0073e9SAndroid Build Coastguard Worker 2070*da0073e9SAndroid Build Coastguard Worker def test_non_tensor_tracing(self): 2071*da0073e9SAndroid Build Coastguard Worker def f(x): 2072*da0073e9SAndroid Build Coastguard Worker return x + param # noqa: F821 2073*da0073e9SAndroid Build Coastguard Worker 2074*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 2075*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"Type 'Tuple\[int\]' cannot be traced" 2076*da0073e9SAndroid Build Coastguard Worker ): 2077*da0073e9SAndroid Build Coastguard Worker torch.jit.trace(f, (1,)) 2078*da0073e9SAndroid Build Coastguard Worker 2079*da0073e9SAndroid Build Coastguard Worker def test_trace_skip_none_submodule(self): 2080*da0073e9SAndroid Build Coastguard Worker class TestModule(torch.nn.Module): 2081*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2082*da0073e9SAndroid Build Coastguard Worker super().__init__() 2083*da0073e9SAndroid Build Coastguard Worker self.submod = torch.nn.Linear(3, 4) 2084*da0073e9SAndroid Build Coastguard Worker self.submod = None 2085*da0073e9SAndroid Build Coastguard Worker 2086*da0073e9SAndroid Build Coastguard Worker def forward(self, inputs): 2087*da0073e9SAndroid Build Coastguard Worker return inputs 2088*da0073e9SAndroid Build Coastguard Worker 2089*da0073e9SAndroid Build Coastguard Worker m = TestModule() 2090*da0073e9SAndroid Build Coastguard Worker tm = torch.jit.trace(m, torch.tensor(1.0)) 2091*da0073e9SAndroid Build Coastguard Worker self.assertFalse(hasattr(tm, "submod")) 2092*da0073e9SAndroid Build Coastguard Worker 2093*da0073e9SAndroid Build Coastguard Worker def test_trace_with_conditional_property(self): 2094*da0073e9SAndroid Build Coastguard Worker class Net(nn.Module): 2095*da0073e9SAndroid Build Coastguard Worker def __init__(self, attr=None): 2096*da0073e9SAndroid Build Coastguard Worker super().__init__() 2097*da0073e9SAndroid Build Coastguard Worker if attr is not None: 2098*da0073e9SAndroid Build Coastguard Worker self._attr = attr 2099*da0073e9SAndroid Build Coastguard Worker self.attr_name = "_attr" 2100*da0073e9SAndroid Build Coastguard Worker 2101*da0073e9SAndroid Build Coastguard Worker @property 2102*da0073e9SAndroid Build Coastguard Worker def attr(self): 2103*da0073e9SAndroid Build Coastguard Worker return getattr(self, self.attr_name) 2104*da0073e9SAndroid Build Coastguard Worker 2105*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2106*da0073e9SAndroid Build Coastguard Worker return x 2107*da0073e9SAndroid Build Coastguard Worker 2108*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1) 2109*da0073e9SAndroid Build Coastguard Worker torch.jit.trace(Net(), x) 2110*da0073e9SAndroid Build Coastguard Worker 2111*da0073e9SAndroid Build Coastguard Worker def test_trace_func_argument_names_captured(self): 2112*da0073e9SAndroid Build Coastguard Worker def fn(first_arg: torch.Tensor, second_arg: torch.Tensor) -> torch.Tensor: 2113*da0073e9SAndroid Build Coastguard Worker return first_arg + second_arg 2114*da0073e9SAndroid Build Coastguard Worker 2115*da0073e9SAndroid Build Coastguard Worker traced_fn = torch.jit.trace(fn, (torch.ones(1), torch.ones(1))) 2116*da0073e9SAndroid Build Coastguard Worker FileCheck().check("first_arg").check_next("second_arg").run( 2117*da0073e9SAndroid Build Coastguard Worker str(traced_fn.graph) 2118*da0073e9SAndroid Build Coastguard Worker ) 2119*da0073e9SAndroid Build Coastguard Worker 2120*da0073e9SAndroid Build Coastguard Worker def test_trace_partial_func_argument_names_captured(self): 2121*da0073e9SAndroid Build Coastguard Worker def fn(first_arg: torch.Tensor, second_arg=1) -> torch.Tensor: 2122*da0073e9SAndroid Build Coastguard Worker return first_arg + second_arg 2123*da0073e9SAndroid Build Coastguard Worker 2124*da0073e9SAndroid Build Coastguard Worker traced_fn = torch.jit.trace(fn, (torch.ones(1),)) 2125*da0073e9SAndroid Build Coastguard Worker FileCheck().check("first_arg").check_not("second_arg").run(str(traced_fn.graph)) 2126*da0073e9SAndroid Build Coastguard Worker 2127*da0073e9SAndroid Build Coastguard Worker def test_trace_module_argument_names_captured(self): 2128*da0073e9SAndroid Build Coastguard Worker class TestModule(nn.Module): 2129*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2130*da0073e9SAndroid Build Coastguard Worker super().__init__() 2131*da0073e9SAndroid Build Coastguard Worker self.conv = nn.Conv2d(1, 1, 3) 2132*da0073e9SAndroid Build Coastguard Worker 2133*da0073e9SAndroid Build Coastguard Worker def forward(self, first_arg: torch.Tensor, second_arg: torch.Tensor): 2134*da0073e9SAndroid Build Coastguard Worker return self.conv(first_arg) + second_arg 2135*da0073e9SAndroid Build Coastguard Worker 2136*da0073e9SAndroid Build Coastguard Worker m = TestModule() 2137*da0073e9SAndroid Build Coastguard Worker example_input = (torch.ones(1, 1, 3, 3), torch.ones(1, 1, 3, 3)) 2138*da0073e9SAndroid Build Coastguard Worker 2139*da0073e9SAndroid Build Coastguard Worker # Explicitly tracing module's forward method 2140*da0073e9SAndroid Build Coastguard Worker traced_module_forward = torch.jit.trace(m.forward, example_input) 2141*da0073e9SAndroid Build Coastguard Worker FileCheck().check("first_arg").check_next("second_arg").run( 2142*da0073e9SAndroid Build Coastguard Worker str(traced_module_forward.graph) 2143*da0073e9SAndroid Build Coastguard Worker ) 2144*da0073e9SAndroid Build Coastguard Worker 2145*da0073e9SAndroid Build Coastguard Worker # Tracing module's directly 2146*da0073e9SAndroid Build Coastguard Worker traced_module = torch.jit.trace(m, example_input) 2147*da0073e9SAndroid Build Coastguard Worker FileCheck().check("first_arg").check_next("second_arg").run( 2148*da0073e9SAndroid Build Coastguard Worker str(traced_module.graph) 2149*da0073e9SAndroid Build Coastguard Worker ) 2150*da0073e9SAndroid Build Coastguard Worker 2151*da0073e9SAndroid Build Coastguard Worker def test_trace_checking_with_deprecated_name(self): 2152*da0073e9SAndroid Build Coastguard Worker class MyClass(torch.nn.Module): 2153*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2154*da0073e9SAndroid Build Coastguard Worker super(MyClass, self).__init__() 2155*da0073e9SAndroid Build Coastguard Worker 2156*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y, **deprecated_arguments): 2157*da0073e9SAndroid Build Coastguard Worker if len(deprecated_arguments) > 0: 2158*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 2159*da0073e9SAndroid Build Coastguard Worker f"Got unexpected arguments: {deprecated_arguments}" 2160*da0073e9SAndroid Build Coastguard Worker ) 2161*da0073e9SAndroid Build Coastguard Worker return x + y 2162*da0073e9SAndroid Build Coastguard Worker 2163*da0073e9SAndroid Build Coastguard Worker model = MyClass() 2164*da0073e9SAndroid Build Coastguard Worker m2 = torch.jit.trace(model, (torch.ones(1), torch.ones(1))) 2165*da0073e9SAndroid Build Coastguard Worker m3 = torch.jit.trace( 2166*da0073e9SAndroid Build Coastguard Worker model, 2167*da0073e9SAndroid Build Coastguard Worker example_kwarg_inputs={"x": torch.ones(1), "y": torch.ones(1)}, 2168*da0073e9SAndroid Build Coastguard Worker strict=False, 2169*da0073e9SAndroid Build Coastguard Worker ) 2170*da0073e9SAndroid Build Coastguard Worker 2171*da0073e9SAndroid Build Coastguard Worker def test_trace_with_tuple_tensor(self): 2172*da0073e9SAndroid Build Coastguard Worker class MyClass(torch.nn.Module): 2173*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2174*da0073e9SAndroid Build Coastguard Worker super(MyClass, self).__init__() 2175*da0073e9SAndroid Build Coastguard Worker 2176*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y): 2177*da0073e9SAndroid Build Coastguard Worker return x + y[0] + y[1] 2178*da0073e9SAndroid Build Coastguard Worker 2179*da0073e9SAndroid Build Coastguard Worker model = MyClass() 2180*da0073e9SAndroid Build Coastguard Worker traced_model = torch.jit.trace( 2181*da0073e9SAndroid Build Coastguard Worker model, (torch.ones(1), (torch.ones(1), torch.ones(1))) 2182*da0073e9SAndroid Build Coastguard Worker ) 2183*da0073e9SAndroid Build Coastguard Worker input_dict = { 2184*da0073e9SAndroid Build Coastguard Worker "x": torch.tensor([2, 3]), 2185*da0073e9SAndroid Build Coastguard Worker "y": (torch.tensor([5, 6]), torch.tensor([7, 8])), 2186*da0073e9SAndroid Build Coastguard Worker } 2187*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model(**input_dict), traced_model(**input_dict)) 2188*da0073e9SAndroid Build Coastguard Worker traced_model = torch.jit.trace( 2189*da0073e9SAndroid Build Coastguard Worker model, 2190*da0073e9SAndroid Build Coastguard Worker example_kwarg_inputs={ 2191*da0073e9SAndroid Build Coastguard Worker "x": torch.ones(1), 2192*da0073e9SAndroid Build Coastguard Worker "y": (torch.ones(1), torch.ones(1)), 2193*da0073e9SAndroid Build Coastguard Worker }, 2194*da0073e9SAndroid Build Coastguard Worker ) 2195*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model(**input_dict), traced_model(**input_dict)) 2196*da0073e9SAndroid Build Coastguard Worker 2197*da0073e9SAndroid Build Coastguard Worker def test_trace_no_duplicated_lifted_input_output(self): 2198*da0073e9SAndroid Build Coastguard Worker class Normalize(nn.Module): 2199*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2200*da0073e9SAndroid Build Coastguard Worker super().__init__() 2201*da0073e9SAndroid Build Coastguard Worker self.norm = nn.GroupNorm(num_groups=32, num_channels=32) 2202*da0073e9SAndroid Build Coastguard Worker 2203*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y): 2204*da0073e9SAndroid Build Coastguard Worker if y is None: 2205*da0073e9SAndroid Build Coastguard Worker y = x 2206*da0073e9SAndroid Build Coastguard Worker else: 2207*da0073e9SAndroid Build Coastguard Worker y = self.norm(y) 2208*da0073e9SAndroid Build Coastguard Worker y = y * 2 2209*da0073e9SAndroid Build Coastguard Worker return y 2210*da0073e9SAndroid Build Coastguard Worker 2211*da0073e9SAndroid Build Coastguard Worker class G(nn.Module): 2212*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2213*da0073e9SAndroid Build Coastguard Worker super().__init__() 2214*da0073e9SAndroid Build Coastguard Worker self.norm = Normalize() 2215*da0073e9SAndroid Build Coastguard Worker 2216*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2217*da0073e9SAndroid Build Coastguard Worker A = self.norm(x, None) 2218*da0073e9SAndroid Build Coastguard Worker B = F.relu(A) 2219*da0073e9SAndroid Build Coastguard Worker return A, B 2220*da0073e9SAndroid Build Coastguard Worker 2221*da0073e9SAndroid Build Coastguard Worker class Net(nn.Module): 2222*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2223*da0073e9SAndroid Build Coastguard Worker super().__init__() 2224*da0073e9SAndroid Build Coastguard Worker self.g = G() 2225*da0073e9SAndroid Build Coastguard Worker self.norm_1 = Normalize() 2226*da0073e9SAndroid Build Coastguard Worker 2227*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2228*da0073e9SAndroid Build Coastguard Worker hs = self.g(x) 2229*da0073e9SAndroid Build Coastguard Worker A, B = hs 2230*da0073e9SAndroid Build Coastguard Worker h = self.norm_1(B, A) 2231*da0073e9SAndroid Build Coastguard Worker return h 2232*da0073e9SAndroid Build Coastguard Worker 2233*da0073e9SAndroid Build Coastguard Worker net = Net() 2234*da0073e9SAndroid Build Coastguard Worker net = net.eval() 2235*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 32, 16, 16) 2236*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(net, x) 2237*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("prim::TupleUnpack").run(str(traced.graph)) 2238*da0073e9SAndroid Build Coastguard Worker 2239*da0073e9SAndroid Build Coastguard Worker 2240*da0073e9SAndroid Build Coastguard Worker@skipIfTorchDynamo("Not a suitable test for TorchDynamo") 2241*da0073e9SAndroid Build Coastguard Workerclass TestMixTracingScripting(JitTestCase): 2242*da0073e9SAndroid Build Coastguard Worker def test_trace_script(self): 2243*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2244*da0073e9SAndroid Build Coastguard Worker def func1(x: Tuple[Tensor, Tensor]) -> Tensor: 2245*da0073e9SAndroid Build Coastguard Worker return x[0] + x[1] 2246*da0073e9SAndroid Build Coastguard Worker 2247*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2248*da0073e9SAndroid Build Coastguard Worker def func2(x: List[Tensor]) -> Tensor: 2249*da0073e9SAndroid Build Coastguard Worker return x[0] + x[1] 2250*da0073e9SAndroid Build Coastguard Worker 2251*da0073e9SAndroid Build Coastguard Worker a = torch.randn(5) 2252*da0073e9SAndroid Build Coastguard Worker b = torch.randn(5) 2253*da0073e9SAndroid Build Coastguard Worker 2254*da0073e9SAndroid Build Coastguard Worker self.checkTrace(func1, ((a, b),)) 2255*da0073e9SAndroid Build Coastguard Worker self.checkTrace(func2, ((a, b),)) 2256*da0073e9SAndroid Build Coastguard Worker 2257*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2258*da0073e9SAndroid Build Coastguard Worker def func3( 2259*da0073e9SAndroid Build Coastguard Worker x: Tensor, method: str = "bilinear", align_corners: bool = True 2260*da0073e9SAndroid Build Coastguard Worker ) -> Tensor: 2261*da0073e9SAndroid Build Coastguard Worker hw = x.shape[2:4] 2262*da0073e9SAndroid Build Coastguard Worker return F.interpolate(x, hw, mode=method, align_corners=align_corners) 2263*da0073e9SAndroid Build Coastguard Worker 2264*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(1, 3, 6, 6) 2265*da0073e9SAndroid Build Coastguard Worker self.checkTrace(func3, (inp,)) 2266*da0073e9SAndroid Build Coastguard Worker 2267*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2268*da0073e9SAndroid Build Coastguard Worker def func4(x: Tensor, a: List[Optional[str]]) -> Tensor: 2269*da0073e9SAndroid Build Coastguard Worker if len(a) == 2: 2270*da0073e9SAndroid Build Coastguard Worker return x + 2 2271*da0073e9SAndroid Build Coastguard Worker else: 2272*da0073e9SAndroid Build Coastguard Worker return x 2273*da0073e9SAndroid Build Coastguard Worker 2274*da0073e9SAndroid Build Coastguard Worker def test_trace_mixed_by_script_with_dict_output(self): 2275*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2276*da0073e9SAndroid Build Coastguard Worker def return_dict(input: torch.Tensor) -> Dict[str, torch.Tensor]: 2277*da0073e9SAndroid Build Coastguard Worker return {"foo": input + 1} 2278*da0073e9SAndroid Build Coastguard Worker 2279*da0073e9SAndroid Build Coastguard Worker class TraceModule(torch.nn.Module): 2280*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 2281*da0073e9SAndroid Build Coastguard Worker dict = return_dict(input) 2282*da0073e9SAndroid Build Coastguard Worker return dict["foo"] + dict["foo"] 2283*da0073e9SAndroid Build Coastguard Worker 2284*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1) 2285*da0073e9SAndroid Build Coastguard Worker tm = torch.jit.trace(TraceModule(), x) 2286*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tm(x), x + 1 + x + 1) 2287*da0073e9SAndroid Build Coastguard Worker 2288*da0073e9SAndroid Build Coastguard Worker def test_trace_of_script(self): 2289*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2290*da0073e9SAndroid Build Coastguard Worker def foo(a, c): 2291*da0073e9SAndroid Build Coastguard Worker b = 0.0 2292*da0073e9SAndroid Build Coastguard Worker if bool(a == 0.0): 2293*da0073e9SAndroid Build Coastguard Worker b = 1.0 2294*da0073e9SAndroid Build Coastguard Worker return b + c 2295*da0073e9SAndroid Build Coastguard Worker 2296*da0073e9SAndroid Build Coastguard Worker a = torch.ones(1, dtype=torch.float) 2297*da0073e9SAndroid Build Coastguard Worker 2298*da0073e9SAndroid Build Coastguard Worker @_trace(torch.zeros(1, dtype=torch.float)) 2299*da0073e9SAndroid Build Coastguard Worker def use(b): 2300*da0073e9SAndroid Build Coastguard Worker return foo(b - 1.0, a) + 1.0 2301*da0073e9SAndroid Build Coastguard Worker 2302*da0073e9SAndroid Build Coastguard Worker # test we propagated shapes through the function 2303*da0073e9SAndroid Build Coastguard Worker self.assertTrue("Dynamic" not in str(use.graph)) 2304*da0073e9SAndroid Build Coastguard Worker 2305*da0073e9SAndroid Build Coastguard Worker self.assertEqual(3, use(torch.ones(1, dtype=torch.float))) 2306*da0073e9SAndroid Build Coastguard Worker self.assertEqual(2, use(torch.zeros(1, dtype=torch.float))) 2307*da0073e9SAndroid Build Coastguard Worker 2308*da0073e9SAndroid Build Coastguard Worker def test_trace_with_size(self): 2309*da0073e9SAndroid Build Coastguard Worker @_trace(torch.zeros(1, 1)) 2310*da0073e9SAndroid Build Coastguard Worker def foo(x): 2311*da0073e9SAndroid Build Coastguard Worker return x + 1 2312*da0073e9SAndroid Build Coastguard Worker 2313*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2314*da0073e9SAndroid Build Coastguard Worker def bar(x): 2315*da0073e9SAndroid Build Coastguard Worker y = int(foo(x)) 2316*da0073e9SAndroid Build Coastguard Worker if 1 == 1: 2317*da0073e9SAndroid Build Coastguard Worker y = 7 2318*da0073e9SAndroid Build Coastguard Worker return y + 1 2319*da0073e9SAndroid Build Coastguard Worker 2320*da0073e9SAndroid Build Coastguard Worker self.assertEqual(8, bar(torch.ones(1, 1))) 2321*da0073e9SAndroid Build Coastguard Worker 2322*da0073e9SAndroid Build Coastguard Worker def test_tracing_slicing(self): 2323*da0073e9SAndroid Build Coastguard Worker @_trace(torch.zeros(10)) 2324*da0073e9SAndroid Build Coastguard Worker def foo_trace(x): 2325*da0073e9SAndroid Build Coastguard Worker return x[-5:-3] 2326*da0073e9SAndroid Build Coastguard Worker 2327*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2328*da0073e9SAndroid Build Coastguard Worker def foo_script(x): 2329*da0073e9SAndroid Build Coastguard Worker return x[-5:-3] 2330*da0073e9SAndroid Build Coastguard Worker 2331*da0073e9SAndroid Build Coastguard Worker def foo(x): 2332*da0073e9SAndroid Build Coastguard Worker return x[-5:-3] 2333*da0073e9SAndroid Build Coastguard Worker 2334*da0073e9SAndroid Build Coastguard Worker a = torch.arange(0, 8) 2335*da0073e9SAndroid Build Coastguard Worker b = torch.arange(0, 20) 2336*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo_trace(a), foo_script(a)) 2337*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo_trace(a), foo(a)) 2338*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(foo_trace(a), foo_trace(b)) 2339*da0073e9SAndroid Build Coastguard Worker 2340*da0073e9SAndroid Build Coastguard Worker def test_tracing_indexing(self): 2341*da0073e9SAndroid Build Coastguard Worker @_trace(torch.zeros(10)) 2342*da0073e9SAndroid Build Coastguard Worker def foo_trace(x): 2343*da0073e9SAndroid Build Coastguard Worker return x[-2] 2344*da0073e9SAndroid Build Coastguard Worker 2345*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2346*da0073e9SAndroid Build Coastguard Worker def foo_script(x): 2347*da0073e9SAndroid Build Coastguard Worker return x[-2] 2348*da0073e9SAndroid Build Coastguard Worker 2349*da0073e9SAndroid Build Coastguard Worker def foo(x): 2350*da0073e9SAndroid Build Coastguard Worker return x[-2] 2351*da0073e9SAndroid Build Coastguard Worker 2352*da0073e9SAndroid Build Coastguard Worker a = torch.arange(0, 8) 2353*da0073e9SAndroid Build Coastguard Worker b = torch.arange(0, 20) 2354*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo_script(a), foo_trace(a)) 2355*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo_trace(a), foo(a)) 2356*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(foo_trace(a), foo_trace(b)) 2357*da0073e9SAndroid Build Coastguard Worker 2358*da0073e9SAndroid Build Coastguard Worker def test_trace_hierarchy(self): 2359*da0073e9SAndroid Build Coastguard Worker # Test that we preserve the module hierarchy for a ScriptModule 2360*da0073e9SAndroid Build Coastguard Worker # submodule during tracing 2361*da0073e9SAndroid Build Coastguard Worker 2362*da0073e9SAndroid Build Coastguard Worker class AnotherScriptMod(torch.jit.ScriptModule): 2363*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2364*da0073e9SAndroid Build Coastguard Worker super().__init__() 2365*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.rand(1, 2, 3)) 2366*da0073e9SAndroid Build Coastguard Worker 2367*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 2368*da0073e9SAndroid Build Coastguard Worker def bar(self): 2369*da0073e9SAndroid Build Coastguard Worker return torch.zeros(4, 5) 2370*da0073e9SAndroid Build Coastguard Worker 2371*da0073e9SAndroid Build Coastguard Worker class SomeScriptMod(torch.jit.ScriptModule): 2372*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2373*da0073e9SAndroid Build Coastguard Worker super().__init__() 2374*da0073e9SAndroid Build Coastguard Worker self.asm = AnotherScriptMod() 2375*da0073e9SAndroid Build Coastguard Worker 2376*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 2377*da0073e9SAndroid Build Coastguard Worker def foo(self): 2378*da0073e9SAndroid Build Coastguard Worker return torch.zeros(3, 4) 2379*da0073e9SAndroid Build Coastguard Worker 2380*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 2381*da0073e9SAndroid Build Coastguard Worker def bar(self): 2382*da0073e9SAndroid Build Coastguard Worker return torch.zeros(4, 3) 2383*da0073e9SAndroid Build Coastguard Worker 2384*da0073e9SAndroid Build Coastguard Worker class TraceMe(torch.nn.Module): 2385*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2386*da0073e9SAndroid Build Coastguard Worker super().__init__() 2387*da0073e9SAndroid Build Coastguard Worker self.ssm = SomeScriptMod() 2388*da0073e9SAndroid Build Coastguard Worker 2389*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2390*da0073e9SAndroid Build Coastguard Worker return self.ssm.bar() + x 2391*da0073e9SAndroid Build Coastguard Worker 2392*da0073e9SAndroid Build Coastguard Worker orig = TraceMe() 2393*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(orig, (torch.rand(4, 3),)) 2394*da0073e9SAndroid Build Coastguard Worker # for each of these checks, check that *BOTH* the underlying 2395*da0073e9SAndroid Build Coastguard Worker # _C.ScriptModule object has the expected method/param, as well as the 2396*da0073e9SAndroid Build Coastguard Worker # Python object that wraps it. 2397*da0073e9SAndroid Build Coastguard Worker self.assertTrue(traced.ssm._c._has_method("foo")) 2398*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(traced.ssm, "foo")) 2399*da0073e9SAndroid Build Coastguard Worker 2400*da0073e9SAndroid Build Coastguard Worker imported = self.getExportImportCopy(traced) 2401*da0073e9SAndroid Build Coastguard Worker 2402*da0073e9SAndroid Build Coastguard Worker self.assertTrue(imported.ssm._c._has_method("foo")) 2403*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(imported.ssm, "foo")) 2404*da0073e9SAndroid Build Coastguard Worker 2405*da0073e9SAndroid Build Coastguard Worker self.assertTrue(imported.ssm.asm._c._has_method("bar")) 2406*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(imported.ssm.asm, "bar")) 2407*da0073e9SAndroid Build Coastguard Worker 2408*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(imported.ssm.asm, "param")) 2409*da0073e9SAndroid Build Coastguard Worker 2410*da0073e9SAndroid Build Coastguard Worker def test_trace_parameter(self): 2411*da0073e9SAndroid Build Coastguard Worker class Param(nn.Module): 2412*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2413*da0073e9SAndroid Build Coastguard Worker super().__init__() 2414*da0073e9SAndroid Build Coastguard Worker self.register_parameter("bias", nn.Parameter(torch.empty(4, 4))) 2415*da0073e9SAndroid Build Coastguard Worker 2416*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2417*da0073e9SAndroid Build Coastguard Worker return x 2418*da0073e9SAndroid Build Coastguard Worker 2419*da0073e9SAndroid Build Coastguard Worker class M3(torch.jit.ScriptModule): 2420*da0073e9SAndroid Build Coastguard Worker def __init__(self, model): 2421*da0073e9SAndroid Build Coastguard Worker super().__init__() 2422*da0073e9SAndroid Build Coastguard Worker self.traced = torch.jit.trace(model, (torch.rand(3, 3))) 2423*da0073e9SAndroid Build Coastguard Worker 2424*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 2425*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2426*da0073e9SAndroid Build Coastguard Worker return self.traced(x) 2427*da0073e9SAndroid Build Coastguard Worker 2428*da0073e9SAndroid Build Coastguard Worker class M2(nn.Module): 2429*da0073e9SAndroid Build Coastguard Worker def __init__(self, model): 2430*da0073e9SAndroid Build Coastguard Worker super().__init__() 2431*da0073e9SAndroid Build Coastguard Worker self.module = M3(model) 2432*da0073e9SAndroid Build Coastguard Worker 2433*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2434*da0073e9SAndroid Build Coastguard Worker return self.module(x) 2435*da0073e9SAndroid Build Coastguard Worker 2436*da0073e9SAndroid Build Coastguard Worker class M1(torch.jit.ScriptModule): 2437*da0073e9SAndroid Build Coastguard Worker def __init__(self, model): 2438*da0073e9SAndroid Build Coastguard Worker super().__init__() 2439*da0073e9SAndroid Build Coastguard Worker self.traced = torch.jit.trace(M2(model), (torch.rand(3, 3))) 2440*da0073e9SAndroid Build Coastguard Worker 2441*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 2442*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2443*da0073e9SAndroid Build Coastguard Worker return self.traced(x) 2444*da0073e9SAndroid Build Coastguard Worker 2445*da0073e9SAndroid Build Coastguard Worker with torch.jit.optimized_execution(False): 2446*da0073e9SAndroid Build Coastguard Worker module = M1(Param()) 2447*da0073e9SAndroid Build Coastguard Worker f = io.BytesIO() 2448*da0073e9SAndroid Build Coastguard Worker torch.jit.save(module, f) 2449*da0073e9SAndroid Build Coastguard Worker 2450*da0073e9SAndroid Build Coastguard Worker @_tmp_donotuse_dont_inline_everything 2451*da0073e9SAndroid Build Coastguard Worker def test_call_script_fn_from_traced_module(self): 2452*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2453*da0073e9SAndroid Build Coastguard Worker def scripted_fn(x): 2454*da0073e9SAndroid Build Coastguard Worker return torch.neg(x) 2455*da0073e9SAndroid Build Coastguard Worker 2456*da0073e9SAndroid Build Coastguard Worker class TracedModule(torch.nn.Module): 2457*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2458*da0073e9SAndroid Build Coastguard Worker super().__init__() 2459*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.rand(4, 5)) 2460*da0073e9SAndroid Build Coastguard Worker 2461*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2462*da0073e9SAndroid Build Coastguard Worker return scripted_fn(torch.mm(x, self.param)) 2463*da0073e9SAndroid Build Coastguard Worker 2464*da0073e9SAndroid Build Coastguard Worker tm = torch.jit.trace(TracedModule(), torch.rand(3, 4)) 2465*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::mm").check('name="scripted_fn"').check( 2466*da0073e9SAndroid Build Coastguard Worker "prim::CallFunction" 2467*da0073e9SAndroid Build Coastguard Worker ).run(str(tm.graph)) 2468*da0073e9SAndroid Build Coastguard Worker 2469*da0073e9SAndroid Build Coastguard Worker @_tmp_donotuse_dont_inline_everything 2470*da0073e9SAndroid Build Coastguard Worker def test_call_script_module_from_traced_module(self): 2471*da0073e9SAndroid Build Coastguard Worker class ScriptMod(torch.jit.ScriptModule): 2472*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2473*da0073e9SAndroid Build Coastguard Worker super().__init__() 2474*da0073e9SAndroid Build Coastguard Worker self.param_foo = torch.nn.Parameter(torch.rand(5, 7)) 2475*da0073e9SAndroid Build Coastguard Worker 2476*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 2477*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2478*da0073e9SAndroid Build Coastguard Worker return torch.mm(x, self.param_foo) 2479*da0073e9SAndroid Build Coastguard Worker 2480*da0073e9SAndroid Build Coastguard Worker class TracedModule(torch.nn.Module): 2481*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2482*da0073e9SAndroid Build Coastguard Worker super().__init__() 2483*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.rand(4, 5)) 2484*da0073e9SAndroid Build Coastguard Worker self.mod = ScriptMod() 2485*da0073e9SAndroid Build Coastguard Worker 2486*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2487*da0073e9SAndroid Build Coastguard Worker return self.mod(torch.mm(x, self.param)) + 1.0 2488*da0073e9SAndroid Build Coastguard Worker 2489*da0073e9SAndroid Build Coastguard Worker tm = torch.jit.trace(TracedModule(), torch.rand(3, 4)) 2490*da0073e9SAndroid Build Coastguard Worker 2491*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::mm").check("prim::CallMethod").check_same( 2492*da0073e9SAndroid Build Coastguard Worker "forward" 2493*da0073e9SAndroid Build Coastguard Worker ).check("aten::add").run(str(tm.graph)) 2494*da0073e9SAndroid Build Coastguard Worker 2495*da0073e9SAndroid Build Coastguard Worker @_tmp_donotuse_dont_inline_everything 2496*da0073e9SAndroid Build Coastguard Worker def test_call_traced_fn_from_script_fn(self): 2497*da0073e9SAndroid Build Coastguard Worker @_trace(torch.rand(3, 4)) 2498*da0073e9SAndroid Build Coastguard Worker def traced_fn(x): 2499*da0073e9SAndroid Build Coastguard Worker return torch.neg(x) 2500*da0073e9SAndroid Build Coastguard Worker 2501*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2502*da0073e9SAndroid Build Coastguard Worker def script_fn(x): 2503*da0073e9SAndroid Build Coastguard Worker return traced_fn(x) + 1 2504*da0073e9SAndroid Build Coastguard Worker 2505*da0073e9SAndroid Build Coastguard Worker FileCheck().check("prim::CallFunction").check("aten::add").run( 2506*da0073e9SAndroid Build Coastguard Worker str(script_fn.graph) 2507*da0073e9SAndroid Build Coastguard Worker ) 2508*da0073e9SAndroid Build Coastguard Worker 2509*da0073e9SAndroid Build Coastguard Worker def test_call_traced_mod_from_script_fn(self): 2510*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 2511*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2512*da0073e9SAndroid Build Coastguard Worker "Cannot call a ScriptModule that is not a submodule of the caller", 2513*da0073e9SAndroid Build Coastguard Worker ): 2514*da0073e9SAndroid Build Coastguard Worker 2515*da0073e9SAndroid Build Coastguard Worker class TracedModule(torch.nn.Module): 2516*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2517*da0073e9SAndroid Build Coastguard Worker return torch.mm(x, torch.zeros(4, 3)) 2518*da0073e9SAndroid Build Coastguard Worker 2519*da0073e9SAndroid Build Coastguard Worker tm = torch.jit.trace(TracedModule(), torch.rand(3, 4)) 2520*da0073e9SAndroid Build Coastguard Worker 2521*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2522*da0073e9SAndroid Build Coastguard Worker def script_fn(x): 2523*da0073e9SAndroid Build Coastguard Worker return tm(x) + 1 2524*da0073e9SAndroid Build Coastguard Worker 2525*da0073e9SAndroid Build Coastguard Worker @_tmp_donotuse_dont_inline_everything 2526*da0073e9SAndroid Build Coastguard Worker def test_call_tracing_fn_from_script_module(self): 2527*da0073e9SAndroid Build Coastguard Worker @_trace(torch.rand(3, 3)) 2528*da0073e9SAndroid Build Coastguard Worker def traced_fn(x): 2529*da0073e9SAndroid Build Coastguard Worker return torch.neg(x) 2530*da0073e9SAndroid Build Coastguard Worker 2531*da0073e9SAndroid Build Coastguard Worker class ScriptMod(torch.jit.ScriptModule): 2532*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2533*da0073e9SAndroid Build Coastguard Worker super().__init__() 2534*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.rand(4, 3)) 2535*da0073e9SAndroid Build Coastguard Worker 2536*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 2537*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2538*da0073e9SAndroid Build Coastguard Worker return traced_fn(torch.mm(x, self.param)) 2539*da0073e9SAndroid Build Coastguard Worker 2540*da0073e9SAndroid Build Coastguard Worker sm = ScriptMod() 2541*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::mm").check("prim::CallFunction").run( 2542*da0073e9SAndroid Build Coastguard Worker str(sm.forward.graph) 2543*da0073e9SAndroid Build Coastguard Worker ) 2544*da0073e9SAndroid Build Coastguard Worker 2545*da0073e9SAndroid Build Coastguard Worker @_tmp_donotuse_dont_inline_everything 2546*da0073e9SAndroid Build Coastguard Worker def test_call_tracing_mod_from_script_module(self): 2547*da0073e9SAndroid Build Coastguard Worker class TracedMod(torch.nn.Module): 2548*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2549*da0073e9SAndroid Build Coastguard Worker super().__init__() 2550*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.rand(3, 5)) 2551*da0073e9SAndroid Build Coastguard Worker 2552*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2553*da0073e9SAndroid Build Coastguard Worker return torch.mm(x, self.param) 2554*da0073e9SAndroid Build Coastguard Worker 2555*da0073e9SAndroid Build Coastguard Worker class ScriptMod(torch.jit.ScriptModule): 2556*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2557*da0073e9SAndroid Build Coastguard Worker super().__init__() 2558*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.rand(4, 3)) 2559*da0073e9SAndroid Build Coastguard Worker self.tm = torch.jit.trace(TracedMod(), torch.rand(3, 3)) 2560*da0073e9SAndroid Build Coastguard Worker 2561*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 2562*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2563*da0073e9SAndroid Build Coastguard Worker return self.tm(torch.mm(x, self.param)) 2564*da0073e9SAndroid Build Coastguard Worker 2565*da0073e9SAndroid Build Coastguard Worker sm = ScriptMod() 2566*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::mm").check("prim::CallMethod").run(str(sm.graph)) 2567*da0073e9SAndroid Build Coastguard Worker 2568*da0073e9SAndroid Build Coastguard Worker def test_script_inline_trace_multiple_args(self): 2569*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 2570*da0073e9SAndroid Build Coastguard Worker def forward(self, input, input2): 2571*da0073e9SAndroid Build Coastguard Worker return input + input2 2572*da0073e9SAndroid Build Coastguard Worker 2573*da0073e9SAndroid Build Coastguard Worker class M2(torch.jit.ScriptModule): 2574*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2575*da0073e9SAndroid Build Coastguard Worker super().__init__() 2576*da0073e9SAndroid Build Coastguard Worker self.m = torch.jit.trace(M(), (torch.zeros(4, 3), torch.zeros(4, 3))) 2577*da0073e9SAndroid Build Coastguard Worker 2578*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 2579*da0073e9SAndroid Build Coastguard Worker def forward(self, inp): 2580*da0073e9SAndroid Build Coastguard Worker return self.m(inp, inp) 2581*da0073e9SAndroid Build Coastguard Worker 2582*da0073e9SAndroid Build Coastguard Worker with torch.jit.optimized_execution(False): 2583*da0073e9SAndroid Build Coastguard Worker m2 = M2() 2584*da0073e9SAndroid Build Coastguard Worker m2(torch.zeros(4, 3)) 2585*da0073e9SAndroid Build Coastguard Worker 2586*da0073e9SAndroid Build Coastguard Worker def test_trace_dict_mix_script(self): 2587*da0073e9SAndroid Build Coastguard Worker class testB(torch.nn.Module): 2588*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2589*da0073e9SAndroid Build Coastguard Worker super().__init__() 2590*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(2, 2) 2591*da0073e9SAndroid Build Coastguard Worker 2592*da0073e9SAndroid Build Coastguard Worker def forward(self, feature_map: Dict[str, List[Tensor]]) -> Tensor: 2593*da0073e9SAndroid Build Coastguard Worker output = [] 2594*da0073e9SAndroid Build Coastguard Worker for j in feature_map.values(): 2595*da0073e9SAndroid Build Coastguard Worker output.append(self.linear(j[0])) 2596*da0073e9SAndroid Build Coastguard Worker 2597*da0073e9SAndroid Build Coastguard Worker return torch.stack(output) 2598*da0073e9SAndroid Build Coastguard Worker 2599*da0073e9SAndroid Build Coastguard Worker class testA(torch.nn.Module): 2600*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2601*da0073e9SAndroid Build Coastguard Worker super().__init__() 2602*da0073e9SAndroid Build Coastguard Worker self.b = torch.jit.script(testB()) 2603*da0073e9SAndroid Build Coastguard Worker 2604*da0073e9SAndroid Build Coastguard Worker def forward(self, input_map: Dict[str, List[Tensor]]) -> Tensor: 2605*da0073e9SAndroid Build Coastguard Worker feature_map = {} 2606*da0073e9SAndroid Build Coastguard Worker for i, j in input_map.items(): 2607*da0073e9SAndroid Build Coastguard Worker feature_map[i] = [j[0]] 2608*da0073e9SAndroid Build Coastguard Worker 2609*da0073e9SAndroid Build Coastguard Worker return self.b(feature_map) 2610*da0073e9SAndroid Build Coastguard Worker 2611*da0073e9SAndroid Build Coastguard Worker input_map = { 2612*da0073e9SAndroid Build Coastguard Worker "1": [torch.rand(2, 2), torch.rand(2, 2)], 2613*da0073e9SAndroid Build Coastguard Worker "3": [torch.rand(2, 2), torch.rand(2, 2)], 2614*da0073e9SAndroid Build Coastguard Worker } 2615*da0073e9SAndroid Build Coastguard Worker model = testA() 2616*da0073e9SAndroid Build Coastguard Worker traced_model = torch.jit.trace(model, input_map) 2617*da0073e9SAndroid Build Coastguard Worker new_input_map = { 2618*da0073e9SAndroid Build Coastguard Worker "1": [torch.rand(2, 2), torch.randn(2, 2)], 2619*da0073e9SAndroid Build Coastguard Worker "3": [torch.rand(2, 2), torch.rand(2, 2)], 2620*da0073e9SAndroid Build Coastguard Worker } 2621*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model(new_input_map), traced_model(new_input_map)) 2622*da0073e9SAndroid Build Coastguard Worker 2623*da0073e9SAndroid Build Coastguard Worker def test_trace_script_returning_complex_dict(self): 2624*da0073e9SAndroid Build Coastguard Worker """Tracing over a script function returning a dictionary should work. 2625*da0073e9SAndroid Build Coastguard Worker The dictionary can should be able to contain other containers (like a tuple) recursively. 2626*da0073e9SAndroid Build Coastguard Worker """ 2627*da0073e9SAndroid Build Coastguard Worker 2628*da0073e9SAndroid Build Coastguard Worker class ReturnsDict(torch.nn.Module): 2629*da0073e9SAndroid Build Coastguard Worker def forward( 2630*da0073e9SAndroid Build Coastguard Worker self, 2631*da0073e9SAndroid Build Coastguard Worker id_score_list: Dict[ 2632*da0073e9SAndroid Build Coastguard Worker str, Tuple[torch.Tensor, torch.Tensor, torch.Tensor] 2633*da0073e9SAndroid Build Coastguard Worker ], 2634*da0073e9SAndroid Build Coastguard Worker ) -> Dict[str, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: 2635*da0073e9SAndroid Build Coastguard Worker # do some random operations and then return a dict of the same structure 2636*da0073e9SAndroid Build Coastguard Worker v = id_score_list["1000"] 2637*da0073e9SAndroid Build Coastguard Worker idx_keys = v[1] - 1500000 2638*da0073e9SAndroid Build Coastguard Worker weights = v[2] 2639*da0073e9SAndroid Build Coastguard Worker result = {"1000": (v[0], idx_keys, weights)} 2640*da0073e9SAndroid Build Coastguard Worker return result 2641*da0073e9SAndroid Build Coastguard Worker 2642*da0073e9SAndroid Build Coastguard Worker class ChecksDict(torch.nn.Module): 2643*da0073e9SAndroid Build Coastguard Worker def forward( 2644*da0073e9SAndroid Build Coastguard Worker self, input: Dict[str, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] 2645*da0073e9SAndroid Build Coastguard Worker ): 2646*da0073e9SAndroid Build Coastguard Worker v = input["1000"] 2647*da0073e9SAndroid Build Coastguard Worker return v[1] + 1 2648*da0073e9SAndroid Build Coastguard Worker 2649*da0073e9SAndroid Build Coastguard Worker class TestModule(torch.nn.Module): 2650*da0073e9SAndroid Build Coastguard Worker def __init__(self, checks_dict, returns_dict): 2651*da0073e9SAndroid Build Coastguard Worker super().__init__() 2652*da0073e9SAndroid Build Coastguard Worker self.checks_dict = checks_dict 2653*da0073e9SAndroid Build Coastguard Worker self.returns_dict = returns_dict 2654*da0073e9SAndroid Build Coastguard Worker 2655*da0073e9SAndroid Build Coastguard Worker def forward( 2656*da0073e9SAndroid Build Coastguard Worker self, input: Dict[str, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] 2657*da0073e9SAndroid Build Coastguard Worker ): 2658*da0073e9SAndroid Build Coastguard Worker foo = self.returns_dict(input) 2659*da0073e9SAndroid Build Coastguard Worker return self.checks_dict(foo) 2660*da0073e9SAndroid Build Coastguard Worker 2661*da0073e9SAndroid Build Coastguard Worker input1 = { 2662*da0073e9SAndroid Build Coastguard Worker "1000": ( 2663*da0073e9SAndroid Build Coastguard Worker torch.tensor([0]), 2664*da0073e9SAndroid Build Coastguard Worker torch.tensor([], dtype=torch.int64), 2665*da0073e9SAndroid Build Coastguard Worker torch.tensor([]), 2666*da0073e9SAndroid Build Coastguard Worker ) 2667*da0073e9SAndroid Build Coastguard Worker } 2668*da0073e9SAndroid Build Coastguard Worker 2669*da0073e9SAndroid Build Coastguard Worker input2 = { 2670*da0073e9SAndroid Build Coastguard Worker "1000": ( 2671*da0073e9SAndroid Build Coastguard Worker torch.tensor([0]), 2672*da0073e9SAndroid Build Coastguard Worker torch.tensor([1500000, 1500004], dtype=torch.int64), 2673*da0073e9SAndroid Build Coastguard Worker torch.tensor([2.0, 3.0]), 2674*da0073e9SAndroid Build Coastguard Worker ) 2675*da0073e9SAndroid Build Coastguard Worker } 2676*da0073e9SAndroid Build Coastguard Worker 2677*da0073e9SAndroid Build Coastguard Worker checks_dict = torch.jit.script(ChecksDict()) 2678*da0073e9SAndroid Build Coastguard Worker returns_dict = torch.jit.script(ReturnsDict()) 2679*da0073e9SAndroid Build Coastguard Worker eager_module = TestModule(checks_dict, returns_dict) 2680*da0073e9SAndroid Build Coastguard Worker traced_module = torch.jit.trace(eager_module, input1) 2681*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced_module(input1), eager_module(input1)) 2682*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced_module(input2), eager_module(input2)) 2683*da0073e9SAndroid Build Coastguard Worker 2684*da0073e9SAndroid Build Coastguard Worker def test_trace_returning_dict_with_tensor_tuples(self): 2685*da0073e9SAndroid Build Coastguard Worker """Tracing over a module returning a dictionary whose values are tuples of tensors 2686*da0073e9SAndroid Build Coastguard Worker should work. 2687*da0073e9SAndroid Build Coastguard Worker """ 2688*da0073e9SAndroid Build Coastguard Worker 2689*da0073e9SAndroid Build Coastguard Worker class ReturnsDict(torch.nn.Module): 2690*da0073e9SAndroid Build Coastguard Worker def forward( 2691*da0073e9SAndroid Build Coastguard Worker self, k: torch.Tensor, v: torch.Tensor 2692*da0073e9SAndroid Build Coastguard Worker ) -> Dict[str, Tuple[torch.Tensor, torch.Tensor]]: 2693*da0073e9SAndroid Build Coastguard Worker x = 2 * k 2694*da0073e9SAndroid Build Coastguard Worker y = 3 * v 2695*da0073e9SAndroid Build Coastguard Worker result = {"imakey": (x, y)} 2696*da0073e9SAndroid Build Coastguard Worker return result 2697*da0073e9SAndroid Build Coastguard Worker 2698*da0073e9SAndroid Build Coastguard Worker class ReturnsBadDict(torch.nn.Module): 2699*da0073e9SAndroid Build Coastguard Worker def forward( 2700*da0073e9SAndroid Build Coastguard Worker self, k: torch.Tensor, v: torch.Tensor 2701*da0073e9SAndroid Build Coastguard Worker ) -> Dict[str, Tuple[torch.Tensor, float]]: 2702*da0073e9SAndroid Build Coastguard Worker x = 2 * k 2703*da0073e9SAndroid Build Coastguard Worker result = {"imakey": (x, 1)} 2704*da0073e9SAndroid Build Coastguard Worker return result 2705*da0073e9SAndroid Build Coastguard Worker 2706*da0073e9SAndroid Build Coastguard Worker mod = ReturnsDict() 2707*da0073e9SAndroid Build Coastguard Worker traced_module = torch.jit.trace( 2708*da0073e9SAndroid Build Coastguard Worker mod, [torch.ones(1), torch.ones(1)], strict=False 2709*da0073e9SAndroid Build Coastguard Worker ) 2710*da0073e9SAndroid Build Coastguard Worker out = traced_module(torch.ones(1), torch.ones(1)) 2711*da0073e9SAndroid Build Coastguard Worker expected = {"imakey": (torch.tensor([2.0]), torch.tensor([3.0]))} 2712*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, expected) 2713*da0073e9SAndroid Build Coastguard Worker 2714*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 2715*da0073e9SAndroid Build Coastguard Worker RuntimeError, "cannot be understood by the tracer, only outputs matching" 2716*da0073e9SAndroid Build Coastguard Worker ): 2717*da0073e9SAndroid Build Coastguard Worker mod = ReturnsBadDict() 2718*da0073e9SAndroid Build Coastguard Worker traced_module = torch.jit.trace( 2719*da0073e9SAndroid Build Coastguard Worker mod, [torch.ones(1), torch.ones(1)], strict=False 2720*da0073e9SAndroid Build Coastguard Worker ) 2721*da0073e9SAndroid Build Coastguard Worker 2722*da0073e9SAndroid Build Coastguard Worker def test_trace_linear(self): 2723*da0073e9SAndroid Build Coastguard Worker m = torch.nn.Linear(20, 20) 2724*da0073e9SAndroid Build Coastguard Worker inp = torch.rand([20, 20]) 2725*da0073e9SAndroid Build Coastguard Worker self.checkTrace(m, (inp,)) 2726*da0073e9SAndroid Build Coastguard Worker g = torch.jit.trace(m, (inp,)).graph 2727*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::linear").run(g) 2728*da0073e9SAndroid Build Coastguard Worker 2729*da0073e9SAndroid Build Coastguard Worker def test_traced_module_implements_interface(self): 2730*da0073e9SAndroid Build Coastguard Worker @torch.jit.interface 2731*da0073e9SAndroid Build Coastguard Worker class TestModuleInterface(nn.Module): 2732*da0073e9SAndroid Build Coastguard Worker def forward( 2733*da0073e9SAndroid Build Coastguard Worker self, first_arg: torch.Tensor, second_arg: torch.Tensor 2734*da0073e9SAndroid Build Coastguard Worker ) -> torch.Tensor: 2735*da0073e9SAndroid Build Coastguard Worker pass 2736*da0073e9SAndroid Build Coastguard Worker 2737*da0073e9SAndroid Build Coastguard Worker make_global(TestModuleInterface) 2738*da0073e9SAndroid Build Coastguard Worker 2739*da0073e9SAndroid Build Coastguard Worker class TestModule(nn.Module): 2740*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2741*da0073e9SAndroid Build Coastguard Worker super().__init__() 2742*da0073e9SAndroid Build Coastguard Worker self.conv = nn.Conv2d(1, 1, 3) 2743*da0073e9SAndroid Build Coastguard Worker 2744*da0073e9SAndroid Build Coastguard Worker def forward( 2745*da0073e9SAndroid Build Coastguard Worker self, first_arg: torch.Tensor, second_arg: torch.Tensor 2746*da0073e9SAndroid Build Coastguard Worker ) -> torch.Tensor: 2747*da0073e9SAndroid Build Coastguard Worker return self.conv(first_arg) + second_arg 2748*da0073e9SAndroid Build Coastguard Worker 2749*da0073e9SAndroid Build Coastguard Worker def fn_takes_interface(x: TestModuleInterface): 2750*da0073e9SAndroid Build Coastguard Worker ones = torch.ones(1, 1, 3, 3) 2751*da0073e9SAndroid Build Coastguard Worker return x.forward(ones, ones) 2752*da0073e9SAndroid Build Coastguard Worker 2753*da0073e9SAndroid Build Coastguard Worker scripted_test_module = torch.jit.script(TestModule()) 2754*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn_takes_interface, (scripted_test_module,)) 2755*da0073e9SAndroid Build Coastguard Worker 2756*da0073e9SAndroid Build Coastguard Worker def test_traced_module_contains_scripted_interface_types(self): 2757*da0073e9SAndroid Build Coastguard Worker class LeafModule(torch.nn.Module): 2758*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2759*da0073e9SAndroid Build Coastguard Worker super().__init__() 2760*da0073e9SAndroid Build Coastguard Worker self.weight = torch.nn.Parameter(torch.rand(19)) 2761*da0073e9SAndroid Build Coastguard Worker 2762*da0073e9SAndroid Build Coastguard Worker def forward(self, input: torch.Tensor): 2763*da0073e9SAndroid Build Coastguard Worker return input + self.weight 2764*da0073e9SAndroid Build Coastguard Worker 2765*da0073e9SAndroid Build Coastguard Worker class LowerModuleImpl(torch.nn.Module): 2766*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2767*da0073e9SAndroid Build Coastguard Worker super().__init__() 2768*da0073e9SAndroid Build Coastguard Worker self.leaf = LeafModule() 2769*da0073e9SAndroid Build Coastguard Worker 2770*da0073e9SAndroid Build Coastguard Worker def forward(self, input: torch.Tensor) -> torch.Tensor: 2771*da0073e9SAndroid Build Coastguard Worker return self.leaf(input) 2772*da0073e9SAndroid Build Coastguard Worker 2773*da0073e9SAndroid Build Coastguard Worker @torch.jit.interface 2774*da0073e9SAndroid Build Coastguard Worker class LowerModuleInterface(torch.nn.Module): 2775*da0073e9SAndroid Build Coastguard Worker def forward(self, input: torch.Tensor) -> torch.Tensor: 2776*da0073e9SAndroid Build Coastguard Worker pass 2777*da0073e9SAndroid Build Coastguard Worker 2778*da0073e9SAndroid Build Coastguard Worker class MiddleModule(torch.nn.Module): 2779*da0073e9SAndroid Build Coastguard Worker lower: LowerModuleInterface 2780*da0073e9SAndroid Build Coastguard Worker 2781*da0073e9SAndroid Build Coastguard Worker def __init__(self, feature_processor_modules=None): 2782*da0073e9SAndroid Build Coastguard Worker super().__init__() 2783*da0073e9SAndroid Build Coastguard Worker self.lower = LowerModuleImpl() 2784*da0073e9SAndroid Build Coastguard Worker 2785*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 2786*da0073e9SAndroid Build Coastguard Worker return self.lower(input) 2787*da0073e9SAndroid Build Coastguard Worker 2788*da0073e9SAndroid Build Coastguard Worker class WrapperModule(torch.nn.Module): 2789*da0073e9SAndroid Build Coastguard Worker def __init__(self, m): 2790*da0073e9SAndroid Build Coastguard Worker super().__init__() 2791*da0073e9SAndroid Build Coastguard Worker self.middle = m 2792*da0073e9SAndroid Build Coastguard Worker 2793*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 2794*da0073e9SAndroid Build Coastguard Worker return self.middle(input) 2795*da0073e9SAndroid Build Coastguard Worker 2796*da0073e9SAndroid Build Coastguard Worker class TopModule(torch.nn.Module): 2797*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2798*da0073e9SAndroid Build Coastguard Worker super().__init__() 2799*da0073e9SAndroid Build Coastguard Worker m = MiddleModule() 2800*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(m) 2801*da0073e9SAndroid Build Coastguard Worker self.sub1 = m 2802*da0073e9SAndroid Build Coastguard Worker self.sub2 = WrapperModule(m) 2803*da0073e9SAndroid Build Coastguard Worker 2804*da0073e9SAndroid Build Coastguard Worker def forward(self, input: torch.Tensor): 2805*da0073e9SAndroid Build Coastguard Worker return self.sub1(input) + self.sub2(input) 2806*da0073e9SAndroid Build Coastguard Worker 2807*da0073e9SAndroid Build Coastguard Worker top = TopModule() 2808*da0073e9SAndroid Build Coastguard Worker top_example_input = torch.ones(1) 2809*da0073e9SAndroid Build Coastguard Worker torch.jit.trace(top, top_example_input) 2810*da0073e9SAndroid Build Coastguard Worker 2811*da0073e9SAndroid Build Coastguard Worker def test_jit_trace_callfunction_return_shapes(self): 2812*da0073e9SAndroid Build Coastguard Worker # a torch.jit.script function gets inserted as a CallFunction node 2813*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2814*da0073e9SAndroid Build Coastguard Worker def inner_fn(x): 2815*da0073e9SAndroid Build Coastguard Worker return torch.cat((x, x)) 2816*da0073e9SAndroid Build Coastguard Worker 2817*da0073e9SAndroid Build Coastguard Worker def outer_fn(x, y): 2818*da0073e9SAndroid Build Coastguard Worker return inner_fn(x + y).relu() 2819*da0073e9SAndroid Build Coastguard Worker 2820*da0073e9SAndroid Build Coastguard Worker x, y = [torch.rand((2, 2), dtype=torch.float) for _ in range(2)] 2821*da0073e9SAndroid Build Coastguard Worker fn_t = torch.jit.trace(outer_fn, (x, y)) 2822*da0073e9SAndroid Build Coastguard Worker 2823*da0073e9SAndroid Build Coastguard Worker # expect that the CallFunction node return type has shape information on it. 2824*da0073e9SAndroid Build Coastguard Worker FileCheck().check("Float").check("4, 2").check("CallFunction").run(fn_t.graph) 2825*da0073e9SAndroid Build Coastguard Worker for n in fn_t.graph.nodes(): 2826*da0073e9SAndroid Build Coastguard Worker if n.kind() == "prim::CallFunction": 2827*da0073e9SAndroid Build Coastguard Worker self.assertTrue(n.output().isCompleteTensor()) 2828