1# Owner(s): ["oncall: export"] 2import torch 3 4 5try: 6 from . import test_export, testing 7except ImportError: 8 import test_export 9 10 import testing 11 12 13test_classes = {} 14 15 16def mocked_training_ir_to_run_decomp_export_strict(*args, **kwargs): 17 ep = torch.export.export_for_training(*args, **kwargs) 18 return ep.run_decompositions( 19 {}, _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY 20 ) 21 22 23def mocked_training_ir_to_run_decomp_export_non_strict(*args, **kwargs): 24 if "strict" in kwargs: 25 ep = torch.export.export_for_training(*args, **kwargs) 26 else: 27 ep = torch.export.export_for_training(*args, **kwargs, strict=False) 28 return ep.run_decompositions( 29 {}, _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY 30 ) 31 32 33def make_dynamic_cls(cls, strict): 34 if strict: 35 test_class = testing.make_test_cls_with_mocked_export( 36 cls, 37 "TrainingIRToRunDecompExport", 38 test_export.TRAINING_IR_DECOMP_STRICT_SUFFIX, 39 mocked_training_ir_to_run_decomp_export_strict, 40 xfail_prop="_expected_failure_training_ir_to_run_decomp", 41 ) 42 else: 43 test_class = testing.make_test_cls_with_mocked_export( 44 cls, 45 "TrainingIRToRunDecompExportNonStrict", 46 test_export.TRAINING_IR_DECOMP_NON_STRICT_SUFFIX, 47 mocked_training_ir_to_run_decomp_export_non_strict, 48 xfail_prop="_expected_failure_training_ir_to_run_decomp_non_strict", 49 ) 50 51 test_classes[test_class.__name__] = test_class 52 # REMOVING THIS LINE WILL STOP TESTS FROM RUNNING 53 globals()[test_class.__name__] = test_class 54 test_class.__module__ = __name__ 55 return test_class 56 57 58tests = [ 59 test_export.TestDynamismExpression, 60 test_export.TestExport, 61] 62for test in tests: 63 make_dynamic_cls(test, True) 64 make_dynamic_cls(test, False) 65del test 66 67if __name__ == "__main__": 68 from torch._dynamo.test_case import run_tests 69 70 run_tests() 71