xref: /aosp_15_r20/external/pytorch/torch/_sources.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerimport ast
3*da0073e9SAndroid Build Coastguard Workerimport functools
4*da0073e9SAndroid Build Coastguard Workerimport inspect
5*da0073e9SAndroid Build Coastguard Workerfrom textwrap import dedent
6*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, List, NamedTuple, Optional, Tuple
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerfrom torch._C import ErrorReport
9*da0073e9SAndroid Build Coastguard Workerfrom torch._C._jit_tree_views import SourceRangeFactory
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workerdef get_source_lines_and_file(
13*da0073e9SAndroid Build Coastguard Worker    obj: Any,
14*da0073e9SAndroid Build Coastguard Worker    error_msg: Optional[str] = None,
15*da0073e9SAndroid Build Coastguard Worker) -> Tuple[List[str], int, Optional[str]]:
16*da0073e9SAndroid Build Coastguard Worker    """
17*da0073e9SAndroid Build Coastguard Worker    Wrapper around inspect.getsourcelines and inspect.getsourcefile.
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker    Returns: (sourcelines, file_lino, filename)
20*da0073e9SAndroid Build Coastguard Worker    """
21*da0073e9SAndroid Build Coastguard Worker    filename = None  # in case getsourcefile throws
22*da0073e9SAndroid Build Coastguard Worker    try:
23*da0073e9SAndroid Build Coastguard Worker        filename = inspect.getsourcefile(obj)
24*da0073e9SAndroid Build Coastguard Worker        sourcelines, file_lineno = inspect.getsourcelines(obj)
25*da0073e9SAndroid Build Coastguard Worker    except OSError as e:
26*da0073e9SAndroid Build Coastguard Worker        msg = (
27*da0073e9SAndroid Build Coastguard Worker            f"Can't get source for {obj}. TorchScript requires source access in "
28*da0073e9SAndroid Build Coastguard Worker            "order to carry out compilation, make sure original .py files are "
29*da0073e9SAndroid Build Coastguard Worker            "available."
30*da0073e9SAndroid Build Coastguard Worker        )
31*da0073e9SAndroid Build Coastguard Worker        if error_msg:
32*da0073e9SAndroid Build Coastguard Worker            msg += "\n" + error_msg
33*da0073e9SAndroid Build Coastguard Worker        raise OSError(msg) from e
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker    return sourcelines, file_lineno, filename
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Workerdef normalize_source_lines(sourcelines: List[str]) -> List[str]:
39*da0073e9SAndroid Build Coastguard Worker    """
40*da0073e9SAndroid Build Coastguard Worker    This helper function accepts a list of source lines. It finds the
41*da0073e9SAndroid Build Coastguard Worker    indentation level of the function definition (`def`), then it indents
42*da0073e9SAndroid Build Coastguard Worker    all lines in the function body to a point at or greater than that
43*da0073e9SAndroid Build Coastguard Worker    level. This allows for comments and continued string literals that
44*da0073e9SAndroid Build Coastguard Worker    are at a lower indentation than the rest of the code.
45*da0073e9SAndroid Build Coastguard Worker    Args:
46*da0073e9SAndroid Build Coastguard Worker        sourcelines: function source code, separated into lines by
47*da0073e9SAndroid Build Coastguard Worker                        the '\n' character
48*da0073e9SAndroid Build Coastguard Worker    Returns:
49*da0073e9SAndroid Build Coastguard Worker        A list of source lines that have been correctly aligned
50*da0073e9SAndroid Build Coastguard Worker    """
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker    def remove_prefix(text, prefix):
53*da0073e9SAndroid Build Coastguard Worker        return text[text.startswith(prefix) and len(prefix) :]
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker    # Find the line and line number containing the function definition
56*da0073e9SAndroid Build Coastguard Worker    idx = None
57*da0073e9SAndroid Build Coastguard Worker    for i, l in enumerate(sourcelines):
58*da0073e9SAndroid Build Coastguard Worker        if l.lstrip().startswith("def"):
59*da0073e9SAndroid Build Coastguard Worker            idx = i
60*da0073e9SAndroid Build Coastguard Worker            break
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker    # This will happen when the function is a lambda- we won't find "def" anywhere in the source
63*da0073e9SAndroid Build Coastguard Worker    # lines in that case. Currently trying to JIT compile a lambda will throw an error up in
64*da0073e9SAndroid Build Coastguard Worker    # `parse_def()`, but we might want to handle this case in the future.
65*da0073e9SAndroid Build Coastguard Worker    if idx is None:
66*da0073e9SAndroid Build Coastguard Worker        return sourcelines
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker    # Get a string representing the amount of leading whitespace
69*da0073e9SAndroid Build Coastguard Worker    fn_def = sourcelines[idx]
70*da0073e9SAndroid Build Coastguard Worker    whitespace = fn_def.split("def")[0]
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Worker    # Add this leading whitespace to all lines before and after the `def`
73*da0073e9SAndroid Build Coastguard Worker    aligned_prefix = [
74*da0073e9SAndroid Build Coastguard Worker        whitespace + remove_prefix(s, whitespace) for s in sourcelines[:idx]
75*da0073e9SAndroid Build Coastguard Worker    ]
76*da0073e9SAndroid Build Coastguard Worker    aligned_suffix = [
77*da0073e9SAndroid Build Coastguard Worker        whitespace + remove_prefix(s, whitespace) for s in sourcelines[idx + 1 :]
78*da0073e9SAndroid Build Coastguard Worker    ]
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker    # Put it together again
81*da0073e9SAndroid Build Coastguard Worker    aligned_prefix.append(fn_def)
82*da0073e9SAndroid Build Coastguard Worker    return aligned_prefix + aligned_suffix
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker# Thin wrapper around SourceRangeFactory to store extra metadata
86*da0073e9SAndroid Build Coastguard Worker# about the function-to-be-compiled.
87*da0073e9SAndroid Build Coastguard Workerclass SourceContext(SourceRangeFactory):
88*da0073e9SAndroid Build Coastguard Worker    def __init__(
89*da0073e9SAndroid Build Coastguard Worker        self,
90*da0073e9SAndroid Build Coastguard Worker        source,
91*da0073e9SAndroid Build Coastguard Worker        filename,
92*da0073e9SAndroid Build Coastguard Worker        file_lineno,
93*da0073e9SAndroid Build Coastguard Worker        leading_whitespace_len,
94*da0073e9SAndroid Build Coastguard Worker        uses_true_division=True,
95*da0073e9SAndroid Build Coastguard Worker        funcname=None,
96*da0073e9SAndroid Build Coastguard Worker    ):
97*da0073e9SAndroid Build Coastguard Worker        super().__init__(source, filename, file_lineno, leading_whitespace_len)
98*da0073e9SAndroid Build Coastguard Worker        self.uses_true_division = uses_true_division
99*da0073e9SAndroid Build Coastguard Worker        self.filename = filename
100*da0073e9SAndroid Build Coastguard Worker        self.funcname = funcname
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker@functools.lru_cache(maxsize=None)
104*da0073e9SAndroid Build Coastguard Workerdef make_source_context(*args):
105*da0073e9SAndroid Build Coastguard Worker    return SourceContext(*args)
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Workerdef fake_range():
109*da0073e9SAndroid Build Coastguard Worker    return SourceContext("", None, 0, 0).make_raw_range(0, 1)
110*da0073e9SAndroid Build Coastguard Worker
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Workerclass ParsedDef(NamedTuple):
113*da0073e9SAndroid Build Coastguard Worker    ast: ast.Module
114*da0073e9SAndroid Build Coastguard Worker    ctx: SourceContext
115*da0073e9SAndroid Build Coastguard Worker    source: str
116*da0073e9SAndroid Build Coastguard Worker    filename: Optional[str]
117*da0073e9SAndroid Build Coastguard Worker    file_lineno: int
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Workerdef parse_def(fn):
121*da0073e9SAndroid Build Coastguard Worker    sourcelines, file_lineno, filename = get_source_lines_and_file(
122*da0073e9SAndroid Build Coastguard Worker        fn, ErrorReport.call_stack()
123*da0073e9SAndroid Build Coastguard Worker    )
124*da0073e9SAndroid Build Coastguard Worker    sourcelines = normalize_source_lines(sourcelines)
125*da0073e9SAndroid Build Coastguard Worker    source = "".join(sourcelines)
126*da0073e9SAndroid Build Coastguard Worker    dedent_src = dedent(source)
127*da0073e9SAndroid Build Coastguard Worker    py_ast = ast.parse(dedent_src)
128*da0073e9SAndroid Build Coastguard Worker    if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
129*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
130*da0073e9SAndroid Build Coastguard Worker            f"Expected a single top-level function: {filename}:{file_lineno}"
131*da0073e9SAndroid Build Coastguard Worker        )
132*da0073e9SAndroid Build Coastguard Worker    leading_whitespace_len = len(source.split("\n", 1)[0]) - len(
133*da0073e9SAndroid Build Coastguard Worker        dedent_src.split("\n", 1)[0]
134*da0073e9SAndroid Build Coastguard Worker    )
135*da0073e9SAndroid Build Coastguard Worker    ctx = make_source_context(
136*da0073e9SAndroid Build Coastguard Worker        source, filename, file_lineno, leading_whitespace_len, True, fn.__name__
137*da0073e9SAndroid Build Coastguard Worker    )
138*da0073e9SAndroid Build Coastguard Worker    return ParsedDef(py_ast, ctx, source, filename, file_lineno)
139