1*da0073e9SAndroid Build Coastguard Worker# mypy: ignore-errors 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport enum 4*da0073e9SAndroid Build Coastguard Workerimport dis 5*da0073e9SAndroid Build Coastguard Workerimport copy 6*da0073e9SAndroid Build Coastguard Workerimport sys 7*da0073e9SAndroid Build Coastguard Workerimport torch 8*da0073e9SAndroid Build Coastguard Workerimport inspect 9*da0073e9SAndroid Build Coastguard Workerimport operator 10*da0073e9SAndroid Build Coastguard Workerimport collections 11*da0073e9SAndroid Build Coastguard Workerimport logging 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Workerfrom dataclasses import is_dataclass, fields 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Workerfrom .graph import magic_methods, reflectable_magic_methods, Graph 17*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._traceback import CapturedTraceback 18*da0073e9SAndroid Build Coastguard Workerfrom typing import Tuple, Dict, OrderedDict, Optional, Any, Iterator, Callable 19*da0073e9SAndroid Build Coastguard Workerfrom .node import Target, Node, Argument, base_types, map_aggregate 20*da0073e9SAndroid Build Coastguard Workerfrom ._compatibility import compatibility 21*da0073e9SAndroid Build Coastguard Workerfrom .operator_schemas import check_for_mutable_operation 22*da0073e9SAndroid Build Coastguard Workerimport torch.fx.traceback as fx_traceback 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker__all__ = ['TracerBase', 'GraphAppendingTracer', 'TraceError', 25*da0073e9SAndroid Build Coastguard Worker 'Proxy', 'Attribute', 'ParameterProxy', 'Scope', 26*da0073e9SAndroid Build Coastguard Worker 'ScopeContextManager'] 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Workerlog = logging.getLogger(__name__) 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=False) 33*da0073e9SAndroid Build Coastguard Workerclass Scope: 34*da0073e9SAndroid Build Coastguard Worker """ Scope object that records the module path and the module type 35*da0073e9SAndroid Build Coastguard Worker of a module. Scope is used to track the information of the module 36*da0073e9SAndroid Build Coastguard Worker that contains a Node in a Graph of GraphModule. For example:: 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker class Sub(torch.nn.Module): 39*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 40*da0073e9SAndroid Build Coastguard Worker # This will be a call_method Node in GraphModule, 41*da0073e9SAndroid Build Coastguard Worker # scope for this would be (module_path="sub", module_type=Sub) 42*da0073e9SAndroid Build Coastguard Worker return x.transpose(1, 2) 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 45*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 46*da0073e9SAndroid Build Coastguard Worker self.sub = Sub() 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 49*da0073e9SAndroid Build Coastguard Worker # This will be a call_method Node as well, 50*da0073e9SAndroid Build Coastguard Worker # scope for this would be (module_path="", None) 51*da0073e9SAndroid Build Coastguard Worker x = x.transpose(1, 2) 52*da0073e9SAndroid Build Coastguard Worker x = self.sub(x) 53*da0073e9SAndroid Build Coastguard Worker return x 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker """ 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker def __init__(self, module_path: str, module_type: Any): 58*da0073e9SAndroid Build Coastguard Worker super().__init__() 59*da0073e9SAndroid Build Coastguard Worker self.module_path = module_path 60*da0073e9SAndroid Build Coastguard Worker self.module_type = module_type 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=False) 64*da0073e9SAndroid Build Coastguard Workerclass ScopeContextManager: 65*da0073e9SAndroid Build Coastguard Worker """ A context manager to track the Scope of Node during symbolic tracing. 66*da0073e9SAndroid Build Coastguard Worker When entering a forward function of a Module, we'll update the scope information of 67*da0073e9SAndroid Build Coastguard Worker the current module, and when we exit, we'll restore the previous scope information. 68*da0073e9SAndroid Build Coastguard Worker """ 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker def __init__( 71*da0073e9SAndroid Build Coastguard Worker self, 72*da0073e9SAndroid Build Coastguard Worker scope: Scope, 73*da0073e9SAndroid Build Coastguard Worker current_scope: Scope, 74*da0073e9SAndroid Build Coastguard Worker ): 75*da0073e9SAndroid Build Coastguard Worker super().__init__() 76*da0073e9SAndroid Build Coastguard Worker # Keep a copy of prev scope to restore on exit 77*da0073e9SAndroid Build Coastguard Worker self._prev_scope = copy.copy(scope) 78*da0073e9SAndroid Build Coastguard Worker # Update scope to current scope 79*da0073e9SAndroid Build Coastguard Worker scope.module_path = current_scope.module_path 80*da0073e9SAndroid Build Coastguard Worker scope.module_type = current_scope.module_type 81*da0073e9SAndroid Build Coastguard Worker # Save a reference so we can restore it 82*da0073e9SAndroid Build Coastguard Worker self._scope = scope 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker def __enter__(self): 85*da0073e9SAndroid Build Coastguard Worker return self._scope 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker def __exit__(self, *args): 88*da0073e9SAndroid Build Coastguard Worker self._scope.module_path = self._prev_scope.module_path 89*da0073e9SAndroid Build Coastguard Worker self._scope.module_type = self._prev_scope.module_type 90*da0073e9SAndroid Build Coastguard Worker return 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker_COPY_META_FIELDS = [ 94*da0073e9SAndroid Build Coastguard Worker "nn_module_stack", 95*da0073e9SAndroid Build Coastguard Worker "torch_fn", 96*da0073e9SAndroid Build Coastguard Worker "source_fn_stack", 97*da0073e9SAndroid Build Coastguard Worker "original_aten", 98*da0073e9SAndroid Build Coastguard Worker "recompute", 99*da0073e9SAndroid Build Coastguard Worker "ac_graph_id", 100*da0073e9SAndroid Build Coastguard Worker "from_node", 101*da0073e9SAndroid Build Coastguard Worker "quantization_tag", # TODO deprecated 102*da0073e9SAndroid Build Coastguard Worker "_numeric_debug_handle", # TODO deprecated 103*da0073e9SAndroid Build Coastguard Worker "custom", 104*da0073e9SAndroid Build Coastguard Worker "partitioner_tag" 105*da0073e9SAndroid Build Coastguard Worker] 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=True) 109*da0073e9SAndroid Build Coastguard Workerclass TracerBase: 110*da0073e9SAndroid Build Coastguard Worker graph: Graph 111*da0073e9SAndroid Build Coastguard Worker record_stack_traces : bool = False 112*da0073e9SAndroid Build Coastguard Worker # Feature flag for mutable schema checking 113*da0073e9SAndroid Build Coastguard Worker # Enableby default in 1.12 114*da0073e9SAndroid Build Coastguard Worker check_mutable_operations : bool = False 115*da0073e9SAndroid Build Coastguard Worker # Feature flag for assert tracing 116*da0073e9SAndroid Build Coastguard Worker trace_asserts : bool = False 117*da0073e9SAndroid Build Coastguard Worker # Feature flag for proxying accesses to buffer values 118*da0073e9SAndroid Build Coastguard Worker proxy_buffer_attributes : bool = False 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Worker # Name of the function to be traced. It will only be used when 121*da0073e9SAndroid Build Coastguard Worker # ``root`` is an instance of ``nn.Module`` 122*da0073e9SAndroid Build Coastguard Worker traced_func_name: str = "forward" 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker # Maps the containing module's name to the operator name 125*da0073e9SAndroid Build Coastguard Worker scope : Scope 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker # Records the module call stack 128*da0073e9SAndroid Build Coastguard Worker module_stack: OrderedDict[str, Tuple[str, Any]] 129*da0073e9SAndroid Build Coastguard Worker 130*da0073e9SAndroid Build Coastguard Worker # Mapping of node name to module scope 131*da0073e9SAndroid Build Coastguard Worker node_name_to_scope: Dict[str, Tuple[str, type]] 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 134*da0073e9SAndroid Build Coastguard Worker def create_node(self, kind : str, target : Target, 135*da0073e9SAndroid Build Coastguard Worker args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None, 136*da0073e9SAndroid Build Coastguard Worker type_expr : Optional[Any] = None) -> Node: 137*da0073e9SAndroid Build Coastguard Worker """ 138*da0073e9SAndroid Build Coastguard Worker Inserts a graph node given target, args, kwargs, and name. 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker This method can be overridden to do extra checking, validation, or 141*da0073e9SAndroid Build Coastguard Worker modification of values used in node creation. For example, one might 142*da0073e9SAndroid Build Coastguard Worker want to disallow in-place operations from being recorded. 143*da0073e9SAndroid Build Coastguard Worker """ 144*da0073e9SAndroid Build Coastguard Worker 145*da0073e9SAndroid Build Coastguard Worker if kind == 'call_function' and self.check_mutable_operations: 146*da0073e9SAndroid Build Coastguard Worker check_for_mutable_operation(target, args, kwargs) 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker node = self.graph.create_node(kind, target, args, kwargs, name, type_expr) 149*da0073e9SAndroid Build Coastguard Worker # TODO node_name_to_scope will be depreciated in favor of 150*da0073e9SAndroid Build Coastguard Worker # node.meta['nn_module_stack'] 151*da0073e9SAndroid Build Coastguard Worker self.node_name_to_scope[node.name] = ( 152*da0073e9SAndroid Build Coastguard Worker self.scope.module_path, 153*da0073e9SAndroid Build Coastguard Worker self.scope.module_type, 154*da0073e9SAndroid Build Coastguard Worker ) 155*da0073e9SAndroid Build Coastguard Worker # Optionally set stack trace on the created Node for debugging purposes 156*da0073e9SAndroid Build Coastguard Worker if fx_traceback.has_preserved_node_meta(): 157*da0073e9SAndroid Build Coastguard Worker current_meta: Dict[str, Any] = fx_traceback.get_current_meta() 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Worker stack_trace = current_meta.get("stack_trace") 160*da0073e9SAndroid Build Coastguard Worker if stack_trace: 161*da0073e9SAndroid Build Coastguard Worker node.stack_trace = stack_trace 162*da0073e9SAndroid Build Coastguard Worker # Explicitly set the stack_trace, nn_module_stack and source_fn on the node.meta 163*da0073e9SAndroid Build Coastguard Worker # If other meta fields are needed, they can be added here 164*da0073e9SAndroid Build Coastguard Worker for field in _COPY_META_FIELDS: 165*da0073e9SAndroid Build Coastguard Worker if field in current_meta: 166*da0073e9SAndroid Build Coastguard Worker node.meta[field] = copy.copy(current_meta[field]) 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker # Here we decrement to account for the sequence_nr having 169*da0073e9SAndroid Build Coastguard Worker # just been incremented while tracing this lowered aten op. 170*da0073e9SAndroid Build Coastguard Worker new_seq_nr = torch.autograd._get_sequence_nr() - 1 171*da0073e9SAndroid Build Coastguard Worker # The sequence_nr increments every time a new autograd Node 172*da0073e9SAndroid Build Coastguard Worker # is created. During the FWD pass we store the sequence_nr 173*da0073e9SAndroid Build Coastguard Worker # corresponding to the last autograd Node created on this fx 174*da0073e9SAndroid Build Coastguard Worker # node's meta. A single aten op can create multiple autograd 175*da0073e9SAndroid Build Coastguard Worker # nodes as is the case with in-place foreach ops. During the 176*da0073e9SAndroid Build Coastguard Worker # BWD pass we retrieve the sequence_nr stored on the current 177*da0073e9SAndroid Build Coastguard Worker # executing autograd Node. See NOTE [ Sequence Number ]. 178*da0073e9SAndroid Build Coastguard Worker if current_meta.get("in_grad_fn", 0) > 0: 179*da0073e9SAndroid Build Coastguard Worker new_seq_nr = current_meta["grad_fn_seq_nr"][-1] 180*da0073e9SAndroid Build Coastguard Worker node.meta["seq_nr"] = new_seq_nr 181*da0073e9SAndroid Build Coastguard Worker 182*da0073e9SAndroid Build Coastguard Worker elif self.module_stack: 183*da0073e9SAndroid Build Coastguard Worker node.meta['nn_module_stack'] = copy.copy(self.module_stack) 184*da0073e9SAndroid Build Coastguard Worker 185*da0073e9SAndroid Build Coastguard Worker log.debug("create_node %s", node) 186*da0073e9SAndroid Build Coastguard Worker return node 187*da0073e9SAndroid Build Coastguard Worker 188*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 189*da0073e9SAndroid Build Coastguard Worker def proxy(self, node: Node) -> 'Proxy': 190*da0073e9SAndroid Build Coastguard Worker return Proxy(node, self) 191*da0073e9SAndroid Build Coastguard Worker 192*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 193*da0073e9SAndroid Build Coastguard Worker def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any], 194*da0073e9SAndroid Build Coastguard Worker name: Optional[str] = None, type_expr : Optional[Any] = None, 195*da0073e9SAndroid Build Coastguard Worker proxy_factory_fn: Callable[[Node], 'Proxy'] = None): 196*da0073e9SAndroid Build Coastguard Worker ''' 197*da0073e9SAndroid Build Coastguard Worker Create a Node from the given arguments, then return the Node 198*da0073e9SAndroid Build Coastguard Worker wrapped in a Proxy object. 199*da0073e9SAndroid Build Coastguard Worker 200*da0073e9SAndroid Build Coastguard Worker If kind = 'placeholder', then we're creating a Node that 201*da0073e9SAndroid Build Coastguard Worker represents the parameter of a function. If we need to encode 202*da0073e9SAndroid Build Coastguard Worker a default parameter, we use the ``args`` tuple. ``args`` is 203*da0073e9SAndroid Build Coastguard Worker otherwise empty for ``placeholder`` Nodes. 204*da0073e9SAndroid Build Coastguard Worker ''' 205*da0073e9SAndroid Build Coastguard Worker 206*da0073e9SAndroid Build Coastguard Worker args_ = self.create_arg(args) 207*da0073e9SAndroid Build Coastguard Worker kwargs_ = self.create_arg(kwargs) 208*da0073e9SAndroid Build Coastguard Worker assert isinstance(args_, tuple) 209*da0073e9SAndroid Build Coastguard Worker assert isinstance(kwargs_, dict) 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker node = self.create_node(kind, target, args_, kwargs_, name, type_expr) 212*da0073e9SAndroid Build Coastguard Worker 213*da0073e9SAndroid Build Coastguard Worker if not proxy_factory_fn: 214*da0073e9SAndroid Build Coastguard Worker proxy = self.proxy(node) 215*da0073e9SAndroid Build Coastguard Worker else: 216*da0073e9SAndroid Build Coastguard Worker proxy = proxy_factory_fn(node) 217*da0073e9SAndroid Build Coastguard Worker 218*da0073e9SAndroid Build Coastguard Worker if self.record_stack_traces and not proxy.node.stack_trace: 219*da0073e9SAndroid Build Coastguard Worker proxy.node.stack_trace = ''.join(CapturedTraceback.extract().format()) 220*da0073e9SAndroid Build Coastguard Worker 221*da0073e9SAndroid Build Coastguard Worker 222*da0073e9SAndroid Build Coastguard Worker return proxy 223*da0073e9SAndroid Build Coastguard Worker 224*da0073e9SAndroid Build Coastguard Worker def _find_user_frame(self): 225*da0073e9SAndroid Build Coastguard Worker """ 226*da0073e9SAndroid Build Coastguard Worker Find the Python stack frame executing the user code during 227*da0073e9SAndroid Build Coastguard Worker symbolic tracing. 228*da0073e9SAndroid Build Coastguard Worker """ 229*da0073e9SAndroid Build Coastguard Worker # We have to do a little dance here. Basically, walk up the callstack and 230*da0073e9SAndroid Build Coastguard Worker # record the first frame not in the pytorch source. This is the frame executing 231*da0073e9SAndroid Build Coastguard Worker # the user code during tracing. 232*da0073e9SAndroid Build Coastguard Worker frame = inspect.currentframe() 233*da0073e9SAndroid Build Coastguard Worker 234*da0073e9SAndroid Build Coastguard Worker pt_files = ['torch/fx/proxy.py', 235*da0073e9SAndroid Build Coastguard Worker 'torch/fx/_symbolic_trace.py', 236*da0073e9SAndroid Build Coastguard Worker 'torch/fx/experimental/proxy_tensor.py', 237*da0073e9SAndroid Build Coastguard Worker 'torch/_ops.py', 238*da0073e9SAndroid Build Coastguard Worker 'torch/_tensor.py', 239*da0073e9SAndroid Build Coastguard Worker 'torch/utils/_python_dispatch.py', 240*da0073e9SAndroid Build Coastguard Worker 'torch/_prims_common/wrappers.py', 241*da0073e9SAndroid Build Coastguard Worker 'torch/_refs/__init__.py', 242*da0073e9SAndroid Build Coastguard Worker 'torch/_refs/nn/functional/__init__.py', 243*da0073e9SAndroid Build Coastguard Worker 'torch/utils/_stats.py', 244*da0073e9SAndroid Build Coastguard Worker ] 245*da0073e9SAndroid Build Coastguard Worker while frame: 246*da0073e9SAndroid Build Coastguard Worker frame = frame.f_back 247*da0073e9SAndroid Build Coastguard Worker if frame and all(not frame.f_code.co_filename.endswith(file) for file in pt_files): 248*da0073e9SAndroid Build Coastguard Worker break 249*da0073e9SAndroid Build Coastguard Worker 250*da0073e9SAndroid Build Coastguard Worker if not frame: 251*da0073e9SAndroid Build Coastguard Worker return None 252*da0073e9SAndroid Build Coastguard Worker 253*da0073e9SAndroid Build Coastguard Worker return frame 254*da0073e9SAndroid Build Coastguard Worker 255*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 256*da0073e9SAndroid Build Coastguard Worker def create_arg(self, a: Any) -> Argument: 257*da0073e9SAndroid Build Coastguard Worker """ 258*da0073e9SAndroid Build Coastguard Worker A method that lowers the objects seen as arguments during symbolic evaluation 259*da0073e9SAndroid Build Coastguard Worker into Argument types that can be stored in IR. 260*da0073e9SAndroid Build Coastguard Worker 261*da0073e9SAndroid Build Coastguard Worker Can be override to support more trace-specific types. 262*da0073e9SAndroid Build Coastguard Worker """ 263*da0073e9SAndroid Build Coastguard Worker if not isinstance(a, Proxy) and hasattr(a, '__fx_create_arg__'): 264*da0073e9SAndroid Build Coastguard Worker return a.__fx_create_arg__(self) 265*da0073e9SAndroid Build Coastguard Worker # aggregates 266*da0073e9SAndroid Build Coastguard Worker elif isinstance(a, tuple) and hasattr(a, '_fields'): 267*da0073e9SAndroid Build Coastguard Worker # NamedTuple constructors don't seem to like getting a generator 268*da0073e9SAndroid Build Coastguard Worker # expression as an argument to their constructor, so build this 269*da0073e9SAndroid Build Coastguard Worker # intermediate tuple and unpack it into the NamedTuple constructor 270*da0073e9SAndroid Build Coastguard Worker args = tuple(self.create_arg(elem) for elem in a) 271*da0073e9SAndroid Build Coastguard Worker return type(a)(*args) # type: ignore[arg-type] 272*da0073e9SAndroid Build Coastguard Worker elif isinstance(a, (tuple, list)): 273*da0073e9SAndroid Build Coastguard Worker return type(a)(self.create_arg(elem) for elem in a) 274*da0073e9SAndroid Build Coastguard Worker elif isinstance(a, dict): 275*da0073e9SAndroid Build Coastguard Worker r = {} 276*da0073e9SAndroid Build Coastguard Worker for k, v in a.items(): 277*da0073e9SAndroid Build Coastguard Worker # Check for invalid dict keys. We do not want a Proxy to appear 278*da0073e9SAndroid Build Coastguard Worker # anywhere within the key. Since keys can be collection types, 279*da0073e9SAndroid Build Coastguard Worker # we iterate through the key with map_aggregate 280*da0073e9SAndroid Build Coastguard Worker k = self.create_arg(k) 281*da0073e9SAndroid Build Coastguard Worker 282*da0073e9SAndroid Build Coastguard Worker def no_node(arg): 283*da0073e9SAndroid Build Coastguard Worker if isinstance(arg, Node): 284*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Keys for dictionaries used as an argument cannot contain a " 285*da0073e9SAndroid Build Coastguard Worker f"Node. Got key: {k}") 286*da0073e9SAndroid Build Coastguard Worker map_aggregate(k, no_node) 287*da0073e9SAndroid Build Coastguard Worker 288*da0073e9SAndroid Build Coastguard Worker r[k] = self.create_arg(v) 289*da0073e9SAndroid Build Coastguard Worker return r 290*da0073e9SAndroid Build Coastguard Worker elif isinstance(a, slice): 291*da0073e9SAndroid Build Coastguard Worker return slice(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step)) 292*da0073e9SAndroid Build Coastguard Worker 293*da0073e9SAndroid Build Coastguard Worker elif isinstance(a, range): 294*da0073e9SAndroid Build Coastguard Worker return range(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step)) 295*da0073e9SAndroid Build Coastguard Worker 296*da0073e9SAndroid Build Coastguard Worker elif isinstance(a, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): 297*da0073e9SAndroid Build Coastguard Worker return a 298*da0073e9SAndroid Build Coastguard Worker 299*da0073e9SAndroid Build Coastguard Worker if isinstance(a, Proxy): 300*da0073e9SAndroid Build Coastguard Worker # base case: we unwrap the Proxy object 301*da0073e9SAndroid Build Coastguard Worker return a.node 302*da0073e9SAndroid Build Coastguard Worker 303*da0073e9SAndroid Build Coastguard Worker if is_dataclass(a): 304*da0073e9SAndroid Build Coastguard Worker kwargs = {field.name: self.create_arg(getattr(a, field.name)) for field in fields(a)} 305*da0073e9SAndroid Build Coastguard Worker return self.create_node("call_function", a.__class__, (), kwargs) 306*da0073e9SAndroid Build Coastguard Worker 307*da0073e9SAndroid Build Coastguard Worker elif isinstance(a, (*base_types, enum.Enum)) or a is None or a is ...: 308*da0073e9SAndroid Build Coastguard Worker return a 309*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError(f"argument of type: {type(a)}") 310*da0073e9SAndroid Build Coastguard Worker 311*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 312*da0073e9SAndroid Build Coastguard Worker def to_bool(self, obj: 'Proxy') -> bool: 313*da0073e9SAndroid Build Coastguard Worker """Called when a proxy object is being converted to a boolean, such as 314*da0073e9SAndroid Build Coastguard Worker when used in control flow. Normally we don't know what to do because 315*da0073e9SAndroid Build Coastguard Worker we don't know the value of the proxy, but a custom tracer can attach more 316*da0073e9SAndroid Build Coastguard Worker information to the graph node using create_node and can choose to return a value. 317*da0073e9SAndroid Build Coastguard Worker """ 318*da0073e9SAndroid Build Coastguard Worker raise TraceError('symbolically traced variables cannot be used as inputs to control flow') 319*da0073e9SAndroid Build Coastguard Worker 320*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 321*da0073e9SAndroid Build Coastguard Worker def iter(self, obj: 'Proxy') -> Iterator: 322*da0073e9SAndroid Build Coastguard Worker """Called when a proxy object is being iterated over, such as 323*da0073e9SAndroid Build Coastguard Worker when used in control flow. Normally we don't know what to do because 324*da0073e9SAndroid Build Coastguard Worker we don't know the value of the proxy, but a custom tracer can attach more 325*da0073e9SAndroid Build Coastguard Worker information to the graph node using create_node and can choose to return an iterator. 326*da0073e9SAndroid Build Coastguard Worker """ 327*da0073e9SAndroid Build Coastguard Worker raise TraceError('Proxy object cannot be iterated. This can be ' 328*da0073e9SAndroid Build Coastguard Worker 'attempted when the Proxy is used in a loop or' 329*da0073e9SAndroid Build Coastguard Worker ' as a *args or **kwargs function argument. ' 330*da0073e9SAndroid Build Coastguard Worker 'See the torch.fx docs on pytorch.org for a ' 331*da0073e9SAndroid Build Coastguard Worker 'more detailed explanation of what types of ' 332*da0073e9SAndroid Build Coastguard Worker 'control flow can be traced, and check out the' 333*da0073e9SAndroid Build Coastguard Worker ' Proxy docstring for help troubleshooting ' 334*da0073e9SAndroid Build Coastguard Worker 'Proxy iteration errors') 335*da0073e9SAndroid Build Coastguard Worker 336*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 337*da0073e9SAndroid Build Coastguard Worker def keys(self, obj: 'Proxy') -> Any: 338*da0073e9SAndroid Build Coastguard Worker """Called when a proxy object is has the keys() method called. 339*da0073e9SAndroid Build Coastguard Worker This is what happens when ** is called on a proxy. This should return an 340*da0073e9SAndroid Build Coastguard Worker iterator it ** is suppose to work in your custom tracer. 341*da0073e9SAndroid Build Coastguard Worker """ 342*da0073e9SAndroid Build Coastguard Worker return Attribute(obj, 'keys')() 343*da0073e9SAndroid Build Coastguard Worker 344*da0073e9SAndroid Build Coastguard Worker 345*da0073e9SAndroid Build Coastguard Worker# used in Proxy object when just appending to the graph while not tracing. 346*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=True) 347*da0073e9SAndroid Build Coastguard Workerclass GraphAppendingTracer(TracerBase): 348*da0073e9SAndroid Build Coastguard Worker def __init__(self, graph: Graph): 349*da0073e9SAndroid Build Coastguard Worker super().__init__() 350*da0073e9SAndroid Build Coastguard Worker self.graph = graph 351*da0073e9SAndroid Build Coastguard Worker self.scope = Scope("", None) 352*da0073e9SAndroid Build Coastguard Worker self.module_stack = collections.OrderedDict() 353*da0073e9SAndroid Build Coastguard Worker self.node_name_to_scope = {} 354*da0073e9SAndroid Build Coastguard Worker 355*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=False) 356*da0073e9SAndroid Build Coastguard Workerdef assert_fn(x): 357*da0073e9SAndroid Build Coastguard Worker assert x 358*da0073e9SAndroid Build Coastguard Worker 359*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=True) 360*da0073e9SAndroid Build Coastguard Workerclass TraceError(ValueError): 361*da0073e9SAndroid Build Coastguard Worker pass 362*da0073e9SAndroid Build Coastguard Worker 363*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=True) 364*da0073e9SAndroid Build Coastguard Workerclass Proxy: 365*da0073e9SAndroid Build Coastguard Worker """ 366*da0073e9SAndroid Build Coastguard Worker ``Proxy`` objects are ``Node`` wrappers that flow through the 367*da0073e9SAndroid Build Coastguard Worker program during symbolic tracing and record all the operations 368*da0073e9SAndroid Build Coastguard Worker (``torch`` function calls, method calls, operators) that they touch 369*da0073e9SAndroid Build Coastguard Worker into the growing FX Graph. 370*da0073e9SAndroid Build Coastguard Worker 371*da0073e9SAndroid Build Coastguard Worker If you're doing graph transforms, you can wrap your own ``Proxy`` 372*da0073e9SAndroid Build Coastguard Worker method around a raw ``Node`` so that you can use the overloaded 373*da0073e9SAndroid Build Coastguard Worker operators to add additional things to a ``Graph``. 374*da0073e9SAndroid Build Coastguard Worker 375*da0073e9SAndroid Build Coastguard Worker ``Proxy`` objects cannot be iterated. In other words, the symbolic 376*da0073e9SAndroid Build Coastguard Worker tracer will throw an error if a ``Proxy`` is used in a loop or as 377*da0073e9SAndroid Build Coastguard Worker an ``*args``/``**kwargs`` function argument. 378*da0073e9SAndroid Build Coastguard Worker 379*da0073e9SAndroid Build Coastguard Worker There are two main ways around this: 380*da0073e9SAndroid Build Coastguard Worker 1. Factor out the untraceable logic into a top-level function and 381*da0073e9SAndroid Build Coastguard Worker use ``fx.wrap`` on it. 382*da0073e9SAndroid Build Coastguard Worker 2. If the control flow is static (i.e. the loop trip count is 383*da0073e9SAndroid Build Coastguard Worker based on some hyperparameter), the code can be kept in its original 384*da0073e9SAndroid Build Coastguard Worker position and refactored into something like:: 385*da0073e9SAndroid Build Coastguard Worker 386*da0073e9SAndroid Build Coastguard Worker for i in range(self.some_hyperparameter): 387*da0073e9SAndroid Build Coastguard Worker indexed_item = proxied_value[i] 388*da0073e9SAndroid Build Coastguard Worker 389*da0073e9SAndroid Build Coastguard Worker For a more detailed description into the Proxy internals, check out 390*da0073e9SAndroid Build Coastguard Worker the "Proxy" section in `torch/fx/README.md` 391*da0073e9SAndroid Build Coastguard Worker """ 392*da0073e9SAndroid Build Coastguard Worker 393*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 394*da0073e9SAndroid Build Coastguard Worker def __init__(self, node: Node, tracer: 'Optional[TracerBase]' = None): 395*da0073e9SAndroid Build Coastguard Worker if tracer is None: 396*da0073e9SAndroid Build Coastguard Worker # This allows you to create a Proxy object around a raw Node 397*da0073e9SAndroid Build Coastguard Worker tracer = GraphAppendingTracer(node.graph) 398*da0073e9SAndroid Build Coastguard Worker self.tracer = tracer 399*da0073e9SAndroid Build Coastguard Worker self.node = node 400*da0073e9SAndroid Build Coastguard Worker 401*da0073e9SAndroid Build Coastguard Worker def __repr__(self) -> str: 402*da0073e9SAndroid Build Coastguard Worker return f'Proxy({self.node.name})' 403*da0073e9SAndroid Build Coastguard Worker 404*da0073e9SAndroid Build Coastguard Worker def __getattr__(self, k) -> 'Attribute': 405*da0073e9SAndroid Build Coastguard Worker # note: not added to the graph yet, if this is a method call 406*da0073e9SAndroid Build Coastguard Worker # we peephole optimize to the method invocation 407*da0073e9SAndroid Build Coastguard Worker return Attribute(self, k) 408*da0073e9SAndroid Build Coastguard Worker 409*da0073e9SAndroid Build Coastguard Worker def __getstate__(self) -> Dict: 410*da0073e9SAndroid Build Coastguard Worker return self.__dict__ 411*da0073e9SAndroid Build Coastguard Worker 412*da0073e9SAndroid Build Coastguard Worker def __deepcopy__(self, memo) -> Dict: 413*da0073e9SAndroid Build Coastguard Worker # We have to explicitly override this method, because otherwise deepcopy 414*da0073e9SAndroid Build Coastguard Worker # will go to __getattr__(self, "__deepcopy__") and return a 415*da0073e9SAndroid Build Coastguard Worker # Attribute(__deepcopy__), and may go into an infinite loop in some cases. 416*da0073e9SAndroid Build Coastguard Worker import copy 417*da0073e9SAndroid Build Coastguard Worker new_dict = {} 418*da0073e9SAndroid Build Coastguard Worker for k, v in self.__dict__.items(): 419*da0073e9SAndroid Build Coastguard Worker try: 420*da0073e9SAndroid Build Coastguard Worker new_obj = copy.deepcopy(v, memo) 421*da0073e9SAndroid Build Coastguard Worker except Exception: 422*da0073e9SAndroid Build Coastguard Worker log.warning( 423*da0073e9SAndroid Build Coastguard Worker "Shallow copy %s of Proxy because it cannot be deepcopied. " 424*da0073e9SAndroid Build Coastguard Worker "Proxy is created for node %s", k, self.node.name) 425*da0073e9SAndroid Build Coastguard Worker new_obj = copy.copy(v) 426*da0073e9SAndroid Build Coastguard Worker new_dict[k] = new_obj 427*da0073e9SAndroid Build Coastguard Worker assert "node" in new_dict 428*da0073e9SAndroid Build Coastguard Worker assert "tracer" in new_dict 429*da0073e9SAndroid Build Coastguard Worker new_proxy = Proxy(new_dict["node"], new_dict["tracer"]) 430*da0073e9SAndroid Build Coastguard Worker for k, v in new_dict.items(): 431*da0073e9SAndroid Build Coastguard Worker new_proxy.__dict__[k] = v 432*da0073e9SAndroid Build Coastguard Worker return new_proxy 433*da0073e9SAndroid Build Coastguard Worker 434*da0073e9SAndroid Build Coastguard Worker def __setstate__(self, d): 435*da0073e9SAndroid Build Coastguard Worker # This is called when being unpickled/loaded. 436*da0073e9SAndroid Build Coastguard Worker self.__dict__ = d 437*da0073e9SAndroid Build Coastguard Worker 438*da0073e9SAndroid Build Coastguard Worker def __call__(self, *args, **kwargs) -> 'Proxy': 439*da0073e9SAndroid Build Coastguard Worker return self.tracer.create_proxy('call_method', '__call__', (self,) + args, kwargs) 440*da0073e9SAndroid Build Coastguard Worker 441*da0073e9SAndroid Build Coastguard Worker def __iter__(self) -> Iterator['Proxy']: 442*da0073e9SAndroid Build Coastguard Worker frame = inspect.currentframe() 443*da0073e9SAndroid Build Coastguard Worker assert frame is not None 444*da0073e9SAndroid Build Coastguard Worker calling_frame = frame.f_back 445*da0073e9SAndroid Build Coastguard Worker assert calling_frame is not None 446*da0073e9SAndroid Build Coastguard Worker inst_list = list(dis.get_instructions(calling_frame.f_code)) 447*da0073e9SAndroid Build Coastguard Worker if sys.version_info >= (3, 11): 448*da0073e9SAndroid Build Coastguard Worker from bisect import bisect_left 449*da0073e9SAndroid Build Coastguard Worker inst_idx = bisect_left(inst_list, calling_frame.f_lasti, key=lambda x: x.offset) 450*da0073e9SAndroid Build Coastguard Worker else: 451*da0073e9SAndroid Build Coastguard Worker inst_idx = calling_frame.f_lasti // 2 452*da0073e9SAndroid Build Coastguard Worker inst = inst_list[inst_idx] 453*da0073e9SAndroid Build Coastguard Worker if inst.opname == 'UNPACK_SEQUENCE': 454*da0073e9SAndroid Build Coastguard Worker return (self[i] for i in range(inst.argval)) # type: ignore[index] 455*da0073e9SAndroid Build Coastguard Worker 456*da0073e9SAndroid Build Coastguard Worker return self.tracer.iter(self) 457*da0073e9SAndroid Build Coastguard Worker 458*da0073e9SAndroid Build Coastguard Worker def __abs__(self): 459*da0073e9SAndroid Build Coastguard Worker return self.tracer.create_proxy('call_function', operator.abs, (self,), {}) 460*da0073e9SAndroid Build Coastguard Worker 461*da0073e9SAndroid Build Coastguard Worker def __bool__(self) -> bool: 462*da0073e9SAndroid Build Coastguard Worker if self.tracer.trace_asserts: 463*da0073e9SAndroid Build Coastguard Worker # check if this boolean is used in an assertion, bytecode pattern for assertions 464*da0073e9SAndroid Build Coastguard Worker # is pretty stable for Python 3.7--3.9 465*da0073e9SAndroid Build Coastguard Worker frame = inspect.currentframe() 466*da0073e9SAndroid Build Coastguard Worker assert frame is not None 467*da0073e9SAndroid Build Coastguard Worker calling_frame = frame.f_back 468*da0073e9SAndroid Build Coastguard Worker assert calling_frame is not None 469*da0073e9SAndroid Build Coastguard Worker insts = list(dis.get_instructions(calling_frame.f_code)) 470*da0073e9SAndroid Build Coastguard Worker if sys.version_info >= (3, 11): 471*da0073e9SAndroid Build Coastguard Worker from bisect import bisect_left 472*da0073e9SAndroid Build Coastguard Worker cur = bisect_left(insts, calling_frame.f_lasti, key=lambda x: x.offset) 473*da0073e9SAndroid Build Coastguard Worker else: 474*da0073e9SAndroid Build Coastguard Worker cur = calling_frame.f_lasti // 2 475*da0073e9SAndroid Build Coastguard Worker inst = insts[cur] 476*da0073e9SAndroid Build Coastguard Worker 477*da0073e9SAndroid Build Coastguard Worker if inst.opname == 'POP_JUMP_IF_TRUE': 478*da0073e9SAndroid Build Coastguard Worker first = insts[cur + 1] 479*da0073e9SAndroid Build Coastguard Worker assert inst.arg is not None 480*da0073e9SAndroid Build Coastguard Worker last = insts[inst.arg // 2 - 1] 481*da0073e9SAndroid Build Coastguard Worker starts_with_assert = (first.opname == 'LOAD_GLOBAL' and first.argval == 'AssertionError' 482*da0073e9SAndroid Build Coastguard Worker or first.opname == 'LOAD_ASSERTION_ERROR') 483*da0073e9SAndroid Build Coastguard Worker if starts_with_assert and last.opname == 'RAISE_VARARGS': 484*da0073e9SAndroid Build Coastguard Worker self.tracer.create_proxy('call_function', assert_fn, (self,), {}) 485*da0073e9SAndroid Build Coastguard Worker return True 486*da0073e9SAndroid Build Coastguard Worker 487*da0073e9SAndroid Build Coastguard Worker return self.tracer.to_bool(self) 488*da0073e9SAndroid Build Coastguard Worker 489*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 490*da0073e9SAndroid Build Coastguard Worker def keys(self): 491*da0073e9SAndroid Build Coastguard Worker return self.tracer.keys(self) 492*da0073e9SAndroid Build Coastguard Worker 493*da0073e9SAndroid Build Coastguard Worker def __len__(self): 494*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want " 495*da0073e9SAndroid Build Coastguard Worker "this call to be recorded, please call torch.fx.wrap('len') at " 496*da0073e9SAndroid Build Coastguard Worker "module scope") 497*da0073e9SAndroid Build Coastguard Worker 498*da0073e9SAndroid Build Coastguard Worker @classmethod 499*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, orig_method, types, args=None, kwargs=None): 500*da0073e9SAndroid Build Coastguard Worker args = args if args else () 501*da0073e9SAndroid Build Coastguard Worker kwargs = kwargs if kwargs else {} 502*da0073e9SAndroid Build Coastguard Worker 503*da0073e9SAndroid Build Coastguard Worker tracers : Dict[Any, None] = {} 504*da0073e9SAndroid Build Coastguard Worker 505*da0073e9SAndroid Build Coastguard Worker def find_tracer(a): 506*da0073e9SAndroid Build Coastguard Worker if isinstance(a, cls): 507*da0073e9SAndroid Build Coastguard Worker tracers[a.tracer] = None 508*da0073e9SAndroid Build Coastguard Worker torch.fx.node.map_aggregate(args, find_tracer) 509*da0073e9SAndroid Build Coastguard Worker torch.fx.node.map_aggregate(kwargs, find_tracer) 510*da0073e9SAndroid Build Coastguard Worker 511*da0073e9SAndroid Build Coastguard Worker if len(tracers) > 1: 512*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f'Found multiple different tracers {list(tracers.keys())} while ' 513*da0073e9SAndroid Build Coastguard Worker f'trying to trace operations {orig_method}') 514*da0073e9SAndroid Build Coastguard Worker tracer = next(iter(tracers.keys())) 515*da0073e9SAndroid Build Coastguard Worker 516*da0073e9SAndroid Build Coastguard Worker if isinstance(orig_method, torch._C.ScriptMethod): 517*da0073e9SAndroid Build Coastguard Worker args = (orig_method.owner,) + args 518*da0073e9SAndroid Build Coastguard Worker return tracer.create_proxy('call_method', orig_method.name, args, kwargs) 519*da0073e9SAndroid Build Coastguard Worker if torch.overrides.is_tensor_method_or_property(orig_method): 520*da0073e9SAndroid Build Coastguard Worker return tracer.create_proxy('call_method', orig_method.__name__, args, kwargs) 521*da0073e9SAndroid Build Coastguard Worker else: 522*da0073e9SAndroid Build Coastguard Worker if isinstance(orig_method, torch._ops.HigherOrderOperator): 523*da0073e9SAndroid Build Coastguard Worker # TODO: Define how to symbolically trace HigherOrderOperators 524*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Unable to symbolically trace HigherOrderOperators") 525*da0073e9SAndroid Build Coastguard Worker return tracer.create_proxy('call_function', orig_method, args, kwargs, 526*da0073e9SAndroid Build Coastguard Worker name=tracer.graph._target_to_str(orig_method.__name__)) 527*da0073e9SAndroid Build Coastguard Worker 528*da0073e9SAndroid Build Coastguard Worker 529*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=True) 530*da0073e9SAndroid Build Coastguard Workerclass Attribute(Proxy): 531*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 532*da0073e9SAndroid Build Coastguard Worker def __init__(self, root: Proxy, attr: str): 533*da0073e9SAndroid Build Coastguard Worker self.root = root 534*da0073e9SAndroid Build Coastguard Worker self.attr = attr 535*da0073e9SAndroid Build Coastguard Worker self.tracer = root.tracer 536*da0073e9SAndroid Build Coastguard Worker self._node: Optional[Node] = None 537*da0073e9SAndroid Build Coastguard Worker 538*da0073e9SAndroid Build Coastguard Worker @property 539*da0073e9SAndroid Build Coastguard Worker def node(self): 540*da0073e9SAndroid Build Coastguard Worker # the node for attributes is added lazily, since most will just be method calls 541*da0073e9SAndroid Build Coastguard Worker # which do not rely on the getitem call 542*da0073e9SAndroid Build Coastguard Worker if self._node is None: 543*da0073e9SAndroid Build Coastguard Worker self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node 544*da0073e9SAndroid Build Coastguard Worker return self._node 545*da0073e9SAndroid Build Coastguard Worker 546*da0073e9SAndroid Build Coastguard Worker def __call__(self, *args, **kwargs): 547*da0073e9SAndroid Build Coastguard Worker return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs) 548*da0073e9SAndroid Build Coastguard Worker 549*da0073e9SAndroid Build Coastguard Worker 550*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=False) 551*da0073e9SAndroid Build Coastguard Workerclass ParameterProxy(Proxy): 552*da0073e9SAndroid Build Coastguard Worker """ 553*da0073e9SAndroid Build Coastguard Worker A special proxy which lets "shape", "size", "dim", and a few other 554*da0073e9SAndroid Build Coastguard Worker attribute accesses pass through to the underlying module parameter object, 555*da0073e9SAndroid Build Coastguard Worker so that conditional tests on these attributes will not throw exception during tracing 556*da0073e9SAndroid Build Coastguard Worker """ 557*da0073e9SAndroid Build Coastguard Worker def __init__(self, tracer: TracerBase, node: Node, name, param): 558*da0073e9SAndroid Build Coastguard Worker super().__init__(node, tracer) 559*da0073e9SAndroid Build Coastguard Worker assert isinstance(param, torch.nn.Parameter) 560*da0073e9SAndroid Build Coastguard Worker self.param = param 561*da0073e9SAndroid Build Coastguard Worker self.name = name 562*da0073e9SAndroid Build Coastguard Worker 563*da0073e9SAndroid Build Coastguard Worker def __repr__(self) -> str: 564*da0073e9SAndroid Build Coastguard Worker return f'ParameterProxy({self.name})' 565*da0073e9SAndroid Build Coastguard Worker 566*da0073e9SAndroid Build Coastguard Worker @property 567*da0073e9SAndroid Build Coastguard Worker def shape(self): 568*da0073e9SAndroid Build Coastguard Worker return self.param.shape 569*da0073e9SAndroid Build Coastguard Worker 570*da0073e9SAndroid Build Coastguard Worker def size(self): 571*da0073e9SAndroid Build Coastguard Worker return self.param.size() 572*da0073e9SAndroid Build Coastguard Worker 573*da0073e9SAndroid Build Coastguard Worker def dim(self): 574*da0073e9SAndroid Build Coastguard Worker return self.param.dim() 575*da0073e9SAndroid Build Coastguard Worker 576*da0073e9SAndroid Build Coastguard Worker @property 577*da0073e9SAndroid Build Coastguard Worker def ndim(self): 578*da0073e9SAndroid Build Coastguard Worker return self.param.ndim 579*da0073e9SAndroid Build Coastguard Worker 580*da0073e9SAndroid Build Coastguard Worker def numel(self): 581*da0073e9SAndroid Build Coastguard Worker return self.param.numel() 582*da0073e9SAndroid Build Coastguard Worker 583*da0073e9SAndroid Build Coastguard Worker def nelement(self): 584*da0073e9SAndroid Build Coastguard Worker return self.param.nelement() 585*da0073e9SAndroid Build Coastguard Worker 586*da0073e9SAndroid Build Coastguard Worker 587*da0073e9SAndroid Build Coastguard Workerfor method in magic_methods: 588*da0073e9SAndroid Build Coastguard Worker def _scope(method): 589*da0073e9SAndroid Build Coastguard Worker def impl(*args, **kwargs): 590*da0073e9SAndroid Build Coastguard Worker tracer = args[0].tracer 591*da0073e9SAndroid Build Coastguard Worker target = getattr(operator, method) 592*da0073e9SAndroid Build Coastguard Worker return tracer.create_proxy('call_function', target, args, kwargs) 593*da0073e9SAndroid Build Coastguard Worker impl.__name__ = method 594*da0073e9SAndroid Build Coastguard Worker as_magic = f'__{method.strip("_")}__' 595*da0073e9SAndroid Build Coastguard Worker setattr(Proxy, as_magic, impl) 596*da0073e9SAndroid Build Coastguard Worker _scope(method) 597*da0073e9SAndroid Build Coastguard Worker 598*da0073e9SAndroid Build Coastguard Workerdef _define_reflectable(orig_method_name): 599*da0073e9SAndroid Build Coastguard Worker method_name = f'__r{orig_method_name.strip("_")}__' 600*da0073e9SAndroid Build Coastguard Worker 601*da0073e9SAndroid Build Coastguard Worker def impl(self, rhs): 602*da0073e9SAndroid Build Coastguard Worker target = getattr(operator, orig_method_name) 603*da0073e9SAndroid Build Coastguard Worker return self.tracer.create_proxy('call_function', target, (rhs, self), {}) 604*da0073e9SAndroid Build Coastguard Worker impl.__name__ = method_name 605*da0073e9SAndroid Build Coastguard Worker impl.__qualname__ = method_name 606*da0073e9SAndroid Build Coastguard Worker setattr(Proxy, method_name, impl) 607*da0073e9SAndroid Build Coastguard Worker 608*da0073e9SAndroid Build Coastguard Workerfor orig_method_name in reflectable_magic_methods: 609*da0073e9SAndroid Build Coastguard Worker _define_reflectable(orig_method_name) 610