xref: /aosp_15_r20/external/pytorch/torch/fx/proxy.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: ignore-errors
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport enum
4*da0073e9SAndroid Build Coastguard Workerimport dis
5*da0073e9SAndroid Build Coastguard Workerimport copy
6*da0073e9SAndroid Build Coastguard Workerimport sys
7*da0073e9SAndroid Build Coastguard Workerimport torch
8*da0073e9SAndroid Build Coastguard Workerimport inspect
9*da0073e9SAndroid Build Coastguard Workerimport operator
10*da0073e9SAndroid Build Coastguard Workerimport collections
11*da0073e9SAndroid Build Coastguard Workerimport logging
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Workerfrom dataclasses import is_dataclass, fields
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Workerfrom .graph import magic_methods, reflectable_magic_methods, Graph
17*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._traceback import CapturedTraceback
18*da0073e9SAndroid Build Coastguard Workerfrom typing import Tuple, Dict, OrderedDict, Optional, Any, Iterator, Callable
19*da0073e9SAndroid Build Coastguard Workerfrom .node import Target, Node, Argument, base_types, map_aggregate
20*da0073e9SAndroid Build Coastguard Workerfrom ._compatibility import compatibility
21*da0073e9SAndroid Build Coastguard Workerfrom .operator_schemas import check_for_mutable_operation
22*da0073e9SAndroid Build Coastguard Workerimport torch.fx.traceback as fx_traceback
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker__all__ = ['TracerBase', 'GraphAppendingTracer', 'TraceError',
25*da0073e9SAndroid Build Coastguard Worker           'Proxy', 'Attribute', 'ParameterProxy', 'Scope',
26*da0073e9SAndroid Build Coastguard Worker           'ScopeContextManager']
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Workerlog = logging.getLogger(__name__)
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=False)
33*da0073e9SAndroid Build Coastguard Workerclass Scope:
34*da0073e9SAndroid Build Coastguard Worker    """ Scope object that records the module path and the module type
35*da0073e9SAndroid Build Coastguard Worker    of a module. Scope is used to track the information of the module
36*da0073e9SAndroid Build Coastguard Worker    that contains a Node in a Graph of GraphModule. For example::
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker        class Sub(torch.nn.Module):
39*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
40*da0073e9SAndroid Build Coastguard Worker                # This will be a call_method Node in GraphModule,
41*da0073e9SAndroid Build Coastguard Worker                # scope for this would be (module_path="sub", module_type=Sub)
42*da0073e9SAndroid Build Coastguard Worker                return x.transpose(1, 2)
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
45*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
46*da0073e9SAndroid Build Coastguard Worker                self.sub = Sub()
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
49*da0073e9SAndroid Build Coastguard Worker                # This will be a call_method Node as well,
50*da0073e9SAndroid Build Coastguard Worker                # scope for this would be (module_path="", None)
51*da0073e9SAndroid Build Coastguard Worker                x = x.transpose(1, 2)
52*da0073e9SAndroid Build Coastguard Worker                x = self.sub(x)
53*da0073e9SAndroid Build Coastguard Worker                return x
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker    """
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker    def __init__(self, module_path: str, module_type: Any):
58*da0073e9SAndroid Build Coastguard Worker        super().__init__()
59*da0073e9SAndroid Build Coastguard Worker        self.module_path = module_path
60*da0073e9SAndroid Build Coastguard Worker        self.module_type = module_type
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=False)
64*da0073e9SAndroid Build Coastguard Workerclass ScopeContextManager:
65*da0073e9SAndroid Build Coastguard Worker    """ A context manager to track the Scope of Node during symbolic tracing.
66*da0073e9SAndroid Build Coastguard Worker    When entering a forward function of a Module, we'll update the scope information of
67*da0073e9SAndroid Build Coastguard Worker    the current module, and when we exit, we'll restore the previous scope information.
68*da0073e9SAndroid Build Coastguard Worker    """
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker    def __init__(
71*da0073e9SAndroid Build Coastguard Worker        self,
72*da0073e9SAndroid Build Coastguard Worker        scope: Scope,
73*da0073e9SAndroid Build Coastguard Worker        current_scope: Scope,
74*da0073e9SAndroid Build Coastguard Worker    ):
75*da0073e9SAndroid Build Coastguard Worker        super().__init__()
76*da0073e9SAndroid Build Coastguard Worker        # Keep a copy of prev scope to restore on exit
77*da0073e9SAndroid Build Coastguard Worker        self._prev_scope = copy.copy(scope)
78*da0073e9SAndroid Build Coastguard Worker        # Update scope to current scope
79*da0073e9SAndroid Build Coastguard Worker        scope.module_path = current_scope.module_path
80*da0073e9SAndroid Build Coastguard Worker        scope.module_type = current_scope.module_type
81*da0073e9SAndroid Build Coastguard Worker        # Save a reference so we can restore it
82*da0073e9SAndroid Build Coastguard Worker        self._scope = scope
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker    def __enter__(self):
85*da0073e9SAndroid Build Coastguard Worker        return self._scope
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker    def __exit__(self, *args):
88*da0073e9SAndroid Build Coastguard Worker        self._scope.module_path = self._prev_scope.module_path
89*da0073e9SAndroid Build Coastguard Worker        self._scope.module_type = self._prev_scope.module_type
90*da0073e9SAndroid Build Coastguard Worker        return
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker_COPY_META_FIELDS = [
94*da0073e9SAndroid Build Coastguard Worker    "nn_module_stack",
95*da0073e9SAndroid Build Coastguard Worker    "torch_fn",
96*da0073e9SAndroid Build Coastguard Worker    "source_fn_stack",
97*da0073e9SAndroid Build Coastguard Worker    "original_aten",
98*da0073e9SAndroid Build Coastguard Worker    "recompute",
99*da0073e9SAndroid Build Coastguard Worker    "ac_graph_id",
100*da0073e9SAndroid Build Coastguard Worker    "from_node",
101*da0073e9SAndroid Build Coastguard Worker    "quantization_tag",  # TODO deprecated
102*da0073e9SAndroid Build Coastguard Worker    "_numeric_debug_handle",  # TODO deprecated
103*da0073e9SAndroid Build Coastguard Worker    "custom",
104*da0073e9SAndroid Build Coastguard Worker    "partitioner_tag"
105*da0073e9SAndroid Build Coastguard Worker]
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=True)
109*da0073e9SAndroid Build Coastguard Workerclass TracerBase:
110*da0073e9SAndroid Build Coastguard Worker    graph: Graph
111*da0073e9SAndroid Build Coastguard Worker    record_stack_traces : bool = False
112*da0073e9SAndroid Build Coastguard Worker    # Feature flag for mutable schema checking
113*da0073e9SAndroid Build Coastguard Worker    # Enableby default in 1.12
114*da0073e9SAndroid Build Coastguard Worker    check_mutable_operations : bool = False
115*da0073e9SAndroid Build Coastguard Worker    # Feature flag for assert tracing
116*da0073e9SAndroid Build Coastguard Worker    trace_asserts : bool = False
117*da0073e9SAndroid Build Coastguard Worker    # Feature flag for proxying accesses to buffer values
118*da0073e9SAndroid Build Coastguard Worker    proxy_buffer_attributes : bool = False
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker    # Name of the function to be traced. It will only be used when
121*da0073e9SAndroid Build Coastguard Worker    # ``root`` is an instance of ``nn.Module``
122*da0073e9SAndroid Build Coastguard Worker    traced_func_name: str = "forward"
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker    # Maps the containing module's name to the operator name
125*da0073e9SAndroid Build Coastguard Worker    scope : Scope
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker    # Records the module call stack
128*da0073e9SAndroid Build Coastguard Worker    module_stack: OrderedDict[str, Tuple[str, Any]]
129*da0073e9SAndroid Build Coastguard Worker
130*da0073e9SAndroid Build Coastguard Worker    # Mapping of node name to module scope
131*da0073e9SAndroid Build Coastguard Worker    node_name_to_scope: Dict[str, Tuple[str, type]]
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
134*da0073e9SAndroid Build Coastguard Worker    def create_node(self, kind : str, target : Target,
135*da0073e9SAndroid Build Coastguard Worker                    args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None,
136*da0073e9SAndroid Build Coastguard Worker                    type_expr : Optional[Any] = None) -> Node:
137*da0073e9SAndroid Build Coastguard Worker        """
138*da0073e9SAndroid Build Coastguard Worker        Inserts a graph node given target, args, kwargs, and name.
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker        This method can be overridden to do extra checking, validation, or
141*da0073e9SAndroid Build Coastguard Worker        modification of values used in node creation. For example, one might
142*da0073e9SAndroid Build Coastguard Worker        want to disallow in-place operations from being recorded.
143*da0073e9SAndroid Build Coastguard Worker        """
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Worker        if kind == 'call_function' and self.check_mutable_operations:
146*da0073e9SAndroid Build Coastguard Worker            check_for_mutable_operation(target, args, kwargs)
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker        node = self.graph.create_node(kind, target, args, kwargs, name, type_expr)
149*da0073e9SAndroid Build Coastguard Worker        # TODO node_name_to_scope will be depreciated in favor of
150*da0073e9SAndroid Build Coastguard Worker        # node.meta['nn_module_stack']
151*da0073e9SAndroid Build Coastguard Worker        self.node_name_to_scope[node.name] = (
152*da0073e9SAndroid Build Coastguard Worker            self.scope.module_path,
153*da0073e9SAndroid Build Coastguard Worker            self.scope.module_type,
154*da0073e9SAndroid Build Coastguard Worker        )
155*da0073e9SAndroid Build Coastguard Worker        # Optionally set stack trace on the created Node for debugging purposes
156*da0073e9SAndroid Build Coastguard Worker        if fx_traceback.has_preserved_node_meta():
157*da0073e9SAndroid Build Coastguard Worker            current_meta: Dict[str, Any] = fx_traceback.get_current_meta()
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Worker            stack_trace = current_meta.get("stack_trace")
160*da0073e9SAndroid Build Coastguard Worker            if stack_trace:
161*da0073e9SAndroid Build Coastguard Worker                node.stack_trace = stack_trace
162*da0073e9SAndroid Build Coastguard Worker            # Explicitly set the stack_trace, nn_module_stack and source_fn on the node.meta
163*da0073e9SAndroid Build Coastguard Worker            # If other meta fields are needed, they can be added here
164*da0073e9SAndroid Build Coastguard Worker            for field in _COPY_META_FIELDS:
165*da0073e9SAndroid Build Coastguard Worker                if field in current_meta:
166*da0073e9SAndroid Build Coastguard Worker                    node.meta[field] = copy.copy(current_meta[field])
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker            # Here we decrement to account for the sequence_nr having
169*da0073e9SAndroid Build Coastguard Worker            # just been incremented while tracing this lowered aten op.
170*da0073e9SAndroid Build Coastguard Worker            new_seq_nr = torch.autograd._get_sequence_nr() - 1
171*da0073e9SAndroid Build Coastguard Worker            # The sequence_nr increments every time a new autograd Node
172*da0073e9SAndroid Build Coastguard Worker            # is created. During the FWD pass we store the sequence_nr
173*da0073e9SAndroid Build Coastguard Worker            # corresponding to the last autograd Node created on this fx
174*da0073e9SAndroid Build Coastguard Worker            # node's meta.  A single aten op can create multiple autograd
175*da0073e9SAndroid Build Coastguard Worker            # nodes as is the case with in-place foreach ops. During the
176*da0073e9SAndroid Build Coastguard Worker            # BWD pass we retrieve the sequence_nr stored on the current
177*da0073e9SAndroid Build Coastguard Worker            # executing autograd Node. See NOTE [ Sequence Number ].
178*da0073e9SAndroid Build Coastguard Worker            if current_meta.get("in_grad_fn", 0) > 0:
179*da0073e9SAndroid Build Coastguard Worker                new_seq_nr = current_meta["grad_fn_seq_nr"][-1]
180*da0073e9SAndroid Build Coastguard Worker            node.meta["seq_nr"] = new_seq_nr
181*da0073e9SAndroid Build Coastguard Worker
182*da0073e9SAndroid Build Coastguard Worker        elif self.module_stack:
183*da0073e9SAndroid Build Coastguard Worker            node.meta['nn_module_stack'] = copy.copy(self.module_stack)
184*da0073e9SAndroid Build Coastguard Worker
185*da0073e9SAndroid Build Coastguard Worker        log.debug("create_node %s", node)
186*da0073e9SAndroid Build Coastguard Worker        return node
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
189*da0073e9SAndroid Build Coastguard Worker    def proxy(self, node: Node) -> 'Proxy':
190*da0073e9SAndroid Build Coastguard Worker        return Proxy(node, self)
191*da0073e9SAndroid Build Coastguard Worker
192*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
193*da0073e9SAndroid Build Coastguard Worker    def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any],
194*da0073e9SAndroid Build Coastguard Worker                     name: Optional[str] = None, type_expr : Optional[Any] = None,
195*da0073e9SAndroid Build Coastguard Worker                     proxy_factory_fn: Callable[[Node], 'Proxy'] = None):
196*da0073e9SAndroid Build Coastguard Worker        '''
197*da0073e9SAndroid Build Coastguard Worker        Create a Node from the given arguments, then return the Node
198*da0073e9SAndroid Build Coastguard Worker        wrapped in a Proxy object.
199*da0073e9SAndroid Build Coastguard Worker
200*da0073e9SAndroid Build Coastguard Worker        If kind = 'placeholder', then we're creating a Node that
201*da0073e9SAndroid Build Coastguard Worker        represents the parameter of a function. If we need to encode
202*da0073e9SAndroid Build Coastguard Worker        a default parameter, we use the ``args`` tuple. ``args`` is
203*da0073e9SAndroid Build Coastguard Worker        otherwise empty for ``placeholder`` Nodes.
204*da0073e9SAndroid Build Coastguard Worker        '''
205*da0073e9SAndroid Build Coastguard Worker
206*da0073e9SAndroid Build Coastguard Worker        args_ = self.create_arg(args)
207*da0073e9SAndroid Build Coastguard Worker        kwargs_ = self.create_arg(kwargs)
208*da0073e9SAndroid Build Coastguard Worker        assert isinstance(args_, tuple)
209*da0073e9SAndroid Build Coastguard Worker        assert isinstance(kwargs_, dict)
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker        node = self.create_node(kind, target, args_, kwargs_, name, type_expr)
212*da0073e9SAndroid Build Coastguard Worker
213*da0073e9SAndroid Build Coastguard Worker        if not proxy_factory_fn:
214*da0073e9SAndroid Build Coastguard Worker            proxy = self.proxy(node)
215*da0073e9SAndroid Build Coastguard Worker        else:
216*da0073e9SAndroid Build Coastguard Worker            proxy = proxy_factory_fn(node)
217*da0073e9SAndroid Build Coastguard Worker
218*da0073e9SAndroid Build Coastguard Worker        if self.record_stack_traces and not proxy.node.stack_trace:
219*da0073e9SAndroid Build Coastguard Worker            proxy.node.stack_trace = ''.join(CapturedTraceback.extract().format())
220*da0073e9SAndroid Build Coastguard Worker
221*da0073e9SAndroid Build Coastguard Worker
222*da0073e9SAndroid Build Coastguard Worker        return proxy
223*da0073e9SAndroid Build Coastguard Worker
224*da0073e9SAndroid Build Coastguard Worker    def _find_user_frame(self):
225*da0073e9SAndroid Build Coastguard Worker        """
226*da0073e9SAndroid Build Coastguard Worker        Find the Python stack frame executing the user code during
227*da0073e9SAndroid Build Coastguard Worker        symbolic tracing.
228*da0073e9SAndroid Build Coastguard Worker        """
229*da0073e9SAndroid Build Coastguard Worker        # We have to do a little dance here. Basically, walk up the callstack and
230*da0073e9SAndroid Build Coastguard Worker        # record the first frame not in the pytorch source. This is the frame executing
231*da0073e9SAndroid Build Coastguard Worker        # the user code during tracing.
232*da0073e9SAndroid Build Coastguard Worker        frame = inspect.currentframe()
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker        pt_files = ['torch/fx/proxy.py',
235*da0073e9SAndroid Build Coastguard Worker                    'torch/fx/_symbolic_trace.py',
236*da0073e9SAndroid Build Coastguard Worker                    'torch/fx/experimental/proxy_tensor.py',
237*da0073e9SAndroid Build Coastguard Worker                    'torch/_ops.py',
238*da0073e9SAndroid Build Coastguard Worker                    'torch/_tensor.py',
239*da0073e9SAndroid Build Coastguard Worker                    'torch/utils/_python_dispatch.py',
240*da0073e9SAndroid Build Coastguard Worker                    'torch/_prims_common/wrappers.py',
241*da0073e9SAndroid Build Coastguard Worker                    'torch/_refs/__init__.py',
242*da0073e9SAndroid Build Coastguard Worker                    'torch/_refs/nn/functional/__init__.py',
243*da0073e9SAndroid Build Coastguard Worker                    'torch/utils/_stats.py',
244*da0073e9SAndroid Build Coastguard Worker                    ]
245*da0073e9SAndroid Build Coastguard Worker        while frame:
246*da0073e9SAndroid Build Coastguard Worker            frame = frame.f_back
247*da0073e9SAndroid Build Coastguard Worker            if frame and all(not frame.f_code.co_filename.endswith(file) for file in pt_files):
248*da0073e9SAndroid Build Coastguard Worker                break
249*da0073e9SAndroid Build Coastguard Worker
250*da0073e9SAndroid Build Coastguard Worker        if not frame:
251*da0073e9SAndroid Build Coastguard Worker            return None
252*da0073e9SAndroid Build Coastguard Worker
253*da0073e9SAndroid Build Coastguard Worker        return frame
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
256*da0073e9SAndroid Build Coastguard Worker    def create_arg(self, a: Any) -> Argument:
257*da0073e9SAndroid Build Coastguard Worker        """
258*da0073e9SAndroid Build Coastguard Worker        A method that lowers the objects seen as arguments during symbolic evaluation
259*da0073e9SAndroid Build Coastguard Worker        into Argument types that can be stored in IR.
260*da0073e9SAndroid Build Coastguard Worker
261*da0073e9SAndroid Build Coastguard Worker        Can be override to support more trace-specific types.
262*da0073e9SAndroid Build Coastguard Worker        """
263*da0073e9SAndroid Build Coastguard Worker        if not isinstance(a, Proxy) and hasattr(a, '__fx_create_arg__'):
264*da0073e9SAndroid Build Coastguard Worker            return a.__fx_create_arg__(self)
265*da0073e9SAndroid Build Coastguard Worker        # aggregates
266*da0073e9SAndroid Build Coastguard Worker        elif isinstance(a, tuple) and hasattr(a, '_fields'):
267*da0073e9SAndroid Build Coastguard Worker            # NamedTuple constructors don't seem to like getting a generator
268*da0073e9SAndroid Build Coastguard Worker            # expression as an argument to their constructor, so build this
269*da0073e9SAndroid Build Coastguard Worker            # intermediate tuple and unpack it into the NamedTuple constructor
270*da0073e9SAndroid Build Coastguard Worker            args = tuple(self.create_arg(elem) for elem in a)
271*da0073e9SAndroid Build Coastguard Worker            return type(a)(*args)  # type: ignore[arg-type]
272*da0073e9SAndroid Build Coastguard Worker        elif isinstance(a, (tuple, list)):
273*da0073e9SAndroid Build Coastguard Worker            return type(a)(self.create_arg(elem) for elem in a)
274*da0073e9SAndroid Build Coastguard Worker        elif isinstance(a, dict):
275*da0073e9SAndroid Build Coastguard Worker            r = {}
276*da0073e9SAndroid Build Coastguard Worker            for k, v in a.items():
277*da0073e9SAndroid Build Coastguard Worker                # Check for invalid dict keys. We do not want a Proxy to appear
278*da0073e9SAndroid Build Coastguard Worker                # anywhere within the key. Since keys can be collection types,
279*da0073e9SAndroid Build Coastguard Worker                # we iterate through the key with map_aggregate
280*da0073e9SAndroid Build Coastguard Worker                k = self.create_arg(k)
281*da0073e9SAndroid Build Coastguard Worker
282*da0073e9SAndroid Build Coastguard Worker                def no_node(arg):
283*da0073e9SAndroid Build Coastguard Worker                    if isinstance(arg, Node):
284*da0073e9SAndroid Build Coastguard Worker                        raise RuntimeError("Keys for dictionaries used as an argument cannot contain a "
285*da0073e9SAndroid Build Coastguard Worker                                           f"Node. Got key: {k}")
286*da0073e9SAndroid Build Coastguard Worker                map_aggregate(k, no_node)
287*da0073e9SAndroid Build Coastguard Worker
288*da0073e9SAndroid Build Coastguard Worker                r[k] = self.create_arg(v)
289*da0073e9SAndroid Build Coastguard Worker            return r
290*da0073e9SAndroid Build Coastguard Worker        elif isinstance(a, slice):
291*da0073e9SAndroid Build Coastguard Worker            return slice(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step))
292*da0073e9SAndroid Build Coastguard Worker
293*da0073e9SAndroid Build Coastguard Worker        elif isinstance(a, range):
294*da0073e9SAndroid Build Coastguard Worker            return range(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step))
295*da0073e9SAndroid Build Coastguard Worker
296*da0073e9SAndroid Build Coastguard Worker        elif isinstance(a, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)):
297*da0073e9SAndroid Build Coastguard Worker            return a
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker        if isinstance(a, Proxy):
300*da0073e9SAndroid Build Coastguard Worker            # base case: we unwrap the Proxy object
301*da0073e9SAndroid Build Coastguard Worker            return a.node
302*da0073e9SAndroid Build Coastguard Worker
303*da0073e9SAndroid Build Coastguard Worker        if is_dataclass(a):
304*da0073e9SAndroid Build Coastguard Worker            kwargs = {field.name: self.create_arg(getattr(a, field.name)) for field in fields(a)}
305*da0073e9SAndroid Build Coastguard Worker            return self.create_node("call_function", a.__class__, (), kwargs)
306*da0073e9SAndroid Build Coastguard Worker
307*da0073e9SAndroid Build Coastguard Worker        elif isinstance(a, (*base_types, enum.Enum)) or a is None or a is ...:
308*da0073e9SAndroid Build Coastguard Worker            return a
309*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError(f"argument of type: {type(a)}")
310*da0073e9SAndroid Build Coastguard Worker
311*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
312*da0073e9SAndroid Build Coastguard Worker    def to_bool(self, obj: 'Proxy') -> bool:
313*da0073e9SAndroid Build Coastguard Worker        """Called when a proxy object is being converted to a boolean, such as
314*da0073e9SAndroid Build Coastguard Worker        when used in control flow.  Normally we don't know what to do because
315*da0073e9SAndroid Build Coastguard Worker        we don't know the value of the proxy, but a custom tracer can attach more
316*da0073e9SAndroid Build Coastguard Worker        information to the graph node using create_node and can choose to return a value.
317*da0073e9SAndroid Build Coastguard Worker        """
318*da0073e9SAndroid Build Coastguard Worker        raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
319*da0073e9SAndroid Build Coastguard Worker
320*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
321*da0073e9SAndroid Build Coastguard Worker    def iter(self, obj: 'Proxy') -> Iterator:
322*da0073e9SAndroid Build Coastguard Worker        """Called when a proxy object is being iterated over, such as
323*da0073e9SAndroid Build Coastguard Worker        when used in control flow.  Normally we don't know what to do because
324*da0073e9SAndroid Build Coastguard Worker        we don't know the value of the proxy, but a custom tracer can attach more
325*da0073e9SAndroid Build Coastguard Worker        information to the graph node using create_node and can choose to return an iterator.
326*da0073e9SAndroid Build Coastguard Worker        """
327*da0073e9SAndroid Build Coastguard Worker        raise TraceError('Proxy object cannot be iterated. This can be '
328*da0073e9SAndroid Build Coastguard Worker                         'attempted when the Proxy is used in a loop or'
329*da0073e9SAndroid Build Coastguard Worker                         ' as a *args or **kwargs function argument. '
330*da0073e9SAndroid Build Coastguard Worker                         'See the torch.fx docs on pytorch.org for a '
331*da0073e9SAndroid Build Coastguard Worker                         'more detailed explanation of what types of '
332*da0073e9SAndroid Build Coastguard Worker                         'control flow can be traced, and check out the'
333*da0073e9SAndroid Build Coastguard Worker                         ' Proxy docstring for help troubleshooting '
334*da0073e9SAndroid Build Coastguard Worker                         'Proxy iteration errors')
335*da0073e9SAndroid Build Coastguard Worker
336*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
337*da0073e9SAndroid Build Coastguard Worker    def keys(self, obj: 'Proxy') -> Any:
338*da0073e9SAndroid Build Coastguard Worker        """Called when a proxy object is has the keys() method called.
339*da0073e9SAndroid Build Coastguard Worker        This is what happens when ** is called on a proxy. This should return an
340*da0073e9SAndroid Build Coastguard Worker        iterator it ** is suppose to work in your custom tracer.
341*da0073e9SAndroid Build Coastguard Worker        """
342*da0073e9SAndroid Build Coastguard Worker        return Attribute(obj, 'keys')()
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Worker
345*da0073e9SAndroid Build Coastguard Worker# used in Proxy object when just appending to the graph while not tracing.
346*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=True)
347*da0073e9SAndroid Build Coastguard Workerclass GraphAppendingTracer(TracerBase):
348*da0073e9SAndroid Build Coastguard Worker    def __init__(self, graph: Graph):
349*da0073e9SAndroid Build Coastguard Worker        super().__init__()
350*da0073e9SAndroid Build Coastguard Worker        self.graph = graph
351*da0073e9SAndroid Build Coastguard Worker        self.scope = Scope("", None)
352*da0073e9SAndroid Build Coastguard Worker        self.module_stack = collections.OrderedDict()
353*da0073e9SAndroid Build Coastguard Worker        self.node_name_to_scope = {}
354*da0073e9SAndroid Build Coastguard Worker
355*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=False)
356*da0073e9SAndroid Build Coastguard Workerdef assert_fn(x):
357*da0073e9SAndroid Build Coastguard Worker    assert x
358*da0073e9SAndroid Build Coastguard Worker
359*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=True)
360*da0073e9SAndroid Build Coastguard Workerclass TraceError(ValueError):
361*da0073e9SAndroid Build Coastguard Worker    pass
362*da0073e9SAndroid Build Coastguard Worker
363*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=True)
364*da0073e9SAndroid Build Coastguard Workerclass Proxy:
365*da0073e9SAndroid Build Coastguard Worker    """
366*da0073e9SAndroid Build Coastguard Worker    ``Proxy`` objects are ``Node`` wrappers that flow through the
367*da0073e9SAndroid Build Coastguard Worker    program during symbolic tracing and record all the operations
368*da0073e9SAndroid Build Coastguard Worker    (``torch`` function calls, method calls, operators) that they touch
369*da0073e9SAndroid Build Coastguard Worker    into the growing FX Graph.
370*da0073e9SAndroid Build Coastguard Worker
371*da0073e9SAndroid Build Coastguard Worker    If you're doing graph transforms, you can wrap your own ``Proxy``
372*da0073e9SAndroid Build Coastguard Worker    method around a raw ``Node`` so that you can use the overloaded
373*da0073e9SAndroid Build Coastguard Worker    operators to add additional things to a ``Graph``.
374*da0073e9SAndroid Build Coastguard Worker
375*da0073e9SAndroid Build Coastguard Worker    ``Proxy`` objects cannot be iterated. In other words, the symbolic
376*da0073e9SAndroid Build Coastguard Worker    tracer will throw an error if a ``Proxy`` is used in a loop or as
377*da0073e9SAndroid Build Coastguard Worker    an ``*args``/``**kwargs`` function argument.
378*da0073e9SAndroid Build Coastguard Worker
379*da0073e9SAndroid Build Coastguard Worker    There are two main ways around this:
380*da0073e9SAndroid Build Coastguard Worker    1. Factor out the untraceable logic into a top-level function and
381*da0073e9SAndroid Build Coastguard Worker    use ``fx.wrap`` on it.
382*da0073e9SAndroid Build Coastguard Worker    2. If the control flow is static (i.e. the loop trip count is
383*da0073e9SAndroid Build Coastguard Worker    based on some hyperparameter), the code can be kept in its original
384*da0073e9SAndroid Build Coastguard Worker    position and refactored into something like::
385*da0073e9SAndroid Build Coastguard Worker
386*da0073e9SAndroid Build Coastguard Worker        for i in range(self.some_hyperparameter):
387*da0073e9SAndroid Build Coastguard Worker            indexed_item = proxied_value[i]
388*da0073e9SAndroid Build Coastguard Worker
389*da0073e9SAndroid Build Coastguard Worker    For a more detailed description into the Proxy internals, check out
390*da0073e9SAndroid Build Coastguard Worker    the "Proxy" section in `torch/fx/README.md`
391*da0073e9SAndroid Build Coastguard Worker    """
392*da0073e9SAndroid Build Coastguard Worker
393*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
394*da0073e9SAndroid Build Coastguard Worker    def __init__(self, node: Node, tracer: 'Optional[TracerBase]' = None):
395*da0073e9SAndroid Build Coastguard Worker        if tracer is None:
396*da0073e9SAndroid Build Coastguard Worker            # This allows you to create a Proxy object around a raw Node
397*da0073e9SAndroid Build Coastguard Worker            tracer = GraphAppendingTracer(node.graph)
398*da0073e9SAndroid Build Coastguard Worker        self.tracer = tracer
399*da0073e9SAndroid Build Coastguard Worker        self.node = node
400*da0073e9SAndroid Build Coastguard Worker
401*da0073e9SAndroid Build Coastguard Worker    def __repr__(self) -> str:
402*da0073e9SAndroid Build Coastguard Worker        return f'Proxy({self.node.name})'
403*da0073e9SAndroid Build Coastguard Worker
404*da0073e9SAndroid Build Coastguard Worker    def __getattr__(self, k) -> 'Attribute':
405*da0073e9SAndroid Build Coastguard Worker        # note: not added to the graph yet, if this is a method call
406*da0073e9SAndroid Build Coastguard Worker        # we peephole optimize to the method invocation
407*da0073e9SAndroid Build Coastguard Worker        return Attribute(self, k)
408*da0073e9SAndroid Build Coastguard Worker
409*da0073e9SAndroid Build Coastguard Worker    def __getstate__(self) -> Dict:
410*da0073e9SAndroid Build Coastguard Worker        return self.__dict__
411*da0073e9SAndroid Build Coastguard Worker
412*da0073e9SAndroid Build Coastguard Worker    def __deepcopy__(self, memo) -> Dict:
413*da0073e9SAndroid Build Coastguard Worker        # We have to explicitly override this method, because otherwise deepcopy
414*da0073e9SAndroid Build Coastguard Worker        # will go to __getattr__(self, "__deepcopy__") and return a
415*da0073e9SAndroid Build Coastguard Worker        # Attribute(__deepcopy__), and may go into an infinite loop in some cases.
416*da0073e9SAndroid Build Coastguard Worker        import copy
417*da0073e9SAndroid Build Coastguard Worker        new_dict = {}
418*da0073e9SAndroid Build Coastguard Worker        for k, v in self.__dict__.items():
419*da0073e9SAndroid Build Coastguard Worker            try:
420*da0073e9SAndroid Build Coastguard Worker                new_obj = copy.deepcopy(v, memo)
421*da0073e9SAndroid Build Coastguard Worker            except Exception:
422*da0073e9SAndroid Build Coastguard Worker                log.warning(
423*da0073e9SAndroid Build Coastguard Worker                    "Shallow copy %s of Proxy because it cannot be deepcopied. "
424*da0073e9SAndroid Build Coastguard Worker                    "Proxy is created for node %s", k, self.node.name)
425*da0073e9SAndroid Build Coastguard Worker                new_obj = copy.copy(v)
426*da0073e9SAndroid Build Coastguard Worker            new_dict[k] = new_obj
427*da0073e9SAndroid Build Coastguard Worker        assert "node" in new_dict
428*da0073e9SAndroid Build Coastguard Worker        assert "tracer" in new_dict
429*da0073e9SAndroid Build Coastguard Worker        new_proxy = Proxy(new_dict["node"], new_dict["tracer"])
430*da0073e9SAndroid Build Coastguard Worker        for k, v in new_dict.items():
431*da0073e9SAndroid Build Coastguard Worker            new_proxy.__dict__[k] = v
432*da0073e9SAndroid Build Coastguard Worker        return new_proxy
433*da0073e9SAndroid Build Coastguard Worker
434*da0073e9SAndroid Build Coastguard Worker    def __setstate__(self, d):
435*da0073e9SAndroid Build Coastguard Worker        # This is called when being unpickled/loaded.
436*da0073e9SAndroid Build Coastguard Worker        self.__dict__ = d
437*da0073e9SAndroid Build Coastguard Worker
438*da0073e9SAndroid Build Coastguard Worker    def __call__(self, *args, **kwargs) -> 'Proxy':
439*da0073e9SAndroid Build Coastguard Worker        return self.tracer.create_proxy('call_method', '__call__', (self,) + args, kwargs)
440*da0073e9SAndroid Build Coastguard Worker
441*da0073e9SAndroid Build Coastguard Worker    def __iter__(self) -> Iterator['Proxy']:
442*da0073e9SAndroid Build Coastguard Worker        frame = inspect.currentframe()
443*da0073e9SAndroid Build Coastguard Worker        assert frame is not None
444*da0073e9SAndroid Build Coastguard Worker        calling_frame = frame.f_back
445*da0073e9SAndroid Build Coastguard Worker        assert calling_frame is not None
446*da0073e9SAndroid Build Coastguard Worker        inst_list = list(dis.get_instructions(calling_frame.f_code))
447*da0073e9SAndroid Build Coastguard Worker        if sys.version_info >= (3, 11):
448*da0073e9SAndroid Build Coastguard Worker            from bisect import bisect_left
449*da0073e9SAndroid Build Coastguard Worker            inst_idx = bisect_left(inst_list, calling_frame.f_lasti, key=lambda x: x.offset)
450*da0073e9SAndroid Build Coastguard Worker        else:
451*da0073e9SAndroid Build Coastguard Worker            inst_idx = calling_frame.f_lasti // 2
452*da0073e9SAndroid Build Coastguard Worker        inst = inst_list[inst_idx]
453*da0073e9SAndroid Build Coastguard Worker        if inst.opname == 'UNPACK_SEQUENCE':
454*da0073e9SAndroid Build Coastguard Worker            return (self[i] for i in range(inst.argval))  # type: ignore[index]
455*da0073e9SAndroid Build Coastguard Worker
456*da0073e9SAndroid Build Coastguard Worker        return self.tracer.iter(self)
457*da0073e9SAndroid Build Coastguard Worker
458*da0073e9SAndroid Build Coastguard Worker    def __abs__(self):
459*da0073e9SAndroid Build Coastguard Worker        return self.tracer.create_proxy('call_function', operator.abs, (self,), {})
460*da0073e9SAndroid Build Coastguard Worker
461*da0073e9SAndroid Build Coastguard Worker    def __bool__(self) -> bool:
462*da0073e9SAndroid Build Coastguard Worker        if self.tracer.trace_asserts:
463*da0073e9SAndroid Build Coastguard Worker            # check if this boolean is used in an assertion, bytecode pattern for assertions
464*da0073e9SAndroid Build Coastguard Worker            # is pretty stable for Python 3.7--3.9
465*da0073e9SAndroid Build Coastguard Worker            frame = inspect.currentframe()
466*da0073e9SAndroid Build Coastguard Worker            assert frame is not None
467*da0073e9SAndroid Build Coastguard Worker            calling_frame = frame.f_back
468*da0073e9SAndroid Build Coastguard Worker            assert calling_frame is not None
469*da0073e9SAndroid Build Coastguard Worker            insts = list(dis.get_instructions(calling_frame.f_code))
470*da0073e9SAndroid Build Coastguard Worker            if sys.version_info >= (3, 11):
471*da0073e9SAndroid Build Coastguard Worker                from bisect import bisect_left
472*da0073e9SAndroid Build Coastguard Worker                cur = bisect_left(insts, calling_frame.f_lasti, key=lambda x: x.offset)
473*da0073e9SAndroid Build Coastguard Worker            else:
474*da0073e9SAndroid Build Coastguard Worker                cur = calling_frame.f_lasti // 2
475*da0073e9SAndroid Build Coastguard Worker            inst = insts[cur]
476*da0073e9SAndroid Build Coastguard Worker
477*da0073e9SAndroid Build Coastguard Worker            if inst.opname == 'POP_JUMP_IF_TRUE':
478*da0073e9SAndroid Build Coastguard Worker                first = insts[cur + 1]
479*da0073e9SAndroid Build Coastguard Worker                assert inst.arg is not None
480*da0073e9SAndroid Build Coastguard Worker                last = insts[inst.arg // 2 - 1]
481*da0073e9SAndroid Build Coastguard Worker                starts_with_assert = (first.opname == 'LOAD_GLOBAL' and first.argval == 'AssertionError'
482*da0073e9SAndroid Build Coastguard Worker                                      or first.opname == 'LOAD_ASSERTION_ERROR')
483*da0073e9SAndroid Build Coastguard Worker                if starts_with_assert and last.opname == 'RAISE_VARARGS':
484*da0073e9SAndroid Build Coastguard Worker                    self.tracer.create_proxy('call_function', assert_fn, (self,), {})
485*da0073e9SAndroid Build Coastguard Worker                    return True
486*da0073e9SAndroid Build Coastguard Worker
487*da0073e9SAndroid Build Coastguard Worker        return self.tracer.to_bool(self)
488*da0073e9SAndroid Build Coastguard Worker
489*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
490*da0073e9SAndroid Build Coastguard Worker    def keys(self):
491*da0073e9SAndroid Build Coastguard Worker        return self.tracer.keys(self)
492*da0073e9SAndroid Build Coastguard Worker
493*da0073e9SAndroid Build Coastguard Worker    def __len__(self):
494*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
495*da0073e9SAndroid Build Coastguard Worker                           "this call to be recorded, please call torch.fx.wrap('len') at "
496*da0073e9SAndroid Build Coastguard Worker                           "module scope")
497*da0073e9SAndroid Build Coastguard Worker
498*da0073e9SAndroid Build Coastguard Worker    @classmethod
499*da0073e9SAndroid Build Coastguard Worker    def __torch_function__(cls, orig_method, types, args=None, kwargs=None):
500*da0073e9SAndroid Build Coastguard Worker        args = args if args else ()
501*da0073e9SAndroid Build Coastguard Worker        kwargs = kwargs if kwargs else {}
502*da0073e9SAndroid Build Coastguard Worker
503*da0073e9SAndroid Build Coastguard Worker        tracers : Dict[Any, None] = {}
504*da0073e9SAndroid Build Coastguard Worker
505*da0073e9SAndroid Build Coastguard Worker        def find_tracer(a):
506*da0073e9SAndroid Build Coastguard Worker            if isinstance(a, cls):
507*da0073e9SAndroid Build Coastguard Worker                tracers[a.tracer] = None
508*da0073e9SAndroid Build Coastguard Worker        torch.fx.node.map_aggregate(args, find_tracer)
509*da0073e9SAndroid Build Coastguard Worker        torch.fx.node.map_aggregate(kwargs, find_tracer)
510*da0073e9SAndroid Build Coastguard Worker
511*da0073e9SAndroid Build Coastguard Worker        if len(tracers) > 1:
512*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(f'Found multiple different tracers {list(tracers.keys())} while '
513*da0073e9SAndroid Build Coastguard Worker                               f'trying to trace operations {orig_method}')
514*da0073e9SAndroid Build Coastguard Worker        tracer = next(iter(tracers.keys()))
515*da0073e9SAndroid Build Coastguard Worker
516*da0073e9SAndroid Build Coastguard Worker        if isinstance(orig_method, torch._C.ScriptMethod):
517*da0073e9SAndroid Build Coastguard Worker            args = (orig_method.owner,) + args
518*da0073e9SAndroid Build Coastguard Worker            return tracer.create_proxy('call_method', orig_method.name, args, kwargs)
519*da0073e9SAndroid Build Coastguard Worker        if torch.overrides.is_tensor_method_or_property(orig_method):
520*da0073e9SAndroid Build Coastguard Worker            return tracer.create_proxy('call_method', orig_method.__name__, args, kwargs)
521*da0073e9SAndroid Build Coastguard Worker        else:
522*da0073e9SAndroid Build Coastguard Worker            if isinstance(orig_method, torch._ops.HigherOrderOperator):
523*da0073e9SAndroid Build Coastguard Worker                # TODO: Define how to symbolically trace HigherOrderOperators
524*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError("Unable to symbolically trace HigherOrderOperators")
525*da0073e9SAndroid Build Coastguard Worker            return tracer.create_proxy('call_function', orig_method, args, kwargs,
526*da0073e9SAndroid Build Coastguard Worker                                       name=tracer.graph._target_to_str(orig_method.__name__))
527*da0073e9SAndroid Build Coastguard Worker
528*da0073e9SAndroid Build Coastguard Worker
529*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=True)
530*da0073e9SAndroid Build Coastguard Workerclass Attribute(Proxy):
531*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
532*da0073e9SAndroid Build Coastguard Worker    def __init__(self, root: Proxy, attr: str):
533*da0073e9SAndroid Build Coastguard Worker        self.root = root
534*da0073e9SAndroid Build Coastguard Worker        self.attr = attr
535*da0073e9SAndroid Build Coastguard Worker        self.tracer = root.tracer
536*da0073e9SAndroid Build Coastguard Worker        self._node: Optional[Node] = None
537*da0073e9SAndroid Build Coastguard Worker
538*da0073e9SAndroid Build Coastguard Worker    @property
539*da0073e9SAndroid Build Coastguard Worker    def node(self):
540*da0073e9SAndroid Build Coastguard Worker        # the node for attributes is added lazily, since most will just be method calls
541*da0073e9SAndroid Build Coastguard Worker        # which do not rely on the getitem call
542*da0073e9SAndroid Build Coastguard Worker        if self._node is None:
543*da0073e9SAndroid Build Coastguard Worker            self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
544*da0073e9SAndroid Build Coastguard Worker        return self._node
545*da0073e9SAndroid Build Coastguard Worker
546*da0073e9SAndroid Build Coastguard Worker    def __call__(self, *args, **kwargs):
547*da0073e9SAndroid Build Coastguard Worker        return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
548*da0073e9SAndroid Build Coastguard Worker
549*da0073e9SAndroid Build Coastguard Worker
550*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=False)
551*da0073e9SAndroid Build Coastguard Workerclass ParameterProxy(Proxy):
552*da0073e9SAndroid Build Coastguard Worker    """
553*da0073e9SAndroid Build Coastguard Worker    A special proxy which lets "shape", "size", "dim", and a few other
554*da0073e9SAndroid Build Coastguard Worker    attribute accesses pass through to the underlying  module parameter object,
555*da0073e9SAndroid Build Coastguard Worker    so that conditional tests on these attributes will not throw exception during tracing
556*da0073e9SAndroid Build Coastguard Worker    """
557*da0073e9SAndroid Build Coastguard Worker    def __init__(self, tracer: TracerBase, node: Node, name, param):
558*da0073e9SAndroid Build Coastguard Worker        super().__init__(node, tracer)
559*da0073e9SAndroid Build Coastguard Worker        assert isinstance(param, torch.nn.Parameter)
560*da0073e9SAndroid Build Coastguard Worker        self.param = param
561*da0073e9SAndroid Build Coastguard Worker        self.name = name
562*da0073e9SAndroid Build Coastguard Worker
563*da0073e9SAndroid Build Coastguard Worker    def __repr__(self) -> str:
564*da0073e9SAndroid Build Coastguard Worker        return f'ParameterProxy({self.name})'
565*da0073e9SAndroid Build Coastguard Worker
566*da0073e9SAndroid Build Coastguard Worker    @property
567*da0073e9SAndroid Build Coastguard Worker    def shape(self):
568*da0073e9SAndroid Build Coastguard Worker        return self.param.shape
569*da0073e9SAndroid Build Coastguard Worker
570*da0073e9SAndroid Build Coastguard Worker    def size(self):
571*da0073e9SAndroid Build Coastguard Worker        return self.param.size()
572*da0073e9SAndroid Build Coastguard Worker
573*da0073e9SAndroid Build Coastguard Worker    def dim(self):
574*da0073e9SAndroid Build Coastguard Worker        return self.param.dim()
575*da0073e9SAndroid Build Coastguard Worker
576*da0073e9SAndroid Build Coastguard Worker    @property
577*da0073e9SAndroid Build Coastguard Worker    def ndim(self):
578*da0073e9SAndroid Build Coastguard Worker        return self.param.ndim
579*da0073e9SAndroid Build Coastguard Worker
580*da0073e9SAndroid Build Coastguard Worker    def numel(self):
581*da0073e9SAndroid Build Coastguard Worker        return self.param.numel()
582*da0073e9SAndroid Build Coastguard Worker
583*da0073e9SAndroid Build Coastguard Worker    def nelement(self):
584*da0073e9SAndroid Build Coastguard Worker        return self.param.nelement()
585*da0073e9SAndroid Build Coastguard Worker
586*da0073e9SAndroid Build Coastguard Worker
587*da0073e9SAndroid Build Coastguard Workerfor method in magic_methods:
588*da0073e9SAndroid Build Coastguard Worker    def _scope(method):
589*da0073e9SAndroid Build Coastguard Worker        def impl(*args, **kwargs):
590*da0073e9SAndroid Build Coastguard Worker            tracer = args[0].tracer
591*da0073e9SAndroid Build Coastguard Worker            target = getattr(operator, method)
592*da0073e9SAndroid Build Coastguard Worker            return tracer.create_proxy('call_function', target, args, kwargs)
593*da0073e9SAndroid Build Coastguard Worker        impl.__name__ = method
594*da0073e9SAndroid Build Coastguard Worker        as_magic = f'__{method.strip("_")}__'
595*da0073e9SAndroid Build Coastguard Worker        setattr(Proxy, as_magic, impl)
596*da0073e9SAndroid Build Coastguard Worker    _scope(method)
597*da0073e9SAndroid Build Coastguard Worker
598*da0073e9SAndroid Build Coastguard Workerdef _define_reflectable(orig_method_name):
599*da0073e9SAndroid Build Coastguard Worker    method_name = f'__r{orig_method_name.strip("_")}__'
600*da0073e9SAndroid Build Coastguard Worker
601*da0073e9SAndroid Build Coastguard Worker    def impl(self, rhs):
602*da0073e9SAndroid Build Coastguard Worker        target = getattr(operator, orig_method_name)
603*da0073e9SAndroid Build Coastguard Worker        return self.tracer.create_proxy('call_function', target, (rhs, self), {})
604*da0073e9SAndroid Build Coastguard Worker    impl.__name__ = method_name
605*da0073e9SAndroid Build Coastguard Worker    impl.__qualname__ = method_name
606*da0073e9SAndroid Build Coastguard Worker    setattr(Proxy, method_name, impl)
607*da0073e9SAndroid Build Coastguard Worker
608*da0073e9SAndroid Build Coastguard Workerfor orig_method_name in reflectable_magic_methods:
609*da0073e9SAndroid Build Coastguard Worker    _define_reflectable(orig_method_name)
610