1# mypy: allow-untyped-defs 2import ast 3import builtins 4import dis 5import enum 6import inspect 7import re 8import typing 9import warnings 10from textwrap import dedent 11from typing import Type 12 13import torch 14from torch._C import ( 15 _GeneratorType, 16 AnyType, 17 AwaitType, 18 BoolType, 19 ComplexType, 20 DeviceObjType, 21 DictType, 22 EnumType, 23 FloatType, 24 FutureType, 25 InterfaceType, 26 IntType, 27 ListType, 28 NoneType, 29 NumberType, 30 OptionalType, 31 StreamObjType, 32 StringType, 33 TensorType, 34 TupleType, 35 UnionType, 36) 37from torch._jit_internal import ( # type: ignore[attr-defined] 38 _Await, 39 _qualified_name, 40 Any, 41 BroadcastingList1, 42 BroadcastingList2, 43 BroadcastingList3, 44 Dict, 45 Future, 46 is_await, 47 is_dict, 48 is_future, 49 is_ignored_fn, 50 is_list, 51 is_optional, 52 is_tuple, 53 is_union, 54 List, 55 Optional, 56 Tuple, 57 Union, 58) 59from torch._sources import get_source_lines_and_file 60 61from ._state import _get_script_class 62 63 64if torch.distributed.rpc.is_available(): 65 from torch._C import RRefType 66 from torch._jit_internal import is_rref, RRef 67 68from torch._ops import OpOverloadPacket 69 70 71class Module: 72 def __init__(self, name, members): 73 self.name = name 74 self.members = members 75 76 def __getattr__(self, name): 77 try: 78 return self.members[name] 79 except KeyError: 80 raise RuntimeError( 81 f"Module {self.name} has no member called {name}" 82 ) from None 83 84 85class EvalEnv: 86 env = { 87 "torch": Module("torch", {"Tensor": torch.Tensor}), 88 "Tensor": torch.Tensor, 89 "typing": Module("typing", {"Tuple": Tuple}), 90 "Tuple": Tuple, 91 "List": List, 92 "Dict": Dict, 93 "Optional": Optional, 94 "Union": Union, 95 "Future": Future, 96 "Await": _Await, 97 } 98 99 def __init__(self, rcb): 100 self.rcb = rcb 101 if torch.distributed.rpc.is_available(): 102 self.env["RRef"] = RRef 103 104 def __getitem__(self, name): 105 if name in self.env: 106 return self.env[name] 107 if self.rcb is not None: 108 return self.rcb(name) 109 return getattr(builtins, name, None) 110 111 112def get_signature(fn, rcb, loc, is_method): 113 if isinstance(fn, OpOverloadPacket): 114 signature = try_real_annotations(fn.op, loc) 115 else: 116 signature = try_real_annotations(fn, loc) 117 if signature is not None and is_method: 118 # If this is a method, then the signature will include a type for 119 # `self`, but type comments do not contain a `self`. So strip it 120 # away here so everything is consistent (`inspect.ismethod` does 121 # not work here since `fn` is unbound at this point) 122 param_types, return_type = signature 123 param_types = param_types[1:] 124 signature = (param_types, return_type) 125 126 if signature is None: 127 type_line, source = None, None 128 try: 129 source = dedent("".join(get_source_lines_and_file(fn)[0])) 130 type_line = get_type_line(source) 131 except TypeError: 132 pass 133 # This might happen both because we failed to get the source of fn, or 134 # because it didn't have any annotations. 135 if type_line is not None: 136 signature = parse_type_line(type_line, rcb, loc) 137 138 return signature 139 140 141def is_function_or_method(the_callable): 142 # A stricter version of `inspect.isroutine` that does not pass for built-in 143 # functions 144 return inspect.isfunction(the_callable) or inspect.ismethod(the_callable) 145 146 147def is_vararg(the_callable): 148 if not is_function_or_method(the_callable) and callable(the_callable): # noqa: B004 149 # If `the_callable` is a class, de-sugar the call so we can still get 150 # the signature 151 the_callable = the_callable.__call__ 152 153 if is_function_or_method(the_callable): 154 return inspect.getfullargspec(the_callable).varargs is not None 155 else: 156 return False 157 158 159def get_param_names(fn, n_args): 160 if isinstance(fn, OpOverloadPacket): 161 fn = fn.op 162 163 if ( 164 not is_function_or_method(fn) 165 and callable(fn) 166 and is_function_or_method(fn.__call__) 167 ): # noqa: B004 168 # De-sugar calls to classes 169 fn = fn.__call__ 170 171 if is_function_or_method(fn): 172 if is_ignored_fn(fn): 173 fn = inspect.unwrap(fn) 174 return inspect.getfullargspec(fn).args 175 else: 176 # The `fn` was not a method or function (maybe a class with a __call__ 177 # method, so use a default param name list) 178 return [str(i) for i in range(n_args)] 179 180 181def check_fn(fn, loc): 182 # Make sure the function definition is not a class instantiation 183 try: 184 source = dedent("".join(get_source_lines_and_file(fn)[0])) 185 except (OSError, TypeError): 186 return 187 if source is None: 188 return 189 190 py_ast = ast.parse(source) 191 if len(py_ast.body) == 1 and isinstance(py_ast.body[0], ast.ClassDef): 192 raise torch.jit.frontend.FrontendError( 193 loc, 194 f"Cannot instantiate class '{py_ast.body[0].name}' in a script function", 195 ) 196 if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef): 197 raise torch.jit.frontend.FrontendError( 198 loc, "Expected a single top-level function" 199 ) 200 201 202def _eval_no_call(stmt, glob, loc): 203 """Evaluate statement as long as it does not contain any method/function calls.""" 204 bytecode = compile(stmt, "", mode="eval") 205 for insn in dis.get_instructions(bytecode): 206 if "CALL" in insn.opname: 207 raise RuntimeError( 208 f"Type annotation should not contain calls, but '{stmt}' does" 209 ) 210 return eval(bytecode, glob, loc) # type: ignore[arg-type] # noqa: P204 211 212 213def parse_type_line(type_line, rcb, loc): 214 """Parse a type annotation specified as a comment. 215 216 Example inputs: 217 # type: (Tensor, torch.Tensor) -> Tuple[Tensor] 218 # type: (Tensor, Tuple[Tensor, Tensor]) -> Tensor 219 """ 220 arg_ann_str, ret_ann_str = split_type_line(type_line) 221 222 try: 223 arg_ann = _eval_no_call(arg_ann_str, {}, EvalEnv(rcb)) 224 except (NameError, SyntaxError) as e: 225 raise RuntimeError( 226 "Failed to parse the argument list of a type annotation" 227 ) from e 228 229 if not isinstance(arg_ann, tuple): 230 arg_ann = (arg_ann,) 231 232 try: 233 ret_ann = _eval_no_call(ret_ann_str, {}, EvalEnv(rcb)) 234 except (NameError, SyntaxError) as e: 235 raise RuntimeError( 236 "Failed to parse the return type of a type annotation" 237 ) from e 238 239 arg_types = [ann_to_type(ann, loc) for ann in arg_ann] 240 return arg_types, ann_to_type(ret_ann, loc) 241 242 243def get_type_line(source): 244 """Try to find the line containing a comment with the type annotation.""" 245 type_comment = "# type:" 246 247 lines = source.split("\n") 248 lines = list(enumerate(lines)) 249 type_lines = list(filter(lambda line: type_comment in line[1], lines)) 250 # `type: ignore` comments may be needed in JIT'ed functions for mypy, due 251 # to the hack in torch/_VF.py. 252 253 # An ignore type comment can be of following format: 254 # 1) type: ignore 255 # 2) type: ignore[rule-code] 256 # This ignore statement must be at the end of the line 257 258 # adding an extra backslash before the space, to avoid triggering 259 # one of the checks in .github/workflows/lint.yml 260 type_pattern = re.compile("# type:\\ ignore(\\[[a-zA-Z-]+\\])?$") 261 type_lines = list(filter(lambda line: not type_pattern.search(line[1]), type_lines)) 262 263 if len(type_lines) == 0: 264 # Catch common typo patterns like extra spaces, typo in 'ignore', etc. 265 wrong_type_pattern = re.compile("#[\t ]*type[\t ]*(?!: ignore(\\[.*\\])?$):") 266 wrong_type_lines = list( 267 filter(lambda line: wrong_type_pattern.search(line[1]), lines) 268 ) 269 if len(wrong_type_lines) > 0: 270 raise RuntimeError( 271 "The annotation prefix in line " 272 + str(wrong_type_lines[0][0]) 273 + " is probably invalid.\nIt must be '# type:'" 274 + "\nSee PEP 484 (https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)" # noqa: B950 275 + "\nfor examples" 276 ) 277 return None 278 elif len(type_lines) == 1: 279 # Only 1 type line, quit now 280 return type_lines[0][1].strip() 281 282 # Parse split up argument types according to PEP 484 283 # https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code 284 return_line = None 285 parameter_type_lines = [] 286 for line_num, line in type_lines: 287 if "# type: (...) -> " in line: 288 return_line = (line_num, line) 289 break 290 elif type_comment in line: 291 parameter_type_lines.append(line) 292 if return_line is None: 293 raise RuntimeError( 294 "Return type line '# type: (...) -> ...' not found on multiline " 295 "type annotation\nfor type lines:\n" 296 + "\n".join([line[1] for line in type_lines]) 297 + "\n(See PEP 484 https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)" 298 ) 299 300 def get_parameter_type(line): 301 item_type = line[line.find(type_comment) + len(type_comment) :] 302 return item_type.strip() 303 304 types = map(get_parameter_type, parameter_type_lines) 305 parameter_types = ", ".join(types) 306 307 return return_line[1].replace("...", parameter_types) 308 309 310def split_type_line(type_line): 311 """Split the comment with the type annotation into parts for argument and return types. 312 313 For example, for an input of: 314 # type: (Tensor, torch.Tensor) -> Tuple[Tensor, Tensor] 315 316 This function will return: 317 ("(Tensor, torch.Tensor)", "Tuple[Tensor, Tensor]") 318 319 """ 320 start_offset = len("# type:") 321 try: 322 arrow_pos = type_line.index("->") 323 except ValueError: 324 raise RuntimeError( 325 "Syntax error in type annotation (couldn't find `->`)" 326 ) from None 327 return type_line[start_offset:arrow_pos].strip(), type_line[arrow_pos + 2 :].strip() 328 329 330def try_real_annotations(fn, loc): 331 """Try to use the Py3.5+ annotation syntax to get the type.""" 332 try: 333 # Note: anything annotated as `Optional[T]` will automatically 334 # be returned as `Union[T, None]` per 335 # https://github.com/python/typing/blob/master/src/typing.py#L850 336 sig = inspect.signature(fn) 337 except ValueError: 338 return None 339 340 all_annots = [sig.return_annotation] + [ 341 p.annotation for p in sig.parameters.values() 342 ] 343 if all(ann is sig.empty for ann in all_annots): 344 return None 345 346 arg_types = [ann_to_type(p.annotation, loc) for p in sig.parameters.values()] 347 return_type = ann_to_type(sig.return_annotation, loc) 348 return arg_types, return_type 349 350 351# Finds common type for enum values belonging to an Enum class. If not all 352# values have the same type, AnyType is returned. 353def get_enum_value_type(e: Type[enum.Enum], loc): 354 enum_values: List[enum.Enum] = list(e) 355 if not enum_values: 356 raise ValueError(f"No enum values defined for: '{e.__class__}'") 357 358 types = {type(v.value) for v in enum_values} 359 ir_types = [try_ann_to_type(t, loc) for t in types] 360 361 # If Enum values are of different types, an exception will be raised here. 362 # Even though Python supports this case, we chose to not implement it to 363 # avoid overcomplicate logic here for a rare use case. Please report a 364 # feature request if you find it necessary. 365 res = torch._C.unify_type_list(ir_types) 366 if not res: 367 return AnyType.get() 368 return res 369 370 371def is_tensor(ann): 372 if issubclass(ann, torch.Tensor): 373 return True 374 375 if issubclass( 376 ann, 377 ( 378 torch.LongTensor, 379 torch.DoubleTensor, 380 torch.FloatTensor, 381 torch.IntTensor, 382 torch.ShortTensor, 383 torch.HalfTensor, 384 torch.CharTensor, 385 torch.ByteTensor, 386 torch.BoolTensor, 387 ), 388 ): 389 warnings.warn( 390 "TorchScript will treat type annotations of Tensor " 391 "dtype-specific subtypes as if they are normal Tensors. " 392 "dtype constraints are not enforced in compilation either." 393 ) 394 return True 395 396 return False 397 398 399def _fake_rcb(inp): 400 return None 401 402 403def try_ann_to_type(ann, loc, rcb=None): 404 ann_args = typing.get_args(ann) # always returns a tuple! 405 406 if ann is inspect.Signature.empty: 407 return TensorType.getInferred() 408 if ann is None: 409 return NoneType.get() 410 if inspect.isclass(ann) and is_tensor(ann): 411 return TensorType.get() 412 if is_tuple(ann): 413 # Special case for the empty Tuple type annotation `Tuple[()]` 414 if len(ann_args) == 1 and ann_args[0] == (): 415 return TupleType([]) 416 return TupleType([try_ann_to_type(a, loc) for a in ann_args]) 417 if is_list(ann): 418 elem_type = try_ann_to_type(ann_args[0], loc) 419 if elem_type: 420 return ListType(elem_type) 421 if is_dict(ann): 422 key = try_ann_to_type(ann_args[0], loc) 423 value = try_ann_to_type(ann_args[1], loc) 424 # Raise error if key or value is None 425 if key is None: 426 raise ValueError( 427 f"Unknown type annotation: '{ann_args[0]}' at {loc.highlight()}" 428 ) 429 if value is None: 430 raise ValueError( 431 f"Unknown type annotation: '{ann_args[1]}' at {loc.highlight()}" 432 ) 433 return DictType(key, value) 434 if is_optional(ann): 435 if issubclass(ann_args[1], type(None)): 436 contained = ann_args[0] 437 else: 438 contained = ann_args[1] 439 valid_type = try_ann_to_type(contained, loc) 440 msg = "Unsupported annotation {} could not be resolved because {} could not be resolved. At\n{}" 441 assert valid_type, msg.format(repr(ann), repr(contained), repr(loc)) 442 return OptionalType(valid_type) 443 if is_union(ann): 444 # TODO: this is hack to recognize NumberType 445 if set(ann_args) == {int, float, complex}: 446 return NumberType.get() 447 inner: List = [] 448 # We need these extra checks because both `None` and invalid 449 # values will return `None` 450 # TODO: Determine if the other cases need to be fixed as well 451 for a in typing.get_args(ann): 452 if a is None: 453 inner.append(NoneType.get()) 454 maybe_type = try_ann_to_type(a, loc) 455 msg = "Unsupported annotation {} could not be resolved because {} could not be resolved. At\n{}" 456 assert maybe_type, msg.format(repr(ann), repr(maybe_type), repr(loc)) 457 inner.append(maybe_type) 458 return UnionType(inner) # type: ignore[arg-type] 459 if torch.distributed.rpc.is_available() and is_rref(ann): 460 return RRefType(try_ann_to_type(ann_args[0], loc)) 461 if is_future(ann): 462 return FutureType(try_ann_to_type(ann_args[0], loc)) 463 if is_await(ann): 464 elementType = try_ann_to_type(ann_args[0], loc) if ann_args else AnyType.get() 465 return AwaitType(elementType) 466 if ann is float: 467 return FloatType.get() 468 if ann is complex: 469 return ComplexType.get() 470 if ann is int or ann is torch.SymInt: 471 return IntType.get() 472 if ann is str: 473 return StringType.get() 474 if ann is bool: 475 return BoolType.get() 476 if ann is Any: 477 return AnyType.get() 478 if ann is type(None): 479 return NoneType.get() 480 if inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"): 481 return InterfaceType(ann.__torch_script_interface__) 482 if ann is torch.device: 483 return DeviceObjType.get() 484 if ann is torch.Generator: 485 return _GeneratorType.get() 486 if ann is torch.Stream: 487 return StreamObjType.get() 488 if ann is torch.dtype: 489 return IntType.get() # dtype not yet bound in as its own type 490 if inspect.isclass(ann) and issubclass(ann, enum.Enum): 491 if _get_script_class(ann) is None: 492 scripted_class = torch.jit._script._recursive_compile_class(ann, loc) 493 name = scripted_class.qualified_name() 494 else: 495 name = _qualified_name(ann) 496 return EnumType(name, get_enum_value_type(ann, loc), list(ann)) 497 if inspect.isclass(ann): 498 maybe_script_class = _get_script_class(ann) 499 if maybe_script_class is not None: 500 return maybe_script_class 501 if torch._jit_internal.can_compile_class(ann): 502 return torch.jit._script._recursive_compile_class(ann, loc) 503 504 # Maybe resolve a NamedTuple to a Tuple Type 505 if rcb is None: 506 rcb = _fake_rcb 507 return torch._C._resolve_type_from_object(ann, loc, rcb) 508 509 510def ann_to_type(ann, loc, rcb=None): 511 the_type = try_ann_to_type(ann, loc, rcb) 512 if the_type is not None: 513 return the_type 514 raise ValueError(f"Unknown type annotation: '{ann}' at {loc.highlight()}") 515 516 517__all__ = [ 518 "Any", 519 "List", 520 "BroadcastingList1", 521 "BroadcastingList2", 522 "BroadcastingList3", 523 "Tuple", 524 "is_tuple", 525 "is_list", 526 "Dict", 527 "is_dict", 528 "is_optional", 529 "is_union", 530 "TensorType", 531 "TupleType", 532 "FloatType", 533 "ComplexType", 534 "IntType", 535 "ListType", 536 "StringType", 537 "DictType", 538 "AnyType", 539 "Module", 540 # TODO: Consider not exporting these during wildcard import (reserve 541 # that for the types; for idiomatic typing code.) 542 "get_signature", 543 "check_fn", 544 "get_param_names", 545 "parse_type_line", 546 "get_type_line", 547 "split_type_line", 548 "try_real_annotations", 549 "try_ann_to_type", 550 "ann_to_type", 551] 552