# Owner(s): ["module: dynamo"]

import collections
import contextlib
import copy
import itertools
import os
import tempfile
import traceback
import types
import unittest
from copy import deepcopy
from functools import partial
from typing import Dict, NamedTuple, Tuple
from unittest.mock import patch

import torch
import torch._dynamo.test_case
import torch._dynamo.testing
import torch.nn.functional as F
from torch._dynamo.debug_utils import same_two_models
from torch._dynamo.eval_frame import unsupported
from torch._dynamo.mutation_guard import GenerationTracker
from torch._dynamo.testing import expectedFailureDynamic, same
from torch._dynamo.variables.torch_function import TensorWithTFOverrideVariable
from torch.nn.modules.lazy import LazyModuleMixin
from torch.nn.parameter import Parameter, UninitializedParameter


try:
    from . import test_functions
except ImportError:
    import test_functions


_variable = 0
_variable1 = 0


def update_global():
    global _variable, _variable1
    _variable += 1
    _variable1 += 1


class BasicModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear1 = torch.nn.Linear(10, 10)
        self.scale = torch.randn(1, 10)

    def forward(self, x):
        return F.relu(self.linear1(x)) * self.scale


class FnMember(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear1 = torch.nn.Linear(10, 10)
        self.activation = F.relu

    def forward(self, x):
        x = self.linear1(x)
        if self.activation:
            x = self.activation(x)
        return x


class FnMemberCmp(torch.nn.Module):
    def __init__(self, activation):
        super().__init__()
        self.linear1 = torch.nn.Linear(10, 10)
        self.activation = activation

    def forward(self, x):
        x = self.linear1(x)
        if self.activation is not None:
            x = self.activation(x)
        if self.activation is None:
            x = torch.sigmoid(x)
        return x


class SubmoduleExample(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layer1 = BasicModule()
        self.layer2 = BasicModule()
        self.scale = torch.randn(1, 10)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        return x * self.scale


class IsTrainingCheck(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear1 = torch.nn.Linear(10, 10)
        self.linear2 = torch.nn.Linear(10, 10)
        self.train(True)

    def forward(self, x):
        if self.training:
            mod = self.linear1
        else:
            mod = self.linear2
        return F.relu(mod(x))


class IsEvalCheck(IsTrainingCheck):
    def __init__(self) -> None:
        super().__init__()
        self.train(False)


class ModuleMethodCall(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layer1 = BasicModule()
        self.layer2 = BasicModule()
        self.scale = torch.randn(1, 10)

    def call_and_scale(self, mod, x):
        x = mod(x)
        return x * self.scale

    def forward(self, x):
        x1 = self.call_and_scale(self.layer1, x)
        x2 = self.call_and_scale(self.layer2, x)
        return x1 + x2


class UnsupportedMethodCall(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layer1 = BasicModule()
        self.scale = torch.randn(1, 10)

    def call_and_scale(self, mod, x):
        x = mod(x)
        x = x * self.scale
        return unsupported(x, x)

    def forward(self, x):
        x1 = self.call_and_scale(self.layer1, x)
        return x + x1


class UnsupportedModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layer1 = BasicModule()
        self.scale = torch.randn(1, 10)

    def forward(self, x):
        x = self.layer1(x) * self.scale
        return unsupported(x, x)


class UnsupportedModuleCall(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.mod = UnsupportedModule()

    def forward(self, x):
        return 1 + self.mod(x * 1.5)


class ModuleWithStaticForward(torch.nn.Module):
    @staticmethod
    def forward(x):
        return x * torch.sigmoid(x)


class ModuleCallModuleWithStaticForward(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.mod = ModuleWithStaticForward()

    def forward(self, x):
        return self.mod(x)


class ModuleStaticMethodCall(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layer1 = BasicModule()
        self.layer2 = BasicModule()
        self.scale = torch.randn(1, 10)

    @staticmethod
    def call_and_scale(scale, mod, x):
        x = mod(x)
        return x * scale

    def forward(self, x):
        x1 = self.call_and_scale(self.scale, self.layer1, x)
        x2 = self.call_and_scale(self.scale, self.layer2, x)
        return x1 + x2


class ModuleClassMethodCall(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layer1 = BasicModule()
        self.layer2 = BasicModule()
        self.scale = torch.randn(1, 10)

    @classmethod
    def call_and_scale(cls, scale, mod, x):
        x = mod(x)
        return x * scale

    def forward(self, x):
        x1 = self.call_and_scale(self.scale, self.layer1, x)
        x2 = self.call_and_scale(self.scale, self.layer2, x)
        return x1 + x2


class ModuleProperty(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.scale = torch.randn(1, 10)

    @property
    def scale_alias(self):
        return self.scale

    def forward(self, x):
        return x * self.scale_alias


class NestedModuleList(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layers = torch.nn.ModuleList([])
        for _ in range(3):
            self.layers.append(
                torch.nn.ModuleList(
                    [
                        torch.nn.Linear(10, 10),
                        torch.nn.ReLU(),
                    ]
                )
            )

    def forward(self, x):
        for layer, act in self.layers:
            x = act(layer(x))
        return x


class ConstLoop(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear1 = torch.nn.Linear(10, 10)
        self.count = 3

    def forward(self, x):
        for i in range(self.count):
            x = torch.sigmoid(self.linear1(x))
        return x


class ViaModuleCall(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear1 = torch.nn.Linear(10, 10)

    def forward(self, x):
        return test_functions.constant3(torch.sigmoid(self.linear1(x)), x)


class IsNoneLayer(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layer1 = torch.nn.Linear(10, 10)
        self.layer2 = None
        self.train(True)

    def forward(self, x):
        if self.layer1 is not None:
            x = self.layer1(x)
        if self.layer2 is not None:
            x = self.layer2(x)
        return x


class LayerList(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layers = [
            torch.nn.Linear(10, 10),
            torch.nn.ReLU(),
            torch.nn.Linear(10, 10),
        ]

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


class ModuleList(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layers = torch.nn.ModuleList(
            [
                torch.nn.Linear(10, 10),
                torch.nn.ReLU(),
                torch.nn.Linear(10, 10),
                torch.nn.ReLU(),
            ]
        )

    def forward(self, x):
        for i in range(len(self.layers)):
            x = self.layers[i](x)

        for layer in self.layers:
            x = layer(x)

        for layer, val in zip(self.layers, (x, x, x, x)):
            x = layer(x) + val

        for layer, val in zip(self.layers, (1, 2, 3, 4)):
            x = layer(x) + val

        for idx, layer in enumerate(self.layers):
            x = layer(x) * idx

        for idx, layer in enumerate(self.layers[::-1]):
            x = layer(x) * idx

        return x


class CustomGetItemModuleList(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layers = torch.nn.ModuleList(
            [
                torch.nn.Linear(10, 10),
                torch.nn.ReLU(),
                torch.nn.Linear(10, 10),
                torch.nn.ReLU(),
            ]
        )

    def __getitem__(self, idx: int):
        return self.layers[idx]

    def __len__(self) -> int:
        return len(self.layers)

    def forward(self, x):
        for i in range(len(self)):
            x = self[i](x)

        return x


class ModuleDict(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layers = torch.nn.ModuleDict(
            {
                "0": torch.nn.Linear(10, 10),
            }
        )

    def forward(self, x):
        # TODO(future PR): handle more logic
        x = self.layers["0"](x)
        return x


class ParameterDict(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layers = torch.nn.ParameterDict(
            {
                "0": torch.nn.Parameter(torch.randn(10, 10)),
            }
        )

    def forward(self, x):
        x = self.layers["0"].mm(x)
        return x


class CustomGetItemParameterDict(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layers = torch.nn.ParameterDict(
            {
                "0": torch.nn.Parameter(torch.randn(10, 10)),
            }
        )

    def __getitem__(self, key: str) -> torch.nn.Module:
        return self.layers[key]

    def forward(self, x):
        x = self["0"].mm(x)
        return x


class CustomGetItemModuleDict(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layers = torch.nn.ModuleDict(
            {
                "0": torch.nn.Linear(10, 10),
            }
        )

    def __getitem__(self, key: str) -> torch.nn.Module:
        return self.layers[key]

    def forward(self, x):
        x = self["0"](x)
        return x


class TensorList(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layers = (
            torch.randn((1, 10)),
            torch.randn((10, 1)),
            torch.randn((1, 10)),
            torch.randn((10, 1)),
        )

    def forward(self, x):
        for layer in self.layers:
            x = x * layer
        return x


class Children(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.l1 = torch.nn.Linear(10, 10)
        self.l2 = torch.nn.ReLU()
        self.l3 = torch.nn.Linear(10, 10)
        self.l4 = torch.nn.ReLU()

    def forward(self, x):
        for block in self.children():
            x = block(x)
        return x


class NamedChildren(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.l1 = torch.nn.Linear(10, 10)
        self.l2 = torch.nn.ReLU()
        self.l3 = torch.nn.Linear(10, 10)
        self.l4 = torch.nn.ReLU()

    def forward(self, x):
        for _, block in self.named_children():
            x = block(x)
        return x


class IntArg(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layer1 = torch.nn.Linear(10, 10)

    def forward(self, x, offset=1):
        x = F.relu(self.layer1(x)) + offset
        return x


class Seq(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layers = torch.nn.Sequential(
            torch.nn.Linear(10, 10),
            torch.nn.ReLU(),
            torch.nn.Linear(10, 10),
            torch.nn.ReLU(),
        )

    def forward(self, x):
        return self.layers(x)


class Cfg:
    def __init__(self) -> None:
        self.val = 0.5
        self.count = 3


class CfgModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.cfg = Cfg()
        self.layer = torch.nn.Linear(10, 10)

    def forward(self, x):
        for i in range(self.cfg.count):
            x = self.layer(x + self.cfg.val)
        return x


class StringMember(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear1 = torch.nn.Linear(10, 10)
        self.mode = "some_string"

    def forward(self, x):
        if self.mode == "some_string":
            return F.relu(self.linear1(x))


class _Block(torch.nn.Module):
    def forward(self, x):
        return 1.5 * torch.cat(x, 1)


class _DenseBlock(torch.nn.ModuleDict):
    _version = 2

    def __init__(
        self,
        num_layers: int = 3,
    ) -> None:
        super().__init__()
        for i in range(num_layers):
            self.add_module("denselayer%d" % (i + 1), _Block())

    def forward(self, init_features):
        features = [init_features]
        for layer in self.values():
            new_features = layer(features)
            features.append(new_features)
        return torch.cat(features, 1)


class DenseNetBlocks(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layers = _DenseBlock()

    def forward(self, x):
        return self.layers(x)


class MaterializedModule(torch.nn.Module):
    """Once the below lazy module is initialized with its first input,
    it is transformed into this module."""

    param: Parameter

    def __init__(self) -> None:
        super().__init__()
        self.register_parameter("param", None)

    def forward(self, x):
        return x


class LazyModule(LazyModuleMixin, MaterializedModule):
    param: UninitializedParameter
    cls_to_become = MaterializedModule

    def __init__(self) -> None:
        super().__init__()
        self.param = UninitializedParameter()

    def initialize_parameters(self, x):
        # force graph break to ensure this was not inlined
        torch._dynamo.graph_break()
        self.param.materialize(x.shape)


class LazyMLP(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.fc1 = torch.nn.LazyLinear(10)
        self.relu1 = torch.nn.ReLU()
        self.fc2 = torch.nn.LazyLinear(1)
        self.relu2 = torch.nn.ReLU()

    def forward(self, input):
        x = self.relu1(self.fc1(input))
        y = self.relu2(self.fc2(x))
        return y


class MyInput(NamedTuple):
    x: Dict[str, Dict[str, torch.Tensor]]
    y: torch.Tensor


class LazyLayerWithNamedTupleInput(LazyModuleMixin, torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def initialize_parameters(self, input):
        with torch.no_grad():
            self._param = torch.nn.Parameter(
                torch.empty(input.x["a"][0].shape).fill_(0.5)
            )

    def forward(self, input):
        input = input.x["a"]
        x = 0
        for i in range(len(input)):
            x = x + input[i]
        return x


class LazyModuleWithNamedTupleInput(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layer = LazyLayerWithNamedTupleInput()

    def forward(self, input):
        return self.layer(input)


class LazyLayerWithListInput(LazyModuleMixin, torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def initialize_parameters(self, input):
        with torch.no_grad():
            self._param = torch.nn.Parameter(torch.empty(input[0].shape).fill_(0.5))

    def forward(self, input):
        x = 0
        for i in range(len(input)):
            x = x + input[i]
        return x


class LazyModuleWithListInput(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layer = LazyLayerWithListInput()

    def forward(self, input):
        return self.layer(input[:-1])


class LazyModuleWithLazySubmodule(LazyModuleMixin, torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def initialize_parameters(self, input):
        with torch.no_grad():
            self.layer = LazyLayerWithListInput()

    def forward(self, x):
        return self.layer(x)


class LazyLayerWithInputs(LazyModuleMixin, torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def initialize_parameters(self, x, y):
        with torch.no_grad():
            self._param_x = torch.nn.Parameter(torch.empty(x[0].shape).fill_(0.5))
            self._param_y = torch.nn.Parameter(torch.empty(y[0].shape).fill_(0.5))

    def forward(self, x, y):
        res_x = 0
        for i in range(len(x)):
            res_x = res_x + x[i]
        res_y = 0
        for i in range(len(y)):
            res_y = res_y + y[i]
        return res_x + res_y


class LazyModuleKwArgs(LazyModuleMixin, torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def initialize_parameters(self, *args, **kwargs):
        with torch.no_grad():
            self.layer = LazyLayerWithInputs()

    def forward(self, x, y):
        return self.layer(x, y=y)


class LazyParentModule(LazyModuleMixin, torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def impl(self, x):
        return x.cos() + self._val


class LazyChildModuleNoClsToBecome(LazyParentModule):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x):
        return super().impl(x.sin())

    def initialize_parameters(self, input):
        self._val = torch.nn.Parameter(torch.ones(2, 2))


def requires_grad1(module: torch.nn.Module, recurse: bool = False) -> bool:
    requires_grad = any(p.requires_grad for p in module.parameters(recurse))
    return requires_grad


def requires_grad2(module: torch.nn.Module, recurse: bool = False) -> bool:
    requires_grad = any(p.requires_grad for p in module.parameters(recurse))
    return requires_grad


class ParametersModule1(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear1 = torch.nn.Linear(10, 10)
        self.scale = torch.nn.Parameter(torch.randn(1, 10))

    def forward(self, x):
        if not requires_grad1(self):
            return F.relu(self.linear1(x)) * self.scale
        else:
            return x + 1


class ParametersModule2(ParametersModule1):
    def forward(self, x):
        if not requires_grad2(self):
            return F.relu(self.linear1(x)) * self.scale
        else:
            return x + 1


class ParametersModule3(ParametersModule1):
    def forward(self, x):
        ones = torch.ones(10, dtype=next(self.parameters()).dtype)
        return F.relu(self.linear1(x)) * self.scale + ones


class ParametersModule4(ParametersModule1):
    def forward(self, x):
        ones = torch.ones(10, dtype=next(self.parameters(recurse=False)).dtype)
        return F.relu(self.linear1(x)) * self.scale + ones


class ParametersModule5(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear1 = torch.nn.Linear(10, 10)
        self.scale = torch.nn.Parameter(torch.randn(10, 10))
        self.scale_dup = self.scale

    def forward(self, x):
        counter = 0
        for param in self.parameters():
            counter += 1

        return x * self.scale * counter


class SuperModule(BasicModule):
    def forward(self, x):
        x = super().forward(x)
        return x + 10.0


class SuperModule2(BasicModule):
    def forward(self, x):
        return BasicModule.forward(self, x)


class ComplicatedSuperParent(torch.nn.Module):
    @classmethod
    def custom_add(cls, x):
        x = x + x
        return x


class SuperChildCallsClassMethod(ComplicatedSuperParent):
    @classmethod
    def child_func(cls, x):
        x = super().custom_add(x)
        return x

    def forward(self, x):
        x = self.child_func(x)
        return x


class HasAttrModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.scale = torch.nn.Parameter(torch.randn(1, 10))

    def forward(self, x):
        x = F.relu(x)
        if hasattr(self, "scale"):
            x *= self.scale
        if hasattr(self, "scale2"):
            x *= self.scale2
        return x


class EnumValues(torch.nn.ModuleDict):
    def __init__(
        self,
        num_layers: int = 3,
    ) -> None:
        super().__init__()
        for i in range(num_layers):
            self.add_module("denselayer%d" % (i + 1), _Block())

    def forward(self, init_features):
        features = [init_features]
        for idx, layer in enumerate(self.values()):
            new_features = layer(features)
            features.append(new_features)
        return torch.cat(features, 1)


class AccessByKeys(torch.nn.ModuleDict):
    def __init__(
        self,
        num_layers: int = 3,
    ) -> None:
        super().__init__()
        for i in range(num_layers):
            self.add_module("denselayer%d" % (i + 1), _Block())

    def forward(self, init_features):
        features = [init_features]
        for k in self.keys():
            new_features = self[k](features)
            features.append(new_features)
        return torch.cat(features, 1)


class CallForwardDirectly(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layer1 = BasicModule()
        self.layer2 = torch.nn.Linear(10, 10)

    def forward(self, x):
        x = self.layer1.forward(x)
        x = self.layer2.forward(x)
        return x


class ConvCallForwardDirectly(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layer = torch.nn.Conv2d(3, 64, 3, 1, 1, bias=False)

    def forward(self, x):
        return self.layer.forward(x)


class ConvTransposeCallForwardDirectly(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layer = torch.nn.ConvTranspose2d(4, 4, 4)

    def forward(self, x):
        return self.layer.forward(x)


class ConvCallSuperForwardDirectly(torch.nn.Conv1d):
    def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
        super().__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            **kwargs,
        )

    def forward(self, inputs, mask=None):
        outputs = super().forward(inputs)
        return outputs


class ConvTransposeCallSuperForwardDirectly(torch.nn.ConvTranspose2d):
    def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
        super().__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            **kwargs,
        )

    def forward(self, x):
        if x.numel() > 0:
            return super().forward(x)
        output_shape = [
            ((i - 1) * d - 2 * p + (di * (k - 1) + 1) + op)
            for i, p, di, k, d, op in zip(
                x.shape[-2:],
                self.padding,
                self.dilation,
                self.kernel_size,
                self.stride,
                self.output_padding,
            )
        ]
        output_shape = [x.shape[0], self.bias.shape[0]] + output_shape
        return _NewEmptyTensorOp.apply(x, output_shape)  # noqa: F821


class ModuleNameString(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear1 = torch.nn.Linear(10, 10)

    def forward(self, x):
        if self.__class__.__name__ == "ABC":
            return 10
        if self.linear1.__class__.__name__ == "Linear":
            return F.relu(self.linear1(x) + 10)
        return 11


class SelfMutatingModule(torch.nn.Module):
    def __init__(self, layer):
        super().__init__()
        self.layer = layer
        self.counter = 0

    def forward(self, x):
        result = self.layer(x) + self.counter
        self.counter += 1
        return F.relu(result)


class ModuleAttributePrecedenceBase(torch.nn.Module):
    def linear(self, x, flag=None):
        if flag:
            return x * 2.0
        return x * 3.0


class ModuleAttributePrecedence(ModuleAttributePrecedenceBase):
    def __init__(self) -> None:
        super().__init__()
        self.activation = torch.nn.ReLU()
        self.linear = torch.nn.Linear(10, 10)
        self.initializer = torch.ones([10, 10])
        self.scale = 0.5

    def activation(self, x):
        return x * 1.2

    def initializer(self):
        return torch.zeros([10, 10])

    def scale(self):
        return 2.0

    def forward(self, x):
        # object attribute takes precedence unless it's a nn.Module
        return self.activation(self.linear(self.initializer + x)) * self.scale


class ModuleForwardHasGraphBreak(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layer1 = BasicModule()
        self.layer2 = BasicModule()
        self.layer3 = torch.nn.Sequential(BasicModule(), BasicModule())
        self.layer4 = torch.nn.ModuleList(
            [
                torch.nn.Linear(10, 10),
                torch.nn.ReLU(),
                torch.nn.Linear(10, 10),
                torch.nn.ReLU(),
            ]
        )
        self.layer5 = torch.nn.ModuleDict(
            {
                "0": torch.nn.Linear(10, 10),
            }
        )
        self.scale = torch.randn(1, 10)

    def forward(self, x):
        """
        This is used to test if the results of functions like `named_parameters`
        can be reconstructed correctly after graph break.

        https://github.com/pytorch/torchdynamo/issues/1931
        """
        x = self.layer1(x)
        params1 = dict(self.named_parameters())
        params2 = list(self.parameters())
        buffers1 = dict(self.named_buffers())
        buffers2 = list(self.buffers())
        modules1 = dict(self.named_modules())
        modules2 = list(self.modules())
        torch._dynamo.graph_break()
        y = modules2
        y = modules1
        y = buffers2
        y = buffers1
        y = params2
        y = params1
        x = (
            self.layer2(x)
            + y["layer3.1.linear1.weight"]
            + y["layer4.2.weight"]
            + y["layer5.0.weight"]
        )
        return x * self.scale


class ModuleGuardNameIsValid(torch.nn.ModuleDict):
    # Guard names should be valid python identifier as we use eval() to get
    # corresponding guard value. Some guard names come from source(module path)
    # where special symbols are valid. But they are not valid python identifier,
    # we should identify these pattern and rewrite them with getattr.
    def __init__(self) -> None:
        super().__init__()
        for i in range(2):
            self.add_module("l@yer-%d" % (i + 1), BasicModule())

    def forward(self, x):
        for layer in self.values():
            x = layer(x)
        return x


class SequentialWithDuplicatedModule(torch.nn.Module):
    # Sequential module(self.layer) contains three duplicated ReLU module.
    def __init__(self) -> None:
        super().__init__()
        self.relu = torch.nn.ReLU()
        self.layer = torch.nn.Sequential(
            torch.nn.Linear(10, 20),
            self.relu,
            torch.nn.Linear(20, 20),
            self.relu,
            torch.nn.Linear(20, 10),
            self.relu,
        )

    def forward(self, x):
        return self.layer(x)


class SequentialWithDuplicatedModule2(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.relu = torch.nn.ReLU()
        self.layer = torch.nn.Sequential(
            collections.OrderedDict(
                [
                    ("linear1", torch.nn.Linear(10, 20)),
                    ("relu1", self.relu),
                    ("linear2", torch.nn.Linear(20, 20)),
                    ("relu2", self.relu),
                    ("linear3", torch.nn.Linear(20, 10)),
                    ("relu3", self.relu),
                ]
            )
        )

    def forward(self, x):
        return self.layer(x)


class ModuleComparison(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layer0 = torch.nn.Linear(10, 10)
        self.layer1 = torch.nn.Linear(10, 10)
        self.layer2 = torch.nn.Linear(10, 10)

    @property
    def encoder_layers(self):
        return [self.layer0, self.layer1, self.layer2]

    def forward(self, x):
        for layer in self.encoder_layers:
            output = layer(x)
            if layer is None or layer == self.layer0:
                output = F.relu6(output)
            else:
                output = F.relu(output)
        return output


class ModulePatch1(torch.nn.Module):
    pass


class ModulePatch2(torch.nn.Module):
    def forward(self, x):
        return x - 1


class UnspecNonInlinableModule(torch.nn.Module):
    torchdynamo_force_dynamic = True  # forced to be a UnspecializedNNModule

    def forward(self, x):
        if x.sum() > 0:
            return x + 1
        else:
            return x - 1


class UnspecNonInlinableToplevelModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.m = UnspecNonInlinableModule()

    def forward(self, x):
        return self.m(x)


def make_test(fn, expected_ops=None):
    def test_fn(self):
        return torch._dynamo.testing.standard_test(
            self, fn=fn, nargs=1, expected_ops=expected_ops
        )

    fn.eval()
    return test_fn


@contextlib.contextmanager
def temporary_tensor_subclass(torch_function=None):
    class TensorProxy(torch.Tensor):
        @classmethod
        def __torch_function__(cls, func, types, args=(), kwargs=None):
            if torch_function is not None:
                torch_function()
            return super().__torch_function__(func, types, args, kwargs)

    torch._dynamo.config.traceable_tensor_subclasses.add(TensorProxy)
    try:
        yield TensorProxy
    finally:
        torch._dynamo.config.traceable_tensor_subclasses.remove(TensorProxy)


class NNModuleTests(torch._dynamo.test_case.TestCase):
    test_seq = make_test(Seq())
    test_basicmodule1 = make_test(BasicModule())
    test_basicmodule2 = make_test(BasicModule())
    test_submodules1 = make_test(SubmoduleExample())
    test_submodules2 = make_test(SubmoduleExample())
    test_modulemethod1 = make_test(ModuleMethodCall())
    test_modulemethod2 = make_test(ModuleMethodCall())
    test_module_call_module_with_static_forward = make_test(
        ModuleCallModuleWithStaticForward()
    )
    test_module_static_method = make_test(ModuleStaticMethodCall())
    test_fnmember = make_test(FnMember())
    test_fnmembercmp1 = make_test(FnMemberCmp(F.relu))
    test_fnmembercmp2 = make_test(FnMemberCmp(None))
    test_constloop = make_test(ConstLoop())
    test_istraining1 = make_test(IsTrainingCheck())
    test_istraining2 = make_test(IsTrainingCheck())
    test_iseval1 = make_test(IsEvalCheck())
    test_iseval2 = make_test(IsEvalCheck())
    test_viamodulecall = make_test(ViaModuleCall())
    test_isnonelayer = make_test(IsNoneLayer())
    test_layerlist = make_test(LayerList())
    test_tensorlist = make_test(TensorList())
    test_intarg = make_test(IntArg())
    test_cfgmod = make_test(CfgModule())
    test_stringmember = make_test(StringMember())
    test_modulelist = make_test(ModuleList())
    test_modulelist_nested = make_test(NestedModuleList())
    test_modulelist_custom = make_test(CustomGetItemModuleList())
    test_moduledict = make_test(ModuleDict())
    test_moduledict_custom = make_test(CustomGetItemModuleDict())
    test_parameterdict = make_test(ParameterDict())
    test_parameterdict_custom = make_test(CustomGetItemParameterDict())
    test_super1 = make_test(SuperModule())
    test_super2 = make_test(SuperModule2())
    test_super_class_method = make_test(SuperChildCallsClassMethod())
    test_children = make_test(Children())
    test_named_children = make_test(NamedChildren())
    test_densenet = make_test(DenseNetBlocks())
    test_parameters1 = make_test(ParametersModule1())
    test_parameters2 = make_test(ParametersModule2())
    test_parameters3 = make_test(ParametersModule3(), expected_ops=5)
    test_parameters4 = make_test(ParametersModule4())
    test_parameters5 = make_test(ParametersModule5())
    test_hasattr = make_test(HasAttrModule())
    test_enumvalues = make_test(EnumValues())
    test_access_by_keys = make_test(AccessByKeys())
    test_module_class_method = make_test(ModuleClassMethodCall())
    test_module_property = make_test(ModuleProperty())
    test_forward_directly = make_test(CallForwardDirectly())
    test_module_name_string = make_test(ModuleNameString())
    test_module_attribute_precedence = make_test(ModuleAttributePrecedence())
    test_module_guard_name_is_valid = make_test(ModuleGuardNameIsValid())
    test_sequential_with_duplicated_module = make_test(SequentialWithDuplicatedModule())
    test_sequential_with_duplicated_module2 = make_test(
        SequentialWithDuplicatedModule2()
    )
    test_module_comparison = make_test(ModuleComparison())

    def test_module_forward_has_graph_break(self):
        m = ModuleForwardHasGraphBreak()
        x = torch.rand([10, 10])
        ref = m(x)
        opt_m = torch._dynamo.optimize("eager")(m)
        res = opt_m(x)
        self.assertTrue(torch.allclose(ref, res))

    def test_unsupportedmethod(self):
        m = UnsupportedMethodCall()
        i = torch.randn(10)
        cnt = torch._dynamo.testing.CompileCounter()
        opt_m = torch._dynamo.optimize(cnt)(m)
        r = opt_m(i)
        self.assertTrue(torch._dynamo.testing.same(r, m(i)))
        self.assertEqual(cnt.op_count, 5)

    def test_unsupportedmodule(self):
        m = UnsupportedModuleCall()
        i = torch.randn(10)
        cnt = torch._dynamo.testing.CompileCounter()
        opt_m = torch._dynamo.optimize(cnt)(m)
        r = opt_m(i)
        self.assertTrue(torch._dynamo.testing.same(r, m(i)))
        self.assertEqual(cnt.op_count, 6)

    def test_self_mutating1(self):
        m1 = torch.nn.Linear(10, 10)
        m2 = SelfMutatingModule(m1)
        m3 = SelfMutatingModule(m1)
        m4 = SelfMutatingModule(m1)
        i = torch.randn(10)
        out2 = [m2(i), m2(i), m2(i)]
        cnt = torch._dynamo.testing.CompileCounter()
        opt_m3 = torch._dynamo.optimize_assert(cnt)(m3)
        opt_m4 = torch._dynamo.optimize_assert(cnt)(m4)
        out3 = [opt_m3(i), opt_m3(i), opt_m3(i)]
        out4 = [opt_m4(i), opt_m4(i), opt_m4(i)]
        self.assertTrue(torch._dynamo.testing.same(out2, out3))
        self.assertTrue(torch._dynamo.testing.same(out2, out4))
        if torch._dynamo.config.assume_static_by_default:
            self.assertExpectedInline(cnt.frame_count, """2""")
        else:
            self.assertExpectedInline(cnt.frame_count, """1""")

    @patch.object(torch._dynamo.config, "raise_on_ctx_manager_usage", False)
    def test_generation_tag(self):
        cnt = torch._dynamo.testing.CompileCounter()

        # guarantee that we have installed
        # the generation tagging function
        with torch._dynamo.optimize_assert(cnt):
            pass

        m1 = torch.nn.Linear(10, 10)
        prev_generation = GenerationTracker.get_generation_value(m1)
        cur_generation = prev_generation + 1

        with torch._dynamo.optimize_assert(cnt):
            m2 = torch.nn.Linear(10, 10)

        self.assertEqual(GenerationTracker.get_generation_value(m1), prev_generation)
        self.assertEqual(GenerationTracker.get_generation_value(m2), cur_generation)
        # check that newly constructed instances
        # also have the same generation (even if copied from an old instance)
        m3 = deepcopy(m1)
        self.assertEqual(GenerationTracker.get_generation_value(m3), cur_generation)

    def test_simple_torch_function(self):
        def foo(x):
            # function call, twice to test wrapping
            x = F.sigmoid(x)
            x = F.sigmoid(x)
            # method call, twice to test wrapping
            x = x.sigmoid()
            x = x.sigmoid()
            return x

        with temporary_tensor_subclass() as TensorProxy:
            x = torch.randn(1).as_subclass(TensorProxy)
            cnt = torch._dynamo.testing.CompileCounter()
            out1 = foo(x)
            opt_foo = torch._dynamo.optimize(cnt, nopython=True)(foo)
            out2 = opt_foo(x)

            self.assertEqual(cnt.op_count, 4)
            self.assertTrue(torch._dynamo.testing.same(out1, out2))

    def test_torch_function_with_closure(self):
        def run():
            def foo(x):
                # function call, twice to test wrapping
                x = F.sigmoid(x)
                x = F.sigmoid(x)
                # method call, twice to test wrapping
                x = x.sigmoid()
                x = x.sigmoid()
                return x

            counter = 0

            def function():
                nonlocal counter
                # for now, only support reads from closure cells
                # TODO(future PR): support writes as well
                counter + 1

            with temporary_tensor_subclass(function) as TensorProxy:
                x = torch.randn(1).as_subclass(TensorProxy)
                x = torch.randn(1)
                cnt = torch._dynamo.testing.CompileCounter()
                out1 = foo(x)
                opt_foo = torch._dynamo.optimize(cnt, nopython=True)(foo)
                out2 = opt_foo(x)

                self.assertEqual(cnt.op_count, 4)
                self.assertTrue(torch._dynamo.testing.same(out1, out2))

        run()

    def test_torch_mangled_class_name(self):
        original = TensorWithTFOverrideVariable.global_mangled_class_name
        results = []

        def instrumented(self, tx):
            result = original(self, tx)
            results.append(result)
            return result

        TensorWithTFOverrideVariable.global_mangled_class_name = instrumented

        def one_break(x):
            x = F.sigmoid(x)
            print()  # force break
            x = x.sigmoid()
            return x

        try:
            with temporary_tensor_subclass() as TensorProxy:
                x = torch.randn(1).as_subclass(TensorProxy)
                x1 = one_break(x)

                cnt = torch._dynamo.testing.CompileCounter()
                opt_one_break = torch._dynamo.optimize(cnt)(one_break)
                x2 = opt_one_break(x)

                self.assertTrue(torch._dynamo.testing.same(x1, x2))
                self.assertEqual(cnt.frame_count, 2)
                self.assertEqual(cnt.op_count, 2)

                compile_ids = set()
                for r in results:
                    # A mangled classname looks like __subclass_TensorProxy_94524181138240_c0
                    # where the last segment contains the compile_id.
                    prefix = "__subclass_TensorProxy_"
                    before, sep, after = r.partition(prefix)
                    self.assertEqual(before, "")
                    self.assertEqual(sep, prefix)

                    class_type_id, compile_id = after.split("_")
                    self.assertTrue(class_type_id.isnumeric())
                    self.assertTrue(compile_id.startswith("c"))

                    cid = compile_id[1:]
                    self.assertTrue(cid.isnumeric())
                    compile_ids.add(cid)

                self.assertEqual(len(compile_ids), 3)

        finally:
            TensorWithTFOverrideVariable.global_mangled_class_name = original

    @patch.object(torch._dynamo.config, "raise_on_ctx_manager_usage", False)
    def test_nn_moduledict_contains(self):
        class M(torch.nn.Module):
            def __init__(self, module_dict):
                super().__init__()
                self.module_dict = module_dict

            def forward(self, x):
                if "foo" in self.module_dict:
                    x = torch.mul(x, 1.0)
                x = torch.add(x, 1.0)
                return x

        module_dict = torch.nn.ModuleDict({"foo": torch.nn.Conv2d(1, 1, 1)})
        m = M(module_dict)
        data = torch.randn(1)
        out1 = m(data)
        cnt = torch._dynamo.testing.CompileCounter()
        opt_m = torch._dynamo.optimize(cnt, nopython=True)(m)
        out2 = opt_m(data)
        self.assertEqual(cnt.op_count, 2)
        self.assertTrue(torch._dynamo.testing.same(out1, out2))

        module_dict = torch.nn.ModuleDict({"bar": torch.nn.Conv2d(1, 1, 1)})
        m = M(module_dict)
        data = torch.randn(1)
        out1 = m(data)
        cnt = torch._dynamo.testing.CompileCounter()
        torch._dynamo.reset()
        opt_m = torch._dynamo.optimize(cnt, nopython=True)(m)
        out2 = opt_m(data)

        self.assertEqual(cnt.op_count, 1)
        self.assertTrue(torch._dynamo.testing.same(out1, out2))

        module_dict = torch.nn.ModuleDict({"cat": torch.nn.Conv2d(1, 1, 1)})
        pre = m(data)
        cnt.clear()

        with torch._dynamo.optimize(cnt, nopython=False):
            opt_pre = m(data)
            m = M(module_dict)
            data = torch.randn(1)
            out1 = m(data)

        out_post = m(data)
        self.assertEqual(cnt.frame_count, 1)
        self.assertEqual(cnt.op_count, 1)
        self.assertTrue(torch._dynamo.testing.same(pre, opt_pre))
        self.assertTrue(torch._dynamo.testing.same(out1, out_post))

    # RuntimeError: SymIntArrayRef expected to contain only concrete integers
    @expectedFailureDynamic
    def test_lazy_module1(self):
        input_shape = (16, 3, 6, 7, 8)

        cnt = torch._dynamo.testing.CompileCounter()
        module = LazyModule()

        def test_static_module():
            input = torch.ones(*input_shape)
            module(input)

        # test no graph break
        opt_test_static_module = torch._dynamo.optimize(cnt, nopython=True)(
            test_static_module
        )
        opt_test_static_module()

        self.assertTrue(
            isinstance(module, MaterializedModule),
            "Module should be transformed to an instance of MaterializedModule.",
        )
        self.assertEqual(module.param.shape, input_shape)

        # test when mapped to UnspecializedNNModule
        module = LazyModule()

        def test_unspecialized():
            nonlocal module
            module = LazyModule()
            input = torch.ones(*input_shape)
            module(input)

        opt_test_unspecialized = torch._dynamo.optimize(cnt)(test_unspecialized)
        opt_test_unspecialized()

        self.assertTrue(
            isinstance(module, MaterializedModule),
            "Module should be transformed to an instance of MaterializedModule.",
        )
        self.assertEqual(module.param.shape, input_shape)

        # test with a static module in torch.*
        module = torch.nn.modules.LazyBatchNorm3d(
            affine=False, track_running_stats=False
        )

        cnt = torch._dynamo.testing.CompileCounter()

        torch._dynamo.reset()

        def test_torch_static():
            input = torch.ones(*input_shape)
            return module(input)  # fully materialized

        # test no graph break
        opt_test_torch_static = torch._dynamo.optimize(cnt, nopython=True)(
            test_torch_static
        )
        opt_test_torch_static()
        out = opt_test_torch_static()

        self.assertTrue(same(out, module(torch.ones(*input_shape))))

        self.assertTrue(
            isinstance(module, torch.nn.modules.batchnorm.BatchNorm3d),
            "Module should be transformed to an instance of BatchNorm3d.",
        )
        self.assertEqual(cnt.frame_count, 1, "No guards should have triggered.")

    # RuntimeError: SymIntArrayRef expected to contain only concrete integers
    @expectedFailureDynamic
    def test_lazy_module2(self):
        # Test FX graph 'call_module' works well if argument is lazy module
        m = LazyMLP()
        x = torch.rand([10, 10])
        opt_m = torch._dynamo.optimize("eager", nopython=True)(m)
        # We should run compile mode firstly, otherwise the module
        # would be initialized when running eager mode.
        res = opt_m(x)
        ref = m(x)
        self.assertTrue(torch.allclose(ref, res))

    # RuntimeError: SymIntArrayRef expected to contain only concrete integers
    @expectedFailureDynamic
    @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
    def test_lazy_module3(self):
        m = LazyMLP()
        x = torch.rand([10, 10])
        cnt = torch._dynamo.testing.CompileCounter()
        opt_m = torch._dynamo.optimize(cnt, nopython=True)(m)
        # first iteration
        res = opt_m(x)
        ref = m(x)
        self.assertTrue(torch.allclose(ref, res))
        # move to cuda and second iteration
        m = m.to("cuda")
        x = x.to("cuda")
        res = opt_m(x)
        ref = m(x)
        self.assertTrue(torch.allclose(ref, res))
        self.assertEqual(cnt.frame_count, 2)

    # RuntimeError: SymIntArrayRef expected to contain only concrete integers
    @expectedFailureDynamic
    def test_lazy_module4(self):
        m = LazyMLP()
        x = torch.rand([10, 10])
        cnt = torch._dynamo.testing.CompileCounter()
        opt_m = torch._dynamo.optimize(cnt, nopython=True)(m)
        # first iteration
        res = opt_m(x)
        ref = m(x)
        self.assertTrue(torch.allclose(ref, res))
        # input shape changed and second iteration
        x = torch.rand([20, 20])
        try:
            opt_m(x)
        except RuntimeError:
            self.assertIn("must have same reduction dim", traceback.format_exc())

    # RuntimeError: SymIntArrayRef expected to contain only concrete integers
    @expectedFailureDynamic
    def test_lazy_module5(self):
        # Test lazy module works well with list/tuple input
        m = LazyModuleWithListInput()
        x = [torch.rand([5, 5])] * 3 + [None]
        opt_m = torch._dynamo.optimize("eager", nopython=True)(m)
        res = opt_m(x)
        ref = m(x)
        self.assertTrue(torch.allclose(ref, res))

    # RuntimeError: SymIntArrayRef expected to contain only concrete integers
    @expectedFailureDynamic
    def test_lazy_module6(self):
        # Test new lazy submodule in lazy module's initialize_parameters
        m = LazyModuleWithLazySubmodule()
        x = [torch.rand([5, 5])] * 3
        opt_m = torch._dynamo.optimize("eager", nopython=True)(m)
        res = opt_m(x)
        ref = m(x)
        self.assertTrue(torch.allclose(ref, res))

    # RuntimeError: SymIntArrayRef expected to contain only concrete integers
    @expectedFailureDynamic
    def test_lazy_module7(self):
        # Test lazy module works well with namedtuple/dict input
        m = LazyModuleWithNamedTupleInput()
        x = MyInput(
            x={"a": [torch.rand([5, 5])] * 3, "b": torch.rand([5, 5])},
            y=torch.rand([5, 5]),
        )
        opt_m = torch.compile(backend="eager", fullgraph=True)(m)
        res = opt_m(x)
        ref = m(x)
        self.assertTrue(torch.allclose(ref, res))

    def test_lazy_module_no_cls_to_become(self):
        # make sure super() works in the case where cls_to_become is None
        m = LazyChildModuleNoClsToBecome()
        x = torch.rand(2, 2)
        opt_m = torch._dynamo.optimize("eager", nopython=True)(m)
        res = opt_m(x)
        ref = m(x)
        self.assertTrue(torch.allclose(ref, res))

    def test_lazy_module_kwargs(self):
        m = LazyModuleKwArgs()
        x = [torch.rand([5, 5])] * 3
        y = [torch.rand([5, 5])] * 2
        opt_m = torch.compile(backend="eager", fullgraph=True)(m)
        exp_res = m(x, y)
        self.assertTrue(torch.allclose(exp_res, opt_m(x, y)))

    def test_call_fn_with_non_const_inputs_safe(self):
        class ModuleSpecialFwd(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = torch.nn.Conv2d(
                    in_channels=3, out_channels=20, kernel_size=(5, 5)
                )

            def _conv_forward(self, x):
                return self.conv._conv_forward(x, self.conv.weight, self.conv.bias)

            def forward(self, x):
                return self._conv_forward(x)

        mod = ModuleSpecialFwd()
        rx = torch.randn([3, 10, 10])
        real = mod(rx)
        graph, _ = torch._dynamo.export(mod)(rx)
        self.assertTrue(torch._dynamo.testing.same(real, graph(rx)))

    def test_conv_call_forward_directly(self):
        m = ConvCallForwardDirectly()
        x = torch.rand([4, 3, 9, 9])
        ref = m(x)
        opt_m = torch.compile(backend="eager", fullgraph=True)(m)
        res = opt_m(x)
        self.assertTrue(torch.allclose(ref, res))

    def test_conv_transpose_call_forward_directly(self):
        m = ConvTransposeCallForwardDirectly()
        x = torch.rand([4, 4, 4, 4])
        ref = m(x)
        opt_m = torch.compile(backend="eager", fullgraph=True)(m)
        res = opt_m(x)
        self.assertTrue(torch.allclose(ref, res))

    def test_conv_call_super_forward_directly(self):
        x = torch.randn(4, 4)
        m = ConvCallSuperForwardDirectly(4, 4, 4)
        ref = m(x)
        opt_m = torch.compile(backend="eager", fullgraph=True)(m)
        res = opt_m(x)
        self.assertTrue(torch.allclose(ref, res))

    def test_conv_transpose_call_super_forward_directly(self):
        x = torch.randn(4, 4, 4)
        m = ConvTransposeCallSuperForwardDirectly(4, 4, 4)
        ref = m(x)
        opt_m = torch.compile(backend="eager", fullgraph=True)(m)
        res = opt_m(x)
        self.assertTrue(torch.allclose(ref, res))


class MockModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.relu = torch.nn.ReLU()
        self.linear = torch.nn.Linear(10, 10)
        self.buf0 = torch.nn.Buffer(torch.randn(10, 10))

    def forward(self, x):
        return self.relu(self.linear(x) + self.buf0)


class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
    def test_nn_module(self):
        mod = MockModule()
        cnt = torch._dynamo.testing.CompileCounter()
        opt_mod = torch._dynamo.optimize(cnt)(mod)
        self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule)

        x = torch.randn(10, 10)
        self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x)))
        self.assertEqual(cnt.frame_count, 1)

    @torch._dynamo.config.patch(guard_nn_modules=True)
    def test_attr_precedence(self):
        class Mod(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = 3

            def forward(self, x, c=4):
                return x * c

            def linear(self, x):
                return x

            def b(self, x):
                raise RuntimeError("Should not be called")

        class MyMod(Mod):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(11, 11)
                self.a = 2
                self.b = 2
                self.scale = 1

            def scale(self, x):
                # Should not be called because it is shadowed by the instance
                # attribute
                raise RuntimeError("Should not be called")

            def forward(self, x, c=None):
                return self.linear(x) * self.a * self.b * self.scale

        mod = MyMod()
        x = torch.ones(3, 3)
        ref = mod(x)

        cnts = torch._dynamo.testing.CompileCounter()
        opt_mod = torch.compile(mod, backend=cnts)
        opt_mod(torch.ones(3, 3))
        res = opt_mod(torch.ones(3, 3))

        self.assertEqual(cnts.frame_count, 1)
        self.assertEqual(ref, res)

    def test_to(self):
        mod = MockModule()
        cnt = torch._dynamo.testing.CompileCounter()
        opt_mod = torch._dynamo.optimize(cnt)(mod)
        x = torch.randn(10, 10)
        self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x)))
        self.assertEqual(cnt.frame_count, 1)

        # Ensure that there is no recompilation
        opt_mod(x)
        self.assertEqual(cnt.frame_count, 1)

        opt_mod = opt_mod.to(device="cpu").to(dtype=torch.float64)
        self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule)
        x = torch.randn(10, 10).to(dtype=torch.float64)
        opt_mod(x)
        # Ensure that there is a recompilation
        self.assertEqual(cnt.frame_count, 2)

        # Ensure that there is no recompilation
        opt_mod(x)
        self.assertEqual(cnt.frame_count, 2)

        torch._dynamo.reset()
        opt_mod(x)
        self.assertEqual(cnt.frame_count, 3)

    @torch._dynamo.config.patch(guard_nn_modules=True)
    def test_param_order(self):
        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.param1 = torch.nn.Parameter(torch.ones([1]))
                self.param2 = torch.nn.Parameter(torch.ones([2]))

            def forward(self, x):
                return x

        mod = MyModule()
        coeffs = [2, 3]

        def fn(x):
            for idx, p in enumerate(mod.parameters()):
                x += p.sum() * coeffs[idx]

            for idx, p in enumerate(mod.named_parameters()):
                x += p[1].sum() * coeffs[idx]

            return x

        ref = fn(torch.ones(1))
        cnts = torch._dynamo.testing.CompileCounter()
        opt_fn = torch._dynamo.optimize(cnts)(fn)
        res = opt_fn(torch.ones(1))

        self.assertEqual(ref, res)
        self.assertEqual(cnts.frame_count, 1)

        mod._parameters["param1"] = mod._parameters.pop("param1")
        ref = fn(torch.ones(1))
        res = opt_fn(torch.ones(1))

        self.assertEqual(ref, res)
        self.assertEqual(cnts.frame_count, 2)

    @torch._dynamo.config.patch(guard_nn_modules=True)
    def test_buffer_order(self):
        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.b1 = torch.nn.Buffer(torch.ones([1]))
                self.b2 = torch.nn.Buffer(torch.ones([2]))

            def forward(self, x):
                return x

        mod = MyModule()
        coeffs = [2, 3]

        def fn(x):
            for idx, p in enumerate(mod.buffers()):
                x += p.sum() * coeffs[idx]

            for idx, p in enumerate(mod.named_buffers()):
                x += p[1].sum() * coeffs[idx]

            return x

        ref = fn(torch.ones(1))
        cnts = torch._dynamo.testing.CompileCounter()
        opt_fn = torch._dynamo.optimize(cnts)(fn)
        res = opt_fn(torch.ones(1))

        self.assertEqual(ref, res)
        self.assertEqual(cnts.frame_count, 1)

        mod._buffers["b1"] = mod._buffers.pop("b1")
        ref = fn(torch.ones(1))
        res = opt_fn(torch.ones(1))

        self.assertEqual(ref, res)
        self.assertEqual(cnts.frame_count, 2)

    @torch._dynamo.config.patch(guard_nn_modules=True)
    def test_module_order(self):
        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear1 = torch.nn.Linear(3, 3)
                self.linear2 = torch.nn.Linear(10, 10)

            def forward(self, x):
                return x

        mod = MyModule()
        coeffs = [2, 3, 4]

        coeffs_for_mod = {mod: 10, mod.linear1: 20, mod.linear2: 30}

        # Check order of _modules
        def fn(x):
            for idx, p in enumerate(mod.modules()):
                # Something silly to force depedency on the order
                x += coeffs_for_mod[p] * coeffs[idx]
            for idx, p in enumerate(mod.named_modules()):
                x += coeffs_for_mod[p[1]] * coeffs[idx]
            for idx, p in enumerate(mod.children()):
                x += coeffs_for_mod[p] * coeffs[idx]
            for idx, p in enumerate(mod.named_children()):
                x += coeffs_for_mod[p[1]] * coeffs[idx]
            return x

        ref = fn(torch.ones(1))
        cnts = torch._dynamo.testing.CompileCounter()
        opt_fn = torch._dynamo.optimize(cnts)(fn)
        res = opt_fn(torch.ones(1))

        self.assertEqual(ref, res)
        self.assertEqual(cnts.frame_count, 1)

        mod._modules["linear1"] = mod._modules.pop("linear1")
        ref = fn(torch.ones(1))
        res = opt_fn(torch.ones(1))

        self.assertEqual(ref, res)
        self.assertEqual(cnts.frame_count, 2)

    def test_attr(self):
        class MockModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(10, 10)
                self.buf0 = torch.nn.Buffer(torch.randn(10, 10))

            def forward(self, x):
                return self.r(torch.sin(x)) + self.buf0

        mod = MockModule()
        opt_mod = torch._dynamo.optimize("eager")(mod)

        # Check parameters and buffers
        for p1, p2 in zip(mod.parameters(), opt_mod.parameters()):
            self.assertTrue(id(p1) == id(p2))
        for b1, b2 in zip(mod.buffers(), opt_mod.buffers()):
            self.assertTrue(id(b1) == id(b2))

        def get_parameter_dtype(mod: torch.nn.Module):
            parameters_and_buffers = itertools.chain(mod.parameters(), mod.buffers())
            return next(parameters_and_buffers).dtype

        opt_mod = torch._dynamo.optimize("eager")(get_parameter_dtype)
        out_dtype = opt_mod(mod)
        self.assertEqual(out_dtype, torch.float32)

    def test_dir(self):
        class MockModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(10, 10)
                self.buf0 = torch.nn.Buffer(torch.nn.Buffer(torch.randn(10, 10)))
                self.register_parameter(
                    name="param0", param=torch.nn.Parameter(torch.randn(10, 10))
                )

            def forward(self, x):
                return self.r(torch.sin(x)) + self.buf0

        mod = MockModule()
        mod_keys = dir(mod)
        opt_mod = torch._dynamo.optimize("eager")(mod)
        opt_mod_keys = dir(opt_mod)

        # Check user-defined attributes, parameters and buffers
        self.assertIn("linear", opt_mod_keys)
        self.assertIn("buf0", opt_mod_keys)
        self.assertIn("param0", opt_mod_keys)

        # Check all attributes, parameters and buffers
        self.assertTrue(len(set(mod_keys).difference(opt_mod_keys)) == 0)

    def test_no_recompile_on_nn_guarded_modules(self):
        size = (10, 10)
        cache_size_limit = 1
        num_submodules = 4
        cnts = torch._dynamo.testing.CompileCounterWithBackend("eager")

        class SubModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(*size)

            def forward(self, x):
                a = torch.sin(torch.cos(x))
                return self.linear(a)

        class MockModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.mods = [SubModule() for _ in range(num_submodules)]
                self.mods = [torch.compile(mod, backend=cnts) for mod in self.mods]

            def forward(self, x):
                for mod in self.mods:
                    x = mod(x)
                return x

        mod = MockModule()
        # Each submod is compiled separately and has a different nn module
        # guard. Ensure that recompilation logic is handle correctly.
        with unittest.mock.patch(
            "torch._dynamo.config.error_on_recompile", True
        ), unittest.mock.patch(
            "torch._dynamo.config.cache_size_limit",
            cache_size_limit,
        ):
            x = torch.randn(*size, requires_grad=True)
            mod(x)
            if torch._dynamo.config.inline_inbuilt_nn_modules:
                self.assertEqual(cnts.frame_count, 1)
            else:
                self.assertEqual(cnts.frame_count, num_submodules)

    @patch.object(torch._dynamo.config, "accumulated_cache_size_limit", 2)
    @patch.object(torch._dynamo.config, "inline_inbuilt_nn_modules", False)
    def test_recompile_limit_on_freed_module(self):
        class Mod(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.lin = torch.nn.Linear(5, 5)

            def forward(self, x):
                return self.lin(x)

        def fn(x, mod):
            return mod(x)

        cnts = torch._dynamo.testing.CompileCounterWithBackend("eager")
        opt_mod = torch.compile(fn, backend=cnts)
        for i in range(8):
            mod = Mod()
            opt_mod(torch.randn(5, 5), mod)

        # fn compiles twice
        self.assertEqual(cnts.frame_count, 2)

    @patch.object(torch._dynamo.config, "inline_inbuilt_nn_modules", True)
    def test_inline_inbuilt_nn_modules(self):
        size = (10, 10)
        cache_size_limit = 1
        num_submodules = 4
        cnts = torch._dynamo.testing.CompileCounterWithBackend("eager")

        class SubModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(*size)

            def forward(self, x):
                a = torch.sin(torch.cos(x))
                return self.linear(a)

        class MockModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.mods = [SubModule() for _ in range(num_submodules)]
                self.mods = [torch.compile(mod, backend=cnts) for mod in self.mods]

            def forward(self, x):
                for mod in self.mods:
                    x = mod(x)
                return x

        mod = MockModule()
        # Each submod is compiled separately and has a different nn module
        # guard. Ensure that recompilation logic is handle correctly.
        with unittest.mock.patch(
            "torch._dynamo.config.error_on_recompile", True
        ), unittest.mock.patch(
            "torch._dynamo.config.cache_size_limit",
            cache_size_limit,
        ):
            x = torch.randn(*size, requires_grad=True)
            mod(x)
            self.assertEqual(cnts.frame_count, 1)

    def test_cache_size_limit_on_guarded_nn_modules(self):
        cache_size_limit = 2
        num_submodules = 4
        cnts = torch._dynamo.testing.CompileCounterWithBackend("eager")

        class SubModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.relu = torch.nn.ReLU()

            def forward(self, x):
                a = torch.sin(torch.cos(x))
                return self.relu(a)

        class MockModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.mods = [SubModule() for _ in range(num_submodules)]
                self.mods = [torch.compile(mod, backend=cnts) for mod in self.mods]

            def forward(self, x):
                for mod in self.mods:
                    x = mod(x)
                return x

        mod = MockModule()
        # For the third iteration, we would reach the cache size limit, and
        # therefore the total number of expected frame count is 2 *
        # num_submodules.
        with unittest.mock.patch(
            "torch._dynamo.config.cache_size_limit",
            cache_size_limit,
        ):
            for size in [
                (4,),
                (4, 4),
                (4, 4, 4),
            ]:
                x = torch.randn(size)
                mod(x)
        if torch._dynamo.config.inline_inbuilt_nn_modules:
            self.assertEqual(cnts.frame_count, 2)
        else:
            self.assertEqual(cnts.frame_count, 2 * num_submodules)

    def test_recursion(self):
        mod = MockModule()
        cnt = torch._dynamo.testing.CompileCounter()
        opt_mod = torch._dynamo.optimize(cnt)(mod)

        for _ in range(5):
            opt_mod = torch._dynamo.optimize(cnt)(opt_mod)
        opt_mod(torch.randn(10, 10))
        self.assertEqual(cnt.frame_count, 1)

    def test_composition(self):
        class InnerModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.relu = torch.nn.ReLU()

            def forward(self, x):
                return self.relu(torch.sin(x))

        opt_inner_mod = InnerModule()

        class OuterModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.mod = opt_inner_mod

            def forward(self, x):
                return self.mod(torch.cos(x))

        outer_mod = OuterModule()
        cnt = torch._dynamo.testing.CompileCounter()
        opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod)

        x = torch.randn(4)
        self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule)
        self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x)))
        self.assertEqual(cnt.frame_count, 1)

    def test_composition_with_opt_mod(self):
        class InnerModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.relu = torch.nn.ReLU()

            def forward(self, x):
                return self.relu(torch.sin(x))

        inner_mod = InnerModule()
        cnt = torch._dynamo.testing.CompileCounter()
        opt_inner_mod = torch._dynamo.optimize(cnt)(inner_mod)

        class OuterModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.mod = opt_inner_mod

            def forward(self, x):
                return self.mod(torch.cos(x))

        outer_mod = OuterModule()
        opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod)

        x = torch.randn(4)
        self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule)
        self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x)))
        # There will be a graph break for the inner mod being OptimizedModule
        self.assertEqual(cnt.frame_count, 2)

    def test_module_patch(self):
        mod = ModulePatch1()
        mod.forward = types.MethodType(ModulePatch2.forward, mod)

        def fn(x):
            return mod(x)

        self.assertTrue(
            torch.allclose(
                torch._dynamo.optimize("eager", nopython=True)(fn)(torch.ones(10)),
                torch.zeros(1),
            )
        )

    @patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", False)
    def test_hooks_outer(self):
        class TestModule(torch.nn.Module):
            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return 2 * x + 1

        m = TestModule()

        def forward_hook(
            module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor
        ) -> torch.Tensor:
            return 2 * output + 1

        handle = m.register_forward_hook(forward_hook)
        inp = torch.tensor(1.0, requires_grad=True)

        failure_reason = None

        def guard_fail_fn(failure):
            nonlocal failure_reason
            failure_reason = failure[0]

        compiled_m = torch._dynamo.optimize(
            guard_fail_fn=guard_fail_fn, backend="eager"
        )(m)

        self.assertEqual(compiled_m(inp), m(inp))
        self.assertEqual(compiled_m(inp).item(), 7)
        self.assertTrue(failure_reason is None)

        # what if we remove our hook? we should recompile?
        handle.remove()
        self.assertEqual(compiled_m(inp), m(inp))
        self.assertEqual(compiled_m(inp).item(), 3)
        # self.assertTrue(failure_reason == "hook")

        """
        Summary:
          - removing a hook doesn't fail a guard, because we weren't compiling the hook
            (at least into the same graph) as forward in the first place! We do correctly
            omit calling the removed hook, but since this hook is a post forward hook,
            the 'RETURN' from forward is breaking the graph.

            Why is 'forward' the entrypoint to an InstructionTranslator, after I changed
            the eval_frame entrypoint to Module.__call__?
        """

    @patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", False)
    def test_hooks_inner(self):
        class TestModule(torch.nn.Module):
            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return 2 * x + 1

        m = TestModule()

        def forward_hook(
            module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor
        ) -> torch.Tensor:
            return 2 * output + 1

        handle = m.register_forward_hook(forward_hook)

        def outer_func(tensor):
            x = tensor * 2 + 1
            y = m(x)
            return y

        inp = torch.tensor(1.0, requires_grad=True)

        failure_reason = None

        def guard_fail_fn(failure):
            nonlocal failure_reason
            failure_reason = failure[0]

        cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
        compiled_func = torch._dynamo.optimize(
            guard_fail_fn=guard_fail_fn,
            backend=cc,
        )(outer_func)

        self.assertEqual(compiled_func(inp), outer_func(inp))
        self.assertEqual(compiled_func(inp).item(), 15)

        # We are compiling 1 big graph for all 3 functions including the hook.
        self.assertEqual(cc.frame_count, 1)
        self.assertEqual(cc.op_count, 6)

        # If we remove the hook, we should recompile
        handle.remove()
        self.assertEqual(compiled_func(inp), outer_func(inp))
        self.assertEqual(compiled_func(inp).item(), 7)
        self.assertTrue("forward_hooks" in failure_reason)
        self.assertEqual(cc.frame_count, 1 + 1)
        self.assertEqual(cc.op_count, 6 + 4)

        # what if instead of removing, we alter our hook?
        torch._dynamo.reset()
        m = TestModule()
        handle = m.register_forward_hook(forward_hook)
        failure_reason = None
        self.assertEqual(compiled_func(inp), outer_func(inp))
        self.assertEqual(compiled_func(inp).item(), 15)

        def new_forward_hook(
            module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor
        ) -> torch.Tensor:
            return 2 * output + 2

        m._forward_hooks[handle.id] = new_forward_hook
        self.assertEqual(compiled_func(inp), outer_func(inp))
        self.assertEqual(compiled_func(inp).item(), 16)
        self.assertRegex(failure_reason, r"___check_obj_id\(L\['m'\]._forward_hooks")

    @patch.object(torch._dynamo.config, "guard_nn_modules", False)
    @patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", True)
    @patch.object(torch._dynamo.config, "inline_inbuilt_nn_modules", False)
    def test_hooks_skip_guards(self):
        class TestModule(torch.nn.Module):
            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return 2 * x + 1

        m = TestModule()

        def forward_hook(
            module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor
        ) -> torch.Tensor:
            return 2 * output + 1

        handle = m.register_forward_hook(forward_hook)

        def outer_func(tensor):
            x = tensor * 2 + 1
            y = m(x)
            return y

        inp = torch.tensor(1.0, requires_grad=True)

        failure_reason = None

        def guard_fail_fn(failure):
            nonlocal failure_reason
            failure_reason = failure[0]

        cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
        compiled_func = torch._dynamo.optimize(
            guard_fail_fn=guard_fail_fn,
            backend=cc,
        )(outer_func)

        m = TestModule()
        handle = m.register_forward_hook(forward_hook)
        failure_reason = None
        self.assertEqual(compiled_func(inp), outer_func(inp))
        self.assertEqual(compiled_func(inp).item(), 15)
        self.assertEqual(cc.frame_count, 1)
        self.assertEqual(cc.op_count, 6)

        # if we remove the hook, dynamo shouldn't notice
        handle.remove()
        self.assertNotEqual(compiled_func(inp), outer_func(inp))
        self.assertEqual(compiled_func(inp).item(), 15)
        self.assertEqual(cc.frame_count, 1)

    def _forward_hook_test_helper(self, model):
        forward_handles = {}
        compiled_activations = {}
        eager_activations = {}
        activations = None

        def save_activations(name, mod, inp, out):
            activations[name] = inp

        for name, module in model.named_modules():
            forward_handles[name] = module.register_forward_hook(
                partial(save_activations, name)
            )

        compiled_model = torch.compile(model, backend="aot_eager")

        activations = compiled_activations
        for i in range(2):
            # second iteration is key, hooks would have fired during aot trace
            # on first iter
            compiled_activations.clear()
            x = torch.randn((20, 10))
            pred = compiled_model(x)
            loss = pred.sum()
            loss.backward()

        activations = eager_activations
        for i in range(2):
            # second iteration is key, hooks would have fired during aot trace
            # on first iter
            eager_activations.clear()
            x = torch.randn((20, 10))
            pred = model(x)
            loss = pred.sum()
            loss.backward()

        print(f"Recorded Layers: {compiled_activations.keys()}\n\n")
        print(f"Expected Layers: {eager_activations.keys()}")

        self.assertTrue(compiled_activations.keys() == eager_activations.keys())
        self.assertTrue(activations.keys() == forward_handles.keys())

    def test_hooks_allowed_modules(self):
        # this test shouldn't care whether hook guards are enabled or not
        class ToyModel(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.net = torch.nn.Sequential(
                    *[torch.nn.Linear(10, 10000), torch.nn.ReLU()]
                    + [torch.nn.Linear(10000, 5), torch.nn.ReLU()]
                )

            def forward(self, x):
                return self.net(x)

        model = ToyModel()
        self._forward_hook_test_helper(model)

    def test_hooks_allowed_modules_compiles(self):
        class ToyModel(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.net = torch.nn.Sequential(
                    *[torch.nn.Linear(10, 10000), torch.nn.ReLU()]
                    + [torch.nn.Linear(10000, 5), torch.nn.ReLU()]
                )

            def forward(self, x):
                return self.net(x)

        model = ToyModel()
        activations = []

        def save_activations(mod, inp, out):
            activations.append(inp)

        for name, module in model.named_modules():
            module.register_forward_hook(save_activations)

        cnt = torch._dynamo.testing.CompileCounter()
        model = torch._dynamo.optimize(cnt, nopython=True)(model)
        for i in range(2):
            # second iteration is key, hooks would have fired during aot trace
            # on first iter
            activations.clear()
            x = torch.randn((20, 10))
            pred = model(x)
            loss = pred.sum()
            loss.backward()
        self.assertEqual(len(activations), 6)
        self.assertEqual(cnt.frame_count, 1)

    def test_hooks_allowed_modules_compiles_self_contained(self):
        class ToyModel(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.net = torch.nn.Sequential(
                    *[torch.nn.Linear(10, 10000), torch.nn.ReLU()]
                    + [torch.nn.Linear(10000, 5), torch.nn.ReLU()]
                )

            def forward(self, x):
                return self.net(x) * self.net(x)

        model = ToyModel()
        forward_handles = {}

        def output_modifying_hook(mod, inp, out):
            return 2 * out + 1

        for name, module in model.named_modules():
            forward_handles[name] = module.register_forward_hook(output_modifying_hook)

        cnt = torch._dynamo.testing.CompileCounter()

        x = torch.randn((20, 10))
        pred_eager = model(x)
        loss_eager = pred_eager.sum()
        eager_loss_bwd = loss_eager.backward()

        model = torch._dynamo.optimize(cnt, nopython=True)(model)
        pred = model(x)

        loss = pred.sum()
        loss_bwd = loss.backward()

        self.assertEqual(eager_loss_bwd, loss_bwd)
        self.assertEqual(cnt.frame_count, 2)

        # Ndim change, recompile
        pred = model(torch.randn([10, 10, 10]))
        self.assertEqual(cnt.frame_count, 4)

        # Stable
        pred = model(torch.randn([10, 10, 10]))
        self.assertEqual(cnt.frame_count, 4)

    def test_dunder_call_explicitly(self):
        # hooks should be triggered if explicit calling `__call__`
        class ToyModel(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(10, 10000)

            def forward(self, x):
                return self.linear.__call__(x)

        model = ToyModel()
        self._forward_hook_test_helper(model)

    def test_backward_hooks(self):
        # this test shouldn't care whether hook guards are enabled or not

        class CustomLinear(torch.nn.Module):
            # not an 'allowed module', so should not graph-break
            def __init__(self, a, b):
                super().__init__()
                self.weight = torch.nn.Parameter(torch.randn(a, b))

            def forward(self, x):
                return torch.mm(x, self.weight)

        class ToyModel(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.net = torch.nn.Sequential(
                    *[CustomLinear(10, 10)]
                    + [CustomLinear(10, 10000)]
                    + [CustomLinear(10000, 5)]
                )

            def forward(self, x):
                return self.net(x)

        model = ToyModel()
        backward_hook_handles = {}
        pre_backward_hook_handles = {}

        grad_sizes = {}

        def backward_hook(name, mod, grad_inp, grad_out):
            grad_sizes[name] = (
                (gi.shape for gi in grad_inp),
                (go.shape for go in grad_out),
            )
            return None

        pre_grad_sizes = {}

        def backward_pre_hook(name, mod, grad_out):
            pre_grad_sizes[name] = (go.shape for go in grad_out)
            return None

        for name, module in model.named_modules():
            backward_hook_handles[name] = module.register_full_backward_hook(
                partial(backward_hook, name)
            )

            pre_backward_hook_handles[name] = module.register_full_backward_pre_hook(
                partial(backward_pre_hook, name)
            )

        model = torch.compile(model, backend="aot_eager")

        for i in range(2):
            # second iteration is key, hooks would have fired during aot trace
            # on first iter
            x = torch.randn((20, 10))
            pred = model(x)
            loss = pred.sum()
            loss.backward()

        self.assertTrue(grad_sizes.keys() == backward_hook_handles.keys())
        self.assertTrue(pre_grad_sizes.keys() == pre_backward_hook_handles.keys())

    def test_udo_instance_method_as_hook(self):
        class CustomClass:
            def __init__(self, module):
                self.module = module
                self.handle = self.module.register_forward_pre_hook(
                    self.func1, prepend=True, with_kwargs=True
                )

            def func1(self, module, args, kwargs):
                return (args[0] + 1,), kwargs

            def __call__(self, x):
                return self.module(x)

        class ToyModel(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()

            def forward(self, x):
                return x * x

        model = ToyModel()
        x = torch.zeros((3, 4))
        obj = CustomClass(model)
        out = torch.compile(obj, fullgraph=True)(x)
        self.assertEqual(out, (x + 1) * (x + 1))

    def test_module_dict_iter_name(self):
        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.activations = torch.nn.ModuleDict(
                    [["lrelu", torch.nn.LeakyReLU()], ["prelu", torch.nn.PReLU()]]
                )

            def forward(self, x):
                for activation_name in self.activations:
                    x = self.activations[activation_name](x)
                return x

        cnt = torch._dynamo.testing.CompileCounter()
        # Eager
        eager_res = MyModule()(torch.ones(10, 10))

        # Compile
        optim_res = torch._dynamo.optimize(cnt)(MyModule())(torch.ones(10, 10))
        self.assertEqual(eager_res, optim_res)
        self.assertEqual(cnt.frame_count, 1)

    def test_module_dict_iter_keys(self):
        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.activations = torch.nn.ModuleDict(
                    [["lrelu", torch.nn.LeakyReLU()], ["prelu", torch.nn.PReLU()]]
                )

            def forward(self, x):
                for activation_name in self.activations.keys():
                    x = self.activations[activation_name](x)
                return x

        cnt = torch._dynamo.testing.CompileCounter()
        # Eager
        eager_res = MyModule()(torch.ones(10, 10))

        # Compile
        optim_res = torch._dynamo.optimize(cnt)(MyModule())(torch.ones(10, 10))
        self.assertEqual(eager_res, optim_res)
        self.assertEqual(cnt.frame_count, 1)

    def test_module_setattr(self):
        models = torch.nn.Sequential(torch.nn.Linear(3, 3))
        models[0].abc = False

        def run():
            models[0].abc = True
            x = torch.randn(1, 3)
            return models(x)

        run = torch.compile(run, fullgraph=True)
        run()
        self.assertTrue(models[0].abc)

    def test_assign_does_not_exist(self):
        class MyModule(torch.nn.Module):
            def forward(self, x):
                self.text_encoding = x + 1
                return self.text_encoding

        mod = MyModule()
        out = torch.compile(mod, fullgraph=True)(torch.randn(10))
        assert mod.text_encoding is out

    def test_module_dict_iter_values(self):
        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.activations = torch.nn.ModuleDict(
                    [["lrelu", torch.nn.LeakyReLU()], ["prelu", torch.nn.PReLU()]]
                )

            def forward(self, x):
                for activation in self.activations.values():
                    x = activation(x)
                return x

        cnt = torch._dynamo.testing.CompileCounter()
        # Eager
        eager_res = MyModule()(torch.ones(10, 10))

        # Compile
        optim_res = torch._dynamo.optimize(cnt)(MyModule())(torch.ones(10, 10))
        self.assertEqual(eager_res, optim_res)
        self.assertEqual(cnt.frame_count, 1)

    def test_unspecialized_seq(self):
        models = torch.nn.Sequential(torch.nn.Linear(3, 3))

        def fn(x):
            models[0].training = False
            return models(x)

        opt_fn = torch._dynamo.optimize("eager")(fn)
        x = torch.randn(1, 3)
        ref = fn(x)
        res = opt_fn(x)
        self.assertEqual(ref, res)

    def test_no_op_assignment(self):
        class Mod(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.buffer = torch.rand([4])

            def forward(self, x):
                # should be a no-op, but causes dynamo to lose the static input
                x = x + 1
                self.buffer = self.buffer.to(x)
                return self.buffer + x

        compiles_without_buffers = 0

        def debug_compile(gm, *args, **kwargs):
            nonlocal compiles_without_buffers
            compiles_without_buffers += len(list(gm.buffers())) == 0
            return gm

        @torch.compile(backend=debug_compile)
        def foo(mod, x):
            return mod(x)

        mod = Mod()
        foo(mod, torch.rand([4]))
        if torch._dynamo.config.inline_inbuilt_nn_modules:
            self.assertEqual(compiles_without_buffers, 1)
        else:
            self.assertEqual(compiles_without_buffers, 0)

        foo(mod, torch.rand([4], dtype=torch.half))
        if torch._dynamo.config.inline_inbuilt_nn_modules:
            self.assertEqual(compiles_without_buffers, 2)
        else:
            self.assertEqual(compiles_without_buffers, 1)

        class Mod2(Mod):
            def __setattr__(self, name, value):
                return super().__setattr__(name, value)

        foo(Mod2(), torch.rand([4]))
        # causes two compilations, bc unimplemented custom setattr
        self.assertTrue(compiles_without_buffers >= 2)

    def test_unspec_non_inlinable_module(self):
        mod = UnspecNonInlinableModule()
        opt_fn = torch._dynamo.optimize("eager")(mod)
        x = torch.randn(100)
        actual = opt_fn(x)
        expected = mod(x)
        self.assertEqual(actual, expected)

    @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
    def test_mark_static_previously_seen_tensor(self):
        # This test verifies that dynamo will mark
        # the buffers/params of a module as static
        # even if this param was previously seen
        # (ex. as a different input)
        num_compiles = 0

        def debug_compiler(gm, _):
            nonlocal num_compiles
            num_compiles += 1

            input_nodes = [
                n for n in gm.graph.nodes if n.op == "placeholder" and n.name == "l_b_"
            ]

            self.assertGreater(len(input_nodes), 0)
            for input_node in input_nodes:
                self.assertEqual(
                    input_node.meta["tensor_dict"]["_dynamo_static_input_type"],
                    "unguarded",
                )

            return gm

        class TestModule(torch.nn.Module):
            def __init__(self, buf) -> None:
                super().__init__()
                # Changing this one to nn.Buffer fails because `nn.Buffer` does a .detach()
                # so the value in self.tx.output.side_effects will no longer evaluate to True
                self.register_buffer("buf", buf)

            def forward(self, x):
                return self.buf * x

        @torch._dynamo.optimize(backend=debug_compiler)
        def fn(x, b, mod):
            z = b + 1
            return z * mod(x)

        buf = torch.ones(2, 2)
        inp = torch.ones(2)
        mod = TestModule(buf)
        fn(inp, buf, mod)
        self.assertEqual(num_compiles, 1)

    @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
    def test_mark_static_nn_module_tensor(self):
        # This test verifies that dynamo will mark
        # the nn module tensor attributes as static
        num_compiles = 0

        def debug_compiler(gm, _):
            nonlocal num_compiles
            num_compiles += 1

            input_nodes = [
                n
                for n in gm.graph.nodes
                if n.op == "placeholder" and n.name == "l_mod_buf"
            ]

            self.assertGreater(len(input_nodes), 0)
            for input_node in input_nodes:
                self.assertEqual(
                    input_node.meta["tensor_dict"]["_dynamo_static_input_type"],
                    "unguarded",
                )

            return gm

        class TestModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.buf = torch.ones(2, 2)

            def forward(self, x):
                return self.buf * x

        mod = TestModule()

        @torch._dynamo.optimize(backend=debug_compiler)
        def fn(x):
            return x * mod(x)

        inp = torch.ones(2)
        fn(inp)
        self.assertEqual(num_compiles, 1)

    @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
    @torch._inductor.config.patch("freezing", True)
    @torch.no_grad()
    def test_mark_static_with_freezing(self):
        # This test verifies that dynamo will
        # add buffers/params as attributes of the
        # graph w/ guards if freezing is enabled
        num_compiles = 0

        def debug_compiler(gm, _):
            nonlocal num_compiles
            num_compiles += 1

            input_nodes = [
                n for n in gm.graph.nodes if n.op == "placeholder" and n.name == "l_b_"
            ]
            self.assertEqual(len(input_nodes), 0)
            self.assertEqual(len(list(gm.buffers())), 1)
            return gm

        class TestModule(torch.nn.Module):
            def __init__(self, buf) -> None:
                super().__init__()
                self.buf = torch.nn.Buffer(buf)

            def forward(self, x):
                return self.buf * x

        @torch._dynamo.optimize(backend=debug_compiler)
        def fn(x, mod):
            return mod(x)

        buf = torch.ones(2, 2)
        inp = torch.ones(2)
        mod = TestModule(buf)
        fn(inp, mod)
        self.assertEqual(num_compiles, 1)
        mod.buf = torch.rand_like(buf)
        fn(inp, mod)
        self.assertEqual(num_compiles, 2)

    @patch.object(torch._dynamo.config, "guard_nn_modules", True)
    def test_guard_on_torch_nn_modules(self):
        # https://github.com/pytorch/pytorch/issues/110048

        class MockModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(10, 10)
                self.multiplier = 10

            def forward(self, x):
                return self.linear(x) * self.multiplier

        mod = MockModule()

        cnt = torch._dynamo.testing.CompileCounter()

        @torch.compile(backend=cnt)
        def generate(x, c):
            return mod(x) + c

        for _ in range(0, 10):
            generate(torch.randn(10, 10), 0)
            generate(torch.randn(10, 10), 1)
        self.assertEqual(cnt.frame_count, 2)

        # Ensure that modification in user module causes recompile
        mod.multiplier = 11
        generate(torch.randn(10, 10), 0)
        self.assertEqual(cnt.frame_count, 3)

    def test_setattr_on_compiled_module(self):
        # https://github.com/pytorch/pytorch/issues/114844

        class ReplayMutation(torch.nn.Module):
            def __init__(self, inp_size, out_size, inner_size):
                super().__init__()
                self.Linear1 = torch.nn.Linear(inp_size, inner_size)
                self.Linear2 = torch.nn.Linear(inner_size, out_size)
                self.x = None

            def forward(self, inp):
                res = self.Linear1(inp)
                self.x = res
                return self.Linear2(res)

        N, D_in, H, D_out, inner = 2, 2, 2, 2, 4
        model = ReplayMutation(D_in, H, inner)
        model2 = copy.deepcopy(model)
        input = torch.ones(N, D_in)

        # Keep some intermediate value in model.x
        model.x = torch.tensor([[100, 100, 100, 100], [200, 200, 200, 200]])
        model(input)

        compiled_model = torch.compile(model2, backend="eager")
        compiled_model.x = torch.tensor([[100, 100, 100, 100], [200, 200, 200, 200]])
        compiled_model(input)

        self.assertEqual(model.x, compiled_model.x)

    def test_globals_change_in_other_file(self):
        @torch.compile(backend="eager", fullgraph=True)
        def fn(x):
            update_global()
            a = test_functions.update_global(x)
            # Ensure that the updated global values are read
            return x * a * (_variable + _variable1 + test_functions._variable)

        res = fn(torch.ones(10))
        self.assertEqual(_variable, 1)
        self.assertEqual(_variable1, 1)
        # Ensure that the reconstructed bytecode updates the global value in the
        # other file.
        self.assertEqual(test_functions._variable, 1)
        self.assertEqual(res, 3 * torch.ones(10))

    @unittest.skipIf(
        "inductor" not in torch._dynamo.list_backends(),
        "inductor backend is not available",
    )
    def test_save_and_load_inductor(self):
        mod = MockModule()
        opt_mod = torch.compile(mod, backend="inductor")
        inp = torch.randn(10, 10)
        opt_mod(inp)

        with tempfile.TemporaryDirectory() as tmpdirname:
            torch.save(opt_mod, os.path.join(tmpdirname, "model.pt"))
            loaded_model = torch.load(os.path.join(tmpdirname, "model.pt"))
        loaded_model(inp)
        self.assertTrue(same_two_models(loaded_model, mod, [inp]))
        self.assertTrue(same_two_models(loaded_model, opt_mod, [inp]))

        torch._dynamo.reset()  # force recompiles
        torch._inductor.metrics.generated_kernel_count = 0
        loaded_model(inp)
        self.assertGreater(torch._inductor.metrics.generated_kernel_count, 0)

    def test_save_and_load_all_backends(self):
        mod = MockModule()
        inp = torch.randn(10, 10)
        for backend in torch._dynamo.list_backends():
            try:
                opt_mod = torch.compile(mod, backend=backend)
                with tempfile.TemporaryDirectory() as tmpdirname:
                    torch.save(opt_mod, os.path.join(tmpdirname, "model.pt"))
                    loaded_model = torch.load(os.path.join(tmpdirname, "model.pt"))
                torch._dynamo.reset()  # force recompiles
                torch._inductor.metrics.generated_kernel_count = 0
                opt_mod(inp)
                opt_success = torch._inductor.metrics.generated_kernel_count == 0
                torch._dynamo.reset()  # force recompiles
                torch._inductor.metrics.generated_kernel_count = 0
                loaded_model(inp)
                loaded_success = torch._inductor.metrics.generated_kernel_count == 0
                self.assertEqual(opt_success, loaded_success)
            except torch._dynamo.exc.BackendCompilerFailed:
                pass

    def test_monkeypatching_forward(self):
        class FakeModule(torch.nn.Module):
            def forward(self, x):
                return torch.sin(x)

        class MyModule(torch.nn.Module):
            def __init__(self, x):
                super().__init__()

            def forward(self, x):
                return torch.cos(x)

        def helper():
            torch._dynamo.reset()
            mod = MyModule(3)

            def fn(x):
                return mod(x)

            cnt = torch._dynamo.testing.CompileCounter()
            opt_fn = torch._dynamo.optimize(cnt)(fn)
            x = torch.randn(10)

            opt_fn(x)
            opt_fn(x)
            self.assertEqual(cnt.frame_count, 1)

            # Monkeypatch forward
            mod.forward = types.MethodType(FakeModule.forward, mod)
            ref = fn(x)
            res = opt_fn(x)
            self.assertEqual(ref, res)
            self.assertEqual(cnt.frame_count, 2)

        helper()
        with torch._dynamo.config.patch(inline_inbuilt_nn_modules=True):
            helper()

    def test_user_defined_nn_module_dynamic(self):
        class Conv2d(torch.nn.Conv2d):
            def __init__(self, *args, **kwargs):
                super().__init__(*args, **kwargs)

            def forward(self, x):
                x = torch.nn.functional.conv2d(
                    x,
                    self.weight,
                    self.bias,
                    self.stride,
                    self.padding,
                    self.dilation,
                    self.groups,
                )
                return x

        cnts = torch._dynamo.testing.CompileCounter()
        mod1 = Conv2d(64, 64, kernel_size=(2, 2), stride=(1, 1))
        mod2 = Conv2d(64, 64, kernel_size=(2, 2), stride=(2, 2))
        mod3 = Conv2d(64, 64, kernel_size=(2, 2), stride=(3, 3))

        opt_mod1 = torch.compile(mod1, backend=cnts, fullgraph=True)
        opt_mod2 = torch.compile(mod2, backend=cnts, fullgraph=True)
        opt_mod3 = torch.compile(mod3, backend=cnts, fullgraph=True)

        x = torch.randn(1, 64, 64, 64)
        opt_mod1(x)
        opt_mod2(x)
        opt_mod3(x)

        # Must be 3 compilations. If not marked static there would be 2, because strides would be converted to symints.
        self.assertEqual(cnts.frame_count, 3)


if __name__ == "__main__":
    from torch._dynamo.test_case import run_tests

    run_tests()