xref: /aosp_15_r20/external/pytorch/torch/export/graph_signature.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 # mypy: allow-untyped-defs
2 import dataclasses
3 from enum import auto, Enum
4 from typing import Collection, Dict, List, Mapping, Optional, Set, TYPE_CHECKING, Union
5 
6 from torch._library.fake_class_registry import FakeScriptObject
7 
8 
9 if TYPE_CHECKING:
10     import torch
11     from torch._functorch._aot_autograd.schemas import GraphSignature
12 
13 __all__ = [
14     "ConstantArgument",
15     "CustomObjArgument",
16     "ExportBackwardSignature",
17     "ExportGraphSignature",
18     "InputKind",
19     "InputSpec",
20     "OutputKind",
21     "OutputSpec",
22     "SymIntArgument",
23     "TensorArgument",
24 ]
25 
26 
27 @dataclasses.dataclass
28 class TensorArgument:
29     name: str
30 
31 
32 @dataclasses.dataclass
33 class TokenArgument:
34     name: str
35 
36 
37 @dataclasses.dataclass
38 class SymIntArgument:
39     name: str
40 
41 
42 @dataclasses.dataclass
43 class CustomObjArgument:
44     name: str
45     class_fqn: str
46     fake_val: Optional[FakeScriptObject] = None
47 
48 
49 @dataclasses.dataclass
50 class ConstantArgument:
51     name: str
52     value: Union[int, float, bool, str, None]
53 
54 
55 ArgumentSpec = Union[
56     TensorArgument,
57     SymIntArgument,
58     ConstantArgument,
59     CustomObjArgument,
60     TokenArgument,
61 ]
62 
63 
64 class InputKind(Enum):
65     USER_INPUT = auto()
66     PARAMETER = auto()
67     BUFFER = auto()
68     CONSTANT_TENSOR = auto()
69     CUSTOM_OBJ = auto()
70     TOKEN = auto()
71 
72 
73 @dataclasses.dataclass
74 class InputSpec:
75     kind: InputKind
76     arg: ArgumentSpec
77     target: Optional[str]
78     persistent: Optional[bool] = None
79 
80     def __post_init__(self):
81         if self.kind == InputKind.BUFFER:
82             assert (
83                 self.persistent is not None
84             ), "Failed to specify persistent flag on BUFFER."
85         assert isinstance(
86             self.arg,
87             (
88                 TensorArgument,
89                 SymIntArgument,
90                 ConstantArgument,
91                 CustomObjArgument,
92                 TokenArgument,
93             ),
94         ), f"got {type(self.arg)}"
95 
96 
97 class OutputKind(Enum):
98     USER_OUTPUT = auto()
99     LOSS_OUTPUT = auto()
100     BUFFER_MUTATION = auto()
101     GRADIENT_TO_PARAMETER = auto()
102     GRADIENT_TO_USER_INPUT = auto()
103     USER_INPUT_MUTATION = auto()
104     TOKEN = auto()
105 
106 
107 @dataclasses.dataclass
108 class OutputSpec:
109     kind: OutputKind
110     arg: ArgumentSpec
111     target: Optional[str]
112 
113     def __post_init__(self):
114         assert isinstance(
115             self.arg,
116             (
117                 TensorArgument,
118                 SymIntArgument,
119                 ConstantArgument,
120                 TokenArgument,
121                 CustomObjArgument,
122             ),
123         ), self.arg
124 
125 
126 @dataclasses.dataclass
127 class ExportBackwardSignature:
128     gradients_to_parameters: Dict[str, str]
129     gradients_to_user_inputs: Dict[str, str]
130     loss_output: str
131 
132 
133 @dataclasses.dataclass
134 class ExportGraphSignature:
135     """
136     :class:`ExportGraphSignature` models the input/output signature of Export Graph,
137     which is a fx.Graph with stronger invariants gurantees.
138 
139     Export Graph is functional and does not access "states" like parameters
140     or buffers within the graph via ``getattr`` nodes. Instead, :func:`export`
141     gurantees that parameters, buffers, and constant tensors are lifted out of
142     the graph as inputs.  Similarly, any mutations to buffers are not included
143     in the graph either, instead the updated values of mutated buffers are
144     modeled as additional outputs of Export Graph.
145 
146     The ordering of all inputs and outputs are::
147 
148         Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs]
149         Outputs = [*mutated_inputs, *flattened_user_outputs]
150 
151     e.g. If following module is exported::
152 
153         class CustomModule(nn.Module):
154             def __init__(self) -> None:
155                 super(CustomModule, self).__init__()
156 
157                 # Define a parameter
158                 self.my_parameter = nn.Parameter(torch.tensor(2.0))
159 
160                 # Define two buffers
161                 self.register_buffer('my_buffer1', torch.tensor(3.0))
162                 self.register_buffer('my_buffer2', torch.tensor(4.0))
163 
164             def forward(self, x1, x2):
165                 # Use the parameter, buffers, and both inputs in the forward method
166                 output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2
167 
168                 # Mutate one of the buffers (e.g., increment it by 1)
169                 self.my_buffer2.add_(1.0) # In-place addition
170 
171                 return output
172 
173     Resulting Graph would be::
174 
175         graph():
176             %arg0_1 := placeholder[target=arg0_1]
177             %arg1_1 := placeholder[target=arg1_1]
178             %arg2_1 := placeholder[target=arg2_1]
179             %arg3_1 := placeholder[target=arg3_1]
180             %arg4_1 := placeholder[target=arg4_1]
181             %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {})
182             %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {})
183             %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {})
184             %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {})
185             %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {})
186             return (add_tensor_2, add_tensor_1)
187 
188     Resulting ExportGraphSignature would be::
189 
190         ExportGraphSignature(
191             input_specs=[
192                 InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='my_parameter'),
193                 InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg1_1'), target='my_buffer1'),
194                 InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg2_1'), target='my_buffer2'),
195                 InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None),
196                 InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg4_1'), target=None)
197             ],
198             output_specs=[
199                 OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add_2'), target='my_buffer2'),
200                 OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None)
201             ]
202         )
203     """
204 
205     input_specs: List[InputSpec]
206     output_specs: List[OutputSpec]
207 
208     # A list of parameters uniquely identified by mangled fully qualified name
209     @property
210     def parameters(self) -> Collection[str]:
211         return tuple(
212             s.target
213             for s in self.input_specs
214             if s.kind == InputKind.PARAMETER
215             if isinstance(s.target, str)
216         )
217 
218     # A list of buffers uniquely identified by mangled fully qualified name
219     @property
220     def buffers(self) -> Collection[str]:
221         return tuple(
222             s.target
223             for s in self.input_specs
224             if s.kind == InputKind.BUFFER
225             if isinstance(s.target, str)
226         )
227 
228     @property
229     def non_persistent_buffers(self) -> Collection[str]:
230         return tuple(
231             s.target
232             for s in self.input_specs
233             if s.kind == InputKind.BUFFER
234             if s.persistent is False
235             if isinstance(s.target, str)
236         )
237 
238     # A list of lifted constant tensors
239     @property
240     def lifted_tensor_constants(self) -> Collection[str]:
241         return tuple(
242             s.target
243             for s in self.input_specs
244             if s.kind == InputKind.CONSTANT_TENSOR
245             if isinstance(s.target, str)
246         )
247 
248     @property
249     def lifted_custom_objs(self) -> Collection[str]:
250         return tuple(
251             s.target
252             for s in self.input_specs
253             if s.kind == InputKind.CUSTOM_OBJ
254             if isinstance(s.target, str)
255         )
256 
257     # Graph node names of pytree-flattened inputs of original program
258     @property
259     def user_inputs(self) -> Collection[Union[int, float, bool, None, str]]:
260         user_inputs: List[Union[int, float, bool, None, str]] = []
261         for s in self.input_specs:
262             if s.kind != InputKind.USER_INPUT:
263                 continue
264 
265             if isinstance(s.arg, (TensorArgument, SymIntArgument, CustomObjArgument)):
266                 user_inputs.append(s.arg.name)
267             elif isinstance(s.arg, ConstantArgument):
268                 user_inputs.append(s.arg.value)
269             else:
270                 raise RuntimeError(f"{s.arg} is not a valid user inputs")
271         return tuple(user_inputs)
272 
273     # Graph node names of pytree-flattened outputs of original program
274     @property
275     def user_outputs(self) -> Collection[Union[int, float, bool, None, str]]:
276         user_outputs: List[Union[int, float, bool, None, str]] = []
277         for s in self.output_specs:
278             if s.kind != OutputKind.USER_OUTPUT:
279                 continue
280 
281             if isinstance(s.arg, (TensorArgument, SymIntArgument)):
282                 user_outputs.append(s.arg.name)
283             elif isinstance(s.arg, ConstantArgument):
284                 user_outputs.append(s.arg.value)
285             elif isinstance(s.arg, CustomObjArgument):
286                 user_outputs.append(s.arg.name)
287             else:
288                 raise RuntimeError(f"{s.arg} is not a valid user output")
289         return tuple(user_outputs)
290 
291     # A dictionary mapping graph input node names to parameters. If a graph input
292     # name is found in this dictionary, it is guranteed to be a lifted parameter.
293     @property
294     def inputs_to_parameters(self) -> Mapping[str, str]:
295         return _immutable_dict(
296             (s.arg.name, s.target)
297             for s in self.input_specs
298             if s.kind == InputKind.PARAMETER
299             and isinstance(s.arg, TensorArgument)
300             and isinstance(s.target, str)
301         )
302 
303     # A dictionary mapping graph input node names to buffers. If a graph input
304     # name is found in this dictionary, it is guranteed to be a lifted buffer.
305     @property
306     def inputs_to_buffers(self) -> Mapping[str, str]:
307         return _immutable_dict(
308             (s.arg.name, s.target)  # type: ignore[union-attr, misc]
309             for s in self.input_specs
310             if s.kind == InputKind.BUFFER
311             and isinstance(s.arg, TensorArgument)
312             and isinstance(s.target, str)
313         )
314 
315     # A dictionary mapping graph output node names to buffers that are mutated in the
316     # original program. Buffers that are not mutated will not be found in this dictionary.
317     @property
318     def buffers_to_mutate(self) -> Mapping[str, str]:
319         return _immutable_dict(
320             (s.arg.name, s.target)
321             for s in self.output_specs
322             if s.kind == OutputKind.BUFFER_MUTATION
323             and isinstance(s.arg, TensorArgument)
324             and isinstance(s.target, str)
325         )
326 
327     @property
328     def user_inputs_to_mutate(self) -> Mapping[str, str]:
329         return _immutable_dict(
330             (s.arg.name, s.target)
331             for s in self.output_specs
332             if s.kind == OutputKind.USER_INPUT_MUTATION
333             and isinstance(s.arg, TensorArgument)
334             and isinstance(s.target, str)
335         )
336 
337     # A dictionary mapping graph input node names to lifted tensor constants.
338     @property
339     def inputs_to_lifted_tensor_constants(self) -> Mapping[str, str]:
340         return _immutable_dict(
341             (s.arg.name, s.target)
342             for s in self.input_specs
343             if s.kind == InputKind.CONSTANT_TENSOR
344             and isinstance(s.arg, TensorArgument)
345             and isinstance(s.target, str)
346         )
347 
348     @property
349     def inputs_to_lifted_custom_objs(self) -> Mapping[str, str]:
350         return _immutable_dict(
351             (s.arg.name, s.target)
352             for s in self.input_specs
353             if s.kind == InputKind.CUSTOM_OBJ
354             and isinstance(s.arg, CustomObjArgument)
355             and isinstance(s.target, str)
356         )
357 
358     @property
359     def backward_signature(self) -> Optional[ExportBackwardSignature]:
360         loss_output = None
361         gradients_to_parameters: Dict[str, str] = {}
362         gradients_to_user_inputs: Dict[str, str] = {}
363         for spec in self.output_specs:
364             if spec.kind == OutputKind.LOSS_OUTPUT:
365                 assert loss_output is None
366                 assert isinstance(spec.arg, TensorArgument)
367                 loss_output = spec.arg.name
368             elif spec.kind == OutputKind.GRADIENT_TO_PARAMETER:
369                 assert isinstance(spec.target, str)
370                 assert isinstance(spec.arg, TensorArgument)
371                 gradients_to_parameters[spec.arg.name] = spec.target
372             elif spec.kind == OutputKind.GRADIENT_TO_USER_INPUT:
373                 assert isinstance(spec.target, str)
374                 assert isinstance(spec.arg, TensorArgument)
375                 gradients_to_user_inputs[spec.arg.name] = spec.target
376 
377         if loss_output is None:
378             return None
379 
380         return ExportBackwardSignature(
381             loss_output=loss_output,
382             gradients_to_parameters=gradients_to_parameters,
383             gradients_to_user_inputs=gradients_to_user_inputs,
384         )
385 
386     # Map from assertion dependency token index to assertion dep token output
387     # name in output. The shape of output after aot_autograd will be like:
388     # (updated_inputs, user_outputs, dep_token).
389     @property
390     def assertion_dep_token(self) -> Optional[Mapping[int, str]]:
391         return None
392 
393     @property
394     def input_tokens(self) -> Collection[str]:
395         input_tokens = []
396         for s in self.input_specs:
397             if s.kind == InputKind.TOKEN:
398                 assert isinstance(s.arg, TokenArgument)
399                 input_tokens.append(s.arg.name)
400         return tuple(input_tokens)
401 
402     @property
403     def output_tokens(self) -> Collection[str]:
404         output_tokens = []
405         for s in self.output_specs:
406             if s.kind == OutputKind.TOKEN:
407                 assert isinstance(s.arg, TokenArgument)
408                 output_tokens.append(s.arg.name)
409         return tuple(output_tokens)
410 
411     def __post_init__(self) -> None:
412         assertion_dep_token = self.assertion_dep_token
413         if assertion_dep_token is None:
414             return
415         assert len(assertion_dep_token) == 1
416         assertion_dep_token_index = next(iter(assertion_dep_token.keys()))
417         assert (
418             len(self.user_outputs) + len(self.buffers_to_mutate)
419             == assertion_dep_token_index
420         )
421 
422     def replace_all_uses(self, old: str, new: str):
423         """
424         Replace all uses of the old name with new name in the signature.
425         """
426         assert isinstance(old, str)
427         assert isinstance(new, str)
428         arg_types = (TensorArgument, SymIntArgument, CustomObjArgument, TokenArgument)
429         for o in self.output_specs:
430             if isinstance(o.arg, arg_types):
431                 if o.arg.name == old:
432                     o.arg.name = new
433         for i in self.input_specs:
434             if isinstance(i.arg, arg_types):
435                 if i.arg.name == old:
436                     i.arg.name = new
437 
438     def get_replace_hook(self):
439         def _(old, new, user):
440             if user.op in ("output", "input"):
441                 self.replace_all_uses(old.name, new)
442 
443         return _
444 
445 
446 def _immutable_dict(items):
447     """
448     Creates a mapping where items cannot be added, deleted, or updated.
449     NOTE: The immutability is shallow (like tuple is an immutable collection).
450     """
451     from types import MappingProxyType
452 
453     return MappingProxyType(dict(items))
454 
455 
456 def _make_argument_spec(node, token_names) -> ArgumentSpec:
457     from torch import ScriptObject, SymInt
458     from torch._library.fake_class_registry import FakeScriptObject
459     from torch._subclasses.fake_tensor import FakeTensor
460 
461     if isinstance(node, (int, bool, float, type(None), str)):
462         # For const outputs we just directly return this
463         return ConstantArgument(name="", value=node)
464 
465     assert (
466         "val" in node.meta
467     ), f"{node} is not a constant or a node with a 'val' metadata field"
468     val = node.meta["val"]
469     if node.name in token_names:
470         return TokenArgument(name=node.name)
471     elif isinstance(val, FakeTensor):
472         return TensorArgument(name=node.name)
473     elif isinstance(val, SymInt):
474         return SymIntArgument(name=node.name)
475     elif isinstance(val, ScriptObject):
476         return CustomObjArgument(name=node.name, class_fqn=val._type().qualified_name())  # type: ignore[attr-defined]
477     elif isinstance(val, FakeScriptObject):
478         return CustomObjArgument(
479             name=node.name, class_fqn=val.script_class_name, fake_val=val
480         )
481     elif isinstance(val, (int, bool, str, float, type(None))):
482         return ConstantArgument(name=node.name, value=val)
483     else:
484         raise AssertionError(
485             f"Encountered an unsupported object of type {type(val)} "
486             f"while writing the metadata for exported program"
487         )
488 
489 
490 def _convert_to_export_graph_signature(
491     graph_signature: "GraphSignature",
492     gm: "torch.fx.GraphModule",
493     non_persistent_buffers: Set[str],
494 ) -> "ExportGraphSignature":
495     from torch.utils import _pytree as pytree
496 
497     is_joint = graph_signature.backward_signature is not None
498 
499     # unpack objects
500     user_inputs = set(graph_signature.user_inputs)
501     inputs_to_parameters = graph_signature.inputs_to_parameters
502     inputs_to_buffers = graph_signature.inputs_to_buffers
503     user_outputs = set(graph_signature.user_outputs)
504     buffer_mutations = graph_signature.buffers_to_mutate
505     user_input_mutations = graph_signature.user_inputs_to_mutate
506     grad_params = graph_signature.backward_signature.gradients_to_parameter if is_joint else {}  # type: ignore[union-attr]
507     grad_user_inputs = graph_signature.backward_signature.gradients_to_user_inputs if is_joint else {}  # type: ignore[union-attr]
508     loss_output = graph_signature.backward_signature.loss_output if is_joint else None  # type: ignore[union-attr]
509     input_tokens = graph_signature.input_tokens
510     output_tokens = graph_signature.output_tokens
511 
512     inputs = [
513         _make_argument_spec(node, input_tokens)
514         for node in gm.graph.nodes
515         if node.op == "placeholder"
516     ]
517     outputs = [
518         _make_argument_spec(node, output_tokens)
519         for node in pytree.tree_leaves(next(iter(reversed(gm.graph.nodes))).args)
520     ]
521 
522     def to_input_spec(inp: ArgumentSpec) -> InputSpec:
523         if isinstance(inp, TokenArgument):
524             return InputSpec(kind=InputKind.TOKEN, arg=inp, target=None)
525 
526         if not isinstance(inp, TensorArgument):
527             return InputSpec(kind=InputKind.USER_INPUT, arg=inp, target=None)
528         name = inp.name
529         if name in user_inputs:
530             return InputSpec(kind=InputKind.USER_INPUT, arg=inp, target=None)
531         elif name in inputs_to_parameters:
532             return InputSpec(
533                 kind=InputKind.PARAMETER,
534                 arg=inp,
535                 target=inputs_to_parameters[name],  # type: ignore[index]
536             )
537         elif name in inputs_to_buffers:
538             return InputSpec(
539                 kind=InputKind.BUFFER,
540                 arg=inp,
541                 target=inputs_to_buffers[name],  # type: ignore[index]
542                 persistent=(inputs_to_buffers[name] not in non_persistent_buffers),  # type: ignore[index]
543             )
544         else:
545             raise AssertionError(f"Unknown tensor input kind: {name}")
546 
547     def to_output_spec(idx: int, o: ArgumentSpec) -> OutputSpec:
548         if isinstance(o, TokenArgument):
549             return OutputSpec(kind=OutputKind.TOKEN, arg=o, target=None)
550 
551         if not isinstance(o, TensorArgument):
552             return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None)
553         name = o.name
554         if idx < len(buffer_mutations) + len(user_input_mutations) + len(output_tokens):
555             if name in buffer_mutations:
556                 return OutputSpec(
557                     kind=OutputKind.BUFFER_MUTATION,
558                     arg=o,
559                     target=buffer_mutations[name],  # type: ignore[index]
560                 )
561             elif name in user_input_mutations:
562                 return OutputSpec(
563                     kind=OutputKind.USER_INPUT_MUTATION,
564                     arg=o,
565                     target=user_input_mutations[name],  # type: ignore[index]
566                 )
567             else:
568                 raise AssertionError(f"Unknown tensor mutation kind: {name}")
569         else:
570             if name in user_outputs:
571                 return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None)
572 
573             elif name in grad_params:
574                 return OutputSpec(
575                     kind=OutputKind.GRADIENT_TO_PARAMETER,
576                     arg=o,
577                     target=grad_params[name],
578                 )
579             elif name in grad_user_inputs:
580                 return OutputSpec(
581                     kind=OutputKind.GRADIENT_TO_USER_INPUT,
582                     arg=o,
583                     target=grad_user_inputs[name],
584                 )
585             elif name == loss_output:
586                 return OutputSpec(kind=OutputKind.LOSS_OUTPUT, arg=o, target=None)
587 
588             else:
589                 raise AssertionError(f"Unknown tensor output kind: {name}")
590 
591     input_specs = [to_input_spec(inp) for inp in inputs]
592     output_specs = [to_output_spec(idx, o) for idx, o in enumerate(outputs)]
593     return ExportGraphSignature(input_specs=input_specs, output_specs=output_specs)
594