xref: /aosp_15_r20/external/pytorch/test/jit/fixtures_srcs/generate_models.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import io
2import logging
3import sys
4import zipfile
5from pathlib import Path
6
7# Use asterisk symbol so developer doesn't need to import here when they add tests for upgraders.
8from test.jit.fixtures_srcs.fixtures_src import *  # noqa: F403
9from typing import Set
10
11import torch
12from torch.jit.mobile import _export_operator_list, _load_for_lite_interpreter
13
14
15logging.basicConfig(stream=sys.stdout, level=logging.INFO)
16logger = logging.getLogger(__name__)
17logger.setLevel(logging.DEBUG)
18
19"""
20This file is used to generate model for test operator change. Please refer to
21https://github.com/pytorch/rfcs/blob/master/RFC-0017-PyTorch-Operator-Versioning.md for more details.
22
23A systematic workflow to change operator is needed to ensure
24Backwards Compatibility (BC) / Forwards Compatibility (FC) for operator changes. For BC-breaking operator change,
25an upgrader is needed. Here is the flow to properly land a BC-breaking operator change.
26
271. Write an upgrader in caffe2/torch/csrc/jit/operator_upgraders/upgraders_entry.cpp file. The softly enforced
28naming format is <operator_name>_<operator_overload>_<start>_<end>. For example, the below example means that
29div.Tensor at version from 0 to 3 needs to be replaced by this upgrader.
30
31```
32/*
33div_Tensor_0_3 is added for a change of operator div in pr xxxxxxx.
34Create date: 12/02/2021
35Expire date: 06/02/2022
36*/
37     {"div_Tensor_0_3", R"SCRIPT(
38def div_Tensor_0_3(self: Tensor, other: Tensor) -> Tensor:
39  if (self.is_floating_point() or other.is_floating_point()):
40    return self.true_divide(other)
41  return self.divide(other, rounding_mode='trunc')
42)SCRIPT"},
43```
44
452. In caffe2/torch/csrc/jit/operator_upgraders/version_map.h, add changes like below.
46You will need to make sure that the entry is SORTED according to the version bump number.
47```
48    {"div.Tensor",
49      {{4,
50        "div_Tensor_0_3",
51        "aten::div.Tensor(Tensor self, Tensor other) -> Tensor"}}},
52```
53
543. After rebuild PyTorch, run the following command and it will auto generate a change to
55fbcode/caffe2/torch/csrc/jit/mobile/upgrader_mobile.cpp
56
57```
58python pytorch/torchgen/operator_versions/gen_mobile_upgraders.py
59```
60
614. Generate the test to cover upgrader.
62
634.1 Switch the commit before the operator change, and add a module in
64`test/jit/fixtures_srcs/fixtures_src.py`. The reason why switching to commit is that,
65an old model with the old operator before the change is needed to ensure the upgrader
66is working as expected. In `test/jit/fixtures_srcs/generate_models.py`, add the module and
67it's corresponding changed operator like following
68```
69ALL_MODULES = {
70    TestVersionedDivTensorExampleV7(): "aten::div.Tensor",
71}
72```
73This module should includes the changed operator. If the operator isn't covered in the model,
74the model export process in step 4.2 will fail.
75
764.2 Export the model to `test/jit/fixtures` by running
77```
78python /Users/chenlai/pytorch/test/jit/fixtures_src/generate_models.py
79```
80
814.3 In `test/jit/test_save_load_for_op_version.py`, add a test to cover the old models and
82ensure the result is equivalent between current module and old module + upgrader.
83
844.4 Save all change in 4.1, 4.2 and 4.3, as well as previous changes made in step 1, 2, 3.
85Submit a pr
86
87"""
88
89"""
90A map of test modules and it's according changed operator
91key: test module
92value: changed operator
93"""
94ALL_MODULES = {
95    TestVersionedDivTensorExampleV7(): "aten::div.Tensor",
96    TestVersionedLinspaceV7(): "aten::linspace",
97    TestVersionedLinspaceOutV7(): "aten::linspace.out",
98    TestVersionedLogspaceV8(): "aten::logspace",
99    TestVersionedLogspaceOutV8(): "aten::logspace.out",
100    TestVersionedGeluV9(): "aten::gelu",
101    TestVersionedGeluOutV9(): "aten::gelu.out",
102    TestVersionedRandomV10(): "aten::random_.from",
103    TestVersionedRandomFuncV10(): "aten::random.from",
104    TestVersionedRandomOutV10(): "aten::random.from_out",
105}
106
107"""
108Get the path to `test/jit/fixtures`, where all test models for operator changes
109(upgrader/downgrader) are stored
110"""
111
112
113def get_fixtures_path() -> Path:
114    pytorch_dir = Path(__file__).resolve().parents[3]
115    fixtures_path = pytorch_dir / "test" / "jit" / "fixtures"
116    return fixtures_path
117
118
119"""
120Get all models' name in `test/jit/fixtures`
121"""
122
123
124def get_all_models(model_directory_path: Path) -> Set[str]:
125    files_in_fixtures = model_directory_path.glob("**/*")
126    all_models_from_fixtures = [
127        fixture.stem for fixture in files_in_fixtures if fixture.is_file()
128    ]
129    return set(all_models_from_fixtures)
130
131
132"""
133Check if a given model already exist in `test/jit/fixtures`
134"""
135
136
137def model_exist(model_file_name: str, all_models: Set[str]) -> bool:
138    return model_file_name in all_models
139
140
141"""
142Get the operator list given a module
143"""
144
145
146def get_operator_list(script_module: torch) -> Set[str]:
147    buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
148    buffer.seek(0)
149    mobile_module = _load_for_lite_interpreter(buffer)
150    operator_list = _export_operator_list(mobile_module)
151    return operator_list
152
153
154"""
155Get the output model operator version, given a module
156"""
157
158
159def get_output_model_version(script_module: torch.nn.Module) -> int:
160    buffer = io.BytesIO()
161    torch.jit.save(script_module, buffer)
162    buffer.seek(0)
163    zipped_model = zipfile.ZipFile(buffer)
164    try:
165        version = int(zipped_model.read("archive/version").decode("utf-8"))
166        return version
167    except KeyError:
168        version = int(zipped_model.read("archive/.data/version").decode("utf-8"))
169        return version
170
171
172"""
173Loop through all test modules. If the corresponding model doesn't exist in
174`test/jit/fixtures`, generate one. For the following reason, a model won't be exported:
175
1761. The test module doens't cover the changed operator. For example, test_versioned_div_tensor_example_v4
177is supposed to test the operator aten::div.Tensor. If the model doesn't include this operator, it will fail.
178The error message includes the actual operator list from the model.
179
1802. The output model version is not the same as expected version. For example, test_versioned_div_tensor_example_v4
181is used to test an operator change aten::div.Tensor, and the operator version will be bumped to v5. This script is
182supposed to run before the operator change (before the commit to make the change). If the actual model version is v5,
183likely this script is running with the commit to make the change.
184
1853. The model already exists in `test/jit/fixtures`.
186
187"""
188
189
190def generate_models(model_directory_path: Path):
191    all_models = get_all_models(model_directory_path)
192    for a_module, expect_operator in ALL_MODULES.items():
193        # For example: TestVersionedDivTensorExampleV7
194        torch_module_name = type(a_module).__name__
195
196        if not isinstance(a_module, torch.nn.Module):
197            logger.error(
198                "The module %s "
199                "is not a torch.nn.module instance. "
200                "Please ensure it's a subclass of torch.nn.module in fixtures_src.py"
201                "and it's registered as an instance in ALL_MODULES in generated_models.py",
202                torch_module_name,
203            )
204
205        # The corresponding model name is: test_versioned_div_tensor_example_v4
206        model_name = "".join(
207            [
208                "_" + char.lower() if char.isupper() else char
209                for char in torch_module_name
210            ]
211        ).lstrip("_")
212
213        # Some models may not compile anymore, so skip the ones
214        # that already has pt file for them.
215        logger.info("Processing %s", torch_module_name)
216        if model_exist(model_name, all_models):
217            logger.info("Model %s already exists, skipping", model_name)
218            continue
219
220        script_module = torch.jit.script(a_module)
221        actual_model_version = get_output_model_version(script_module)
222
223        current_operator_version = torch._C._get_max_operator_version()
224        if actual_model_version >= current_operator_version + 1:
225            logger.error(
226                "Actual model version %s "
227                "is equal or larger than %s + 1. "
228                "Please run the script before the commit to change operator.",
229                actual_model_version,
230                current_operator_version,
231            )
232            continue
233
234        actual_operator_list = get_operator_list(script_module)
235        if expect_operator not in actual_operator_list:
236            logger.error(
237                "The model includes operator: %s, "
238                "however it doesn't cover the operator %s."
239                "Please ensure the output model includes the tested operator.",
240                actual_operator_list,
241                expect_operator,
242            )
243            continue
244
245        export_model_path = str(model_directory_path / (str(model_name) + ".ptl"))
246        script_module._save_for_lite_interpreter(export_model_path)
247        logger.info(
248            "Generating model %s and it's save to %s", model_name, export_model_path
249        )
250
251
252def main() -> None:
253    model_directory_path = get_fixtures_path()
254    generate_models(model_directory_path)
255
256
257if __name__ == "__main__":
258    main()
259