xref: /aosp_15_r20/external/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# -*- coding: utf-8 -*-
2# Copyright (c) Meta Platforms, Inc. and affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8"""
9Exporting to ExecuTorch Tutorial
10================================
11
12**Author:** `Angela Yi <https://github.com/angelayi>`__
13"""
14
15######################################################################
16# ExecuTorch is a unified ML stack for lowering PyTorch models to edge devices.
17# It introduces improved entry points to perform model, device, and/or use-case
18# specific optimizations such as backend delegation, user-defined compiler
19# transformations, default or user-defined memory planning, and more.
20#
21# At a high level, the workflow looks as follows:
22#
23# .. image:: ../executorch_stack.png
24#   :width: 560
25#
26# In this tutorial, we will cover the APIs in the "Program preparation" steps to
27# lower a PyTorch model to a format which can be loaded to device and run on the
28# ExecuTorch runtime.
29
30######################################################################
31# Prerequisites
32# -------------
33#
34# To run this tutorial, you’ll first need to
35# `Set up your ExecuTorch environment <../getting-started-setup.html>`__.
36
37######################################################################
38# Exporting a Model
39# -----------------
40#
41# Note: The Export APIs are still undergoing changes to align better with the
42# longer term state of export. Please refer to this
43# `issue <https://github.com/pytorch/executorch/issues/290>`__ for more details.
44#
45# The first step of lowering to ExecuTorch is to export the given model (any
46# callable or ``torch.nn.Module``) to a graph representation. This is done via
47# ``torch.export``, which takes in an ``torch.nn.Module``, a tuple of
48# positional arguments, optionally a dictionary of keyword arguments (not shown
49# in the example), and a list of dynamic shapes (covered later).
50
51import torch
52from torch.export import export, ExportedProgram
53
54
55class SimpleConv(torch.nn.Module):
56    def __init__(self) -> None:
57        super().__init__()
58        self.conv = torch.nn.Conv2d(
59            in_channels=3, out_channels=16, kernel_size=3, padding=1
60        )
61        self.relu = torch.nn.ReLU()
62
63    def forward(self, x: torch.Tensor) -> torch.Tensor:
64        a = self.conv(x)
65        return self.relu(a)
66
67
68example_args = (torch.randn(1, 3, 256, 256),)
69aten_dialect: ExportedProgram = export(SimpleConv(), example_args)
70print(aten_dialect)
71
72######################################################################
73# The output of ``torch.export.export`` is a fully flattened graph (meaning the
74# graph does not contain any module hierarchy, except in the case of control
75# flow operators). Additionally, the graph is purely functional, meaning it does
76# not contain operations with side effects such as mutations or aliasing.
77#
78# More specifications about the result of ``torch.export`` can be found
79# `here <https://pytorch.org/docs/main/export.html>`__ .
80#
81# The graph returned by ``torch.export`` only contains functional ATen operators
82# (~2000 ops), which we will call the ``ATen Dialect``.
83
84######################################################################
85# Expressing Dynamism
86# ^^^^^^^^^^^^^^^^^^^
87#
88# By default, the exporting flow will trace the program assuming that all input
89# shapes are static, so if we run the program with inputs shapes that are
90# different than the ones we used while tracing, we will run into an error:
91
92import traceback as tb
93
94
95class Basic(torch.nn.Module):
96    def __init__(self):
97        super().__init__()
98
99    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
100        return x + y
101
102
103example_args = (torch.randn(3, 3), torch.randn(3, 3))
104aten_dialect: ExportedProgram = export(Basic(), example_args)
105
106# Works correctly
107print(aten_dialect.module()(torch.ones(3, 3), torch.ones(3, 3)))
108
109# Errors
110try:
111    print(aten_dialect.module()(torch.ones(3, 2), torch.ones(3, 2)))
112except Exception:
113    tb.print_exc()
114
115######################################################################
116# To express that some input shapes are dynamic, we can insert dynamic
117#  shapes to the exporting flow. This is done through the ``Dim`` API:
118
119from torch.export import Dim
120
121
122class Basic(torch.nn.Module):
123    def __init__(self):
124        super().__init__()
125
126    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
127        return x + y
128
129
130example_args = (torch.randn(3, 3), torch.randn(3, 3))
131dim1_x = Dim("dim1_x", min=1, max=10)
132dynamic_shapes = {"x": {1: dim1_x}, "y": {1: dim1_x}}
133aten_dialect: ExportedProgram = export(
134    Basic(), example_args, dynamic_shapes=dynamic_shapes
135)
136print(aten_dialect)
137
138######################################################################
139# Note that that the inputs ``arg0_1`` and ``arg1_1`` now have shapes (3, s0),
140# with ``s0`` being a symbol representing that this dimension can be a range
141# of values.
142#
143# Additionally, we can see in the **Range constraints** that value of ``s0`` has
144# the range [1, 10], which was specified by our dynamic shapes.
145#
146# Now let's try running the model with different shapes:
147
148# Works correctly
149print(aten_dialect.module()(torch.ones(3, 3), torch.ones(3, 3)))
150print(aten_dialect.module()(torch.ones(3, 2), torch.ones(3, 2)))
151
152# Errors because it violates our constraint that input 0, dim 1 <= 10
153try:
154    print(aten_dialect.module()(torch.ones(3, 15), torch.ones(3, 15)))
155except Exception:
156    tb.print_exc()
157
158# Errors because it violates our constraint that input 0, dim 1 == input 1, dim 1
159try:
160    print(aten_dialect.module()(torch.ones(3, 3), torch.ones(3, 2)))
161except Exception:
162    tb.print_exc()
163
164
165######################################################################
166# Addressing Untraceable Code
167# ^^^^^^^^^^^^^^^^^^^^^^^^^^^
168#
169# As our goal is to capture the entire computational graph from a PyTorch
170# program, we might ultimately run into untraceable parts of programs. To
171# address these issues, the
172# `torch.export documentation <https://pytorch.org/docs/main/export.html#limitations-of-torch-export>`__,
173# or the
174# `torch.export tutorial <https://pytorch.org/tutorials/intermediate/torch_export_tutorial.html>`__
175# would be the best place to look.
176
177######################################################################
178# Performing Quantization
179# -----------------------
180#
181# To quantize a model, we first need to capture the graph with
182# ``torch.export.export_for_training``, perform quantization, and then
183# call ``torch.export``. ``torch.export.export_for_training`` returns a
184# graph which contains ATen operators which are Autograd safe, meaning they are
185# safe for eager-mode training, which is needed for quantization. We will call
186# the graph at this level, the ``Pre-Autograd ATen Dialect`` graph.
187#
188# Compared to
189# `FX Graph Mode Quantization <https://pytorch.org/tutorials/prototype/fx_graph_mode_ptq_static.html>`__,
190# we will need to call two new APIs: ``prepare_pt2e`` and ``convert_pt2e``
191# instead of ``prepare_fx`` and ``convert_fx``. It differs in that
192# ``prepare_pt2e`` takes a backend-specific ``Quantizer`` as an argument, which
193# will annotate the nodes in the graph with information needed to quantize the
194# model properly for a specific backend.
195
196from torch.export import export_for_training
197
198example_args = (torch.randn(1, 3, 256, 256),)
199pre_autograd_aten_dialect = export_for_training(SimpleConv(), example_args).module()
200print("Pre-Autograd ATen Dialect Graph")
201print(pre_autograd_aten_dialect)
202
203from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
204from torch.ao.quantization.quantizer.xnnpack_quantizer import (
205    get_symmetric_quantization_config,
206    XNNPACKQuantizer,
207)
208
209quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
210prepared_graph = prepare_pt2e(pre_autograd_aten_dialect, quantizer)
211# calibrate with a sample dataset
212converted_graph = convert_pt2e(prepared_graph)
213print("Quantized Graph")
214print(converted_graph)
215
216aten_dialect: ExportedProgram = export(converted_graph, example_args)
217print("ATen Dialect Graph")
218print(aten_dialect)
219
220######################################################################
221# More information on how to quantize a model, and how a backend can implement a
222# ``Quantizer`` can be found
223# `here <https://pytorch.org/docs/main/quantization.html#prototype-pytorch-2-export-quantization>`__.
224
225######################################################################
226# Lowering to Edge Dialect
227# ------------------------
228#
229# After exporting and lowering the graph to the ``ATen Dialect``, the next step
230# is to lower to the ``Edge Dialect``, in which specializations that are useful
231# for edge devices but not necessary for general (server) environments will be
232# applied.
233# Some of these specializations include:
234#
235# - DType specialization
236# - Scalar to tensor conversion
237# - Converting all ops to the ``executorch.exir.dialects.edge`` namespace.
238#
239# Note that this dialect is still backend (or target) agnostic.
240#
241# The lowering is done through the ``to_edge`` API.
242
243from executorch.exir import EdgeProgramManager, to_edge
244
245example_args = (torch.randn(1, 3, 256, 256),)
246aten_dialect: ExportedProgram = export(SimpleConv(), example_args)
247
248edge_program: EdgeProgramManager = to_edge(aten_dialect)
249print("Edge Dialect Graph")
250print(edge_program.exported_program())
251
252######################################################################
253# ``to_edge()`` returns an ``EdgeProgramManager`` object, which contains the
254# exported programs which will be placed on this device. This data structure
255# allows users to export multiple programs and combine them into one binary. If
256# there is only one program, it will by default be saved to the name "forward".
257
258
259class Encode(torch.nn.Module):
260    def forward(self, x):
261        return torch.nn.functional.linear(x, torch.randn(5, 10))
262
263
264class Decode(torch.nn.Module):
265    def forward(self, x):
266        return torch.nn.functional.linear(x, torch.randn(10, 5))
267
268
269encode_args = (torch.randn(1, 10),)
270aten_encode: ExportedProgram = export(Encode(), encode_args)
271
272decode_args = (torch.randn(1, 5),)
273aten_decode: ExportedProgram = export(Decode(), decode_args)
274
275edge_program: EdgeProgramManager = to_edge(
276    {"encode": aten_encode, "decode": aten_decode}
277)
278for method in edge_program.methods:
279    print(f"Edge Dialect graph of {method}")
280    print(edge_program.exported_program(method))
281
282######################################################################
283# We can also run additional passes on the exported program through
284# the ``transform`` API. An in-depth documentation on how to write
285# transformations can be found
286# `here <../compiler-custom-compiler-passes.html>`__.
287#
288# Note that since the graph is now in the Edge Dialect, all passes must also
289# result in a valid Edge Dialect graph (specifically one thing to point out is
290# that the operators are now in the ``executorch.exir.dialects.edge`` namespace,
291# rather than the ``torch.ops.aten`` namespace.
292
293example_args = (torch.randn(1, 3, 256, 256),)
294aten_dialect: ExportedProgram = export(SimpleConv(), example_args)
295edge_program: EdgeProgramManager = to_edge(aten_dialect)
296print("Edge Dialect Graph")
297print(edge_program.exported_program())
298
299from executorch.exir.dialects._ops import ops as exir_ops
300from executorch.exir.pass_base import ExportPass
301
302
303class ConvertReluToSigmoid(ExportPass):
304    def call_operator(self, op, args, kwargs, meta):
305        if op == exir_ops.edge.aten.relu.default:
306            return super().call_operator(
307                exir_ops.edge.aten.sigmoid.default, args, kwargs, meta
308            )
309        else:
310            return super().call_operator(op, args, kwargs, meta)
311
312
313transformed_edge_program = edge_program.transform((ConvertReluToSigmoid(),))
314print("Transformed Edge Dialect Graph")
315print(transformed_edge_program.exported_program())
316
317######################################################################
318# Note: if you see error like ``torch._export.verifier.SpecViolationError:
319# Operator torch._ops.aten._native_batch_norm_legit_functional.default is not
320# Aten Canonical``,
321# please file an issue in https://github.com/pytorch/executorch/issues and we're happy to help!
322
323
324######################################################################
325# Delegating to a Backend
326# -----------------------
327#
328# We can now delegate parts of the graph or the whole graph to a third-party
329# backend through the ``to_backend`` API.  An in-depth documentation on the
330# specifics of backend delegation, including how to delegate to a backend and
331# how to implement a backend, can be found
332# `here <../compiler-delegate-and-partitioner.html>`__.
333#
334# There are three ways for using this API:
335#
336# 1. We can lower the whole module.
337# 2. We can take the lowered module, and insert it in another larger module.
338# 3. We can partition the module into subgraphs that are lowerable, and then
339#    lower those subgraphs to a backend.
340
341######################################################################
342# Lowering the Whole Module
343# ^^^^^^^^^^^^^^^^^^^^^^^^^
344#
345# To lower an entire module, we can pass ``to_backend`` the backend name, the
346# module to be lowered, and a list of compile specs to help the backend with the
347# lowering process.
348
349
350class LowerableModule(torch.nn.Module):
351    def __init__(self):
352        super().__init__()
353
354    def forward(self, x):
355        return torch.sin(x)
356
357
358# Export and lower the module to Edge Dialect
359example_args = (torch.ones(1),)
360aten_dialect: ExportedProgram = export(LowerableModule(), example_args)
361edge_program: EdgeProgramManager = to_edge(aten_dialect)
362to_be_lowered_module = edge_program.exported_program()
363
364from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend
365
366# Import the backend
367from executorch.exir.backend.test.backend_with_compiler_demo import (  # noqa
368    BackendWithCompilerDemo,
369)
370
371# Lower the module
372lowered_module: LoweredBackendModule = to_backend(
373    "BackendWithCompilerDemo", to_be_lowered_module, []
374)
375print(lowered_module)
376print(lowered_module.backend_id)
377print(lowered_module.processed_bytes)
378print(lowered_module.original_module)
379
380# Serialize and save it to a file
381save_path = "delegate.pte"
382with open(save_path, "wb") as f:
383    f.write(lowered_module.buffer())
384
385######################################################################
386# In this call, ``to_backend`` will return a ``LoweredBackendModule``. Some
387# important attributes of the ``LoweredBackendModule`` are:
388#
389# - ``backend_id``: The name of the backend this lowered module will run on in
390#   the runtime
391# - ``processed_bytes``: a binary blob which will tell the backend how to run
392#   this program in the runtime
393# - ``original_module``: the original exported module
394
395######################################################################
396# Compose the Lowered Module into Another Module
397# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
398#
399# In cases where we want to reuse this lowered module in multiple programs, we
400# can compose this lowered module with another module.
401
402
403class NotLowerableModule(torch.nn.Module):
404    def __init__(self, bias):
405        super().__init__()
406        self.bias = bias
407
408    def forward(self, a, b):
409        return torch.add(torch.add(a, b), self.bias)
410
411
412class ComposedModule(torch.nn.Module):
413    def __init__(self):
414        super().__init__()
415        self.non_lowerable = NotLowerableModule(torch.ones(1) * 0.3)
416        self.lowerable = lowered_module
417
418    def forward(self, x):
419        a = self.lowerable(x)
420        b = self.lowerable(a)
421        ret = self.non_lowerable(a, b)
422        return a, b, ret
423
424
425example_args = (torch.ones(1),)
426aten_dialect: ExportedProgram = export(ComposedModule(), example_args)
427edge_program: EdgeProgramManager = to_edge(aten_dialect)
428exported_program = edge_program.exported_program()
429print("Edge Dialect graph")
430print(exported_program)
431print("Lowered Module within the graph")
432print(exported_program.graph_module.lowered_module_0.backend_id)
433print(exported_program.graph_module.lowered_module_0.processed_bytes)
434print(exported_program.graph_module.lowered_module_0.original_module)
435
436######################################################################
437# Notice that there is now a ``torch.ops.higher_order.executorch_call_delegate`` node in the
438# graph, which is calling ``lowered_module_0``. Additionally, the contents of
439# ``lowered_module_0`` are the same as the ``lowered_module`` we created
440# previously.
441
442######################################################################
443# Partition and Lower Parts of a Module
444# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
445#
446# A separate lowering flow is to pass ``to_backend`` the module that we want to
447# lower, and a backend-specific partitioner. ``to_backend`` will use the
448# backend-specific partitioner to tag nodes in the module which are lowerable,
449# partition those nodes into subgraphs, and then create a
450# ``LoweredBackendModule`` for each of those subgraphs.
451
452
453class Foo(torch.nn.Module):
454    def forward(self, a, x, b):
455        y = torch.mm(a, x)
456        z = y + b
457        a = z - a
458        y = torch.mm(a, x)
459        z = y + b
460        return z
461
462
463example_args = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
464aten_dialect: ExportedProgram = export(Foo(), example_args)
465edge_program: EdgeProgramManager = to_edge(aten_dialect)
466exported_program = edge_program.exported_program()
467print("Edge Dialect graph")
468print(exported_program)
469
470from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo
471
472delegated_program = to_backend(exported_program, AddMulPartitionerDemo())
473print("Delegated program")
474print(delegated_program)
475print(delegated_program.graph_module.lowered_module_0.original_module)
476print(delegated_program.graph_module.lowered_module_1.original_module)
477
478######################################################################
479# Notice that there are now 2 ``torch.ops.higher_order.executorch_call_delegate`` nodes in the
480# graph, one containing the operations `add, mul` and the other containing the
481# operations `mul, add`.
482#
483# Alternatively, a more cohesive API to lower parts of a module is to directly
484# call ``to_backend`` on it:
485
486
487class Foo(torch.nn.Module):
488    def forward(self, a, x, b):
489        y = torch.mm(a, x)
490        z = y + b
491        a = z - a
492        y = torch.mm(a, x)
493        z = y + b
494        return z
495
496
497example_args = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
498aten_dialect: ExportedProgram = export(Foo(), example_args)
499edge_program: EdgeProgramManager = to_edge(aten_dialect)
500exported_program = edge_program.exported_program()
501delegated_program = edge_program.to_backend(AddMulPartitionerDemo())
502
503print("Delegated program")
504print(delegated_program.exported_program())
505
506######################################################################
507# Running User-Defined Passes and Memory Planning
508# -----------------------------------------------
509#
510# As a final step of lowering, we can use the ``to_executorch()`` API to pass in
511# backend-specific passes, such as replacing sets of operators with a custom
512# backend operator, and a memory planning pass, to tell the runtime how to
513# allocate memory ahead of time when running the program.
514#
515# A default memory planning pass is provided, but we can also choose a
516# backend-specific memory planning pass if it exists. More information on
517# writing a custom memory planning pass can be found
518# `here <../compiler-memory-planning.html>`__
519
520from executorch.exir import ExecutorchBackendConfig, ExecutorchProgramManager
521from executorch.exir.passes import MemoryPlanningPass
522
523executorch_program: ExecutorchProgramManager = edge_program.to_executorch(
524    ExecutorchBackendConfig(
525        passes=[],  # User-defined passes
526        memory_planning_pass=MemoryPlanningPass(),  # Default memory planning pass
527    )
528)
529
530print("ExecuTorch Dialect")
531print(executorch_program.exported_program())
532
533import executorch.exir as exir
534
535######################################################################
536# Notice that in the graph we now see operators like ``torch.ops.aten.sub.out``
537# and ``torch.ops.aten.div.out`` rather than ``torch.ops.aten.sub.Tensor`` and
538# ``torch.ops.aten.div.Tensor``.
539#
540# This is because between running the backend passes and memory planning passes,
541# to prepare the graph for memory planning, an out-variant pass is run on
542# the graph to convert all of the operators to their out variants. Instead of
543# allocating returned tensors in the kernel implementations, an operator's
544# ``out`` variant will take in a prealloacated tensor to its out kwarg, and
545# store the result there, making it easier for memory planners to do tensor
546# lifetime analysis.
547#
548# We also insert ``alloc`` nodes into the graph containing calls to a special
549# ``executorch.exir.memory.alloc`` operator. This tells us how much memory is
550# needed to allocate each tensor output by the out-variant operator.
551#
552
553######################################################################
554# Saving to a File
555# ----------------
556#
557# Finally, we can save the ExecuTorch Program to a file and load it to a device
558# to be run.
559#
560# Here is an example for an entire end-to-end workflow:
561
562import torch
563from torch.export import export, export_for_training, ExportedProgram
564
565
566class M(torch.nn.Module):
567    def __init__(self):
568        super().__init__()
569        self.param = torch.nn.Parameter(torch.rand(3, 4))
570        self.linear = torch.nn.Linear(4, 5)
571
572    def forward(self, x):
573        return self.linear(x + self.param).clamp(min=0.0, max=1.0)
574
575
576example_args = (torch.randn(3, 4),)
577pre_autograd_aten_dialect = export_for_training(M(), example_args).module()
578# Optionally do quantization:
579# pre_autograd_aten_dialect = convert_pt2e(prepare_pt2e(pre_autograd_aten_dialect, CustomBackendQuantizer))
580aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
581edge_program: exir.EdgeProgramManager = exir.to_edge(aten_dialect)
582# Optionally do delegation:
583# edge_program = edge_program.to_backend(CustomBackendPartitioner)
584executorch_program: exir.ExecutorchProgramManager = edge_program.to_executorch(
585    ExecutorchBackendConfig(
586        passes=[],  # User-defined passes
587    )
588)
589
590with open("model.pte", "wb") as file:
591    file.write(executorch_program.buffer)
592
593######################################################################
594# Conclusion
595# ----------
596#
597# In this tutorial, we went over the APIs and steps required to lower a PyTorch
598# program to a file that can be run on the ExecuTorch runtime.
599#
600# Links Mentioned
601# ^^^^^^^^^^^^^^^
602#
603# - `torch.export Documentation <https://pytorch.org/docs/2.1/export.html>`__
604# - `Quantization Documentation <https://pytorch.org/docs/main/quantization.html#prototype-pytorch-2-export-quantization>`__
605# - `IR Spec <../ir-exir.html>`__
606# - `Writing Compiler Passes + Partitioner Documentation <../compiler-custom-compiler-passes.html>`__
607# - `Backend Delegation Documentation <../compiler-delegate-and-partitioner.html>`__
608# - `Memory Planning Documentation <../compiler-memory-planning.html>`__
609