1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import copy 8 9import logging 10import random 11import sys 12from abc import ABC, abstractmethod 13from collections import Counter, OrderedDict 14from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union 15 16import torch 17from executorch.backends.xnnpack._passes import XNNPACKPassManager 18from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner 19from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config 20from executorch.exir import ( 21 EdgeCompileConfig, 22 EdgeProgramManager, 23 ExecutorchBackendConfig, 24 ExecutorchProgramManager, 25 to_edge, 26 to_edge_transform_and_lower, 27) 28from executorch.exir.backend.backend_api import validation_disabled 29from executorch.exir.backend.partitioner import Partitioner 30from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass 31 32from executorch.exir.print_program import pretty_print, print_program 33from torch.export import export_for_training 34 35logger = logging.getLogger(__name__) 36logger.setLevel(logging.INFO) 37try: 38 from executorch.extension.pybindings.portable_lib import ( # @manual 39 _load_for_executorch_from_buffer, 40 ) 41except ImportError as e: 42 logger.warning(f"{e=}") 43 pass 44 45from executorch.exir.program._program import _transform 46from torch._export.pass_base import PassType 47from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e 48from torch.ao.quantization.quantizer.quantizer import Quantizer 49from torch.ao.quantization.quantizer.xnnpack_quantizer import ( 50 get_symmetric_quantization_config, 51 XNNPACKQuantizer, 52) 53from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import QuantizationConfig 54from torch.export import export, ExportedProgram 55from torch.testing import FileCheck 56from torch.utils._pytree import tree_flatten 57 58 59class Stage(ABC): 60 """ 61 Interface for a Stage in the PT2.0 lowering pipeline 62 """ 63 64 @abstractmethod 65 def run(self, artifact, inputs): 66 """ 67 Executes this stage, generates the 'artifact', for later stages. 68 """ 69 pass 70 71 @property 72 @abstractmethod 73 def artifact(self): 74 """ 75 Returns the artifact generated by this stage. To be used by the next stage in the pipeline. 76 """ 77 pass 78 79 @property 80 @abstractmethod 81 def graph_module(self): 82 """ 83 Return the artifact's graph module for this stage 84 """ 85 pass 86 87 def run_artifact(self, inputs): 88 """ 89 Returns the output of calling the artifact generated by this stage with inputs 90 """ 91 if isinstance(self.artifact, ExportedProgram): 92 return self.artifact(*inputs) 93 else: 94 return self.artifact.exported_program().module()(*inputs) 95 96 # Debug Tools for stages 97 def artifact_str(self): 98 """ 99 Return string printable artifact for this stage 100 """ 101 if isinstance(self.artifact, EdgeProgramManager): 102 return self.artifact.exported_program() 103 return self.artifact 104 105 def stage_banner(self): 106 """ 107 Returns banner string for this stage 108 """ 109 return "#" * 36 + " " + str(self.__class__.__name__) + " " + "#" * 36 + "\n" 110 111 def dump_artifact(self, path_to_dump: Optional[str]): 112 """ 113 Dumps string printable artifact to path. If path_to_dump, then it is printed to terminal 114 """ 115 if path_to_dump: 116 with open(path_to_dump, "a") as fp: 117 fp.write(str(self.stage_banner() + "\n")) 118 fp.write(str(self.artifact_str())) 119 else: 120 print(self.stage_banner() + "\n") 121 print(self.artifact_str()) 122 123 124_stages_: Dict[str, Stage] = {} 125 126 127def register_stage(stage: Stage): 128 """ 129 Register a Stage to be used in the Tester. 130 """ 131 assert isinstance(stage, type) 132 name = stage.__qualname__ 133 if name in _stages_: 134 raise RuntimeError(f"Duplicate stage in Tester, {name}") 135 _stages_[name] = stage 136 return stage 137 138 139@register_stage 140class Quantize(Stage): 141 def __init__( 142 self, 143 quantizer: Optional[Quantizer] = None, 144 quantization_config: Optional[QuantizationConfig] = None, 145 calibrate: bool = True, 146 ): 147 self.quantizer = quantizer or XNNPACKQuantizer() 148 self.quantization_config = ( 149 quantization_config or get_symmetric_quantization_config() 150 ) 151 self.calibrate = calibrate 152 153 self.quantizer.set_global(self.quantization_config) 154 155 self.converted_graph = None 156 157 def run( 158 self, artifact: torch.nn.Module, inputs: Optional[Tuple[torch.Tensor]] 159 ) -> None: 160 assert inputs is not None 161 captured_graph = export_for_training(artifact, inputs).module() 162 163 assert isinstance(captured_graph, torch.fx.GraphModule) 164 prepared = prepare_pt2e(captured_graph, self.quantizer) 165 166 if self.calibrate: 167 # Calibrate prepared model to provide data to quantization observers. 168 prepared(*inputs) 169 170 converted = convert_pt2e(prepared) 171 self.converted_graph = converted 172 173 @property 174 def artifact(self) -> torch.fx.GraphModule: 175 return self.converted_graph 176 177 @property 178 def graph_module(self) -> str: 179 return self.converted_graph 180 181 def run_artifact(self, inputs): 182 return self.converted_graph.forward(*inputs) 183 184 185@register_stage 186class Export(Stage): 187 def __init__(self, dynamic_shapes: Optional[Tuple[Any]] = None): 188 self.exported_program = None 189 self.dynamic_shapes = dynamic_shapes 190 191 def run( 192 self, 193 artifact: torch.nn.Module, 194 inputs: Tuple[torch.Tensor], 195 ) -> None: 196 self.exported_program = export( 197 artifact, inputs, dynamic_shapes=self.dynamic_shapes 198 ) 199 200 @property 201 def artifact(self) -> ExportedProgram: 202 return self.exported_program 203 204 @property 205 def graph_module(self) -> str: 206 return self.exported_program.graph_module 207 208 209@register_stage 210class ToEdge(Stage): 211 def __init__(self, edge_compile_config: Optional[EdgeCompileConfig] = None): 212 self.edge_compile_conf = ( 213 edge_compile_config or get_xnnpack_edge_compile_config() 214 ) 215 self.edge_dialect_program = None 216 217 def run(self, artifact: ExportedProgram, inputs=None) -> None: 218 self.edge_dialect_program = to_edge( 219 artifact, compile_config=self.edge_compile_conf 220 ) 221 222 @property 223 def artifact(self) -> EdgeProgramManager: 224 return self.edge_dialect_program 225 226 @property 227 def graph_module(self) -> str: 228 return self.edge_dialect_program.exported_program().graph_module 229 230 231@register_stage 232class RunPasses(Stage): 233 def __init__( 234 self, 235 pass_list: Optional[List[Type[PassType]]] = None, 236 pass_functions: Optional[List[Callable]] = None, 237 ): 238 self.pass_list = pass_list 239 self.pass_functions = pass_functions 240 self.edge_or_aten_program = None 241 242 def run( 243 self, artifact: Union[EdgeProgramManager, ExportedProgram], inputs=None 244 ) -> None: 245 if isinstance(artifact, EdgeProgramManager): 246 self.edge_or_aten_program = artifact 247 if self.pass_list: 248 pass_manager = XNNPACKPassManager( 249 artifact.exported_program(), self.pass_list 250 ) 251 self.edge_or_aten_program._edge_programs["forward"] = ( 252 pass_manager.transform() 253 ) 254 if self.pass_functions: 255 assert isinstance(self.pass_functions, list) 256 for pass_function in self.pass_functions: 257 self.edge_or_aten_program._edge_programs["forward"] = pass_function( 258 self.edge_or_aten_program.exported_program() 259 ) 260 else: 261 transformed_ep = artifact 262 if self.pass_list: 263 assert isinstance(self.pass_list, list) 264 for pass_ in self.pass_list: 265 transformed_ep = _transform(transformed_ep, pass_()) 266 267 if self.pass_functions: 268 assert isinstance(self.pass_functions, list) 269 for pass_function in self.pass_functions: 270 transformed_ep = pass_function(transformed_ep) 271 272 self.edge_or_aten_program = transformed_ep 273 274 @property 275 def artifact(self) -> Union[EdgeProgramManager, ExportedProgram]: 276 return self.edge_or_aten_program 277 278 @property 279 def graph_module(self) -> str: 280 if isinstance(self.edge_or_aten_program, EdgeProgramManager): 281 return self.edge_or_aten_program.exported_program().graph_module 282 else: 283 return self.edge_or_aten_program.graph_module 284 285 286@register_stage 287class ToEdgeTransformAndLower(Stage): 288 def __init__( 289 self, 290 partitioners: Optional[List[Partitioner]] = None, 291 edge_compile_config: Optional[EdgeCompileConfig] = None, 292 ): 293 self.partitioners = partitioners or [XnnpackPartitioner()] 294 self.edge_compile_conf = ( 295 edge_compile_config or get_xnnpack_edge_compile_config() 296 ) 297 self.edge_dialect_program = None 298 299 def run(self, artifact: ExportedProgram, inputs=None) -> None: 300 artifact_to_run = copy.deepcopy(artifact) 301 self.edge_dialect_program = to_edge_transform_and_lower( 302 artifact_to_run, 303 compile_config=self.edge_compile_conf, 304 partitioner=self.partitioners, 305 ) 306 307 @property 308 def artifact(self) -> EdgeProgramManager: 309 return self.edge_dialect_program 310 311 @property 312 def graph_module(self) -> str: 313 return self.edge_dialect_program.exported_program().graph_module 314 315 316@register_stage 317class Partition(Stage): 318 def __init__(self, partitioner: Optional[Partitioner] = None): 319 self.partitioner = partitioner or XnnpackPartitioner() 320 self.delegate_module = None 321 322 def run(self, artifact: EdgeProgramManager, inputs=None): 323 with validation_disabled(): 324 self.delegate_module = artifact 325 self.delegate_module = self.delegate_module.to_backend(self.partitioner) 326 327 @property 328 def artifact(self) -> EdgeProgramManager: 329 return self.delegate_module 330 331 @property 332 def graph_module(self) -> str: 333 return self.delegate_module.exported_program().graph_module 334 335 336@register_stage 337class ToExecutorch(Stage): 338 def __init__( 339 self, 340 config: Optional[ExecutorchBackendConfig] = None, 341 ): 342 self.config = config or ExecutorchBackendConfig( 343 extract_delegate_segments=True, 344 sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), 345 ) 346 self.executorch_program = None 347 348 def run(self, artifact: EdgeProgramManager, inputs=None): 349 self.executorch_program = artifact.to_executorch(self.config) 350 351 @property 352 def artifact(self) -> ExecutorchProgramManager: 353 return self.executorch_program 354 355 @property 356 def graph_module(self) -> str: 357 return self.executorch_program().graph_module 358 359 def dump_artifact(self, path_to_dump: Optional[str]): 360 """ 361 dump_artifact is overridden to dump the serialized program 362 """ 363 original_stdout = sys.stdout 364 365 sys.stdout = open(path_to_dump, "a") if path_to_dump else sys.stdout 366 print(self.stage_banner() + "\n") 367 pretty_print(self.artifact._emitter_output.program) 368 print_program( 369 self.artifact._emitter_output.program, 370 show_meminfo=True, 371 mark_dynamic_shape_tensor=True, 372 ) 373 sys.stdout = original_stdout 374 375 376@register_stage 377class Serialize(Stage): 378 def __init__(self): 379 self.buffer = None 380 381 def run(self, artifact: ExecutorchProgramManager, inputs=None) -> None: 382 self.buffer = artifact.buffer 383 384 @property 385 def artifact(self) -> bytes: 386 return self.buffer 387 388 @property 389 def graph_module(self) -> None: 390 return None 391 392 def run_artifact(self, inputs): 393 inputs_flattened, _ = tree_flatten(inputs) 394 executorch_module = _load_for_executorch_from_buffer(self.buffer) 395 executorch_output = copy.deepcopy( 396 executorch_module.run_method("forward", tuple(inputs_flattened)) 397 ) 398 return executorch_output 399 400 def dump_artifact(self, path_to_dump: Optional[str]): 401 """ 402 dump_artifact is overridden to dump the serialized bytes into pte file 403 """ 404 if not path_to_dump: 405 raise RuntimeError("path_to_dump file not provided") 406 else: 407 with open(path_to_dump, "wb") as f: 408 f.write(self.artifact) 409 410 411class Tester: 412 def __init__( 413 self, 414 module: torch.nn.Module, 415 example_inputs: Tuple[torch.Tensor], 416 dynamic_shapes: Optional[Tuple[Any]] = None, 417 ): 418 module.eval() 419 420 self.original_module = module 421 self.example_inputs = example_inputs 422 self.dynamic_shapes = dynamic_shapes 423 self.stages: Dict[str, Stage] = OrderedDict.fromkeys(list(_stages_.keys())) 424 self.pipeline = { 425 self.stage_name(Quantize): [self.stage_name(Export)], 426 self.stage_name(Export): [ 427 self.stage_name(RunPasses), 428 self.stage_name(ToEdge), 429 self.stage_name(ToEdgeTransformAndLower), 430 ], 431 self.stage_name(ToEdgeTransformAndLower): [ 432 self.stage_name(RunPasses), 433 self.stage_name(ToExecutorch), 434 ], 435 self.stage_name(ToEdge): [ 436 self.stage_name(Partition), 437 self.stage_name(RunPasses), 438 ], 439 self.stage_name(RunPasses): [ 440 self.stage_name(Partition), 441 self.stage_name(ToEdgeTransformAndLower), 442 ], 443 # TODO Make this Stage optional 444 self.stage_name(Partition): [self.stage_name(ToExecutorch)], 445 self.stage_name(ToExecutorch): [self.stage_name(Serialize)], 446 self.stage_name(Serialize): [], 447 } 448 assert all( 449 stage in self.pipeline for stage in self.stages 450 ), "Invalid Tester internal state!" 451 452 # Current stage name 453 self.cur: str = "" 454 455 # Reference output from eager mode 456 self.reference_output = None 457 458 # Quantization scale from eager mode 459 self.quantization_scale: Optional[float] = None 460 461 # Artifact output from stage 462 self.stage_output = None 463 464 def generate_random_inputs(self): 465 # Get shapes of inputs 466 input_shapes = [] 467 if self.dynamic_shapes is None: 468 for tensor_arg in self.example_inputs: 469 assert isinstance(tensor_arg, torch.Tensor) 470 input_shapes.append(tensor_arg.shape) 471 else: 472 # Random shapes depending on dynamic shape constraint 473 dim_name_to_size = {} 474 for arg_idx in range(len(self.example_inputs)): 475 assert isinstance(self.example_inputs[arg_idx], torch.Tensor) 476 ex_shape = list(self.example_inputs[arg_idx].shape) 477 dynamic_dim_spec = self.dynamic_shapes[arg_idx] 478 for dim_idx, dim_spec in dynamic_dim_spec.items(): 479 assert dim_idx < len(ex_shape) 480 if isinstance(dim_spec, torch.export.dynamic_shapes._DerivedDim): 481 # derived dims are of the form {0: 2 * torch.export.Dim() // 2} 482 # The root contains the min/max of the export dim and fn contains 483 # the function to compute the derived dim. 484 dim_spec = dim_spec.root 485 fn = dim_spec.fn 486 elif isinstance(dim_spec, torch.export.dynamic_shapes._Dim): 487 # Not derived dim so fn is just itself 488 def fn(x): 489 return x 490 491 else: 492 raise RuntimeError( 493 f"Expected Dynamic Dims to be of type _DerivedDim or _Dim but got {type(dim_spec)}" 494 ) 495 dim_name = dim_spec.__name__ 496 if dim_name not in dim_name_to_size: 497 upper_bound = min( 498 dim_spec.max, 1000 499 ) # unbounded int max is too large 500 lower_bound = ( 501 dim_spec.min if dim_spec.min >= 2 else 1 502 ) # 0/1 specialization means dim_spec.min can never be 1 503 dim_name_to_size[dim_name] = fn( 504 random.randint(lower_bound, upper_bound) 505 ) 506 ex_shape[dim_idx] = dim_name_to_size[dim_spec.__name__] 507 input_shapes.append(torch.Size(ex_shape)) 508 # create random tensor inputs with the shapes given above: 509 random_inputs = [] 510 for arg_idx in range(len(self.example_inputs)): 511 random_inputs.append( 512 torch.randn(input_shapes[arg_idx]).to( 513 dtype=self.example_inputs[arg_idx].dtype 514 ) 515 ) 516 517 yield tuple(random_inputs) 518 519 @staticmethod 520 def stage_name(stage) -> str: 521 t = stage if isinstance(stage, type) else type(stage) 522 return t.__qualname__ 523 524 def _pre(self, stage): 525 name: str = self.stage_name(stage) 526 assert isinstance(name, str) and name in self.stages and not self.stages[name] 527 528 last_artifact = self.original_module 529 if self.cur: 530 assert self.cur in self.pipeline, f"Invalid state: {self.cur}" 531 allowed_next_stages = self.pipeline[self.cur] 532 assert name in allowed_next_stages, f"Invalid next stage: {name}" 533 last_artifact = self.get_artifact() 534 self.cur = name 535 return last_artifact 536 537 def _post(self, stage): 538 name = self.stage_name(stage) 539 assert name in self.stages 540 self.stages[name] = stage 541 542 def _run_stage(self, stage_instance, inputs=None): 543 assert isinstance(stage_instance, Stage) 544 prev_stage_artifact = self._pre(stage_instance) 545 stage_instance.run(prev_stage_artifact, inputs=inputs) 546 self._post(stage_instance) 547 return self 548 549 # Stages 550 def quantize(self, quantize_stage: Optional[Quantize] = None): 551 return self._run_stage(quantize_stage or Quantize(), self.example_inputs) 552 553 def export(self, export_stage: Optional[Export] = None): 554 return self._run_stage( 555 export_stage or Export(dynamic_shapes=self.dynamic_shapes), 556 self.example_inputs, 557 ) 558 559 def to_edge(self, to_edge_stage: Optional[ToEdge] = None): 560 # TODO(T182187531): Skip dim order for now. Support dim order and its op after alpha release. 561 if not to_edge_stage: 562 to_edge_stage = ToEdge() 563 to_edge_stage.edge_compile_conf._skip_dim_order = True 564 res = self._run_stage(to_edge_stage) 565 return res 566 567 def to_edge_transform_and_lower( 568 self, to_edge_and_transform_stage: Optional[ToEdgeTransformAndLower] = None 569 ): 570 return self._run_stage(to_edge_and_transform_stage or ToEdgeTransformAndLower()) 571 572 def run_passes(self, run_passes_stage: Optional[RunPasses] = None): 573 return self._run_stage(run_passes_stage or RunPasses()) 574 575 def partition(self, partition_stage: Optional[Partition] = None): 576 return self._run_stage(partition_stage or Partition()) 577 578 def to_executorch(self, to_executorch_stage: Optional[ToExecutorch] = None): 579 return self._run_stage(to_executorch_stage or ToExecutorch()) 580 581 def serialize(self, serialize_stage: Optional[Serialize] = None): 582 return self._run_stage(serialize_stage or Serialize()) 583 584 # Util functions 585 def dump_artifact(self, path: Optional[str] = None, stage: Optional[str] = None): 586 stage = stage or self.cur 587 self.stages[stage].dump_artifact(path) 588 return self 589 590 def get_artifact(self, stage: Optional[str] = None): 591 stage = stage or self.cur 592 return self.stages[stage].artifact 593 594 def check(self, input: List[str]): 595 for key in input: 596 FileCheck().check(key).run(self.stages[self.cur].graph_module.code) 597 return self 598 599 def check_not(self, input: List[str]): 600 for key in input: 601 FileCheck().check_not(key).run(self.stages[self.cur].graph_module.code) 602 return self 603 604 def check_count(self, input: Dict[Any, int]): 605 # TODO target checks similar to checkGraphModuleNodes() 606 for key, count in input.items(): 607 FileCheck().check_count(key, count, exactly=True).run( 608 self.stages[self.cur].graph_module.code 609 ) 610 return self 611 612 def check_node_count(self, input: Dict[Any, int]): 613 # Count the occurances of each target in the graph. 614 target_ops = [ 615 node.target 616 for node in self.stages[self.cur].graph_module.graph.nodes 617 if node.op == "call_function" 618 ] 619 op_counts = Counter(target_ops) 620 621 for key, count in input.items(): 622 if count != op_counts[key]: 623 print(f"Nodes: {op_counts}") 624 raise AssertionError( 625 f"Expected {count} {key} nodes but found {op_counts[key]}." 626 ) 627 628 return self 629 630 def run_method_and_compare_outputs( 631 self, 632 stage: Optional[str] = None, 633 inputs: Optional[Tuple[torch.Tensor]] = None, 634 num_runs=1, 635 atol=1e-03, 636 rtol=1e-03, 637 qtol=0, 638 ): 639 number_of_runs = 1 if inputs is not None else num_runs 640 reference_stage = self.stages[self.stage_name(Export)] 641 642 stage = stage or self.cur 643 644 print(f"Comparing Stage {stage} with Stage {reference_stage}") 645 for run_iteration in range(number_of_runs): 646 inputs_to_run = inputs if inputs else next(self.generate_random_inputs()) 647 input_shapes = [generated_input.shape for generated_input in inputs_to_run] 648 print(f"Run {run_iteration} with input shapes: {input_shapes}") 649 650 # Reference output (and quantization scale) 651 ( 652 reference_output, 653 quantization_scale, 654 ) = self._calculate_reference_output( 655 reference_stage.artifact, inputs_to_run 656 ) 657 658 # Output from running artifact at stage 659 stage_output = self.stages[stage].run_artifact(inputs_to_run) 660 self._compare_outputs( 661 reference_output, stage_output, quantization_scale, atol, rtol, qtol 662 ) 663 664 return self 665 666 @staticmethod 667 def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03): 668 """ 669 Helper testing function that asserts that the model output and the reference output 670 are equal with some tolerance. Due to numerical differences between eager mode and 671 the XNNPACK's backend, we relax the detal such that absolute tolerance is 1e-3. and 672 relative tolerance is 1e-3. In the event that the computation was quantized, we 673 further relax the tolerance to one quantized step (equal to the quantization scale). 674 This allows the quantized value to differ by 1 between the reference and model output. 675 """ 676 677 assert len(model_output) == len(ref_output) 678 679 for i in range(len(model_output)): 680 model = model_output[i] 681 ref = ref_output[i] 682 assert torch.allclose( 683 model, 684 ref, 685 atol=atol, 686 rtol=rtol, 687 ), ( 688 f"Output {i} does not match reference output.\n" 689 f"\tGiven atol: {atol}, rtol: {rtol}.\n" 690 f"\tOutput tensor shape: {model.shape}, dtype: {model.dtype}\n" 691 f"\tDifference: max: {torch.max(model-ref)}, abs: {torch.max(torch.abs(model-ref))}, mean abs error: {torch.mean(torch.abs(model-ref))}.\n" 692 f"\t-- Model vs. Reference --\n" 693 f"\t Numel: {model.numel()}, {ref.numel()}\n" 694 f"\tMedian: {model.median()}, {ref.median()}\n" 695 f"\t Mean: {model.mean()}, {ref.mean()}\n" 696 f"\t Max: {model.max()}, {ref.max()}\n" 697 f"\t Min: {model.min()}, {ref.min()}\n" 698 ) 699 700 @staticmethod 701 def _compare_outputs( 702 reference_output, 703 stage_output, 704 quantization_scale=None, 705 atol=1e-03, 706 rtol=1e-03, 707 qtol=0, 708 ): 709 """ 710 Compares the original of the original nn module with the output of the generated artifact. 711 This requres calling run_method before calling compare_outputs. As that runs the generated 712 artifact on the sample inputs and sets the stage output to be compared against the reference. 713 """ 714 # Wrap both outputs as tuple, since executor output is always a tuple even if single tensor 715 if isinstance(reference_output, torch.Tensor): 716 reference_output = (reference_output,) 717 if isinstance(stage_output, torch.Tensor): 718 stage_output = (stage_output,) 719 720 # If a qtol is provided and we found an dequantization node prior to the output, relax the 721 # atol by qtol quant units. 722 if quantization_scale is not None: 723 atol += quantization_scale * qtol 724 725 Tester._assert_outputs_equal( 726 stage_output, 727 reference_output, 728 atol=atol, 729 rtol=rtol, 730 ) 731 732 @staticmethod 733 def _calculate_reference_output( 734 program: ExportedProgram, inputs 735 ) -> Tuple[torch.Tensor, Optional[float]]: 736 """ 737 Execute the reference program and return the output. If the output comes from a dequantize node, 738 return the quantization scale as well. 739 """ 740 741 # Locate the output node. 742 output_node = None 743 for node in program.graph.nodes: 744 if node.op == "output": 745 output_node = node 746 break 747 assert output_node is not None 748 749 # Look for a dequantization node in the output node args. Returned values are found in the first 750 # argument of the output node. 751 dequant_node = None 752 for arg_node in output_node.args[0]: 753 if ( 754 arg_node.op == "call_function" 755 and arg_node.target 756 == torch.ops.quantized_decomposed.dequantize_per_tensor.default 757 ): 758 dequant_node = arg_node 759 break 760 761 scale = None 762 if dequant_node is not None: 763 original_target = dequant_node.target 764 765 # Replace the dequant node with shim to intercept the quantization parameters. 766 # It will be invoked when we evaluate the program to find the reference outputs. 767 def dequant_shim(*args): 768 nonlocal scale 769 scale = args[1] 770 result = original_target(*args) 771 return result 772 773 dequant_node.target = dequant_shim 774 775 output = program.module()(*inputs) 776 return output, scale 777