xref: /aosp_15_r20/external/pytorch/torch/jit/annotations.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import ast
3import builtins
4import dis
5import enum
6import inspect
7import re
8import typing
9import warnings
10from textwrap import dedent
11from typing import Type
12
13import torch
14from torch._C import (
15    _GeneratorType,
16    AnyType,
17    AwaitType,
18    BoolType,
19    ComplexType,
20    DeviceObjType,
21    DictType,
22    EnumType,
23    FloatType,
24    FutureType,
25    InterfaceType,
26    IntType,
27    ListType,
28    NoneType,
29    NumberType,
30    OptionalType,
31    StreamObjType,
32    StringType,
33    TensorType,
34    TupleType,
35    UnionType,
36)
37from torch._jit_internal import (  # type: ignore[attr-defined]
38    _Await,
39    _qualified_name,
40    Any,
41    BroadcastingList1,
42    BroadcastingList2,
43    BroadcastingList3,
44    Dict,
45    Future,
46    is_await,
47    is_dict,
48    is_future,
49    is_ignored_fn,
50    is_list,
51    is_optional,
52    is_tuple,
53    is_union,
54    List,
55    Optional,
56    Tuple,
57    Union,
58)
59from torch._sources import get_source_lines_and_file
60
61from ._state import _get_script_class
62
63
64if torch.distributed.rpc.is_available():
65    from torch._C import RRefType
66    from torch._jit_internal import is_rref, RRef
67
68from torch._ops import OpOverloadPacket
69
70
71class Module:
72    def __init__(self, name, members):
73        self.name = name
74        self.members = members
75
76    def __getattr__(self, name):
77        try:
78            return self.members[name]
79        except KeyError:
80            raise RuntimeError(
81                f"Module {self.name} has no member called {name}"
82            ) from None
83
84
85class EvalEnv:
86    env = {
87        "torch": Module("torch", {"Tensor": torch.Tensor}),
88        "Tensor": torch.Tensor,
89        "typing": Module("typing", {"Tuple": Tuple}),
90        "Tuple": Tuple,
91        "List": List,
92        "Dict": Dict,
93        "Optional": Optional,
94        "Union": Union,
95        "Future": Future,
96        "Await": _Await,
97    }
98
99    def __init__(self, rcb):
100        self.rcb = rcb
101        if torch.distributed.rpc.is_available():
102            self.env["RRef"] = RRef
103
104    def __getitem__(self, name):
105        if name in self.env:
106            return self.env[name]
107        if self.rcb is not None:
108            return self.rcb(name)
109        return getattr(builtins, name, None)
110
111
112def get_signature(fn, rcb, loc, is_method):
113    if isinstance(fn, OpOverloadPacket):
114        signature = try_real_annotations(fn.op, loc)
115    else:
116        signature = try_real_annotations(fn, loc)
117    if signature is not None and is_method:
118        # If this is a method, then the signature will include a type for
119        # `self`, but type comments do not contain a `self`. So strip it
120        # away here so everything is consistent (`inspect.ismethod` does
121        # not work here since `fn` is unbound at this point)
122        param_types, return_type = signature
123        param_types = param_types[1:]
124        signature = (param_types, return_type)
125
126    if signature is None:
127        type_line, source = None, None
128        try:
129            source = dedent("".join(get_source_lines_and_file(fn)[0]))
130            type_line = get_type_line(source)
131        except TypeError:
132            pass
133        # This might happen both because we failed to get the source of fn, or
134        # because it didn't have any annotations.
135        if type_line is not None:
136            signature = parse_type_line(type_line, rcb, loc)
137
138    return signature
139
140
141def is_function_or_method(the_callable):
142    # A stricter version of `inspect.isroutine` that does not pass for built-in
143    # functions
144    return inspect.isfunction(the_callable) or inspect.ismethod(the_callable)
145
146
147def is_vararg(the_callable):
148    if not is_function_or_method(the_callable) and callable(the_callable):  # noqa: B004
149        # If `the_callable` is a class, de-sugar the call so we can still get
150        # the signature
151        the_callable = the_callable.__call__
152
153    if is_function_or_method(the_callable):
154        return inspect.getfullargspec(the_callable).varargs is not None
155    else:
156        return False
157
158
159def get_param_names(fn, n_args):
160    if isinstance(fn, OpOverloadPacket):
161        fn = fn.op
162
163    if (
164        not is_function_or_method(fn)
165        and callable(fn)
166        and is_function_or_method(fn.__call__)
167    ):  # noqa: B004
168        # De-sugar calls to classes
169        fn = fn.__call__
170
171    if is_function_or_method(fn):
172        if is_ignored_fn(fn):
173            fn = inspect.unwrap(fn)
174        return inspect.getfullargspec(fn).args
175    else:
176        # The `fn` was not a method or function (maybe a class with a __call__
177        # method, so use a default param name list)
178        return [str(i) for i in range(n_args)]
179
180
181def check_fn(fn, loc):
182    # Make sure the function definition is not a class instantiation
183    try:
184        source = dedent("".join(get_source_lines_and_file(fn)[0]))
185    except (OSError, TypeError):
186        return
187    if source is None:
188        return
189
190    py_ast = ast.parse(source)
191    if len(py_ast.body) == 1 and isinstance(py_ast.body[0], ast.ClassDef):
192        raise torch.jit.frontend.FrontendError(
193            loc,
194            f"Cannot instantiate class '{py_ast.body[0].name}' in a script function",
195        )
196    if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
197        raise torch.jit.frontend.FrontendError(
198            loc, "Expected a single top-level function"
199        )
200
201
202def _eval_no_call(stmt, glob, loc):
203    """Evaluate statement as long as it does not contain any method/function calls."""
204    bytecode = compile(stmt, "", mode="eval")
205    for insn in dis.get_instructions(bytecode):
206        if "CALL" in insn.opname:
207            raise RuntimeError(
208                f"Type annotation should not contain calls, but '{stmt}' does"
209            )
210    return eval(bytecode, glob, loc)  # type: ignore[arg-type] # noqa: P204
211
212
213def parse_type_line(type_line, rcb, loc):
214    """Parse a type annotation specified as a comment.
215
216    Example inputs:
217        # type: (Tensor, torch.Tensor) -> Tuple[Tensor]
218        # type: (Tensor, Tuple[Tensor, Tensor]) -> Tensor
219    """
220    arg_ann_str, ret_ann_str = split_type_line(type_line)
221
222    try:
223        arg_ann = _eval_no_call(arg_ann_str, {}, EvalEnv(rcb))
224    except (NameError, SyntaxError) as e:
225        raise RuntimeError(
226            "Failed to parse the argument list of a type annotation"
227        ) from e
228
229    if not isinstance(arg_ann, tuple):
230        arg_ann = (arg_ann,)
231
232    try:
233        ret_ann = _eval_no_call(ret_ann_str, {}, EvalEnv(rcb))
234    except (NameError, SyntaxError) as e:
235        raise RuntimeError(
236            "Failed to parse the return type of a type annotation"
237        ) from e
238
239    arg_types = [ann_to_type(ann, loc) for ann in arg_ann]
240    return arg_types, ann_to_type(ret_ann, loc)
241
242
243def get_type_line(source):
244    """Try to find the line containing a comment with the type annotation."""
245    type_comment = "# type:"
246
247    lines = source.split("\n")
248    lines = list(enumerate(lines))
249    type_lines = list(filter(lambda line: type_comment in line[1], lines))
250    # `type: ignore` comments may be needed in JIT'ed functions for mypy, due
251    # to the hack in torch/_VF.py.
252
253    # An ignore type comment can be of following format:
254    #   1) type: ignore
255    #   2) type: ignore[rule-code]
256    # This ignore statement must be at the end of the line
257
258    # adding an extra backslash before the space, to avoid triggering
259    # one of the checks in .github/workflows/lint.yml
260    type_pattern = re.compile("# type:\\ ignore(\\[[a-zA-Z-]+\\])?$")
261    type_lines = list(filter(lambda line: not type_pattern.search(line[1]), type_lines))
262
263    if len(type_lines) == 0:
264        # Catch common typo patterns like extra spaces, typo in 'ignore', etc.
265        wrong_type_pattern = re.compile("#[\t ]*type[\t ]*(?!: ignore(\\[.*\\])?$):")
266        wrong_type_lines = list(
267            filter(lambda line: wrong_type_pattern.search(line[1]), lines)
268        )
269        if len(wrong_type_lines) > 0:
270            raise RuntimeError(
271                "The annotation prefix in line "
272                + str(wrong_type_lines[0][0])
273                + " is probably invalid.\nIt must be '# type:'"
274                + "\nSee PEP 484 (https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)"  # noqa: B950
275                + "\nfor examples"
276            )
277        return None
278    elif len(type_lines) == 1:
279        # Only 1 type line, quit now
280        return type_lines[0][1].strip()
281
282    # Parse split up argument types according to PEP 484
283    # https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code
284    return_line = None
285    parameter_type_lines = []
286    for line_num, line in type_lines:
287        if "# type: (...) -> " in line:
288            return_line = (line_num, line)
289            break
290        elif type_comment in line:
291            parameter_type_lines.append(line)
292    if return_line is None:
293        raise RuntimeError(
294            "Return type line '# type: (...) -> ...' not found on multiline "
295            "type annotation\nfor type lines:\n"
296            + "\n".join([line[1] for line in type_lines])
297            + "\n(See PEP 484 https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)"
298        )
299
300    def get_parameter_type(line):
301        item_type = line[line.find(type_comment) + len(type_comment) :]
302        return item_type.strip()
303
304    types = map(get_parameter_type, parameter_type_lines)
305    parameter_types = ", ".join(types)
306
307    return return_line[1].replace("...", parameter_types)
308
309
310def split_type_line(type_line):
311    """Split the comment with the type annotation into parts for argument and return types.
312
313    For example, for an input of:
314        # type: (Tensor, torch.Tensor) -> Tuple[Tensor, Tensor]
315
316    This function will return:
317        ("(Tensor, torch.Tensor)", "Tuple[Tensor, Tensor]")
318
319    """
320    start_offset = len("# type:")
321    try:
322        arrow_pos = type_line.index("->")
323    except ValueError:
324        raise RuntimeError(
325            "Syntax error in type annotation (couldn't find `->`)"
326        ) from None
327    return type_line[start_offset:arrow_pos].strip(), type_line[arrow_pos + 2 :].strip()
328
329
330def try_real_annotations(fn, loc):
331    """Try to use the Py3.5+ annotation syntax to get the type."""
332    try:
333        # Note: anything annotated as `Optional[T]` will automatically
334        # be returned as `Union[T, None]` per
335        # https://github.com/python/typing/blob/master/src/typing.py#L850
336        sig = inspect.signature(fn)
337    except ValueError:
338        return None
339
340    all_annots = [sig.return_annotation] + [
341        p.annotation for p in sig.parameters.values()
342    ]
343    if all(ann is sig.empty for ann in all_annots):
344        return None
345
346    arg_types = [ann_to_type(p.annotation, loc) for p in sig.parameters.values()]
347    return_type = ann_to_type(sig.return_annotation, loc)
348    return arg_types, return_type
349
350
351# Finds common type for enum values belonging to an Enum class. If not all
352# values have the same type, AnyType is returned.
353def get_enum_value_type(e: Type[enum.Enum], loc):
354    enum_values: List[enum.Enum] = list(e)
355    if not enum_values:
356        raise ValueError(f"No enum values defined for: '{e.__class__}'")
357
358    types = {type(v.value) for v in enum_values}
359    ir_types = [try_ann_to_type(t, loc) for t in types]
360
361    # If Enum values are of different types, an exception will be raised here.
362    # Even though Python supports this case, we chose to not implement it to
363    # avoid overcomplicate logic here for a rare use case. Please report a
364    # feature request if you find it necessary.
365    res = torch._C.unify_type_list(ir_types)
366    if not res:
367        return AnyType.get()
368    return res
369
370
371def is_tensor(ann):
372    if issubclass(ann, torch.Tensor):
373        return True
374
375    if issubclass(
376        ann,
377        (
378            torch.LongTensor,
379            torch.DoubleTensor,
380            torch.FloatTensor,
381            torch.IntTensor,
382            torch.ShortTensor,
383            torch.HalfTensor,
384            torch.CharTensor,
385            torch.ByteTensor,
386            torch.BoolTensor,
387        ),
388    ):
389        warnings.warn(
390            "TorchScript will treat type annotations of Tensor "
391            "dtype-specific subtypes as if they are normal Tensors. "
392            "dtype constraints are not enforced in compilation either."
393        )
394        return True
395
396    return False
397
398
399def _fake_rcb(inp):
400    return None
401
402
403def try_ann_to_type(ann, loc, rcb=None):
404    ann_args = typing.get_args(ann)  # always returns a tuple!
405
406    if ann is inspect.Signature.empty:
407        return TensorType.getInferred()
408    if ann is None:
409        return NoneType.get()
410    if inspect.isclass(ann) and is_tensor(ann):
411        return TensorType.get()
412    if is_tuple(ann):
413        # Special case for the empty Tuple type annotation `Tuple[()]`
414        if len(ann_args) == 1 and ann_args[0] == ():
415            return TupleType([])
416        return TupleType([try_ann_to_type(a, loc) for a in ann_args])
417    if is_list(ann):
418        elem_type = try_ann_to_type(ann_args[0], loc)
419        if elem_type:
420            return ListType(elem_type)
421    if is_dict(ann):
422        key = try_ann_to_type(ann_args[0], loc)
423        value = try_ann_to_type(ann_args[1], loc)
424        # Raise error if key or value is None
425        if key is None:
426            raise ValueError(
427                f"Unknown type annotation: '{ann_args[0]}' at {loc.highlight()}"
428            )
429        if value is None:
430            raise ValueError(
431                f"Unknown type annotation: '{ann_args[1]}' at {loc.highlight()}"
432            )
433        return DictType(key, value)
434    if is_optional(ann):
435        if issubclass(ann_args[1], type(None)):
436            contained = ann_args[0]
437        else:
438            contained = ann_args[1]
439        valid_type = try_ann_to_type(contained, loc)
440        msg = "Unsupported annotation {} could not be resolved because {} could not be resolved. At\n{}"
441        assert valid_type, msg.format(repr(ann), repr(contained), repr(loc))
442        return OptionalType(valid_type)
443    if is_union(ann):
444        # TODO: this is hack to recognize NumberType
445        if set(ann_args) == {int, float, complex}:
446            return NumberType.get()
447        inner: List = []
448        # We need these extra checks because both `None` and invalid
449        # values will return `None`
450        # TODO: Determine if the other cases need to be fixed as well
451        for a in typing.get_args(ann):
452            if a is None:
453                inner.append(NoneType.get())
454            maybe_type = try_ann_to_type(a, loc)
455            msg = "Unsupported annotation {} could not be resolved because {} could not be resolved. At\n{}"
456            assert maybe_type, msg.format(repr(ann), repr(maybe_type), repr(loc))
457            inner.append(maybe_type)
458        return UnionType(inner)  # type: ignore[arg-type]
459    if torch.distributed.rpc.is_available() and is_rref(ann):
460        return RRefType(try_ann_to_type(ann_args[0], loc))
461    if is_future(ann):
462        return FutureType(try_ann_to_type(ann_args[0], loc))
463    if is_await(ann):
464        elementType = try_ann_to_type(ann_args[0], loc) if ann_args else AnyType.get()
465        return AwaitType(elementType)
466    if ann is float:
467        return FloatType.get()
468    if ann is complex:
469        return ComplexType.get()
470    if ann is int or ann is torch.SymInt:
471        return IntType.get()
472    if ann is str:
473        return StringType.get()
474    if ann is bool:
475        return BoolType.get()
476    if ann is Any:
477        return AnyType.get()
478    if ann is type(None):
479        return NoneType.get()
480    if inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"):
481        return InterfaceType(ann.__torch_script_interface__)
482    if ann is torch.device:
483        return DeviceObjType.get()
484    if ann is torch.Generator:
485        return _GeneratorType.get()
486    if ann is torch.Stream:
487        return StreamObjType.get()
488    if ann is torch.dtype:
489        return IntType.get()  # dtype not yet bound in as its own type
490    if inspect.isclass(ann) and issubclass(ann, enum.Enum):
491        if _get_script_class(ann) is None:
492            scripted_class = torch.jit._script._recursive_compile_class(ann, loc)
493            name = scripted_class.qualified_name()
494        else:
495            name = _qualified_name(ann)
496        return EnumType(name, get_enum_value_type(ann, loc), list(ann))
497    if inspect.isclass(ann):
498        maybe_script_class = _get_script_class(ann)
499        if maybe_script_class is not None:
500            return maybe_script_class
501        if torch._jit_internal.can_compile_class(ann):
502            return torch.jit._script._recursive_compile_class(ann, loc)
503
504    # Maybe resolve a NamedTuple to a Tuple Type
505    if rcb is None:
506        rcb = _fake_rcb
507    return torch._C._resolve_type_from_object(ann, loc, rcb)
508
509
510def ann_to_type(ann, loc, rcb=None):
511    the_type = try_ann_to_type(ann, loc, rcb)
512    if the_type is not None:
513        return the_type
514    raise ValueError(f"Unknown type annotation: '{ann}' at {loc.highlight()}")
515
516
517__all__ = [
518    "Any",
519    "List",
520    "BroadcastingList1",
521    "BroadcastingList2",
522    "BroadcastingList3",
523    "Tuple",
524    "is_tuple",
525    "is_list",
526    "Dict",
527    "is_dict",
528    "is_optional",
529    "is_union",
530    "TensorType",
531    "TupleType",
532    "FloatType",
533    "ComplexType",
534    "IntType",
535    "ListType",
536    "StringType",
537    "DictType",
538    "AnyType",
539    "Module",
540    # TODO: Consider not exporting these during wildcard import (reserve
541    # that for the types; for idiomatic typing code.)
542    "get_signature",
543    "check_fn",
544    "get_param_names",
545    "parse_type_line",
546    "get_type_line",
547    "split_type_line",
548    "try_real_annotations",
549    "try_ann_to_type",
550    "ann_to_type",
551]
552