xref: /aosp_15_r20/external/executorch/backends/xnnpack/test/tester/tester.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import copy
8
9import logging
10import random
11import sys
12from abc import ABC, abstractmethod
13from collections import Counter, OrderedDict
14from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
15
16import torch
17from executorch.backends.xnnpack._passes import XNNPACKPassManager
18from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
19from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config
20from executorch.exir import (
21    EdgeCompileConfig,
22    EdgeProgramManager,
23    ExecutorchBackendConfig,
24    ExecutorchProgramManager,
25    to_edge,
26    to_edge_transform_and_lower,
27)
28from executorch.exir.backend.backend_api import validation_disabled
29from executorch.exir.backend.partitioner import Partitioner
30from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
31
32from executorch.exir.print_program import pretty_print, print_program
33from torch.export import export_for_training
34
35logger = logging.getLogger(__name__)
36logger.setLevel(logging.INFO)
37try:
38    from executorch.extension.pybindings.portable_lib import (  # @manual
39        _load_for_executorch_from_buffer,
40    )
41except ImportError as e:
42    logger.warning(f"{e=}")
43    pass
44
45from executorch.exir.program._program import _transform
46from torch._export.pass_base import PassType
47from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
48from torch.ao.quantization.quantizer.quantizer import Quantizer
49from torch.ao.quantization.quantizer.xnnpack_quantizer import (
50    get_symmetric_quantization_config,
51    XNNPACKQuantizer,
52)
53from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import QuantizationConfig
54from torch.export import export, ExportedProgram
55from torch.testing import FileCheck
56from torch.utils._pytree import tree_flatten
57
58
59class Stage(ABC):
60    """
61    Interface for a Stage in the PT2.0 lowering pipeline
62    """
63
64    @abstractmethod
65    def run(self, artifact, inputs):
66        """
67        Executes this stage, generates the 'artifact', for later stages.
68        """
69        pass
70
71    @property
72    @abstractmethod
73    def artifact(self):
74        """
75        Returns the artifact generated by this stage. To be used by the next stage in the pipeline.
76        """
77        pass
78
79    @property
80    @abstractmethod
81    def graph_module(self):
82        """
83        Return the artifact's graph module for this stage
84        """
85        pass
86
87    def run_artifact(self, inputs):
88        """
89        Returns the output of calling the artifact generated by this stage with inputs
90        """
91        if isinstance(self.artifact, ExportedProgram):
92            return self.artifact(*inputs)
93        else:
94            return self.artifact.exported_program().module()(*inputs)
95
96    # Debug Tools for stages
97    def artifact_str(self):
98        """
99        Return string printable artifact for this stage
100        """
101        if isinstance(self.artifact, EdgeProgramManager):
102            return self.artifact.exported_program()
103        return self.artifact
104
105    def stage_banner(self):
106        """
107        Returns banner string for this stage
108        """
109        return "#" * 36 + " " + str(self.__class__.__name__) + " " + "#" * 36 + "\n"
110
111    def dump_artifact(self, path_to_dump: Optional[str]):
112        """
113        Dumps string printable artifact to path. If path_to_dump, then it is printed to terminal
114        """
115        if path_to_dump:
116            with open(path_to_dump, "a") as fp:
117                fp.write(str(self.stage_banner() + "\n"))
118                fp.write(str(self.artifact_str()))
119        else:
120            print(self.stage_banner() + "\n")
121            print(self.artifact_str())
122
123
124_stages_: Dict[str, Stage] = {}
125
126
127def register_stage(stage: Stage):
128    """
129    Register a Stage to be used in the Tester.
130    """
131    assert isinstance(stage, type)
132    name = stage.__qualname__
133    if name in _stages_:
134        raise RuntimeError(f"Duplicate stage in Tester, {name}")
135    _stages_[name] = stage
136    return stage
137
138
139@register_stage
140class Quantize(Stage):
141    def __init__(
142        self,
143        quantizer: Optional[Quantizer] = None,
144        quantization_config: Optional[QuantizationConfig] = None,
145        calibrate: bool = True,
146    ):
147        self.quantizer = quantizer or XNNPACKQuantizer()
148        self.quantization_config = (
149            quantization_config or get_symmetric_quantization_config()
150        )
151        self.calibrate = calibrate
152
153        self.quantizer.set_global(self.quantization_config)
154
155        self.converted_graph = None
156
157    def run(
158        self, artifact: torch.nn.Module, inputs: Optional[Tuple[torch.Tensor]]
159    ) -> None:
160        assert inputs is not None
161        captured_graph = export_for_training(artifact, inputs).module()
162
163        assert isinstance(captured_graph, torch.fx.GraphModule)
164        prepared = prepare_pt2e(captured_graph, self.quantizer)
165
166        if self.calibrate:
167            # Calibrate prepared model to provide data to quantization observers.
168            prepared(*inputs)
169
170        converted = convert_pt2e(prepared)
171        self.converted_graph = converted
172
173    @property
174    def artifact(self) -> torch.fx.GraphModule:
175        return self.converted_graph
176
177    @property
178    def graph_module(self) -> str:
179        return self.converted_graph
180
181    def run_artifact(self, inputs):
182        return self.converted_graph.forward(*inputs)
183
184
185@register_stage
186class Export(Stage):
187    def __init__(self, dynamic_shapes: Optional[Tuple[Any]] = None):
188        self.exported_program = None
189        self.dynamic_shapes = dynamic_shapes
190
191    def run(
192        self,
193        artifact: torch.nn.Module,
194        inputs: Tuple[torch.Tensor],
195    ) -> None:
196        self.exported_program = export(
197            artifact, inputs, dynamic_shapes=self.dynamic_shapes
198        )
199
200    @property
201    def artifact(self) -> ExportedProgram:
202        return self.exported_program
203
204    @property
205    def graph_module(self) -> str:
206        return self.exported_program.graph_module
207
208
209@register_stage
210class ToEdge(Stage):
211    def __init__(self, edge_compile_config: Optional[EdgeCompileConfig] = None):
212        self.edge_compile_conf = (
213            edge_compile_config or get_xnnpack_edge_compile_config()
214        )
215        self.edge_dialect_program = None
216
217    def run(self, artifact: ExportedProgram, inputs=None) -> None:
218        self.edge_dialect_program = to_edge(
219            artifact, compile_config=self.edge_compile_conf
220        )
221
222    @property
223    def artifact(self) -> EdgeProgramManager:
224        return self.edge_dialect_program
225
226    @property
227    def graph_module(self) -> str:
228        return self.edge_dialect_program.exported_program().graph_module
229
230
231@register_stage
232class RunPasses(Stage):
233    def __init__(
234        self,
235        pass_list: Optional[List[Type[PassType]]] = None,
236        pass_functions: Optional[List[Callable]] = None,
237    ):
238        self.pass_list = pass_list
239        self.pass_functions = pass_functions
240        self.edge_or_aten_program = None
241
242    def run(
243        self, artifact: Union[EdgeProgramManager, ExportedProgram], inputs=None
244    ) -> None:
245        if isinstance(artifact, EdgeProgramManager):
246            self.edge_or_aten_program = artifact
247            if self.pass_list:
248                pass_manager = XNNPACKPassManager(
249                    artifact.exported_program(), self.pass_list
250                )
251                self.edge_or_aten_program._edge_programs["forward"] = (
252                    pass_manager.transform()
253                )
254            if self.pass_functions:
255                assert isinstance(self.pass_functions, list)
256                for pass_function in self.pass_functions:
257                    self.edge_or_aten_program._edge_programs["forward"] = pass_function(
258                        self.edge_or_aten_program.exported_program()
259                    )
260        else:
261            transformed_ep = artifact
262            if self.pass_list:
263                assert isinstance(self.pass_list, list)
264                for pass_ in self.pass_list:
265                    transformed_ep = _transform(transformed_ep, pass_())
266
267            if self.pass_functions:
268                assert isinstance(self.pass_functions, list)
269                for pass_function in self.pass_functions:
270                    transformed_ep = pass_function(transformed_ep)
271
272            self.edge_or_aten_program = transformed_ep
273
274    @property
275    def artifact(self) -> Union[EdgeProgramManager, ExportedProgram]:
276        return self.edge_or_aten_program
277
278    @property
279    def graph_module(self) -> str:
280        if isinstance(self.edge_or_aten_program, EdgeProgramManager):
281            return self.edge_or_aten_program.exported_program().graph_module
282        else:
283            return self.edge_or_aten_program.graph_module
284
285
286@register_stage
287class ToEdgeTransformAndLower(Stage):
288    def __init__(
289        self,
290        partitioners: Optional[List[Partitioner]] = None,
291        edge_compile_config: Optional[EdgeCompileConfig] = None,
292    ):
293        self.partitioners = partitioners or [XnnpackPartitioner()]
294        self.edge_compile_conf = (
295            edge_compile_config or get_xnnpack_edge_compile_config()
296        )
297        self.edge_dialect_program = None
298
299    def run(self, artifact: ExportedProgram, inputs=None) -> None:
300        artifact_to_run = copy.deepcopy(artifact)
301        self.edge_dialect_program = to_edge_transform_and_lower(
302            artifact_to_run,
303            compile_config=self.edge_compile_conf,
304            partitioner=self.partitioners,
305        )
306
307    @property
308    def artifact(self) -> EdgeProgramManager:
309        return self.edge_dialect_program
310
311    @property
312    def graph_module(self) -> str:
313        return self.edge_dialect_program.exported_program().graph_module
314
315
316@register_stage
317class Partition(Stage):
318    def __init__(self, partitioner: Optional[Partitioner] = None):
319        self.partitioner = partitioner or XnnpackPartitioner()
320        self.delegate_module = None
321
322    def run(self, artifact: EdgeProgramManager, inputs=None):
323        with validation_disabled():
324            self.delegate_module = artifact
325            self.delegate_module = self.delegate_module.to_backend(self.partitioner)
326
327    @property
328    def artifact(self) -> EdgeProgramManager:
329        return self.delegate_module
330
331    @property
332    def graph_module(self) -> str:
333        return self.delegate_module.exported_program().graph_module
334
335
336@register_stage
337class ToExecutorch(Stage):
338    def __init__(
339        self,
340        config: Optional[ExecutorchBackendConfig] = None,
341    ):
342        self.config = config or ExecutorchBackendConfig(
343            extract_delegate_segments=True,
344            sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
345        )
346        self.executorch_program = None
347
348    def run(self, artifact: EdgeProgramManager, inputs=None):
349        self.executorch_program = artifact.to_executorch(self.config)
350
351    @property
352    def artifact(self) -> ExecutorchProgramManager:
353        return self.executorch_program
354
355    @property
356    def graph_module(self) -> str:
357        return self.executorch_program().graph_module
358
359    def dump_artifact(self, path_to_dump: Optional[str]):
360        """
361        dump_artifact is overridden to dump the serialized program
362        """
363        original_stdout = sys.stdout
364
365        sys.stdout = open(path_to_dump, "a") if path_to_dump else sys.stdout
366        print(self.stage_banner() + "\n")
367        pretty_print(self.artifact._emitter_output.program)
368        print_program(
369            self.artifact._emitter_output.program,
370            show_meminfo=True,
371            mark_dynamic_shape_tensor=True,
372        )
373        sys.stdout = original_stdout
374
375
376@register_stage
377class Serialize(Stage):
378    def __init__(self):
379        self.buffer = None
380
381    def run(self, artifact: ExecutorchProgramManager, inputs=None) -> None:
382        self.buffer = artifact.buffer
383
384    @property
385    def artifact(self) -> bytes:
386        return self.buffer
387
388    @property
389    def graph_module(self) -> None:
390        return None
391
392    def run_artifact(self, inputs):
393        inputs_flattened, _ = tree_flatten(inputs)
394        executorch_module = _load_for_executorch_from_buffer(self.buffer)
395        executorch_output = copy.deepcopy(
396            executorch_module.run_method("forward", tuple(inputs_flattened))
397        )
398        return executorch_output
399
400    def dump_artifact(self, path_to_dump: Optional[str]):
401        """
402        dump_artifact is overridden to dump the serialized bytes into pte file
403        """
404        if not path_to_dump:
405            raise RuntimeError("path_to_dump file not provided")
406        else:
407            with open(path_to_dump, "wb") as f:
408                f.write(self.artifact)
409
410
411class Tester:
412    def __init__(
413        self,
414        module: torch.nn.Module,
415        example_inputs: Tuple[torch.Tensor],
416        dynamic_shapes: Optional[Tuple[Any]] = None,
417    ):
418        module.eval()
419
420        self.original_module = module
421        self.example_inputs = example_inputs
422        self.dynamic_shapes = dynamic_shapes
423        self.stages: Dict[str, Stage] = OrderedDict.fromkeys(list(_stages_.keys()))
424        self.pipeline = {
425            self.stage_name(Quantize): [self.stage_name(Export)],
426            self.stage_name(Export): [
427                self.stage_name(RunPasses),
428                self.stage_name(ToEdge),
429                self.stage_name(ToEdgeTransformAndLower),
430            ],
431            self.stage_name(ToEdgeTransformAndLower): [
432                self.stage_name(RunPasses),
433                self.stage_name(ToExecutorch),
434            ],
435            self.stage_name(ToEdge): [
436                self.stage_name(Partition),
437                self.stage_name(RunPasses),
438            ],
439            self.stage_name(RunPasses): [
440                self.stage_name(Partition),
441                self.stage_name(ToEdgeTransformAndLower),
442            ],
443            # TODO Make this Stage optional
444            self.stage_name(Partition): [self.stage_name(ToExecutorch)],
445            self.stage_name(ToExecutorch): [self.stage_name(Serialize)],
446            self.stage_name(Serialize): [],
447        }
448        assert all(
449            stage in self.pipeline for stage in self.stages
450        ), "Invalid Tester internal state!"
451
452        # Current stage name
453        self.cur: str = ""
454
455        # Reference output from eager mode
456        self.reference_output = None
457
458        # Quantization scale from eager mode
459        self.quantization_scale: Optional[float] = None
460
461        # Artifact output from stage
462        self.stage_output = None
463
464    def generate_random_inputs(self):
465        # Get shapes of inputs
466        input_shapes = []
467        if self.dynamic_shapes is None:
468            for tensor_arg in self.example_inputs:
469                assert isinstance(tensor_arg, torch.Tensor)
470                input_shapes.append(tensor_arg.shape)
471        else:
472            # Random shapes depending on dynamic shape constraint
473            dim_name_to_size = {}
474            for arg_idx in range(len(self.example_inputs)):
475                assert isinstance(self.example_inputs[arg_idx], torch.Tensor)
476                ex_shape = list(self.example_inputs[arg_idx].shape)
477                dynamic_dim_spec = self.dynamic_shapes[arg_idx]
478                for dim_idx, dim_spec in dynamic_dim_spec.items():
479                    assert dim_idx < len(ex_shape)
480                    if isinstance(dim_spec, torch.export.dynamic_shapes._DerivedDim):
481                        # derived dims are of the form {0: 2 * torch.export.Dim() // 2}
482                        # The root contains the min/max of the export dim and fn contains
483                        # the function to compute the derived dim.
484                        dim_spec = dim_spec.root
485                        fn = dim_spec.fn
486                    elif isinstance(dim_spec, torch.export.dynamic_shapes._Dim):
487                        # Not derived dim so fn is just itself
488                        def fn(x):
489                            return x
490
491                    else:
492                        raise RuntimeError(
493                            f"Expected Dynamic Dims to be of type _DerivedDim or _Dim but got {type(dim_spec)}"
494                        )
495                    dim_name = dim_spec.__name__
496                    if dim_name not in dim_name_to_size:
497                        upper_bound = min(
498                            dim_spec.max, 1000
499                        )  # unbounded int max is too large
500                        lower_bound = (
501                            dim_spec.min if dim_spec.min >= 2 else 1
502                        )  # 0/1 specialization means dim_spec.min can never be 1
503                        dim_name_to_size[dim_name] = fn(
504                            random.randint(lower_bound, upper_bound)
505                        )
506                    ex_shape[dim_idx] = dim_name_to_size[dim_spec.__name__]
507                input_shapes.append(torch.Size(ex_shape))
508        # create random tensor inputs with the shapes given above:
509        random_inputs = []
510        for arg_idx in range(len(self.example_inputs)):
511            random_inputs.append(
512                torch.randn(input_shapes[arg_idx]).to(
513                    dtype=self.example_inputs[arg_idx].dtype
514                )
515            )
516
517        yield tuple(random_inputs)
518
519    @staticmethod
520    def stage_name(stage) -> str:
521        t = stage if isinstance(stage, type) else type(stage)
522        return t.__qualname__
523
524    def _pre(self, stage):
525        name: str = self.stage_name(stage)
526        assert isinstance(name, str) and name in self.stages and not self.stages[name]
527
528        last_artifact = self.original_module
529        if self.cur:
530            assert self.cur in self.pipeline, f"Invalid state: {self.cur}"
531            allowed_next_stages = self.pipeline[self.cur]
532            assert name in allowed_next_stages, f"Invalid next stage: {name}"
533            last_artifact = self.get_artifact()
534        self.cur = name
535        return last_artifact
536
537    def _post(self, stage):
538        name = self.stage_name(stage)
539        assert name in self.stages
540        self.stages[name] = stage
541
542    def _run_stage(self, stage_instance, inputs=None):
543        assert isinstance(stage_instance, Stage)
544        prev_stage_artifact = self._pre(stage_instance)
545        stage_instance.run(prev_stage_artifact, inputs=inputs)
546        self._post(stage_instance)
547        return self
548
549    # Stages
550    def quantize(self, quantize_stage: Optional[Quantize] = None):
551        return self._run_stage(quantize_stage or Quantize(), self.example_inputs)
552
553    def export(self, export_stage: Optional[Export] = None):
554        return self._run_stage(
555            export_stage or Export(dynamic_shapes=self.dynamic_shapes),
556            self.example_inputs,
557        )
558
559    def to_edge(self, to_edge_stage: Optional[ToEdge] = None):
560        # TODO(T182187531): Skip dim order for now. Support dim order and its op after alpha release.
561        if not to_edge_stage:
562            to_edge_stage = ToEdge()
563        to_edge_stage.edge_compile_conf._skip_dim_order = True
564        res = self._run_stage(to_edge_stage)
565        return res
566
567    def to_edge_transform_and_lower(
568        self, to_edge_and_transform_stage: Optional[ToEdgeTransformAndLower] = None
569    ):
570        return self._run_stage(to_edge_and_transform_stage or ToEdgeTransformAndLower())
571
572    def run_passes(self, run_passes_stage: Optional[RunPasses] = None):
573        return self._run_stage(run_passes_stage or RunPasses())
574
575    def partition(self, partition_stage: Optional[Partition] = None):
576        return self._run_stage(partition_stage or Partition())
577
578    def to_executorch(self, to_executorch_stage: Optional[ToExecutorch] = None):
579        return self._run_stage(to_executorch_stage or ToExecutorch())
580
581    def serialize(self, serialize_stage: Optional[Serialize] = None):
582        return self._run_stage(serialize_stage or Serialize())
583
584    # Util functions
585    def dump_artifact(self, path: Optional[str] = None, stage: Optional[str] = None):
586        stage = stage or self.cur
587        self.stages[stage].dump_artifact(path)
588        return self
589
590    def get_artifact(self, stage: Optional[str] = None):
591        stage = stage or self.cur
592        return self.stages[stage].artifact
593
594    def check(self, input: List[str]):
595        for key in input:
596            FileCheck().check(key).run(self.stages[self.cur].graph_module.code)
597        return self
598
599    def check_not(self, input: List[str]):
600        for key in input:
601            FileCheck().check_not(key).run(self.stages[self.cur].graph_module.code)
602        return self
603
604    def check_count(self, input: Dict[Any, int]):
605        # TODO target checks similar to checkGraphModuleNodes()
606        for key, count in input.items():
607            FileCheck().check_count(key, count, exactly=True).run(
608                self.stages[self.cur].graph_module.code
609            )
610        return self
611
612    def check_node_count(self, input: Dict[Any, int]):
613        # Count the occurances of each target in the graph.
614        target_ops = [
615            node.target
616            for node in self.stages[self.cur].graph_module.graph.nodes
617            if node.op == "call_function"
618        ]
619        op_counts = Counter(target_ops)
620
621        for key, count in input.items():
622            if count != op_counts[key]:
623                print(f"Nodes: {op_counts}")
624                raise AssertionError(
625                    f"Expected {count} {key} nodes but found {op_counts[key]}."
626                )
627
628        return self
629
630    def run_method_and_compare_outputs(
631        self,
632        stage: Optional[str] = None,
633        inputs: Optional[Tuple[torch.Tensor]] = None,
634        num_runs=1,
635        atol=1e-03,
636        rtol=1e-03,
637        qtol=0,
638    ):
639        number_of_runs = 1 if inputs is not None else num_runs
640        reference_stage = self.stages[self.stage_name(Export)]
641
642        stage = stage or self.cur
643
644        print(f"Comparing Stage {stage} with Stage {reference_stage}")
645        for run_iteration in range(number_of_runs):
646            inputs_to_run = inputs if inputs else next(self.generate_random_inputs())
647            input_shapes = [generated_input.shape for generated_input in inputs_to_run]
648            print(f"Run {run_iteration} with input shapes: {input_shapes}")
649
650            # Reference output (and quantization scale)
651            (
652                reference_output,
653                quantization_scale,
654            ) = self._calculate_reference_output(
655                reference_stage.artifact, inputs_to_run
656            )
657
658            # Output from running artifact at stage
659            stage_output = self.stages[stage].run_artifact(inputs_to_run)
660            self._compare_outputs(
661                reference_output, stage_output, quantization_scale, atol, rtol, qtol
662            )
663
664        return self
665
666    @staticmethod
667    def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
668        """
669        Helper testing function that asserts that the model output and the reference output
670        are equal with some tolerance. Due to numerical differences between eager mode and
671        the XNNPACK's backend, we relax the detal such that absolute tolerance is 1e-3. and
672        relative tolerance is 1e-3. In the event that the computation was quantized, we
673        further relax the tolerance to one quantized step (equal to the quantization scale).
674        This allows the quantized value to differ by 1 between the reference and model output.
675        """
676
677        assert len(model_output) == len(ref_output)
678
679        for i in range(len(model_output)):
680            model = model_output[i]
681            ref = ref_output[i]
682            assert torch.allclose(
683                model,
684                ref,
685                atol=atol,
686                rtol=rtol,
687            ), (
688                f"Output {i} does not match reference output.\n"
689                f"\tGiven atol: {atol}, rtol: {rtol}.\n"
690                f"\tOutput tensor shape: {model.shape}, dtype: {model.dtype}\n"
691                f"\tDifference: max: {torch.max(model-ref)}, abs: {torch.max(torch.abs(model-ref))}, mean abs error: {torch.mean(torch.abs(model-ref))}.\n"
692                f"\t-- Model vs. Reference --\n"
693                f"\t Numel: {model.numel()}, {ref.numel()}\n"
694                f"\tMedian: {model.median()}, {ref.median()}\n"
695                f"\t  Mean: {model.mean()}, {ref.mean()}\n"
696                f"\t   Max: {model.max()}, {ref.max()}\n"
697                f"\t   Min: {model.min()}, {ref.min()}\n"
698            )
699
700    @staticmethod
701    def _compare_outputs(
702        reference_output,
703        stage_output,
704        quantization_scale=None,
705        atol=1e-03,
706        rtol=1e-03,
707        qtol=0,
708    ):
709        """
710        Compares the original of the original nn module with the output of the generated artifact.
711        This requres calling run_method before calling compare_outputs. As that runs the generated
712        artifact on the sample inputs and sets the stage output to be compared against the reference.
713        """
714        # Wrap both outputs as tuple, since executor output is always a tuple even if single tensor
715        if isinstance(reference_output, torch.Tensor):
716            reference_output = (reference_output,)
717        if isinstance(stage_output, torch.Tensor):
718            stage_output = (stage_output,)
719
720        # If a qtol is provided and we found an dequantization node prior to the output, relax the
721        # atol by qtol quant units.
722        if quantization_scale is not None:
723            atol += quantization_scale * qtol
724
725        Tester._assert_outputs_equal(
726            stage_output,
727            reference_output,
728            atol=atol,
729            rtol=rtol,
730        )
731
732    @staticmethod
733    def _calculate_reference_output(
734        program: ExportedProgram, inputs
735    ) -> Tuple[torch.Tensor, Optional[float]]:
736        """
737        Execute the reference program and return the output. If the output comes from a dequantize node,
738        return the quantization scale as well.
739        """
740
741        # Locate the output node.
742        output_node = None
743        for node in program.graph.nodes:
744            if node.op == "output":
745                output_node = node
746                break
747        assert output_node is not None
748
749        # Look for a dequantization node in the output node args. Returned values are found in the first
750        # argument of the output node.
751        dequant_node = None
752        for arg_node in output_node.args[0]:
753            if (
754                arg_node.op == "call_function"
755                and arg_node.target
756                == torch.ops.quantized_decomposed.dequantize_per_tensor.default
757            ):
758                dequant_node = arg_node
759                break
760
761        scale = None
762        if dequant_node is not None:
763            original_target = dequant_node.target
764
765            # Replace the dequant node with shim to intercept the quantization parameters.
766            # It will be invoked when we evaluate the program to find the reference outputs.
767            def dequant_shim(*args):
768                nonlocal scale
769                scale = args[1]
770                result = original_target(*args)
771                return result
772
773            dequant_node.target = dequant_shim
774
775        output = program.module()(*inputs)
776        return output, scale
777