xref: /aosp_15_r20/external/pytorch/torch/jit/_dataclass_impls.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Functions for synthesizing magic methods for JIT-compiled dataclasses
3import ast
4import dataclasses
5import inspect
6import os
7from functools import partial
8from typing import Callable, Dict, List
9
10from torch._jit_internal import FAKE_FILENAME_PREFIX, is_optional
11from torch._sources import ParsedDef, SourceContext
12
13
14def _get_fake_filename(cls, method_name):
15    return os.path.join(FAKE_FILENAME_PREFIX, cls.__name__, method_name)
16
17
18def compose_fn(cls, name: str, body_lines: List[str], signature: str) -> ParsedDef:
19    body = "\n".join(f"  {b}" for b in body_lines)
20    decl = f"def {name}{signature}:\n{body}"
21
22    # Parse the function declaration
23    try:
24        py_ast = ast.parse(decl)
25    except SyntaxError as e:
26        # This should only happen if there's some unforeseeable change
27        # in the dataclasses module that makes our synthesized code fail
28        raise RuntimeError(
29            f"TorchScript failed to synthesize dataclass method '{name}' for class '{cls.__name__}'. "
30            "Please file a bug report at <https://github.com/pytorch/pytorch/issues>"
31        ) from e
32    fake_filename = _get_fake_filename(cls, name)
33    # Parse the function
34    return ParsedDef(
35        py_ast,
36        ctx=SourceContext(
37            source=decl, filename=fake_filename, file_lineno=0, leading_whitespace_len=0
38        ),
39        source=decl,
40        filename=fake_filename,
41        file_lineno=0,
42    )
43
44
45def synthesize__init__(cls) -> ParsedDef:
46    # Supporting default factories in the way that people expect would sort of require us to
47    # allow compiling lambda functions, which is not currently supported.
48    if any(
49        field.default_factory is not dataclasses.MISSING
50        for field in dataclasses.fields(cls)
51    ):
52        raise NotImplementedError(
53            "Default factory initializers are not supported in TorchScript dataclasses"
54        )
55
56    # Simply read off the generated __init__ signature from CPython's implementation. It'll be
57    # almost correct except for InitVar annotations, which we need to handle specially.
58    signature = inspect.signature(cls.__init__)
59
60    # Handle InitVars if needed (only works on Python 3.8+, when a `type` attribute was added to InitVar);
61    # see CPython commit here https://github.com/python/cpython/commit/01ee12ba35a333e8a6a25c4153c4a21838e9585c
62    init_vars: List[str] = []
63    params = []
64    for name, param in signature.parameters.items():
65        ann = param.annotation
66
67        if isinstance(ann, dataclasses.InitVar):
68            # The TorchScript interpreter can't handle InitVar annotations, so we unwrap the underlying type here
69            init_vars.append(name)
70            params.append(param.replace(annotation=ann.type))  # type: ignore[attr-defined]
71        else:
72            params.append(param)
73
74    signature = signature.replace(parameters=params)
75
76    body = [
77        # Assign all attributes to self
78        f"self.{field.name} = {field.name}"
79        for field in dataclasses.fields(cls)
80        if field.init and field.name not in init_vars
81    ]
82    # Call user's impl of __post_init__ if it exists
83    if hasattr(cls, "__post_init__"):
84        body.append("self.__post_init__(" + ", ".join(init_vars) + ")")
85
86    return compose_fn(cls, "__init__", body or ["pass"], signature=str(signature))
87
88
89# This is a placeholder at the moment since the TorchScript interpreter doesn't call __repr__
90def synthesize__repr__(cls) -> ParsedDef:
91    return compose_fn(
92        cls,
93        "__repr__",
94        [
95            f"return '{cls.__name__}("
96            + ", ".join(
97                [
98                    f"{field.name}=self.{field.name}"
99                    for field in dataclasses.fields(cls)
100                    if field.repr
101                ]
102            )
103            + ")'"
104        ],
105        signature="(self) -> str",
106    )
107
108
109def synthesize__hash__(cls) -> ParsedDef:
110    return compose_fn(
111        cls,
112        "__hash__",
113        [
114            # This is just a placeholder to prevent compilation from failing; this won't even get called at
115            # all right now because the TorchScript interpreter doesn't call custom __hash__ implementations
116            "raise NotImplementedError('__hash__ is not supported for dataclasses in TorchScript')"
117        ],
118        signature="(self) -> int",
119    )
120
121
122# Implementation for __eq__ and __ne__
123def synthesize_equality(cls, name: str, converse: str) -> ParsedDef:
124    return synthesize_comparison(
125        cls,
126        name,
127        allow_eq=True,
128        raise_on_none=False,
129        inner=[f"if val1 {converse} val2: return False"],
130    )
131
132
133def synthesize_inequality(cls, name: str, op: str, allow_eq: bool) -> ParsedDef:
134    return synthesize_comparison(
135        cls,
136        name,
137        allow_eq,
138        raise_on_none=True,
139        inner=[
140            f"if val1 {op} val2: return True",
141            f"elif val2 {op} val1: return False",
142        ],
143    )
144
145
146def synthesize_comparison(
147    cls, name: str, allow_eq: bool, raise_on_none: bool, inner: List[str]
148) -> ParsedDef:
149    body = []
150    for field in dataclasses.fields(cls):
151        if not field.compare:
152            continue
153
154        body.extend(
155            [
156                f"val1 = self.{field.name}",
157                f"val2 = other.{field.name}",
158            ]
159        )
160        body.extend(
161            inner
162            if not is_optional(field.type)
163            else [
164                # Type refinement for optional fields; we need this to avoid type errors from the interpreter
165                "if val1 is not None and val2 is not None:",
166                *["  " + line for line in inner],
167                "elif (val1 is None) != (val2 is None):",
168                f"  raise TypeError('Cannot compare {cls.__name__} with None')"
169                if raise_on_none
170                else "  return False",
171            ]
172        )
173
174    body.append(f"return {allow_eq}")
175    return compose_fn(
176        cls, name, body, signature=f"(self, other: {cls.__name__}) -> bool"
177    )
178
179
180DATACLASS_MAGIC_METHODS: Dict[str, Callable] = {
181    "__init__": synthesize__init__,
182    "__repr__": synthesize__repr__,
183    "__hash__": synthesize__hash__,
184    "__eq__": partial(synthesize_equality, name="__eq__", converse="!="),
185    "__ne__": partial(synthesize_equality, name="__ne__", converse="=="),
186    "__lt__": partial(synthesize_inequality, name="__lt__", op="<", allow_eq=False),
187    "__le__": partial(synthesize_inequality, name="__le__", op="<", allow_eq=True),
188    "__gt__": partial(synthesize_inequality, name="__gt__", op=">", allow_eq=False),
189    "__ge__": partial(synthesize_inequality, name="__ge__", op=">", allow_eq=True),
190}
191