xref: /aosp_15_r20/external/pytorch/scripts/release_notes/categorize.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import argparse
2import os
3import textwrap
4from pathlib import Path
5
6import common
7
8# Imports for working with classi
9from classifier import (
10    CategoryConfig,
11    CommitClassifier,
12    CommitClassifierInputs,
13    get_author_map,
14    get_file_map,
15    XLMR_BASE,
16)
17from commitlist import CommitList
18from common import get_commit_data_cache, topics
19
20import torch
21
22
23class Categorizer:
24    def __init__(self, path, category="Uncategorized", use_classifier: bool = False):
25        self.cache = get_commit_data_cache()
26        self.commits = CommitList.from_existing(path)
27        if use_classifier:
28            print("Using a classifier to aid with categorization.")
29            device = "cuda" if torch.cuda.is_available() else "cpu"
30            classifier_config = CategoryConfig(common.categories)
31            author_map = get_author_map(
32                Path("results/classifier"), regen_data=False, assert_stored=True
33            )
34            file_map = get_file_map(
35                Path("results/classifier"), regen_data=False, assert_stored=True
36            )
37            self.classifier = CommitClassifier(
38                XLMR_BASE, author_map, file_map, classifier_config
39            ).to(device)
40            self.classifier.load_state_dict(
41                torch.load(Path("results/classifier/commit_classifier.pt"))
42            )
43            self.classifier.eval()
44        else:
45            self.classifier = None
46        # Special categories: 'Uncategorized'
47        # All other categories must be real
48        self.category = category
49
50    def categorize(self):
51        commits = self.commits.filter(category=self.category)
52        total_commits = len(self.commits.commits)
53        already_done = total_commits - len(commits)
54        i = 0
55        while i < len(commits):
56            cur_commit = commits[i]
57            next_commit = commits[i + 1] if i + 1 < len(commits) else None
58            jump_to = self.handle_commit(
59                cur_commit, already_done + i + 1, total_commits, commits
60            )
61
62            # Increment counter
63            if jump_to is not None:
64                i = jump_to
65            elif next_commit is None:
66                i = len(commits)
67            else:
68                i = commits.index(next_commit)
69
70    def features(self, commit):
71        return self.cache.get(commit.commit_hash)
72
73    def potential_reverts_of(self, commit, commits):
74        submodule_update_str = [
75            "Update TensorPipe submodule",
76            "Updating submodules",
77            "Automated submodule update",
78        ]
79        if any(a in commit.title for a in submodule_update_str):
80            return []
81
82        features = self.features(commit)
83        if "Reverted" in features.labels:
84            reasons = {"GithubBot": "Reverted"}
85        else:
86            reasons = {}
87
88        index = commits.index(commit)
89        # -8 to remove the (#35011)
90        cleaned_title = commit.title[:-10]
91        # NB: the index + 2 is sketch
92        reasons.update(
93            {
94                (index + 2 + delta): cand
95                for delta, cand in enumerate(commits[index + 1 :])
96                if cleaned_title in cand.title
97                and commit.commit_hash != cand.commit_hash
98            }
99        )
100        return reasons
101
102    def handle_commit(self, commit, i, total, commits):
103        potential_reverts = self.potential_reverts_of(commit, commits)
104        if potential_reverts:
105            potential_reverts = f"!!!POTENTIAL REVERTS!!!: {potential_reverts}"
106        else:
107            potential_reverts = ""
108
109        features = self.features(commit)
110        if self.classifier is not None:
111            # Some commits don't have authors:
112            author = features.author if features.author else "Unknown"
113            files = " ".join(features.files_changed)
114            classifier_input = CommitClassifierInputs(
115                title=[features.title], files=[files], author=[author]
116            )
117            classifier_category = self.classifier.get_most_likely_category_name(
118                classifier_input
119            )[0]
120
121        else:
122            classifier_category = commit.category
123
124        breaking_alarm = ""
125        if "module: bc-breaking" in features.labels:
126            breaking_alarm += "\n!!!!!! BC BREAKING !!!!!!"
127
128        if "module: deprecation" in features.labels:
129            breaking_alarm += "\n!!!!!! DEPRECATION !!!!!!"
130
131        os.system("clear")
132        view = textwrap.dedent(
133            f"""\
134[{i}/{total}]
135================================================================================
136{features.title}
137
138{potential_reverts} {breaking_alarm}
139
140{features.body}
141
142Files changed: {features.files_changed}
143
144Labels: {features.labels}
145
146Current category: {commit.category}
147
148Select from: {', '.join(common.categories)}
149
150        """
151        )
152        print(view)
153        cat_choice = None
154        while cat_choice is None:
155            print("Enter category: ")
156            value = input(f"{classifier_category} ").strip()
157            if len(value) == 0:
158                # The user just pressed enter and likes the default value
159                cat_choice = classifier_category
160                continue
161            choices = [cat for cat in common.categories if cat.startswith(value)]
162            if len(choices) != 1:
163                print(f"Possible matches: {choices}, try again")
164                continue
165            cat_choice = choices[0]
166        print(f"\nSelected: {cat_choice}")
167        print(f"\nCurrent topic: {commit.topic}")
168        print(f"""Select from: {', '.join(topics)}""")
169        topic_choice = None
170        while topic_choice is None:
171            value = input("topic> ").strip()
172            if len(value) == 0:
173                topic_choice = commit.topic
174                continue
175            choices = [cat for cat in topics if cat.startswith(value)]
176            if len(choices) != 1:
177                print(f"Possible matches: {choices}, try again")
178                continue
179            topic_choice = choices[0]
180        print(f"\nSelected: {topic_choice}")
181        self.update_commit(commit, cat_choice, topic_choice)
182        return None
183
184    def update_commit(self, commit, category, topic):
185        assert category in common.categories
186        assert topic in topics
187        commit.category = category
188        commit.topic = topic
189        self.commits.write_result()
190
191
192def main():
193    parser = argparse.ArgumentParser(description="Tool to help categorize commits")
194    parser.add_argument(
195        "--category",
196        type=str,
197        default="Uncategorized",
198        help='Which category to filter by. "Uncategorized", None, or a category name',
199    )
200    parser.add_argument(
201        "--file",
202        help="The location of the commits CSV",
203        default="results/commitlist.csv",
204    )
205    parser.add_argument(
206        "--use_classifier",
207        action="store_true",
208        help="Whether or not to use a classifier to aid in categorization.",
209    )
210
211    args = parser.parse_args()
212    categorizer = Categorizer(args.file, args.category, args.use_classifier)
213    categorizer.categorize()
214
215
216if __name__ == "__main__":
217    main()
218