# Owner(s): ["oncall: jit"] import os import sys from typing import Any, Dict, List import torch from torch.testing._internal.jit_utils import JitTestCase # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) if __name__ == "__main__": raise RuntimeError( "This test file is not meant to be run directly, use:\n\n" "\tpython test/test_jit.py TESTNAME\n\n" "instead." ) class TestModuleAPIs(JitTestCase): def test_default_state_dict_methods(self): """Tests that default state dict methods are automatically available""" class DefaultStateDictModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(6, 16, 5) self.fc = torch.nn.Linear(16 * 5 * 5, 120) def forward(self, x): x = self.conv(x) x = self.fc(x) return x m1 = torch.jit.script(DefaultStateDictModule()) m2 = torch.jit.script(DefaultStateDictModule()) state_dict = m1.state_dict() m2.load_state_dict(state_dict) def test_customized_state_dict_methods(self): """Tests that customized state dict methods are in effect""" class CustomStateDictModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(6, 16, 5) self.fc = torch.nn.Linear(16 * 5 * 5, 120) self.customized_save_state_dict_called: bool = False self.customized_load_state_dict_called: bool = False def forward(self, x): x = self.conv(x) x = self.fc(x) return x @torch.jit.export def _save_to_state_dict( self, destination: Dict[str, torch.Tensor], prefix: str, keep_vars: bool ): self.customized_save_state_dict_called = True return {"dummy": torch.ones(1)} @torch.jit.export def _load_from_state_dict( self, state_dict: Dict[str, torch.Tensor], prefix: str, local_metadata: Any, strict: bool, missing_keys: List[str], unexpected_keys: List[str], error_msgs: List[str], ): self.customized_load_state_dict_called = True return m1 = torch.jit.script(CustomStateDictModule()) self.assertFalse(m1.customized_save_state_dict_called) state_dict = m1.state_dict() self.assertTrue(m1.customized_save_state_dict_called) m2 = torch.jit.script(CustomStateDictModule()) self.assertFalse(m2.customized_load_state_dict_called) m2.load_state_dict(state_dict) self.assertTrue(m2.customized_load_state_dict_called) def test_submodule_customized_state_dict_methods(self): """Tests that customized state dict methods on submodules are in effect""" class CustomStateDictModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(6, 16, 5) self.fc = torch.nn.Linear(16 * 5 * 5, 120) self.customized_save_state_dict_called: bool = False self.customized_load_state_dict_called: bool = False def forward(self, x): x = self.conv(x) x = self.fc(x) return x @torch.jit.export def _save_to_state_dict( self, destination: Dict[str, torch.Tensor], prefix: str, keep_vars: bool ): self.customized_save_state_dict_called = True return {"dummy": torch.ones(1)} @torch.jit.export def _load_from_state_dict( self, state_dict: Dict[str, torch.Tensor], prefix: str, local_metadata: Any, strict: bool, missing_keys: List[str], unexpected_keys: List[str], error_msgs: List[str], ): self.customized_load_state_dict_called = True return class ParentModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.sub = CustomStateDictModule() def forward(self, x): return self.sub(x) m1 = torch.jit.script(ParentModule()) self.assertFalse(m1.sub.customized_save_state_dict_called) state_dict = m1.state_dict() self.assertTrue(m1.sub.customized_save_state_dict_called) m2 = torch.jit.script(ParentModule()) self.assertFalse(m2.sub.customized_load_state_dict_called) m2.load_state_dict(state_dict) self.assertTrue(m2.sub.customized_load_state_dict_called)