1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerimport ast 3*da0073e9SAndroid Build Coastguard Workerimport functools 4*da0073e9SAndroid Build Coastguard Workerimport inspect 5*da0073e9SAndroid Build Coastguard Workerfrom textwrap import dedent 6*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, List, NamedTuple, Optional, Tuple 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerfrom torch._C import ErrorReport 9*da0073e9SAndroid Build Coastguard Workerfrom torch._C._jit_tree_views import SourceRangeFactory 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Workerdef get_source_lines_and_file( 13*da0073e9SAndroid Build Coastguard Worker obj: Any, 14*da0073e9SAndroid Build Coastguard Worker error_msg: Optional[str] = None, 15*da0073e9SAndroid Build Coastguard Worker) -> Tuple[List[str], int, Optional[str]]: 16*da0073e9SAndroid Build Coastguard Worker """ 17*da0073e9SAndroid Build Coastguard Worker Wrapper around inspect.getsourcelines and inspect.getsourcefile. 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker Returns: (sourcelines, file_lino, filename) 20*da0073e9SAndroid Build Coastguard Worker """ 21*da0073e9SAndroid Build Coastguard Worker filename = None # in case getsourcefile throws 22*da0073e9SAndroid Build Coastguard Worker try: 23*da0073e9SAndroid Build Coastguard Worker filename = inspect.getsourcefile(obj) 24*da0073e9SAndroid Build Coastguard Worker sourcelines, file_lineno = inspect.getsourcelines(obj) 25*da0073e9SAndroid Build Coastguard Worker except OSError as e: 26*da0073e9SAndroid Build Coastguard Worker msg = ( 27*da0073e9SAndroid Build Coastguard Worker f"Can't get source for {obj}. TorchScript requires source access in " 28*da0073e9SAndroid Build Coastguard Worker "order to carry out compilation, make sure original .py files are " 29*da0073e9SAndroid Build Coastguard Worker "available." 30*da0073e9SAndroid Build Coastguard Worker ) 31*da0073e9SAndroid Build Coastguard Worker if error_msg: 32*da0073e9SAndroid Build Coastguard Worker msg += "\n" + error_msg 33*da0073e9SAndroid Build Coastguard Worker raise OSError(msg) from e 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker return sourcelines, file_lineno, filename 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Workerdef normalize_source_lines(sourcelines: List[str]) -> List[str]: 39*da0073e9SAndroid Build Coastguard Worker """ 40*da0073e9SAndroid Build Coastguard Worker This helper function accepts a list of source lines. It finds the 41*da0073e9SAndroid Build Coastguard Worker indentation level of the function definition (`def`), then it indents 42*da0073e9SAndroid Build Coastguard Worker all lines in the function body to a point at or greater than that 43*da0073e9SAndroid Build Coastguard Worker level. This allows for comments and continued string literals that 44*da0073e9SAndroid Build Coastguard Worker are at a lower indentation than the rest of the code. 45*da0073e9SAndroid Build Coastguard Worker Args: 46*da0073e9SAndroid Build Coastguard Worker sourcelines: function source code, separated into lines by 47*da0073e9SAndroid Build Coastguard Worker the '\n' character 48*da0073e9SAndroid Build Coastguard Worker Returns: 49*da0073e9SAndroid Build Coastguard Worker A list of source lines that have been correctly aligned 50*da0073e9SAndroid Build Coastguard Worker """ 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker def remove_prefix(text, prefix): 53*da0073e9SAndroid Build Coastguard Worker return text[text.startswith(prefix) and len(prefix) :] 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker # Find the line and line number containing the function definition 56*da0073e9SAndroid Build Coastguard Worker idx = None 57*da0073e9SAndroid Build Coastguard Worker for i, l in enumerate(sourcelines): 58*da0073e9SAndroid Build Coastguard Worker if l.lstrip().startswith("def"): 59*da0073e9SAndroid Build Coastguard Worker idx = i 60*da0073e9SAndroid Build Coastguard Worker break 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker # This will happen when the function is a lambda- we won't find "def" anywhere in the source 63*da0073e9SAndroid Build Coastguard Worker # lines in that case. Currently trying to JIT compile a lambda will throw an error up in 64*da0073e9SAndroid Build Coastguard Worker # `parse_def()`, but we might want to handle this case in the future. 65*da0073e9SAndroid Build Coastguard Worker if idx is None: 66*da0073e9SAndroid Build Coastguard Worker return sourcelines 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker # Get a string representing the amount of leading whitespace 69*da0073e9SAndroid Build Coastguard Worker fn_def = sourcelines[idx] 70*da0073e9SAndroid Build Coastguard Worker whitespace = fn_def.split("def")[0] 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Worker # Add this leading whitespace to all lines before and after the `def` 73*da0073e9SAndroid Build Coastguard Worker aligned_prefix = [ 74*da0073e9SAndroid Build Coastguard Worker whitespace + remove_prefix(s, whitespace) for s in sourcelines[:idx] 75*da0073e9SAndroid Build Coastguard Worker ] 76*da0073e9SAndroid Build Coastguard Worker aligned_suffix = [ 77*da0073e9SAndroid Build Coastguard Worker whitespace + remove_prefix(s, whitespace) for s in sourcelines[idx + 1 :] 78*da0073e9SAndroid Build Coastguard Worker ] 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker # Put it together again 81*da0073e9SAndroid Build Coastguard Worker aligned_prefix.append(fn_def) 82*da0073e9SAndroid Build Coastguard Worker return aligned_prefix + aligned_suffix 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker# Thin wrapper around SourceRangeFactory to store extra metadata 86*da0073e9SAndroid Build Coastguard Worker# about the function-to-be-compiled. 87*da0073e9SAndroid Build Coastguard Workerclass SourceContext(SourceRangeFactory): 88*da0073e9SAndroid Build Coastguard Worker def __init__( 89*da0073e9SAndroid Build Coastguard Worker self, 90*da0073e9SAndroid Build Coastguard Worker source, 91*da0073e9SAndroid Build Coastguard Worker filename, 92*da0073e9SAndroid Build Coastguard Worker file_lineno, 93*da0073e9SAndroid Build Coastguard Worker leading_whitespace_len, 94*da0073e9SAndroid Build Coastguard Worker uses_true_division=True, 95*da0073e9SAndroid Build Coastguard Worker funcname=None, 96*da0073e9SAndroid Build Coastguard Worker ): 97*da0073e9SAndroid Build Coastguard Worker super().__init__(source, filename, file_lineno, leading_whitespace_len) 98*da0073e9SAndroid Build Coastguard Worker self.uses_true_division = uses_true_division 99*da0073e9SAndroid Build Coastguard Worker self.filename = filename 100*da0073e9SAndroid Build Coastguard Worker self.funcname = funcname 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker@functools.lru_cache(maxsize=None) 104*da0073e9SAndroid Build Coastguard Workerdef make_source_context(*args): 105*da0073e9SAndroid Build Coastguard Worker return SourceContext(*args) 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Workerdef fake_range(): 109*da0073e9SAndroid Build Coastguard Worker return SourceContext("", None, 0, 0).make_raw_range(0, 1) 110*da0073e9SAndroid Build Coastguard Worker 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Workerclass ParsedDef(NamedTuple): 113*da0073e9SAndroid Build Coastguard Worker ast: ast.Module 114*da0073e9SAndroid Build Coastguard Worker ctx: SourceContext 115*da0073e9SAndroid Build Coastguard Worker source: str 116*da0073e9SAndroid Build Coastguard Worker filename: Optional[str] 117*da0073e9SAndroid Build Coastguard Worker file_lineno: int 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Workerdef parse_def(fn): 121*da0073e9SAndroid Build Coastguard Worker sourcelines, file_lineno, filename = get_source_lines_and_file( 122*da0073e9SAndroid Build Coastguard Worker fn, ErrorReport.call_stack() 123*da0073e9SAndroid Build Coastguard Worker ) 124*da0073e9SAndroid Build Coastguard Worker sourcelines = normalize_source_lines(sourcelines) 125*da0073e9SAndroid Build Coastguard Worker source = "".join(sourcelines) 126*da0073e9SAndroid Build Coastguard Worker dedent_src = dedent(source) 127*da0073e9SAndroid Build Coastguard Worker py_ast = ast.parse(dedent_src) 128*da0073e9SAndroid Build Coastguard Worker if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef): 129*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 130*da0073e9SAndroid Build Coastguard Worker f"Expected a single top-level function: {filename}:{file_lineno}" 131*da0073e9SAndroid Build Coastguard Worker ) 132*da0073e9SAndroid Build Coastguard Worker leading_whitespace_len = len(source.split("\n", 1)[0]) - len( 133*da0073e9SAndroid Build Coastguard Worker dedent_src.split("\n", 1)[0] 134*da0073e9SAndroid Build Coastguard Worker ) 135*da0073e9SAndroid Build Coastguard Worker ctx = make_source_context( 136*da0073e9SAndroid Build Coastguard Worker source, filename, file_lineno, leading_whitespace_len, True, fn.__name__ 137*da0073e9SAndroid Build Coastguard Worker ) 138*da0073e9SAndroid Build Coastguard Worker return ParsedDef(py_ast, ctx, source, filename, file_lineno) 139