# Owner(s): ["module: nn"] import gc import math import pickle import unittest import warnings import weakref from collections import namedtuple, OrderedDict from copy import deepcopy from functools import partial from tempfile import NamedTemporaryFile from typing import Any, Dict, List, Tuple import torch import torch.nn as nn from torch.testing._internal.common_nn import _create_basic_net, NNTestCase from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, IS_WINDOWS, parametrize as parametrize_test, run_tests, skipIfTorchDynamo, swap, TestCase, ) class Net(nn.Module): def __init__(self) -> None: super().__init__() self.seq1 = nn.Sequential(*[nn.Linear(10, 10) for _ in range(2)]) self.seq2 = nn.Sequential(*[nn.Linear(10, 10) for _ in range(2)]) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.seq2(self.seq1(x)) ToyNamedTuple = namedtuple("ToyNamedTuple", "content") class ToyModel(nn.Module): def __init__(self, with_named_tuple=False) -> None: super().__init__() self.net1 = Net() self.net2 = Net() self.with_named_tuple = with_named_tuple def forward(self, x: torch.Tensor) -> torch.Tensor: res = self.net2(self.net1(x)) if self.with_named_tuple: return ToyNamedTuple(res) else: return (res,) def forward_hook( self: TestCase, fired_hooks: List[int], expected_module: nn.Module, hook_id: int, module: nn.Module, inp: Tuple[torch.Tensor], out: torch.Tensor, ) -> None: fired_hooks.append(hook_id) self.assertEqual(id(module), id(expected_module)) self.assertEqual(len(inp), 1) def forward_pre_hook( self: TestCase, fired_hooks: List[int], expected_module: nn.Module, hook_id: int, module: nn.Module, inp: Tuple[torch.Tensor], ) -> None: fired_hooks.append(hook_id) self.assertEqual(id(module), id(expected_module)) self.assertEqual(len(inp), 1) def full_backward_hook( self: TestCase, fired_hooks: List[int], expected_module: nn.Module, hook_id: int, module: nn.Module, grad_input: Tuple[torch.Tensor], grad_output: Tuple[torch.Tensor], ) -> None: fired_hooks.append(hook_id) self.assertEqual(id(module), id(expected_module)) self.assertEqual(len(grad_input), 1) self.assertEqual(len(grad_output), 1) def full_backward_pre_hook( self: TestCase, fired_hooks: List[int], expected_module: nn.Module, hook_id: int, module: nn.Module, grad_input: Tuple[torch.Tensor], ) -> None: fired_hooks.append(hook_id) self.assertEqual(id(module), id(expected_module)) self.assertEqual(len(grad_input), 1) class KwargModel(nn.Module): def __init__(self) -> None: super().__init__() self.net1 = Net() self.net2 = Net() def forward(self, x: torch.Tensor, bias: torch.Tensor = None) -> torch.Tensor: if bias is not None: x = x + bias return x def internal_forward_hook( self, module: nn.Module, args: Tuple[torch.Tensor], kwargs: Dict[str, Any], out: torch.Tensor, ): return out + kwargs["bias"] class FailsInForwardModel(nn.Module): def __init__(self) -> None: super().__init__() self.net1 = Net() def forward(self, x: torch.Tensor, fail: bool = True) -> torch.Tensor: if fail: raise RuntimeError("failing in forward") return self.net1(x) def kwarg_forward_pre_hook( self: TestCase, fired_hooks: List[int], expected_module: nn.Module, hook_id: int, module: nn.Module, args: Tuple[torch.Tensor], kwargs: Dict[str, Any], ) -> Tuple[Any, Any]: fired_hooks.append(hook_id) self.assertEqual(id(module), id(expected_module)) self.assertEqual(len(args), 1) kwargs["bias"] = 2 * kwargs["bias"] return args, kwargs def kwarg_forward_hook( self: TestCase, fired_hooks: List[int], expected_module: nn.Module, hook_id: int, module: nn.Module, args: Tuple[torch.Tensor], kwargs: Dict[str, Any], out: torch.Tensor, ) -> Any: fired_hooks.append(hook_id) self.assertEqual(id(module), id(expected_module)) self.assertEqual(len(args), 1) out = out + kwargs["bias"] return out class DummyContextManager: def __init__(self, inp): self.input = inp def __enter__(self, *args, **kwargs): self.input.append(2) def __exit__(self, *args, **kwargs): self.input.append(-1) class TestModuleHooks(TestCase): @parametrize_test("named_tuple", (True, False)) def test_forward_hooks(self, named_tuple): fired_hooks: List[int] = [] model = ToyModel(named_tuple) x = torch.randn(10, 10) hook = partial(forward_hook, self, fired_hooks, model.net1.seq2) model.net1.seq2.register_forward_hook(partial(hook, 0)) model.net1.seq2.register_forward_hook(partial(hook, 1), prepend=True) model.net1.seq2.register_forward_hook(partial(hook, 2)) model.net1.seq2.register_forward_hook(partial(hook, 3)) model.net1.seq2.register_forward_hook(partial(hook, 4), prepend=True) expected = [4, 1, 0, 2, 3] self.assertEqual(fired_hooks, []) out = model(x) self.assertEqual(fired_hooks, expected) self.assertIsInstance(out, ToyNamedTuple if named_tuple else tuple) out[0].sum().backward() self.assertEqual(fired_hooks, expected) model(x)[0].sum().backward() self.assertEqual(fired_hooks, expected + expected) @parametrize_test("named_tuple", (True, False)) def test_forward_pre_hooks(self, named_tuple): fired_hooks: List[int] = [] model = ToyModel(named_tuple) x = torch.randn(10, 10) hook = partial(forward_pre_hook, self, fired_hooks, model.net2.seq1) model.net2.seq1.register_forward_pre_hook(partial(hook, 0), prepend=True) model.net2.seq1.register_forward_pre_hook(partial(hook, 1)) model.net2.seq1.register_forward_pre_hook(partial(hook, 2)) model.net2.seq1.register_forward_pre_hook(partial(hook, 3)) model.net2.seq1.register_forward_pre_hook(partial(hook, 4), prepend=True) expected = [4, 0, 1, 2, 3] self.assertEqual(fired_hooks, []) out = model(x) self.assertEqual(fired_hooks, expected) self.assertIsInstance(out, ToyNamedTuple if named_tuple else tuple) out[0].sum().backward() self.assertEqual(fired_hooks, expected) model(x)[0].sum().backward() self.assertEqual(fired_hooks, expected + expected) @parametrize_test("named_tuple", (True, False)) def test_full_backward_hooks(self, named_tuple): fired_hooks: List[int] = [] model = ToyModel(named_tuple) x = torch.randn(10, 10) hook = partial(full_backward_hook, self, fired_hooks, model.net1) model.net1.register_full_backward_hook(partial(hook, 0)) model.net1.register_full_backward_hook(partial(hook, 1)) model.net1.register_full_backward_hook(partial(hook, 2)) model.net1.register_full_backward_hook(partial(hook, 3), prepend=True) model.net1.register_full_backward_hook(partial(hook, 4), prepend=True) expected = [4, 3, 0, 1, 2] self.assertEqual(fired_hooks, []) out = model(x) self.assertEqual(fired_hooks, []) self.assertIsInstance(out, ToyNamedTuple if named_tuple else tuple) out[0].sum().backward() self.assertEqual(fired_hooks, expected) model(x)[0].sum().backward() self.assertEqual(fired_hooks, expected + expected) @parametrize_test("named_tuple", (True, False)) def test_full_backward_pre_hooks(self, named_tuple): fired_hooks: List[int] = [] model = ToyModel(named_tuple) x = torch.randn(10, 10) hook = partial(full_backward_pre_hook, self, fired_hooks, model.net1) model.net1.register_full_backward_pre_hook(partial(hook, 0), prepend=True) model.net1.register_full_backward_pre_hook(partial(hook, 1), prepend=True) model.net1.register_full_backward_pre_hook(partial(hook, 2)) model.net1.register_full_backward_pre_hook(partial(hook, 3)) model.net1.register_full_backward_pre_hook(partial(hook, 4)) expected = [1, 0, 2, 3, 4] self.assertEqual(fired_hooks, []) out = model(x) self.assertEqual(fired_hooks, []) self.assertIsInstance(out, ToyNamedTuple if named_tuple else tuple) out[0].sum().backward() self.assertEqual(fired_hooks, expected) model(x)[0].sum().backward() self.assertEqual(fired_hooks, expected + expected) # Backward pre hook can affect subsequent gradient computation for rg in [True, False]: a = torch.ones(2, requires_grad=rg) model = nn.Linear(2, 2) def fn(_unused_module, grad_output): return (grad_output[0] * 0,) model.register_full_backward_pre_hook(fn) out = model(a) out.sum().backward() self.assertEqual(model.weight.grad, torch.zeros(2, 2)) if rg: self.assertEqual(a.grad, torch.zeros_like(a)) else: self.assertIsNone(a.grad) @parametrize_test("named_tuple", (True, False)) def test_mixed_hooks(self, named_tuple): fired_hooks: List[int] = [] model = ToyModel(named_tuple) x = torch.randn(10, 10) model.register_forward_pre_hook( partial(forward_pre_hook, self, fired_hooks, model, 0) ) model.register_forward_hook(partial(forward_hook, self, fired_hooks, model, 1)) model.register_full_backward_pre_hook( partial(full_backward_pre_hook, self, fired_hooks, model, 2) ) model.register_full_backward_hook( partial(full_backward_hook, self, fired_hooks, model, 3) ) self.assertEqual(fired_hooks, []) out = model(x) self.assertEqual(fired_hooks, [0, 1]) self.assertIsInstance(out, ToyNamedTuple if named_tuple else tuple) out[0].sum().backward() self.assertEqual(fired_hooks, [0, 1, 2, 3]) model(x)[0].sum().backward() self.assertEqual(fired_hooks, [0, 1, 2, 3, 0, 1, 2, 3]) def test_kwarg_hooks(self): # 1. test forward pre hook fired_hooks: List[int] = [] x: torch.Tensor = torch.ones(10, 10) bias: torch.Tensor = torch.ones(10, 10) model = KwargModel() model.register_forward_pre_hook( partial(kwarg_forward_pre_hook, self, fired_hooks, model, 0), with_kwargs=True, ) # forward-pre: bias' = bias * 2 # So, out = x + bias * 2 self.assertEqual(fired_hooks, []) out = model(x, bias=bias) self.assertEqual(fired_hooks, [0]) self.assertEqual(out, x + 2 * bias, rtol=0, atol=1e-5) # 2. test forward pre and forward hooks fired_hooks: List[int] = [] x: torch.Tensor = torch.ones(10, 10) bias: torch.Tensor = torch.ones(10, 10) model = KwargModel() model.register_forward_hook( partial(kwarg_forward_hook, self, fired_hooks, model, 1), with_kwargs=True, ) model.register_forward_pre_hook( partial(kwarg_forward_pre_hook, self, fired_hooks, model, 0), with_kwargs=True, ) # forward-pre: bias' = bias * 2 # forward: out = x + bias' # forward-post: out = out + bias' # So, out = x + bias * 4 self.assertEqual(fired_hooks, []) out = model(x, bias=bias) self.assertEqual(fired_hooks, [0, 1]) self.assertEqual(out, x + 4 * bias, rtol=0, atol=1e-5) # 3. test nn.Module member method as forward-post hook x: torch.Tensor = torch.ones(10, 10) bias: torch.Tensor = torch.ones(10, 10) model = KwargModel() model.register_forward_hook(model.internal_forward_hook, with_kwargs=True) # forward: out = x + bias # forward-post: out = out + bias # So, out = x + bias * 2 out = model(x, bias=bias) self.assertEqual(out, x + 2 * bias, rtol=0, atol=1e-5) def test_remove_kwarg_hooks(self): # test forward pre and forward hooks fired_hooks: List[int] = [] x: torch.Tensor = torch.ones(10, 10) bias: torch.Tensor = torch.ones(10, 10) model = KwargModel() forward_hook_handle = model.register_forward_hook( partial(kwarg_forward_hook, self, fired_hooks, model, 1), with_kwargs=True, ) forward_pre_hook_handle = model.register_forward_pre_hook( partial(kwarg_forward_pre_hook, self, fired_hooks, model, 0), with_kwargs=True, ) # forward-pre: bias' = bias * 2 # forward: out = x + bias' # forward-post: out = out + bias' # So, out = x + bias * 4 self.assertEqual(fired_hooks, []) out = model(x, bias=bias) self.assertEqual(fired_hooks, [0, 1]) self.assertEqual(out, x + 4 * bias, rtol=0, atol=1e-5) # forward-pre: bias' = bias * 2 # forward: out = x + bias' # So, out = x + bias * 2 forward_hook_handle.remove() out = model(x, bias=bias) self.assertEqual(fired_hooks, [0, 1, 0]) self.assertEqual(out, x + 2 * bias, rtol=0, atol=1e-5) self.assertFalse(forward_hook_handle.id in model._forward_hooks_with_kwargs) # forward: out = x + bias # So, out = x + bias forward_pre_hook_handle.remove() out = model(x, bias=bias) self.assertEqual(fired_hooks, [0, 1, 0]) self.assertEqual(out, x + bias, rtol=0, atol=1e-5) self.assertFalse( forward_pre_hook_handle.id in model._forward_pre_hooks_with_kwargs ) def test_always_called_forward_hooks(self): x: torch.Tensor = torch.ones(10, 10) model = FailsInForwardModel() stack = [] ctx = None def setup_context(): nonlocal ctx ctx = DummyContextManager(stack) def ctx_setup_hook(m, i): setup_context() ctx.__enter__() def ctx_setup_failure_hook(m, i): setup_context() ctx.__enter__() raise RuntimeError("failing in ctx setup") def ctx_shutdown_hook(m, i, o): ctx.__exit__() def ctx_shutdown_failure_hook(m, i, o): ctx.__exit__() raise RuntimeError("failing in ctx shutdown") def throw_hook(m, i, o): raise RuntimeError("failing in throw") forward_pre_hook_handle = model.register_forward_pre_hook(ctx_setup_hook) forward_hook_handle = model.register_forward_hook( ctx_shutdown_hook, always_call=True ) self.assertTrue(len(model._forward_hooks_always_called) == 1) # make sure always_called forward hook runs when model.forward raises RuntimeError with self.assertRaisesRegex(RuntimeError, "failing in forward"): model(x) self.assertEqual(stack, [2, -1]) # make sure that always_called forward hook does not run twice if there is no error model(x, fail=False) self.assertEqual(stack, [2, -1, 2, -1]) # make sure always_called forward hook runs when forward pre hook raises RuntimeError forward_pre_hook_handle.remove() model.register_forward_pre_hook(ctx_setup_failure_hook) with self.assertRaisesRegex(RuntimeError, "failing in ctx setup"): model(x, fail=False) self.assertEqual(stack, [2, -1, 2, -1, 2, -1]) # make sure always_called hook runs when another always_called forward hook raises an error forward_hook_handle2 = model.register_forward_hook( throw_hook, prepend=True, always_call=True ) # error raised should not be error of the forced hook with self.assertRaisesRegex(RuntimeError, "failing in ctx setup"): model(x, fail=False) self.assertEqual(stack, [2, -1, 2, -1, 2, -1, 2, -1]) # make sure that always called forward hooks are properly removed forward_hook_handle.remove() forward_hook_handle2.remove() self.assertTrue(len(model._forward_hooks_always_called) == 0) # make sure that always called forward hook is not run twice if it fails while running forward_hook_handle3 = model.register_forward_hook( ctx_shutdown_failure_hook, always_call=True ) with self.assertRaisesRegex(RuntimeError, "failing in ctx setup"): model(x, fail=False) self.assertEqual(stack, [2, -1, 2, -1, 2, -1, 2, -1, 2, -1]) forward_hook_handle3.remove() global_forward_hook_handle = nn.modules.module.register_module_forward_hook( ctx_shutdown_hook, always_call=True ) self.assertTrue(len(nn.modules.module._global_forward_hooks_always_called) == 1) # make sure global forward hook runs when forward pre hook raises RuntimeError with self.assertRaisesRegex(RuntimeError, "failing in ctx setup"): model(x, fail=False) self.assertEqual(stack, [2, -1, 2, -1, 2, -1, 2, -1, 2, -1, 2, -1]) # make sure forced global forward hook is properly removed global_forward_hook_handle.remove() self.assertTrue(len(nn.modules.module._global_forward_hooks_always_called) == 0) with self.assertRaisesRegex(RuntimeError, "failing in ctx setup"): model(x) self.assertEqual(stack, [2, -1, 2, -1, 2, -1, 2, -1, 2, -1, 2, -1, 2]) def test_bw_hook_warning_for_non_tensor_or_tuple(self): # Test to verify that backward hook raises warning # if result is not a Tensor or tuple of Tensors. counter = {"forward": 0, "backward": 0} def fw_pre_hook(module: nn.Module, _inputs): counter["forward"] += 1 def fw_hook(module: nn.Module, _inputs, _outputs): counter["forward"] += 1 def bw_hook(module: nn.Module, _inputs, _outputs): counter["backward"] += 1 class TestModule(nn.Module): def forward(self, dict): inp = dict["x"] x = torch.nn.functional.softmax(inp, dim=0) return {"x": x} x = torch.ones(2, requires_grad=True) model = TestModule() model.register_forward_pre_hook(fw_pre_hook) model.register_forward_hook(fw_hook) model.register_full_backward_pre_hook(bw_hook) model.register_full_backward_hook(bw_hook) with warnings.catch_warnings(record=True) as w: y = model({"x": x})["x"] loss = y.sum() loss.backward() self.assertEqual(counter["forward"], 2) self.assertEqual(counter["backward"], 0) self.assertEqual(len(w), 1) self.assertTrue("should be a Tensor or a tuple of Tensors" in str(w[0].message)) def _hook_to_pickle(*args, **kwargs): pass class TestStateDictHooks(TestCase): @swap([True, False]) def test_load_state_dict_pre_hook(self): m = nn.Linear(10, 10) m_state_dict = m.state_dict() m_load = nn.Linear(10, 10) hook_called = 0 def hook_without_module( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ): self.assertEqual(m_state_dict, state_dict) nonlocal hook_called hook_called += 1 def hook_with_module( module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ): self.assertEqual(m_state_dict, state_dict) self.assertTrue(m_load is module) nonlocal hook_called hook_called += 1 hook_called = 0 # Test private API since this sets with_module=False which diverges from public API m_load._register_load_state_dict_pre_hook(hook_without_module) m_load.load_state_dict(m_state_dict) self.assertEqual(1, hook_called) hook_called = 0 m_load.register_load_state_dict_pre_hook(hook_with_module) m_load.load_state_dict(m_state_dict) self.assertEqual(2, hook_called) # Test private API with with_module=True hook_called = 0 m_load._register_load_state_dict_pre_hook(hook_with_module, True) m_load.load_state_dict(m_state_dict) self.assertEqual(3, hook_called) def test_no_extra_ref_to_module(self): try: gc.disable() m = nn.Linear(10, 10) m.register_load_state_dict_pre_hook(_hook_to_pickle) weak_m = weakref.ref(m) del m self.assertEqual(weak_m(), None) finally: gc.enable() def test_pickled_hook(self): m = nn.Linear(10, 10) m.register_load_state_dict_pre_hook(_hook_to_pickle) pickle.loads(pickle.dumps(m)) @swap([True, False]) def test_load_state_dict_module_pre_hook(self): hook_called = 0 # Test with module instance method as hook class MyModule(nn.Module): def __init__(self) -> None: super().__init__() self.foo = torch.nn.Parameter(torch.rand(10)) def my_pre_load_hook( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ): assert [] == error_msgs assert [] == unexpected_keys assert [] == missing_keys assert strict nonlocal hook_called hook_called += 1 def my_pre_load_hook_with_module( self, module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ): assert [] == error_msgs assert [] == unexpected_keys assert [] == missing_keys assert strict assert self is module nonlocal hook_called hook_called += 1 # Test that hooks registered on a submodule are also called # appropriately, i.e. with the submodule as module argument in # my_pre_load_hook_with_module. class MyModuleContainer(nn.Module): def __init__(self, mod): super().__init__() self.mod = mod for ctor in [MyModuleContainer, lambda x: x]: m = ctor(MyModule()) state_dict = m.state_dict() if isinstance(m, MyModuleContainer): mod = m.mod else: mod = m hook_called = 0 # Test private API since this sets with_module=False which diverges from public API mod._register_load_state_dict_pre_hook(mod.my_pre_load_hook) m.load_state_dict(state_dict) self.assertEqual(1, hook_called) hook_called = 0 mod.register_load_state_dict_pre_hook(mod.my_pre_load_hook_with_module) m.load_state_dict(state_dict) self.assertEqual(2, hook_called) @swap([True, False]) def test_load_state_dict_post_hook(self): hook_called = 0 class MyModule(nn.Module): def __init__(self) -> None: super().__init__() self.foo = torch.nn.Parameter(torch.rand(10)) def my_post_load_hook(self, module, incompatible_keys): assert module is self nonlocal hook_called incompatible_keys.missing_keys.append("foo") incompatible_keys.unexpected_keys.append("bar") hook_called += 1 nested = MyModule() wrapped = nn.ModuleList([nested]) handle = nested.register_load_state_dict_post_hook( nested.my_post_load_hook, ) # Hook must be called even if it is wrapped ret = wrapped.load_state_dict(wrapped.state_dict(), strict=False) self.assertEqual(hook_called, 1) # Ensure that the hook modified missing_keys and unexpected_keys missing = ret.missing_keys unexpected = ret.unexpected_keys self.assertEqual(missing, ["foo"]) self.assertEqual(unexpected, ["bar"]) # When called with strict=True, the error raised should mention the # missing and unexpected keys the hook added. with self.assertRaisesRegex(RuntimeError, "foo.*\n.*bar"): wrapped.load_state_dict(wrapped.state_dict(), strict=True) self.assertEqual(hook_called, 2) # Removing the hook via handle.remove() should cause it not to # fire anymore. handle.remove() # Hook did not run so it should not have added any keys ret = wrapped.load_state_dict(wrapped.state_dict(), strict=False) self.assertEqual(ret.missing_keys, []) self.assertEqual(ret.unexpected_keys, []) # hook_called should not have been incremented self.assertEqual(hook_called, 2) def load_hook_clear_incompatible(module, incompatible_keys): incompatible_keys.missing_keys.clear() incompatible_keys.unexpected_keys.clear() nested.register_load_state_dict_post_hook(load_hook_clear_incompatible) state_dict = wrapped.state_dict() state_dict["extra"] = torch.ones(1) # load state_dict with strict=True should not throw. ret = wrapped.load_state_dict(state_dict, strict=True) # explicitly ensure that the post hook clearned out incompatible_keys self.assertEqual([], ret.missing_keys) self.assertEqual([], ret.unexpected_keys) @unittest.skipIf(IS_WINDOWS, "Tempfile permission issue on windows") @swap([True, False]) def test_load_state_dict_post_hook_backward_compatibility(self): def my_post_load_hook(mod, _): nonlocal called called = True for m in [nn.Softmin(10), nn.Softmax(10), nn.LogSoftmax(10)]: called = False sd = deepcopy(m.state_dict()) self.assertTrue(hasattr(m, "_load_state_dict_post_hooks")) # Simulate an older model that did not have this attr delattr(m, "_load_state_dict_post_hooks") # Save and load, and ensure that load_state_dict works (without proper # BC we would run into errors because this attribute would be expected). # In particular, Softmax runs into the issue described here: # https://github.com/pytorch/pytorch/issues/77280 with NamedTemporaryFile() as f: # Note that torch.save / torch.load is not recommended to save/load # modules. torch.save(m, f.name) # weights_only=False as this is legacy code that saves the model m = torch.load(f.name, weights_only=False) m.load_state_dict(sd) self.assertFalse(called) # Ensure hooks can be registered and called. m.register_load_state_dict_post_hook(my_post_load_hook) m.load_state_dict(sd) self.assertTrue(called) def _test_register_state_dict_pre_hook(self, model, submodule): _state_dict_prefix = "foo." state_dict_pre_hook_count = 0 keep_var_setting = False def my_state_dict_pre_hook(module, prefix, keep_vars): self.assertEqual(keep_vars, keep_var_setting) nonlocal state_dict_pre_hook_count state_dict_pre_hook_count += 1 self.assertTrue(prefix.startswith(_state_dict_prefix)) model.register_state_dict_pre_hook(my_state_dict_pre_hook) # Test to ensure submodules run the hook as well. submodule.register_state_dict_pre_hook(my_state_dict_pre_hook) def check_results(model): nonlocal state_dict_pre_hook_count, keep_var_setting for keep_var_setting in [True, False]: _ = model.state_dict( prefix=_state_dict_prefix, keep_vars=keep_var_setting ) self.assertEqual(2, state_dict_pre_hook_count) state_dict_pre_hook_count = 0 # Test state dict works as expected after model construction check_results(model) # Test state dict works as expected after forward model(torch.ones(10, 3)) check_results(model) def test_register_state_dict_pre_hook(self): class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.a = nn.Sequential( nn.Linear(3, 3), nn.Linear(3, 3), nn.Linear(3, 3) ) def forward(self, x): return self.a(x) mod = MyModule() self._test_register_state_dict_pre_hook(mod, mod.a) def test_register_state_dict_pre_hook_lazy_module(self): class MyLazyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.layer1 = nn.LazyLinear(8) self.layer2 = nn.LazyLinear(5) def forward(self, x): return self.layer2(self.layer1(x)) mod = MyLazyModule() self._test_register_state_dict_pre_hook(mod, mod.layer1) @unittest.skipIf(IS_WINDOWS, "Tempfile permission issue on windows") def test_register_state_dict_pre_hook_backward_compat(self): called = False def my_state_dict_pre_hook(*args, **kwargs): nonlocal called called = True m = nn.Linear(1, 1) self.assertTrue(hasattr(m, "_state_dict_pre_hooks")) delattr(m, "_state_dict_pre_hooks") # Save and load, ensure we can still call state_dict # without running into issues. with NamedTemporaryFile() as f: # Note that torch.save / torch.load is not recommended # to save / load modules. torch.save(m, f.name) # weights_only=False as this is legacy code that saves the model m = torch.load(f.name, weights_only=False) # Ensure we can run state_dict without issues _ = m.state_dict() self.assertFalse(called) m.register_state_dict_pre_hook(my_state_dict_pre_hook) _ = m.state_dict() self.assertTrue(called) @parametrize_test("private", [True, False]) def test_register_state_dict_post_hook(self, private): m = nn.Transformer( d_model=4, nhead=2, num_encoder_layers=2, num_decoder_layers=2 ) def linear_state_dict_post_hook(module, state_dict, prefix, local_metadata): for name, param in module.named_parameters(recurse=False): state_dict[prefix + name] = torch.nn.Parameter( state_dict[prefix + name] ) def register_linear_hook(module): if isinstance(module, nn.Linear): hook_registration_fn = ( module._register_state_dict_hook if private else module.register_state_dict_post_hook ) hook_registration_fn(linear_state_dict_post_hook) def _check_sd(state_dict): for k, v in m.state_dict().items(): if "linear" in k or "out_proj" in k: self.assertTrue(isinstance(v, torch.nn.Parameter)) else: self.assertFalse(isinstance(v, torch.nn.Parameter)) # verify that return type of hook registered on child submodules has no effect # regardless of whether using public or private API m.apply(register_linear_hook) _check_sd(m.state_dict()) # verify that return type of hook registered root module has no effect # for public API but has effect for private API hook_registration_fn = ( m._register_state_dict_hook if private else m.register_state_dict_post_hook ) def fn(m, s, p, l): return OrderedDict() handle = hook_registration_fn(fn) if private: self.assertFalse(hasattr(fn, "_from_public_api")) self.assertTrue(len(m.state_dict()) == 0) else: self.assertTrue(hasattr(fn, "_from_public_api")) with self.assertRaisesRegex( RuntimeError, "state_dict post-hook must return None" ): sd = m.state_dict() with self.assertRaisesRegex( RuntimeError, "previously registered via register_state_dict_post_hook" ): m._register_state_dict_hook(fn) class TestModuleGlobalHooks(TestCase): def tearDown(self): nn.modules.module._global_backward_hooks = OrderedDict() nn.modules.module._global_forward_hooks = OrderedDict() nn.modules.module._global_forward_pre_hooks = OrderedDict() @skipIfTorchDynamo("TorchDynamo does not work well with hooks") def test_module_global_hooks(self): module = nn.Sigmoid module_1 = module() module_2 = module() module_3 = module() input = torch.ones(5, 5, requires_grad=True) counter = {"forwards": 0, "backwards": 0} def fw_hook(inc, h_module, input, output): self.assertIsInstance(input, tuple) self.assertTrue(isinstance(output, torch.Tensor)) self.assertTrue(isinstance(h_module, module)) self.assertEqual(input[0], torch.ones(5, 5)) self.assertEqual(output, torch.empty(5, 5).fill_(1 / (1 + 1 / math.e))) counter["forwards"] += inc def bw_hook(inc, h_module, grad_input, grad_output): self.assertIsInstance(grad_input, tuple) self.assertIsInstance(grad_output, tuple) self.assertTrue(isinstance(h_module, module)) self.assertEqual(grad_output[0], torch.ones(5, 5) * 2) counter["backwards"] += inc test_fwd = nn.modules.module.register_module_forward_hook( lambda *args: fw_hook(1, *args) ) module_1(input) module_2(input) module_3(input) self.assertEqual(counter["forwards"], 3) self.assertEqual(counter["backwards"], 0) test_bwd = nn.modules.module.register_module_backward_hook( lambda *args: bw_hook(1, *args) ) output_1 = module_1(input) output_2 = module_2(input) output_3 = module_3(input) self.assertEqual(counter["forwards"], 6) self.assertEqual(counter["backwards"], 0) output_1.backward(torch.ones(5, 5) * 2, retain_graph=True) output_2.backward(torch.ones(5, 5) * 2, retain_graph=False) output_3.backward(torch.ones(5, 5) * 2, retain_graph=False) self.assertEqual(counter["forwards"], 6) self.assertEqual(counter["backwards"], 3) output_1.backward(torch.ones(5, 5) * 2, retain_graph=True) self.assertEqual(counter["forwards"], 6) self.assertEqual(counter["backwards"], 4) test2_fwd = nn.modules.module.register_module_forward_hook( lambda *args: fw_hook(2, *args) ) output = module_1(input) output = module_2(input) output = module_3(input) self.assertEqual(counter["forwards"], 15) self.assertEqual(counter["backwards"], 4) test2_bwd = nn.modules.module.register_module_backward_hook( lambda *args: bw_hook(2, *args) ) module_1(input).backward(torch.ones(5, 5) * 2) self.assertEqual(counter["forwards"], 18) self.assertEqual(counter["backwards"], 7) test2_bwd.remove() module_2(input).backward(torch.ones(5, 5) * 2) self.assertEqual(counter["forwards"], 21) self.assertEqual(counter["backwards"], 8) test2_fwd.remove() module_3(input).backward(torch.ones(5, 5) * 2) self.assertEqual(counter["forwards"], 22) self.assertEqual(counter["backwards"], 9) test_fwd.remove() test_bwd.remove() def test_module_global_hook_invalid_outputs(self): module = nn.Sigmoid() input = torch.randn(5, 5, requires_grad=True) def bw_fail1(self, grad_input, grad_output): return grad_input[:-1] def bw_fail2(self, grad_input, grad_output): return grad_input + (torch.randn(2, 2),) with nn.modules.module.register_module_backward_hook(bw_fail1): with self.assertRaisesRegex(RuntimeError, "got 0, but expected 1"): module(input).sum().backward() with nn.modules.module.register_module_backward_hook(bw_fail2): with self.assertRaisesRegex(RuntimeError, "got 2, but expected 1"): module(input).sum().backward() def test_module_backward_global_hook_writeable(self): module = nn.Sigmoid() input = torch.randn(5, 5, requires_grad=True) sig_x = torch.sigmoid(input) def bw_hook(module, grad_input, grad_output): for grad in grad_input: self.assertTrue(isinstance(grad, torch.Tensor)) for grad in grad_output: self.assertTrue(isinstance(grad, torch.Tensor)) return tuple(gi * 2 for gi in grad_input) nn.modules.module.register_module_backward_hook(bw_hook) module(input).backward(torch.ones(5, 5)) expected_grad = sig_x * (1 - sig_x) * 2 self.assertEqual(input.grad, expected_grad) @skipIfTorchDynamo("TorchDynamo does not work well with hooks") def test_module_global_forward_preforward_hook_writeable(self): module = nn.Sigmoid() input = torch.randn(5, 5, requires_grad=True) sig_x = torch.sigmoid(input) def forward_pre_hook(m, input): return torch.nn.functional.relu(input[0]) def forward_hook(m, input, output): return -output nn.modules.module.register_module_forward_pre_hook(forward_pre_hook) nn.modules.module.register_module_forward_hook(forward_hook) output = module(input) expected_res = -torch.sigmoid(torch.nn.functional.relu(input)) self.assertEqual(output, expected_res) output.backward(torch.ones(5, 5) * 2, retain_graph=True) mask = input > 0 expected_grad = -sig_x * (1 - sig_x) * 2 * mask self.assertEqual(input.grad, expected_grad) def test_module_forward_preforward_hook_removable(self): """ This test is to test when multiple pre-forward hook functions can be registered successfully and used correctly, if the handle can be removable during the pre-forward hook function call. """ module = nn.Sigmoid() def removable_hook(m, input): nonlocal handle handle.remove() return input def removable_hook_2(m, input): nonlocal handle_2 handle_2.remove() return input handle = module.register_forward_pre_hook(removable_hook) handle_2 = module.register_forward_pre_hook(removable_hook_2) # make sure hook register is successful self.assertEqual(len(handle.hooks_dict_ref()), 2) self.assertEqual(len(handle_2.hooks_dict_ref()), 2) input = torch.randn(2, 2) output = module(input) self.assertEqual(torch.sigmoid(input), output) # make sure hook removal is successful self.assertFalse(handle.id in handle.hooks_dict_ref()) self.assertFalse(handle_2.id in handle.hooks_dict_ref()) self.assertEqual(len(handle.hooks_dict_ref()), 0) self.assertEqual(len(handle_2.hooks_dict_ref()), 0) def test_module_forward_forward_hook_removable(self): """ This test is to test when multiple forward hook functions can be registered successfully and used correctly, if the handle can be removable during the forward hook function call. """ module = nn.Sigmoid() def removable_hook(m, input, output): nonlocal handle handle.remove() return output def removable_hook_2(m, input, output): nonlocal handle_2 handle_2.remove() return output handle = module.register_forward_hook(removable_hook) handle_2 = module.register_forward_hook(removable_hook_2) # make sure hook register is successful self.assertEqual(len(handle.hooks_dict_ref()), 2) self.assertEqual(len(handle_2.hooks_dict_ref()), 2) input = torch.randn(2, 2) output = module(input) self.assertEqual(torch.sigmoid(input), output) # make sure hook removal is successful self.assertFalse(handle.id in handle.hooks_dict_ref()) self.assertFalse(handle_2.id in handle.hooks_dict_ref()) self.assertEqual(len(handle.hooks_dict_ref()), 0) self.assertEqual(len(handle_2.hooks_dict_ref()), 0) @skipIfTorchDynamo("TorchDynamo does not work well with hooks") def test_global_and_local_hooks_order(self): module = nn.Sigmoid() global_forward_pre_called = False local_forward_pre_called = False global_forward_called = False local_forward_called = False global_backward_called = False local_backward_called = False def global_forward_pre_hook(m, input): nonlocal global_forward_pre_called self.assertTrue(not local_forward_pre_called) global_forward_pre_called = True return input def local_forward_pre_hook(m, input): nonlocal local_forward_pre_called self.assertTrue(global_forward_pre_called) local_forward_pre_called = True return input def global_forward_hook(m, input, output): nonlocal global_forward_called self.assertTrue(not local_forward_called) global_forward_called = True return output def local_forward_hook(m, input, output): nonlocal local_forward_called self.assertTrue(global_forward_called) local_forward_called = True return output def global_backward_hook(m, input, output): nonlocal global_backward_called self.assertTrue(not local_backward_called) global_backward_called = True return input def local_backward_hook(m, input, output): nonlocal local_backward_called self.assertTrue(global_backward_called) local_backward_called = True return input input = torch.randn(5, 5, requires_grad=True) nn.modules.module.register_module_forward_pre_hook(global_forward_pre_hook) module.register_forward_pre_hook(local_forward_pre_hook) nn.modules.module.register_module_forward_hook(global_forward_hook) module.register_forward_hook(local_forward_hook) nn.modules.module.register_module_backward_hook(global_backward_hook) module.register_backward_hook(local_backward_hook) output = module(input) self.assertTrue( local_forward_called and local_forward_pre_called and global_forward_called and global_forward_pre_called ) output.backward(torch.ones(5, 5), retain_graph=True) self.assertTrue(local_backward_called and global_backward_called) class TestModuleHookNN(NNTestCase): _do_cuda_memory_leak_check = True _do_cuda_non_default_stream = True def _test_hooks(self, backward_register_fn): module = nn.Sigmoid() input = torch.ones(5, 5, requires_grad=True) counter = {"forwards": 0, "backwards": 0} def fw_hook(inc, h_module, input, output): self.assertIsInstance(input, tuple) self.assertTrue(isinstance(output, torch.Tensor)) self.assertTrue(h_module is module) self.assertEqual(input[0], torch.ones(5, 5)) self.assertEqual(output, torch.empty(5, 5).fill_(1 / (1 + 1 / math.e))) counter["forwards"] += inc def bw_hook(inc, h_module, grad_input, grad_output): self.assertIsInstance(grad_input, tuple) self.assertIsInstance(grad_output, tuple) self.assertTrue(h_module is module) self.assertEqual(grad_output[0], torch.ones(5, 5) * 2) counter["backwards"] += inc # backward_pre_hook expects callback with only `module` and `grad_output` # as arguments. def bw_pre_hook(inc, h_module, grad_output): self.assertIsInstance(grad_output, tuple) self.assertTrue(h_module is module) self.assertEqual(grad_output[0], torch.ones(5, 5) * 2) counter["backwards"] += inc test_fwd = module.register_forward_hook(lambda *args: fw_hook(1, *args)) module(input) module(input) self.assertEqual(counter["forwards"], 2) self.assertEqual(counter["backwards"], 0) bw_hook_fn = ( bw_pre_hook if backward_register_fn == "register_full_backward_pre_hook" else bw_hook ) test_bwd = getattr(module, backward_register_fn)( lambda *args: bw_hook_fn(1, *args) ) output = module(input) self.assertEqual(counter["forwards"], 3) self.assertEqual(counter["backwards"], 0) output.backward(torch.ones(5, 5) * 2, retain_graph=True) self.assertEqual(counter["forwards"], 3) self.assertEqual(counter["backwards"], 1) output.backward(torch.ones(5, 5) * 2, retain_graph=True) self.assertEqual(counter["forwards"], 3) self.assertEqual(counter["backwards"], 2) test2_fwd = module.register_forward_hook(lambda *args: fw_hook(2, *args)) output = module(input) self.assertEqual(counter["forwards"], 6) self.assertEqual(counter["backwards"], 2) test2_bwd = getattr(module, backward_register_fn)( lambda *args: bw_hook_fn(2, *args) ) module(input).backward(torch.ones(5, 5) * 2) self.assertEqual(counter["forwards"], 9) self.assertEqual(counter["backwards"], 5) test2_bwd.remove() module(input).backward(torch.ones(5, 5) * 2) self.assertEqual(counter["forwards"], 12) self.assertEqual(counter["backwards"], 6) test2_fwd.remove() module(input).backward(torch.ones(5, 5) * 2) self.assertEqual(counter["forwards"], 13) self.assertEqual(counter["backwards"], 7) test_fwd.remove() test_bwd.remove() def test_hooks(self): self._test_hooks("register_backward_hook") self._test_hooks("register_full_backward_hook") self._test_hooks("register_full_backward_pre_hook") def test_hook_cpp(self): bn = nn.BatchNorm1d(5) def hook(module, grad_inputs, grad_outputs): self.assertEqual(len(grad_inputs), 1) self.assertEqual(len(grad_outputs), 1) self.assertEqual(module, bn) bn.register_full_backward_hook(hook) output = bn(torch.randn(5, 5, requires_grad=True)) output.sum().backward() def test_backward_hooks_interaction(self): # Test to make sure that the grad_outputs # updated by full_backward_pre_hook are received by # the full_backward_hook module = torch.nn.Sigmoid() cnt = {"backward_cnt": 0} def bw_pre_hook(m, grad_output): cnt["backward_cnt"] += 1 return (grad_output[0] * 0.5,) def bw_hook(m, grad_in, grad_output): self.assertEqual(torch.full_like(grad_output[0], 0.5), grad_output[0]) cnt["backward_cnt"] += 1 return grad_output module.register_full_backward_pre_hook(bw_pre_hook) module.register_full_backward_hook(bw_hook) t = torch.ones(1, 2, requires_grad=True) module(t).sum().backward() self.assertEqual(cnt["backward_cnt"], 2) def test_hook_invalid_outputs(self): module = nn.Sigmoid() input = torch.randn(5, 5, requires_grad=True) def bw_fail1(self, grad_input, grad_output): return grad_input[:-1] def bw_fail2(self, grad_input, grad_output): return grad_input + (torch.randn(2, 2),) with module.register_backward_hook(bw_fail1): with self.assertRaisesRegex(RuntimeError, "got 0, but expected 1"): module(input).sum().backward() with module.register_backward_hook(bw_fail2): with self.assertRaisesRegex(RuntimeError, "got 2, but expected 1"): module(input).sum().backward() def bw_pre_fail1(self, grad_output): return () def bw_pre_fail2(self, grad_output): return grad_output + (torch.randn(2, 2),) with module.register_full_backward_pre_hook(bw_pre_fail1): with self.assertRaisesRegex(RuntimeError, "got 0, but expected 1"): module(input).sum().backward() with module.register_full_backward_pre_hook(bw_pre_fail2): with self.assertRaisesRegex(RuntimeError, "got 2, but expected 1"): module(input).sum().backward() def test_hook_requires_grad(self): test_self = self class MyModule(nn.Module): def forward(self, arg1, arg2, arg3): test_self.assertTrue(arg1.requires_grad) test_self.assertFalse(arg2.requires_grad) test_self.assertTrue(arg3.requires_grad) return arg1.sum() + arg2.sum() + arg3.sum() inp = torch.rand(2, requires_grad=True) mod = MyModule() mod(inp, inp.detach(), inp) # Ensure that requires grad is properly propagated mod.register_full_backward_hook(lambda mod, gI, gO: None) mod(inp, inp.detach(), inp) def test_hook_no_requires_grad(self): mod = nn.Linear(2, 3) inp = torch.rand(1, 2) return_val = "None" hook_called = [0] def hook(mod, grad_input, grad_output): hook_called[0] += 1 for gI in grad_input: self.assertIsNone(gI) for gO in grad_output: self.assertEqual(gO.size(), (1, 3)) if return_val == "grad_input": return grad_input elif return_val == "invalid": # If the inputs were requiring gradients, this would be # a valid return return inp elif return_val == "None": return None else: raise RuntimeError("Invalid return_val string") mod.register_full_backward_hook(hook) # This should run and trigger the hook properly mod(inp).sum().backward() self.assertEqual(hook_called[0], 1) return_val = "grad_input" mod(inp).sum().backward() self.assertEqual(hook_called[0], 2) return_val = "invalid" with self.assertRaisesRegex(RuntimeError, "where no input requires gradient"): mod(inp).sum().backward() def test_hook_last_arg_requires_grad(self): mod = nn.L1Loss() inp = torch.rand(1, requires_grad=True) mod.register_full_backward_hook(lambda m, gI, gO: None) try: mod(inp.detach(), inp) except Exception as ex: self.fail(f"Unexpected exception: {ex}") def test_hook_extra_input(self): class MyModule(nn.Module): def forward(self, non_tensor, tensor): return tensor.clone(), non_tensor inp = torch.rand(2, requires_grad=True) mod = MyModule() def hook(mod, grad_input, grad_output): self.assertIsNone(grad_input[0]) self.assertIsInstance(grad_input[1], torch.Tensor) self.assertIsInstance(grad_output[0], torch.Tensor) self.assertIsNone(grad_output[1]) mod.register_full_backward_hook(hook) out, _ = mod(True, inp) out.sum().backward() def test_hook_inplace(self): class MyModule(nn.Module): def forward(self, inp, do_inplace): self.inp = inp if do_inplace: inp += 1 return inp.clone() hook_called = [0] def hook(mod, grad_input, grad_output): hook_called[0] += 1 def hook_pre(mod, grad_output): hook_called[0] += 1 inp = torch.rand(10, requires_grad=True) mod = MyModule() for hook_fn, register_fn in [ (hook, mod.register_full_backward_hook), (hook_pre, mod.register_full_backward_pre_hook), ]: hook_called[0] = 0 with register_fn(hook_fn): # No inplace should work mod(inp, False).sum().backward() self.assertEqual(hook_called[0], 1) # Input inplace error should throw an error with self.assertRaisesRegex( RuntimeError, "Output 0 of BackwardHookFunctionBackward is " "a view and is being modified inplace.", ): mod(inp.clone(), True) # Input inplace error should throw an error if we try to re-use the view after they have # been modified local_inp = inp.clone() out = mod(local_inp, False) local_inp[0] *= 1 with self.assertRaisesRegex( RuntimeError, "Output 0 of BackwardHookFunctionBackward is " "a view and its base or another view", ): # Any operation involving the view will fail here mod.inp + 2 # Output inplace error should throw an error out = mod(inp, False) with self.assertRaisesRegex( RuntimeError, "BackwardHookFunctionBackward is a view " "and is being modified inplace.", ): out += 1 def test_hook_non_full_warning(self): def noop(*args): pass a = torch.rand(2, requires_grad=True) b = torch.rand(2, requires_grad=True) # Check invalid input container class MyModule(nn.Module): def forward(self, l): return l[0].clone(), l[1].clone() m = MyModule() m.register_backward_hook(noop) with self.assertWarnsRegex( FutureWarning, "does not take as input a single Tensor or a tuple of Tensors", ): m([a, b]) # Check invalid output container class MyModule(nn.Module): def forward(self, a, b): return [a.clone(), b.clone()] m = MyModule() m.register_backward_hook(noop) with self.assertWarnsRegex( FutureWarning, "does not return a single Tensor or a tuple of Tensors" ): m(a, b) # Check invalid output from different Nodes class MyModule(nn.Module): def forward(self, a, b): return a.clone(), b.clone() m = MyModule() m.register_backward_hook(noop) with self.assertWarnsRegex( FutureWarning, "outputs are generated by different autograd Nodes" ): m(a, b) # Check invalid forward with multiple Nodes class MyModule(nn.Module): def forward(self, a): return a.clone().clone() m = MyModule() m.register_backward_hook(noop) with self.assertWarnsRegex( FutureWarning, "the forward contains multiple autograd Nodes" ): m(a) def test_hook_backward_size(self): # Make module with multiple operations in forward # And different size for input and outputs class MyModule(nn.Module): def forward(self, arg1, arg2): tmp = arg1.sum() * arg2 tmp = tmp + arg2.sum() * arg1.sum() tmp = tmp.sum().view(1) tmp = tmp.expand(8).contiguous() return tmp module = MyModule() inp1 = torch.randn(5, 5, requires_grad=True) inp2 = torch.randn(10, 10, requires_grad=True) def bw_hook(module, grad_input, grad_output): self.assertEqual(len(grad_input), 2) self.assertEqual(grad_input[0].size(), torch.Size([5, 5])) self.assertEqual(grad_input[1].size(), torch.Size([10, 10])) self.assertEqual(len(grad_output), 1) self.assertEqual(grad_output[0].size(), torch.Size([8])) with module.register_full_backward_hook(bw_hook): module(inp1, inp2).sum().backward() def test_hook_backward_writeable(self): module = nn.Sigmoid() input = torch.randn(5, 5, requires_grad=True) sig_x = torch.nn.functional.sigmoid(input) def bw_hook(module, grad_input, grad_output): for grad in grad_input: self.assertTrue(isinstance(grad, torch.Tensor)) for grad in grad_output: self.assertTrue(isinstance(grad, torch.Tensor)) return tuple(gi * 2 for gi in grad_input) module.register_backward_hook(bw_hook) module(input).backward(torch.ones(5, 5)) expected_grad = sig_x * (1 - sig_x) * 2 self.assertEqual(input.grad, expected_grad) def test_hook_forward_preforward_writable(self): module = nn.Sigmoid() input = torch.randn(5, 5, requires_grad=True) sig_x = torch.nn.functional.sigmoid(input) def forward_pre_hook(m, input): return torch.nn.functional.relu(input[0]) def forward_hook(m, input, output): return -output module.register_forward_pre_hook(forward_pre_hook) module.register_forward_hook(forward_hook) output = module(input) expected_res = -torch.nn.functional.sigmoid(torch.nn.functional.relu(input)) self.assertEqual(output, expected_res) output.backward(torch.ones(5, 5) * 2, retain_graph=True) mask = input > 0 expected_grad = -sig_x * (1 - sig_x) * 2 * mask self.assertEqual(input.grad, expected_grad) def test_hook_buffer_registration(self): for return_buffer in (True, False): def buffer_registration_hook(module, name, buffer): buffer.registered = True if return_buffer: return buffer handle = torch.nn.modules.module.register_module_buffer_registration_hook( buffer_registration_hook ) try: l, n, s = _create_basic_net() for b in s.buffers(): self.assertTrue(getattr(b, "registered", False)) finally: handle.remove() def test_hook_submodule_registration(self): for return_submodule in (True, False): def module_registration_hook(module, name, submodule): module.registered = True submodule.registered = True if return_submodule: return submodule handle = torch.nn.modules.module.register_module_module_registration_hook( module_registration_hook ) try: l, n, s = _create_basic_net() for m in s.modules(): self.assertTrue(getattr(m, "registered", False)) finally: handle.remove() def test_hook_parameter_registration(self): for return_parameter in (True, False): def parameter_registration_hook(module, name, parameter): parameter.registered = True if return_parameter: return parameter handle = ( torch.nn.modules.module.register_module_parameter_registration_hook( parameter_registration_hook ) ) try: l, n, s = _create_basic_net() for p in s.parameters(): self.assertTrue(getattr(p, "registered", False)) finally: handle.remove() instantiate_parametrized_tests(TestModuleHooks) instantiate_parametrized_tests(TestStateDictHooks) if __name__ == "__main__": run_tests()