1# mypy: allow-untyped-defs 2# Functions for synthesizing magic methods for JIT-compiled dataclasses 3import ast 4import dataclasses 5import inspect 6import os 7from functools import partial 8from typing import Callable, Dict, List 9 10from torch._jit_internal import FAKE_FILENAME_PREFIX, is_optional 11from torch._sources import ParsedDef, SourceContext 12 13 14def _get_fake_filename(cls, method_name): 15 return os.path.join(FAKE_FILENAME_PREFIX, cls.__name__, method_name) 16 17 18def compose_fn(cls, name: str, body_lines: List[str], signature: str) -> ParsedDef: 19 body = "\n".join(f" {b}" for b in body_lines) 20 decl = f"def {name}{signature}:\n{body}" 21 22 # Parse the function declaration 23 try: 24 py_ast = ast.parse(decl) 25 except SyntaxError as e: 26 # This should only happen if there's some unforeseeable change 27 # in the dataclasses module that makes our synthesized code fail 28 raise RuntimeError( 29 f"TorchScript failed to synthesize dataclass method '{name}' for class '{cls.__name__}'. " 30 "Please file a bug report at <https://github.com/pytorch/pytorch/issues>" 31 ) from e 32 fake_filename = _get_fake_filename(cls, name) 33 # Parse the function 34 return ParsedDef( 35 py_ast, 36 ctx=SourceContext( 37 source=decl, filename=fake_filename, file_lineno=0, leading_whitespace_len=0 38 ), 39 source=decl, 40 filename=fake_filename, 41 file_lineno=0, 42 ) 43 44 45def synthesize__init__(cls) -> ParsedDef: 46 # Supporting default factories in the way that people expect would sort of require us to 47 # allow compiling lambda functions, which is not currently supported. 48 if any( 49 field.default_factory is not dataclasses.MISSING 50 for field in dataclasses.fields(cls) 51 ): 52 raise NotImplementedError( 53 "Default factory initializers are not supported in TorchScript dataclasses" 54 ) 55 56 # Simply read off the generated __init__ signature from CPython's implementation. It'll be 57 # almost correct except for InitVar annotations, which we need to handle specially. 58 signature = inspect.signature(cls.__init__) 59 60 # Handle InitVars if needed (only works on Python 3.8+, when a `type` attribute was added to InitVar); 61 # see CPython commit here https://github.com/python/cpython/commit/01ee12ba35a333e8a6a25c4153c4a21838e9585c 62 init_vars: List[str] = [] 63 params = [] 64 for name, param in signature.parameters.items(): 65 ann = param.annotation 66 67 if isinstance(ann, dataclasses.InitVar): 68 # The TorchScript interpreter can't handle InitVar annotations, so we unwrap the underlying type here 69 init_vars.append(name) 70 params.append(param.replace(annotation=ann.type)) # type: ignore[attr-defined] 71 else: 72 params.append(param) 73 74 signature = signature.replace(parameters=params) 75 76 body = [ 77 # Assign all attributes to self 78 f"self.{field.name} = {field.name}" 79 for field in dataclasses.fields(cls) 80 if field.init and field.name not in init_vars 81 ] 82 # Call user's impl of __post_init__ if it exists 83 if hasattr(cls, "__post_init__"): 84 body.append("self.__post_init__(" + ", ".join(init_vars) + ")") 85 86 return compose_fn(cls, "__init__", body or ["pass"], signature=str(signature)) 87 88 89# This is a placeholder at the moment since the TorchScript interpreter doesn't call __repr__ 90def synthesize__repr__(cls) -> ParsedDef: 91 return compose_fn( 92 cls, 93 "__repr__", 94 [ 95 f"return '{cls.__name__}(" 96 + ", ".join( 97 [ 98 f"{field.name}=self.{field.name}" 99 for field in dataclasses.fields(cls) 100 if field.repr 101 ] 102 ) 103 + ")'" 104 ], 105 signature="(self) -> str", 106 ) 107 108 109def synthesize__hash__(cls) -> ParsedDef: 110 return compose_fn( 111 cls, 112 "__hash__", 113 [ 114 # This is just a placeholder to prevent compilation from failing; this won't even get called at 115 # all right now because the TorchScript interpreter doesn't call custom __hash__ implementations 116 "raise NotImplementedError('__hash__ is not supported for dataclasses in TorchScript')" 117 ], 118 signature="(self) -> int", 119 ) 120 121 122# Implementation for __eq__ and __ne__ 123def synthesize_equality(cls, name: str, converse: str) -> ParsedDef: 124 return synthesize_comparison( 125 cls, 126 name, 127 allow_eq=True, 128 raise_on_none=False, 129 inner=[f"if val1 {converse} val2: return False"], 130 ) 131 132 133def synthesize_inequality(cls, name: str, op: str, allow_eq: bool) -> ParsedDef: 134 return synthesize_comparison( 135 cls, 136 name, 137 allow_eq, 138 raise_on_none=True, 139 inner=[ 140 f"if val1 {op} val2: return True", 141 f"elif val2 {op} val1: return False", 142 ], 143 ) 144 145 146def synthesize_comparison( 147 cls, name: str, allow_eq: bool, raise_on_none: bool, inner: List[str] 148) -> ParsedDef: 149 body = [] 150 for field in dataclasses.fields(cls): 151 if not field.compare: 152 continue 153 154 body.extend( 155 [ 156 f"val1 = self.{field.name}", 157 f"val2 = other.{field.name}", 158 ] 159 ) 160 body.extend( 161 inner 162 if not is_optional(field.type) 163 else [ 164 # Type refinement for optional fields; we need this to avoid type errors from the interpreter 165 "if val1 is not None and val2 is not None:", 166 *[" " + line for line in inner], 167 "elif (val1 is None) != (val2 is None):", 168 f" raise TypeError('Cannot compare {cls.__name__} with None')" 169 if raise_on_none 170 else " return False", 171 ] 172 ) 173 174 body.append(f"return {allow_eq}") 175 return compose_fn( 176 cls, name, body, signature=f"(self, other: {cls.__name__}) -> bool" 177 ) 178 179 180DATACLASS_MAGIC_METHODS: Dict[str, Callable] = { 181 "__init__": synthesize__init__, 182 "__repr__": synthesize__repr__, 183 "__hash__": synthesize__hash__, 184 "__eq__": partial(synthesize_equality, name="__eq__", converse="!="), 185 "__ne__": partial(synthesize_equality, name="__ne__", converse="=="), 186 "__lt__": partial(synthesize_inequality, name="__lt__", op="<", allow_eq=False), 187 "__le__": partial(synthesize_inequality, name="__le__", op="<", allow_eq=True), 188 "__gt__": partial(synthesize_inequality, name="__gt__", op=">", allow_eq=False), 189 "__ge__": partial(synthesize_inequality, name="__ge__", op=">", allow_eq=True), 190} 191