# Owner(s): ["oncall: jit"] import os import re import sys import types import typing import typing_extensions from collections import OrderedDict from typing import Dict, List, Optional, Tuple import torch import torch.jit.frontend import torch.nn as nn from torch import Tensor from torch.testing import FileCheck # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) from torch.testing._internal.jit_utils import ( _tmp_donotuse_dont_inline_everything, JitTestCase, ) if __name__ == "__main__": raise RuntimeError( "This test file is not meant to be run directly, use:\n\n" "\tpython test/test_jit.py TESTNAME\n\n" "instead." ) class TestRecursiveScript(JitTestCase): def test_inferred_nonetype(self): class M(nn.Module): def __init__(self) -> None: super().__init__() self.x = None def forward(self): assert self.x is None m = torch.jit.script(M()) self.checkModule(M(), ()) def test_script_function_attribute(self): @torch.jit.script def fn1(x): return x + x @torch.jit.script def fn2(x): return x - x class M(torch.nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x): return self.fn(x) fn1_mod = M(fn1) fn2_mod = M(fn2) self.checkModule(fn1_mod, (torch.randn(2, 2),)) self.checkModule(fn2_mod, (torch.randn(2, 2),)) def test_python_function_attribute(self): class M(torch.nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x): return self.fn(x) mod = M(torch.sigmoid) self.checkModule(mod, (torch.randn(2, 2),)) def test_failed_function_compilation(self): def fn(x): return i_dont_exist # noqa: F821 class M(torch.nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x): return self.fn(x) m = M(fn) with self.assertRaisesRegexWithHighlight( RuntimeError, "failed to compile", "i_dont_exist" ): torch.jit.script(m) def test_init_error(self): class M(nn.Module): def __init__(self) -> None: self.x = 2 def forward(self): pass with self.assertRaisesRegex(RuntimeError, "has not been initialized"): torch.jit.script(M()) def test_script_after_eval(self): class M(nn.Module): def forward(self): if self.training: return 2 else: return 0 m = M() sm1 = torch.jit.script(m) m.eval() sm2 = torch.jit.script(m) # m is in eval mode, training should be False self.assertFalse(m.training) # sm1 was created while m had training = True self.assertTrue(sm1.training) self.assertEqual(sm1.training, sm1._c.getattr("training")) self.assertEqual(sm1(), 2) # sm2 was created after m was eval'ed self.assertFalse(sm2.training) self.assertEqual(sm2.training, sm2._c.getattr("training")) self.assertEqual(sm2(), 0) def test_module_name(self): class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.x = 2 def forward(self, t): return t + self.x m = torch.jit.script(MyModule()) FileCheck().check("MyModule").run(m.graph) def test_repeated_error_stack(self): def d(x): return "a" - 2 def c(x): return d(x) def b(x): return c(x) def a(x): return b(x) try: torch.jit.script(a) except Exception as e: FileCheck().check_count("is being compiled", 2).run(str(e)) try: torch.jit.script(a) except Exception as e: # Make sure that no entries are left over from the previous failure FileCheck().check_count("is being compiled", 2).run(str(e)) def test_constants_with_final(self): class M1(torch.nn.Module): x: torch.jit.Final[int] def __init__(self) -> None: super().__init__() self.x = 2 def forward(self, t): return t + self.x self.checkModule(M1(), (torch.randn(2, 2),)) class M2(torch.nn.Module): x: typing_extensions.Final[int] def __init__(self) -> None: super().__init__() self.x = 2 def forward(self, t): return t + self.x self.checkModule(M2(), (torch.randn(2, 2),)) class M3(torch.nn.Module): x: typing.Final[int] def __init__(self) -> None: super().__init__() self.x = 2 def forward(self, t): return t + self.x self.checkModule(M3(), (torch.randn(2, 2),)) def test_ignore_class(self): @torch.jit.ignore class MyScriptClass: def unscriptable(self): return "a" + 200 class TestModule(torch.nn.Module): def forward(self, x): return MyScriptClass() with self.assertRaisesRegexWithHighlight( torch.jit.frontend.FrontendError, "Cannot instantiate class", "MyScriptClass", ): t = torch.jit.script(TestModule()) def test_method_call(self): class M(nn.Module): def test(self, x): return x def forward(self, z): y = self.test(z) return z + 20 + y self.checkModule(M(), (torch.randn(2, 2),)) def test_module_repr(self): class Submodule(nn.Module): def forward(self, x): return x class MyModule(nn.Module): def __init__(self) -> None: super().__init__() self.conv = nn.Conv2d(10, 10, 3) self.lin = nn.Linear(10, 10) self.sub = Submodule() def forward(self, x): return self.lin(x) + self.sub(x) + self.conv(x) m = torch.jit.script(MyModule()) with self.capture_stdout() as out: print(m) f = FileCheck() f.check("MyModule") f.check("Conv2d") f.check("Linear") f.check("Submodule") f.run(out[0]) self.assertEqual(m.original_name, "MyModule") def test_dir(self): def test_module_dir(mod): dir_set = dir(mod) scripted_mod = torch.jit.script(mod) dir_scripted = set(dir(scripted_mod)) # set not currently copied over ignore_set = [ "training", "__delitem__", "__setitem__", "clear", "items", "keys", "pop", "update", "values", ] for attr in dir_set: if attr in ignore_set: continue self.assertTrue(attr in dir_scripted, attr) class MyModule(nn.Module): def __init__(self) -> None: super().__init__() self.conv = nn.Conv2d(10, 10, 3) self.lin = nn.Linear(10, 10) def forward(self, x): return self.lin(x) + self.conv(x) test_module_dir(MyModule()) # test custom __dir__ for containers conv = nn.Conv2d(10, 10, 3) linear = nn.Linear(10, 10) test_module_dir(nn.Sequential(conv, linear)) test_module_dir( nn.ModuleDict(OrderedDict([("conv", conv), ("linear", linear)])) ) def test_class_compile(self): def other_fn(a: int, b: Tensor) -> Tensor: return a * b class B: def __init__(self, x): self.x = 2 def helper(self, a): return self.x + a + other_fn(self.x, a) class N(torch.nn.Module): def forward(self, x): b = B(x) return b.helper(x) self.checkModule(N(), (torch.randn(2, 2),)) def test_error_stack(self): def d(x: int) -> int: return x + 10 def c(x): return d("hello") + d(x) def b(x): return c(x) def a(x): return b(x) try: scripted = torch.jit.script(a) except RuntimeError as e: checker = FileCheck() checker.check("Expected a value of type 'int'") checker.check("def c(x)") checker.check("def b(x)") checker.check("def a(x)") checker.run(str(e)) def test_error_stack_module(self): def d(x: int) -> int: return x + 10 def c(x): return d("hello") + d(x) def b(x): return c(x) class Submodule(torch.nn.Module): def forward(self, x): return b(x) class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.submodule = Submodule() def some_method(self, y): return y + self.submodule(y) def forward(self, x): return self.some_method(x) try: scripted = torch.jit.script(M()) except RuntimeError as e: checker = FileCheck() checker.check("Expected a value of type 'int'") checker.check("'c' is being compiled since it was called from 'b'") checker.check("'b' is being compiled since it was called from") checker.run(str(e)) @_tmp_donotuse_dont_inline_everything def test_script_basic(self): def a_python_fn(a, b, c): return a + b + c @torch.jit.script def a_script_fn(d, e, f): return a_python_fn(d, e, f) graph = str(a_script_fn.graph) FileCheck().check("prim::CallFunction").run(graph) FileCheck().check_not("^a_python_fn").run(graph) t = torch.ones(2, 2) self.assertEqual(a_script_fn(t, t, t), t + t + t) def test_error_stack_class(self): class X: def bad_fn(self): import pdb # noqa: F401 def fn(x) -> X: return X(10) try: torch.jit.script(fn) except Exception as e: checker = FileCheck() checker.check("import statements") checker.check("is being compiled since it was called from") checker.run(str(e)) def test_error_stack_annotation(self): class X: def bad_fn(self): import pdb # noqa: F401 def fn(x) -> X: return X(10) try: torch.jit.script(fn) except Exception as e: checker = FileCheck() checker.check("import statements") checker.check("is being compiled since it was called from") checker.check("-> X") checker.run(str(e)) def test_module_basic(self): class Other(torch.nn.Module): __constants__ = ["x"] def __init__(self, x): super().__init__() self.x = x self.param = torch.nn.Parameter(torch.ones(2, 2)) def some_unscriptable_method(self): a = 2 a = [2] return a def forward(self, t): return t + self.x + self.param class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.other = Other(200) def forward(self, t): return self.other(t) * 2 self.checkModule(M(), (torch.ones(2, 2),)) def test_module_function_export(self): class Other(torch.nn.Module): __constants__ = ["x"] def __init__(self, x): super().__init__() self.x = x self.param = torch.nn.Parameter(torch.ones(2, 2)) @torch.jit.export def some_entry_point(self, y): return y + 20 def forward(self, t): return t + self.x + self.param class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.other = Other(200) def forward(self, t): return self.other(t) * 2 self.checkModule(M(), (torch.ones(2, 2),)) def test_iterable_modules(self): class Inner(torch.nn.Module): def forward(self, x): return x + 10 class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.sequential = nn.Sequential( Inner(), Inner(), nn.Sequential(Inner(), Inner()) ) self.module_list = nn.ModuleList([Inner(), Inner()]) def forward(self, x): for mod in self.module_list: x += mod(x) x += self.sequential(x) return x self.checkModule(M(), (torch.randn(5, 5),)) def test_prepare_scriptable_basic(self): class SeluButReluWhenScripted(torch.nn.SELU): def __prepare_scriptable__(self): return nn.ReLU() t = torch.randn(5, 5) m = SeluButReluWhenScripted() sm = torch.jit.script(m) eager_out = m(t) script_out = sm(t) self.assertNotEqual(eager_out, script_out) def test_prepare_scriptable_iterable_modules(self): class SeluButReluWhenScripted(torch.nn.SELU): def __prepare_scriptable__(self): return nn.ReLU() class M(torch.nn.Module): def __init__(self) -> None: super().__init__() shared = SeluButReluWhenScripted() self.sequential = nn.Sequential( SeluButReluWhenScripted(), SeluButReluWhenScripted(), nn.Sequential( SeluButReluWhenScripted(), shared, SeluButReluWhenScripted() ), shared, ) self.module_list = nn.ModuleList( [SeluButReluWhenScripted(), shared, SeluButReluWhenScripted()] ) def forward(self, x): for mod in self.module_list: x += mod(x) x += self.sequential(x) return x t = torch.randn(5, 5) m = M() eager_out = m(t.clone()) sm = torch.jit.script(m) script_out = sm(t.clone()) self.assertNotEqual(eager_out, script_out) def test_prepare_scriptable_cycle(self): t = torch.randn(5, 5) c = torch.nn.Module() p = torch.nn.Module() c.__dict__["_p"] = p p.__dict__["_c"] = c sm = torch.jit.script(p) def test_prepare_scriptable_escape_hatch(self): class NonJitableClass: def __call__(self, int1, int2, *args): total = int1 + int2 for arg in args: total += arg return total obj = NonJitableClass() self.assertEqual(obj(1, 2), 3) self.assertEqual(obj(1, 2, 3, 4), 10) with self.assertRaisesRegex( torch.jit.frontend.NotSupportedError, expected_regex="can't take variable number of arguments", ): torch.jit.script(obj) def escape_hatch(int1: int, int2: int) -> int: return int1 + int2 class NonJitableClassWithEscapeHatch(NonJitableClass): def __prepare_scriptable__(self): return escape_hatch jit_obj = torch.jit.script(NonJitableClassWithEscapeHatch()) self.assertEqual(jit_obj(1, 2), 3) with self.assertRaisesRegex( RuntimeError, expected_regex=re.escape( "expected at most 2 argument(s) but received 4 argument(s)" ), ): jit_obj(1, 2, 3, 4) def test_attributes(self): @torch.jit.script class Inner2: def __init__(self) -> None: self.b = "a string" @torch.jit.script class Foo: def __init__(self) -> None: self.a = 4 self.inner = Inner2() @torch.jit.script class SFoo: def __init__(self) -> None: self.a = 4 self.inner = Inner2() def __setstate__(self, obj: Tuple[int, Inner2]) -> None: a, inner = obj self.a = a self.inner = inner def __getstate__(self): return (self.a, self.inner) untyped_values = ( ("my_dict", {"I": "am", "a test": "test"}), ("my_float", 2.3), ("my_int", 99), ("my_bool", False), ("my_tuple", (1, 2, 3, 4)), ("my_list", [(1, 2), (3, 4)]), # ('my_tensor', torch.randn(2, 2)), ("my_int_list", [1, 2, 3, 4]), # ('my_tensor_list', [torch.ones(2, 2) + i for i in range(4)]), ("my_bool_list", [True, True, False, True]), ("my_float_list", [1.0, 2.0, 3.0, 4.0]), ("my_str_list", ["hello", "bye"]), ) typed_values = ( ("my_empty_list", []), ("my_empty_dict", {}), ("my_none", None), ("my_object", Foo()), ("my_object2", SFoo()), ) class M(torch.nn.Module): # TODO: re-enable this once this test is in a Python 3-only syntax # file # my_empty_list : List[int] # my_empty_dict : Dict[str, int] # my_none : Optional[int] def forward(self, x): return ( self.my_dict, self.my_float, self.my_int, self.my_bool, # self.my_tensor, self.my_int_list, # self.my_tensor_list, self.my_bool_list, self.my_float_list, self.my_str_list, self.my_empty_list, self.my_empty_dict, self.my_none, self.my_object.a, self.my_object.inner.b, self.my_object.a, self.my_object2.inner.b, ) # TODO: as a followup, fix this test # We can't define class attributes like we should be doing: # class M(torch.nn.Module): # my_empty_list : List[int] # my_empty_dict : Dict[str, int] # my_none : Optional[int] # my_out_of_line_attribute: List[int] = [1, 2, 3] # since there's no string frontend for Python classes (so the `define`) # trick doesn't work. M.__annotations__ = { "my_empty_list": List[int], "my_empty_dict": Dict[str, int], "my_none": Optional[int], "my_object": Foo, "my_object2": SFoo, } m = M() for name, value in untyped_values + typed_values: setattr(m, name, value) self.checkModule(m, (torch.randn(5, 5),)) def test_function_attribute_in_submodule(self): class N(nn.Module): def __init__(self, norm): super().__init__() self.activation = torch.nn.functional.relu self.norm = norm def forward(self, src): output = src output = self.norm(output) return output class M(nn.Module): def __init__(self) -> None: super().__init__() encoder_norm = nn.ReLU() self.encoder = N(encoder_norm) def forward(self, x): return self.encoder(x) m = M() self.checkModule(m, (torch.randn(5, 5),)) def test_inner_traced_module(self): class Dummy(nn.Module): def forward(self, x): return x class Model(nn.Module): def __init__(self, dummies): super().__init__() self._dummies = dummies def forward(self, x): out = [] for dummy in self._dummies: out.append(dummy(x)) return out dummy = torch.jit.trace(Dummy(), torch.randn(1, 2)) dummies = nn.ModuleList([dummy]) model = Model(dummies) self.checkModule(model, (torch.rand(5, 5),)) def test_script_loaded_module(self): """ Test that we can hold a loaded ScriptModule as a submodule. """ class Dummy(nn.Module): def forward(self, x): return x dummy = torch.jit.script(Dummy()) dummy = self.getExportImportCopy(dummy) class ContainsLoaded(torch.nn.Module): def __init__(self) -> None: super().__init__() self.encoder = dummy def forward(self, input): return self.encoder(input) self.checkModule(ContainsLoaded(), (torch.rand(2, 3),)) def test_optional_module(self): class Dummy(nn.Module): def __init__(self) -> None: super().__init__() self.foo = nn.Linear(2, 2) def forward(self, x): if self.foo is not None: return self.foo(x) return x mod = Dummy() self.checkModule(mod, (torch.rand(2, 2),)) mod.foo = None self.checkModule(mod, (torch.rand(2, 2),)) def test_override_instance_method_ignore(self): class M(torch.nn.Module): @torch.jit.ignore def i_am_ignored(self): return "old" m = M() # Override the ignored method by binding a new method to this instance. @torch.jit.ignore def i_am_ignored(self): return "new" m.i_am_ignored = types.MethodType(i_am_ignored, m) self.assertEqual(m.i_am_ignored(), "new") # ScriptModule should correctly reflect the override. s = torch.jit.script(m) self.assertEqual(s.i_am_ignored(), "new")