1import argparse 2import os 3import textwrap 4from pathlib import Path 5 6import common 7 8# Imports for working with classi 9from classifier import ( 10 CategoryConfig, 11 CommitClassifier, 12 CommitClassifierInputs, 13 get_author_map, 14 get_file_map, 15 XLMR_BASE, 16) 17from commitlist import CommitList 18from common import get_commit_data_cache, topics 19 20import torch 21 22 23class Categorizer: 24 def __init__(self, path, category="Uncategorized", use_classifier: bool = False): 25 self.cache = get_commit_data_cache() 26 self.commits = CommitList.from_existing(path) 27 if use_classifier: 28 print("Using a classifier to aid with categorization.") 29 device = "cuda" if torch.cuda.is_available() else "cpu" 30 classifier_config = CategoryConfig(common.categories) 31 author_map = get_author_map( 32 Path("results/classifier"), regen_data=False, assert_stored=True 33 ) 34 file_map = get_file_map( 35 Path("results/classifier"), regen_data=False, assert_stored=True 36 ) 37 self.classifier = CommitClassifier( 38 XLMR_BASE, author_map, file_map, classifier_config 39 ).to(device) 40 self.classifier.load_state_dict( 41 torch.load(Path("results/classifier/commit_classifier.pt")) 42 ) 43 self.classifier.eval() 44 else: 45 self.classifier = None 46 # Special categories: 'Uncategorized' 47 # All other categories must be real 48 self.category = category 49 50 def categorize(self): 51 commits = self.commits.filter(category=self.category) 52 total_commits = len(self.commits.commits) 53 already_done = total_commits - len(commits) 54 i = 0 55 while i < len(commits): 56 cur_commit = commits[i] 57 next_commit = commits[i + 1] if i + 1 < len(commits) else None 58 jump_to = self.handle_commit( 59 cur_commit, already_done + i + 1, total_commits, commits 60 ) 61 62 # Increment counter 63 if jump_to is not None: 64 i = jump_to 65 elif next_commit is None: 66 i = len(commits) 67 else: 68 i = commits.index(next_commit) 69 70 def features(self, commit): 71 return self.cache.get(commit.commit_hash) 72 73 def potential_reverts_of(self, commit, commits): 74 submodule_update_str = [ 75 "Update TensorPipe submodule", 76 "Updating submodules", 77 "Automated submodule update", 78 ] 79 if any(a in commit.title for a in submodule_update_str): 80 return [] 81 82 features = self.features(commit) 83 if "Reverted" in features.labels: 84 reasons = {"GithubBot": "Reverted"} 85 else: 86 reasons = {} 87 88 index = commits.index(commit) 89 # -8 to remove the (#35011) 90 cleaned_title = commit.title[:-10] 91 # NB: the index + 2 is sketch 92 reasons.update( 93 { 94 (index + 2 + delta): cand 95 for delta, cand in enumerate(commits[index + 1 :]) 96 if cleaned_title in cand.title 97 and commit.commit_hash != cand.commit_hash 98 } 99 ) 100 return reasons 101 102 def handle_commit(self, commit, i, total, commits): 103 potential_reverts = self.potential_reverts_of(commit, commits) 104 if potential_reverts: 105 potential_reverts = f"!!!POTENTIAL REVERTS!!!: {potential_reverts}" 106 else: 107 potential_reverts = "" 108 109 features = self.features(commit) 110 if self.classifier is not None: 111 # Some commits don't have authors: 112 author = features.author if features.author else "Unknown" 113 files = " ".join(features.files_changed) 114 classifier_input = CommitClassifierInputs( 115 title=[features.title], files=[files], author=[author] 116 ) 117 classifier_category = self.classifier.get_most_likely_category_name( 118 classifier_input 119 )[0] 120 121 else: 122 classifier_category = commit.category 123 124 breaking_alarm = "" 125 if "module: bc-breaking" in features.labels: 126 breaking_alarm += "\n!!!!!! BC BREAKING !!!!!!" 127 128 if "module: deprecation" in features.labels: 129 breaking_alarm += "\n!!!!!! DEPRECATION !!!!!!" 130 131 os.system("clear") 132 view = textwrap.dedent( 133 f"""\ 134[{i}/{total}] 135================================================================================ 136{features.title} 137 138{potential_reverts} {breaking_alarm} 139 140{features.body} 141 142Files changed: {features.files_changed} 143 144Labels: {features.labels} 145 146Current category: {commit.category} 147 148Select from: {', '.join(common.categories)} 149 150 """ 151 ) 152 print(view) 153 cat_choice = None 154 while cat_choice is None: 155 print("Enter category: ") 156 value = input(f"{classifier_category} ").strip() 157 if len(value) == 0: 158 # The user just pressed enter and likes the default value 159 cat_choice = classifier_category 160 continue 161 choices = [cat for cat in common.categories if cat.startswith(value)] 162 if len(choices) != 1: 163 print(f"Possible matches: {choices}, try again") 164 continue 165 cat_choice = choices[0] 166 print(f"\nSelected: {cat_choice}") 167 print(f"\nCurrent topic: {commit.topic}") 168 print(f"""Select from: {', '.join(topics)}""") 169 topic_choice = None 170 while topic_choice is None: 171 value = input("topic> ").strip() 172 if len(value) == 0: 173 topic_choice = commit.topic 174 continue 175 choices = [cat for cat in topics if cat.startswith(value)] 176 if len(choices) != 1: 177 print(f"Possible matches: {choices}, try again") 178 continue 179 topic_choice = choices[0] 180 print(f"\nSelected: {topic_choice}") 181 self.update_commit(commit, cat_choice, topic_choice) 182 return None 183 184 def update_commit(self, commit, category, topic): 185 assert category in common.categories 186 assert topic in topics 187 commit.category = category 188 commit.topic = topic 189 self.commits.write_result() 190 191 192def main(): 193 parser = argparse.ArgumentParser(description="Tool to help categorize commits") 194 parser.add_argument( 195 "--category", 196 type=str, 197 default="Uncategorized", 198 help='Which category to filter by. "Uncategorized", None, or a category name', 199 ) 200 parser.add_argument( 201 "--file", 202 help="The location of the commits CSV", 203 default="results/commitlist.csv", 204 ) 205 parser.add_argument( 206 "--use_classifier", 207 action="store_true", 208 help="Whether or not to use a classifier to aid in categorization.", 209 ) 210 211 args = parser.parse_args() 212 categorizer = Categorizer(args.file, args.category, args.use_classifier) 213 categorizer.categorize() 214 215 216if __name__ == "__main__": 217 main() 218