1# Owner(s): ["oncall: distributed"] 2 3import pathlib 4import sys 5 6 7REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent 8 9sys.path.insert(0, str(REPO_ROOT)) 10from tools.flight_recorder.components.types import MatchState 11from tools.flight_recorder.components.utils import match_one_event 12 13 14# Make sure to remove REPO_ROOT after import is done 15sys.path.remove(str(REPO_ROOT)) 16 17from torch.testing._internal.common_utils import run_tests, TestCase 18 19 20def create_one_event( 21 collectcive_name, 22 pg_info, 23 input_sizes, 24 output_sizes, 25 state="scheduled", 26 collective_seq_id=0, 27 p2p_seq_id=0, 28 output_dtypes="float32", 29): 30 return { 31 "profiling_name": f"nccl:{collectcive_name}", 32 "state": state, 33 "process_group": pg_info, 34 "input_sizes": input_sizes, 35 "output_sizes": output_sizes, 36 "input_dtypes": "float32", 37 "output_dtypes": output_dtypes, 38 "collective_seq_id": str(collective_seq_id), 39 "p2p_seq_id": str(p2p_seq_id), 40 } 41 42 43class FlightRecorderEventTest(TestCase): 44 def test_match_one_event(self): 45 e1 = create_one_event( 46 "all_reduce", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1 47 ) 48 membership = {"0": {0, 1}} 49 self.assertEqual( 50 match_one_event(e1, e1, membership, "0"), MatchState.FULLY_MATCHED 51 ) 52 53 e2 = create_one_event( 54 "all_gather", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1 55 ) 56 self.assertEqual( 57 match_one_event(e1, e2, membership, "0"), 58 MatchState.COLLECTIVE_TYPE_MISMATCH, 59 ) 60 61 e3 = create_one_event( 62 "all_to_all", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1 63 ) 64 e4 = create_one_event( 65 "all_to_all", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1 66 ) 67 self.assertEqual(match_one_event(e3, e4, membership, "0"), MatchState.UNDECIDED) 68 69 e5 = create_one_event( 70 "all_reduce", ("0", "default"), [[5, 4]], [[4, 4]], "scheduled", 1, 1 71 ) 72 self.assertEqual( 73 match_one_event(e1, e5, membership, "0"), MatchState.SIZE_OR_SYNTAX_MISMATCH 74 ) 75 76 e6 = create_one_event( 77 "all_reduce", ("0", "default"), [[4, 4]], [[5, 4]], "scheduled", 1, 2 78 ) 79 self.assertEqual( 80 match_one_event(e1, e6, membership, "0"), MatchState.SIZE_OR_SYNTAX_MISMATCH 81 ) 82 83 e7 = create_one_event( 84 "all_reduce", ("0", "default"), [[4, 4]], [[5, 4]], "scheduled", 2 85 ) 86 self.assertEqual( 87 match_one_event(e7, e7, membership, "0"), MatchState.SIZE_OR_SYNTAX_MISMATCH 88 ) 89 90 e9 = create_one_event( 91 "all_reduce", ("0", "default"), [[4, 4]], [[4, 4]], "completed", 1 92 ) 93 self.assertEqual( 94 match_one_event(e1, e9, membership, "0"), 95 MatchState.COLLECTIVE_STATE_MISMATCH, 96 ) 97 98 e10 = create_one_event( 99 "all_reduce", 100 ("0", "default"), 101 [[4, 4]], 102 [[4, 4]], 103 "completed", 104 1, 105 output_dtypes="float16", 106 ) 107 self.assertEqual( 108 match_one_event(e10, e9, membership, "0"), 109 MatchState.COLLECTIVE_DTYPE_MISMATCH, 110 ) 111 112 113if __name__ == "__main__": 114 run_tests() 115