1import io 2import logging 3import sys 4import zipfile 5from pathlib import Path 6 7# Use asterisk symbol so developer doesn't need to import here when they add tests for upgraders. 8from test.jit.fixtures_srcs.fixtures_src import * # noqa: F403 9from typing import Set 10 11import torch 12from torch.jit.mobile import _export_operator_list, _load_for_lite_interpreter 13 14 15logging.basicConfig(stream=sys.stdout, level=logging.INFO) 16logger = logging.getLogger(__name__) 17logger.setLevel(logging.DEBUG) 18 19""" 20This file is used to generate model for test operator change. Please refer to 21https://github.com/pytorch/rfcs/blob/master/RFC-0017-PyTorch-Operator-Versioning.md for more details. 22 23A systematic workflow to change operator is needed to ensure 24Backwards Compatibility (BC) / Forwards Compatibility (FC) for operator changes. For BC-breaking operator change, 25an upgrader is needed. Here is the flow to properly land a BC-breaking operator change. 26 271. Write an upgrader in caffe2/torch/csrc/jit/operator_upgraders/upgraders_entry.cpp file. The softly enforced 28naming format is <operator_name>_<operator_overload>_<start>_<end>. For example, the below example means that 29div.Tensor at version from 0 to 3 needs to be replaced by this upgrader. 30 31``` 32/* 33div_Tensor_0_3 is added for a change of operator div in pr xxxxxxx. 34Create date: 12/02/2021 35Expire date: 06/02/2022 36*/ 37 {"div_Tensor_0_3", R"SCRIPT( 38def div_Tensor_0_3(self: Tensor, other: Tensor) -> Tensor: 39 if (self.is_floating_point() or other.is_floating_point()): 40 return self.true_divide(other) 41 return self.divide(other, rounding_mode='trunc') 42)SCRIPT"}, 43``` 44 452. In caffe2/torch/csrc/jit/operator_upgraders/version_map.h, add changes like below. 46You will need to make sure that the entry is SORTED according to the version bump number. 47``` 48 {"div.Tensor", 49 {{4, 50 "div_Tensor_0_3", 51 "aten::div.Tensor(Tensor self, Tensor other) -> Tensor"}}}, 52``` 53 543. After rebuild PyTorch, run the following command and it will auto generate a change to 55fbcode/caffe2/torch/csrc/jit/mobile/upgrader_mobile.cpp 56 57``` 58python pytorch/torchgen/operator_versions/gen_mobile_upgraders.py 59``` 60 614. Generate the test to cover upgrader. 62 634.1 Switch the commit before the operator change, and add a module in 64`test/jit/fixtures_srcs/fixtures_src.py`. The reason why switching to commit is that, 65an old model with the old operator before the change is needed to ensure the upgrader 66is working as expected. In `test/jit/fixtures_srcs/generate_models.py`, add the module and 67it's corresponding changed operator like following 68``` 69ALL_MODULES = { 70 TestVersionedDivTensorExampleV7(): "aten::div.Tensor", 71} 72``` 73This module should includes the changed operator. If the operator isn't covered in the model, 74the model export process in step 4.2 will fail. 75 764.2 Export the model to `test/jit/fixtures` by running 77``` 78python /Users/chenlai/pytorch/test/jit/fixtures_src/generate_models.py 79``` 80 814.3 In `test/jit/test_save_load_for_op_version.py`, add a test to cover the old models and 82ensure the result is equivalent between current module and old module + upgrader. 83 844.4 Save all change in 4.1, 4.2 and 4.3, as well as previous changes made in step 1, 2, 3. 85Submit a pr 86 87""" 88 89""" 90A map of test modules and it's according changed operator 91key: test module 92value: changed operator 93""" 94ALL_MODULES = { 95 TestVersionedDivTensorExampleV7(): "aten::div.Tensor", 96 TestVersionedLinspaceV7(): "aten::linspace", 97 TestVersionedLinspaceOutV7(): "aten::linspace.out", 98 TestVersionedLogspaceV8(): "aten::logspace", 99 TestVersionedLogspaceOutV8(): "aten::logspace.out", 100 TestVersionedGeluV9(): "aten::gelu", 101 TestVersionedGeluOutV9(): "aten::gelu.out", 102 TestVersionedRandomV10(): "aten::random_.from", 103 TestVersionedRandomFuncV10(): "aten::random.from", 104 TestVersionedRandomOutV10(): "aten::random.from_out", 105} 106 107""" 108Get the path to `test/jit/fixtures`, where all test models for operator changes 109(upgrader/downgrader) are stored 110""" 111 112 113def get_fixtures_path() -> Path: 114 pytorch_dir = Path(__file__).resolve().parents[3] 115 fixtures_path = pytorch_dir / "test" / "jit" / "fixtures" 116 return fixtures_path 117 118 119""" 120Get all models' name in `test/jit/fixtures` 121""" 122 123 124def get_all_models(model_directory_path: Path) -> Set[str]: 125 files_in_fixtures = model_directory_path.glob("**/*") 126 all_models_from_fixtures = [ 127 fixture.stem for fixture in files_in_fixtures if fixture.is_file() 128 ] 129 return set(all_models_from_fixtures) 130 131 132""" 133Check if a given model already exist in `test/jit/fixtures` 134""" 135 136 137def model_exist(model_file_name: str, all_models: Set[str]) -> bool: 138 return model_file_name in all_models 139 140 141""" 142Get the operator list given a module 143""" 144 145 146def get_operator_list(script_module: torch) -> Set[str]: 147 buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter()) 148 buffer.seek(0) 149 mobile_module = _load_for_lite_interpreter(buffer) 150 operator_list = _export_operator_list(mobile_module) 151 return operator_list 152 153 154""" 155Get the output model operator version, given a module 156""" 157 158 159def get_output_model_version(script_module: torch.nn.Module) -> int: 160 buffer = io.BytesIO() 161 torch.jit.save(script_module, buffer) 162 buffer.seek(0) 163 zipped_model = zipfile.ZipFile(buffer) 164 try: 165 version = int(zipped_model.read("archive/version").decode("utf-8")) 166 return version 167 except KeyError: 168 version = int(zipped_model.read("archive/.data/version").decode("utf-8")) 169 return version 170 171 172""" 173Loop through all test modules. If the corresponding model doesn't exist in 174`test/jit/fixtures`, generate one. For the following reason, a model won't be exported: 175 1761. The test module doens't cover the changed operator. For example, test_versioned_div_tensor_example_v4 177is supposed to test the operator aten::div.Tensor. If the model doesn't include this operator, it will fail. 178The error message includes the actual operator list from the model. 179 1802. The output model version is not the same as expected version. For example, test_versioned_div_tensor_example_v4 181is used to test an operator change aten::div.Tensor, and the operator version will be bumped to v5. This script is 182supposed to run before the operator change (before the commit to make the change). If the actual model version is v5, 183likely this script is running with the commit to make the change. 184 1853. The model already exists in `test/jit/fixtures`. 186 187""" 188 189 190def generate_models(model_directory_path: Path): 191 all_models = get_all_models(model_directory_path) 192 for a_module, expect_operator in ALL_MODULES.items(): 193 # For example: TestVersionedDivTensorExampleV7 194 torch_module_name = type(a_module).__name__ 195 196 if not isinstance(a_module, torch.nn.Module): 197 logger.error( 198 "The module %s " 199 "is not a torch.nn.module instance. " 200 "Please ensure it's a subclass of torch.nn.module in fixtures_src.py" 201 "and it's registered as an instance in ALL_MODULES in generated_models.py", 202 torch_module_name, 203 ) 204 205 # The corresponding model name is: test_versioned_div_tensor_example_v4 206 model_name = "".join( 207 [ 208 "_" + char.lower() if char.isupper() else char 209 for char in torch_module_name 210 ] 211 ).lstrip("_") 212 213 # Some models may not compile anymore, so skip the ones 214 # that already has pt file for them. 215 logger.info("Processing %s", torch_module_name) 216 if model_exist(model_name, all_models): 217 logger.info("Model %s already exists, skipping", model_name) 218 continue 219 220 script_module = torch.jit.script(a_module) 221 actual_model_version = get_output_model_version(script_module) 222 223 current_operator_version = torch._C._get_max_operator_version() 224 if actual_model_version >= current_operator_version + 1: 225 logger.error( 226 "Actual model version %s " 227 "is equal or larger than %s + 1. " 228 "Please run the script before the commit to change operator.", 229 actual_model_version, 230 current_operator_version, 231 ) 232 continue 233 234 actual_operator_list = get_operator_list(script_module) 235 if expect_operator not in actual_operator_list: 236 logger.error( 237 "The model includes operator: %s, " 238 "however it doesn't cover the operator %s." 239 "Please ensure the output model includes the tested operator.", 240 actual_operator_list, 241 expect_operator, 242 ) 243 continue 244 245 export_model_path = str(model_directory_path / (str(model_name) + ".ptl")) 246 script_module._save_for_lite_interpreter(export_model_path) 247 logger.info( 248 "Generating model %s and it's save to %s", model_name, export_model_path 249 ) 250 251 252def main() -> None: 253 model_directory_path = get_fixtures_path() 254 generate_models(model_directory_path) 255 256 257if __name__ == "__main__": 258 main() 259