xref: /aosp_15_r20/external/pytorch/test/distributed/flight_recorder/test_fr_analysis.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import pathlib
4import sys
5
6
7REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent
8
9sys.path.insert(0, str(REPO_ROOT))
10from tools.flight_recorder.components.types import MatchState
11from tools.flight_recorder.components.utils import match_one_event
12
13
14# Make sure to remove REPO_ROOT after import is done
15sys.path.remove(str(REPO_ROOT))
16
17from torch.testing._internal.common_utils import run_tests, TestCase
18
19
20def create_one_event(
21    collectcive_name,
22    pg_info,
23    input_sizes,
24    output_sizes,
25    state="scheduled",
26    collective_seq_id=0,
27    p2p_seq_id=0,
28    output_dtypes="float32",
29):
30    return {
31        "profiling_name": f"nccl:{collectcive_name}",
32        "state": state,
33        "process_group": pg_info,
34        "input_sizes": input_sizes,
35        "output_sizes": output_sizes,
36        "input_dtypes": "float32",
37        "output_dtypes": output_dtypes,
38        "collective_seq_id": str(collective_seq_id),
39        "p2p_seq_id": str(p2p_seq_id),
40    }
41
42
43class FlightRecorderEventTest(TestCase):
44    def test_match_one_event(self):
45        e1 = create_one_event(
46            "all_reduce", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
47        )
48        membership = {"0": {0, 1}}
49        self.assertEqual(
50            match_one_event(e1, e1, membership, "0"), MatchState.FULLY_MATCHED
51        )
52
53        e2 = create_one_event(
54            "all_gather", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
55        )
56        self.assertEqual(
57            match_one_event(e1, e2, membership, "0"),
58            MatchState.COLLECTIVE_TYPE_MISMATCH,
59        )
60
61        e3 = create_one_event(
62            "all_to_all", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
63        )
64        e4 = create_one_event(
65            "all_to_all", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
66        )
67        self.assertEqual(match_one_event(e3, e4, membership, "0"), MatchState.UNDECIDED)
68
69        e5 = create_one_event(
70            "all_reduce", ("0", "default"), [[5, 4]], [[4, 4]], "scheduled", 1, 1
71        )
72        self.assertEqual(
73            match_one_event(e1, e5, membership, "0"), MatchState.SIZE_OR_SYNTAX_MISMATCH
74        )
75
76        e6 = create_one_event(
77            "all_reduce", ("0", "default"), [[4, 4]], [[5, 4]], "scheduled", 1, 2
78        )
79        self.assertEqual(
80            match_one_event(e1, e6, membership, "0"), MatchState.SIZE_OR_SYNTAX_MISMATCH
81        )
82
83        e7 = create_one_event(
84            "all_reduce", ("0", "default"), [[4, 4]], [[5, 4]], "scheduled", 2
85        )
86        self.assertEqual(
87            match_one_event(e7, e7, membership, "0"), MatchState.SIZE_OR_SYNTAX_MISMATCH
88        )
89
90        e9 = create_one_event(
91            "all_reduce", ("0", "default"), [[4, 4]], [[4, 4]], "completed", 1
92        )
93        self.assertEqual(
94            match_one_event(e1, e9, membership, "0"),
95            MatchState.COLLECTIVE_STATE_MISMATCH,
96        )
97
98        e10 = create_one_event(
99            "all_reduce",
100            ("0", "default"),
101            [[4, 4]],
102            [[4, 4]],
103            "completed",
104            1,
105            output_dtypes="float16",
106        )
107        self.assertEqual(
108            match_one_event(e10, e9, membership, "0"),
109            MatchState.COLLECTIVE_DTYPE_MISMATCH,
110        )
111
112
113if __name__ == "__main__":
114    run_tests()
115