xref: /aosp_15_r20/external/pytorch/test/export/test_export_training_ir_to_run_decomp.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: export"]
2import torch
3
4
5try:
6    from . import test_export, testing
7except ImportError:
8    import test_export
9
10    import testing
11
12
13test_classes = {}
14
15
16def mocked_training_ir_to_run_decomp_export_strict(*args, **kwargs):
17    ep = torch.export.export_for_training(*args, **kwargs)
18    return ep.run_decompositions(
19        {}, _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY
20    )
21
22
23def mocked_training_ir_to_run_decomp_export_non_strict(*args, **kwargs):
24    if "strict" in kwargs:
25        ep = torch.export.export_for_training(*args, **kwargs)
26    else:
27        ep = torch.export.export_for_training(*args, **kwargs, strict=False)
28    return ep.run_decompositions(
29        {}, _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY
30    )
31
32
33def make_dynamic_cls(cls, strict):
34    if strict:
35        test_class = testing.make_test_cls_with_mocked_export(
36            cls,
37            "TrainingIRToRunDecompExport",
38            test_export.TRAINING_IR_DECOMP_STRICT_SUFFIX,
39            mocked_training_ir_to_run_decomp_export_strict,
40            xfail_prop="_expected_failure_training_ir_to_run_decomp",
41        )
42    else:
43        test_class = testing.make_test_cls_with_mocked_export(
44            cls,
45            "TrainingIRToRunDecompExportNonStrict",
46            test_export.TRAINING_IR_DECOMP_NON_STRICT_SUFFIX,
47            mocked_training_ir_to_run_decomp_export_non_strict,
48            xfail_prop="_expected_failure_training_ir_to_run_decomp_non_strict",
49        )
50
51    test_classes[test_class.__name__] = test_class
52    # REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
53    globals()[test_class.__name__] = test_class
54    test_class.__module__ = __name__
55    return test_class
56
57
58tests = [
59    test_export.TestDynamismExpression,
60    test_export.TestExport,
61]
62for test in tests:
63    make_dynamic_cls(test, True)
64    make_dynamic_cls(test, False)
65del test
66
67if __name__ == "__main__":
68    from torch._dynamo.test_case import run_tests
69
70    run_tests()
71