xref: /aosp_15_r20/external/pytorch/torchgen/_autoheuristic/train_regression.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import warnings
4
5import numpy as np
6import pandas as pd  # type: ignore[import-untyped]
7from scipy.stats import gmean  # type: ignore[import-untyped]
8from sklearn.model_selection import train_test_split  # type: ignore[import-untyped]
9from sklearn.tree import DecisionTreeRegressor  # type: ignore[import-untyped]
10from train import AHTrain
11
12from torch._inductor.autoheuristic.autoheuristic_utils import CHOICE_COL, FEEDBACK_COL
13
14
15# TODO (AlnisM): Fix these warnings
16warnings.filterwarnings(
17    "ignore",
18    message="The behavior of DataFrame concatenation with empty or all-NA entries is deprecated",
19)
20warnings.filterwarnings(
21    "ignore",
22    message="DataFrameGroupBy.apply operated on the grouping columns.",
23)
24
25
26class AHTrainRegressionTree(AHTrain):
27    """
28    This class is responsible for generating a heuristic by using data collected with AutoHeuristic. It will learn a
29    regression tree that predicts a score that represents how well a specific choice will perform given an input.
30    A higher score means a better choice. The heuristic will be generated in a file named <heuristic_name>.py in the
31    torch/_inductor/autoheuristic/artifacts/ directory.
32    """
33
34    def __init__(self):
35        super().__init__()
36
37    def main(
38        self,
39        log_path,
40        other_datasets,
41        nrows,
42        heuristic_name,
43        save_dot=False,
44        ranking=False,
45    ):
46        """
47        Main function that trains a decision tree and generates a heuristic.
48        """
49        (df, choices, cat_feature2cats, dummy_col_2_col_val, metadata) = self.get_df(
50            log_path, nrows=nrows, apply_filters=True
51        )
52        df_train, df_val, df_test, feature_columns = self.custom_train_test_split(df)
53        datasets = {"train": df_train, "val": df_val, "test": df_test}
54        self.add_real_datasets(datasets, other_datasets, cat_feature2cats)
55
56        # We will do a grid search over these values
57        # Only trying out max_depths of 5, 6, and 7 because we want to keep the tree and
58        # generated code small, but smaller than 5 does not perform well enough
59        max_depths = [5, 6, 7]
60        min_samples_leafs = [1, 2, 5, 10]
61        choice_columns = [f"{CHOICE_COL}_{choice}" for choice in choices]
62        (results_df, best_model, threshold) = self.train_and_evaluate_models(
63            datasets, feature_columns, choice_columns, max_depths, min_samples_leafs
64        )
65
66        # prints results for all models and datasets
67        print(results_df.to_string())
68
69        # prints results grouped by dataset
70        for set_name in results_df["dataset"].unique():
71            dataset_results = results_df[results_df["dataset"] == set_name]
72            dataset_results = dataset_results.sort_values(by="correct")
73            print(dataset_results.to_string() + "\n")
74
75        feature_names = feature_columns + choice_columns
76        self.dt_to_python(
77            best_model,
78            metadata,
79            feature_names,
80            dummy_col_2_col_val,
81            heuristic_name,
82            threshold,
83        )
84
85    def get_df(self, log_path, cat_feature2cats=None, nrows=None, apply_filters=False):
86        """
87        Parses the log file and processes the data into a dataframe that can be used for training.
88        """
89        (df, metadata, feature_columns, categorical_features, choices) = self.parse_log(
90            log_path, nrows
91        )
92
93        def process_data(
94            df,
95            feature_columns,
96            apply_filters,
97            min_count_measurements=3,
98            max_relative_std=5,
99        ):
100            # Calculate statistics for each input and choice combination
101            def calculate_stats(group):
102                count = len(group)
103                mean = group[FEEDBACK_COL].mean()
104                std = group[FEEDBACK_COL].std()
105                relative_std = (std / mean) * 100 if mean != 0 else np.inf
106                median = group[FEEDBACK_COL].median()
107                return pd.Series(
108                    {
109                        "count": count,
110                        "median_execution_time": median,
111                        "relative_std": relative_std,
112                    }
113                )
114
115            stats = (
116                df.groupby(feature_columns + [CHOICE_COL])
117                .apply(calculate_stats)
118                .reset_index()
119            )
120
121            if apply_filters:
122                # Remove unstables measurements
123                valid_stats = stats[
124                    (stats["count"] >= min_count_measurements)
125                    & (stats["relative_std"] <= max_relative_std)
126                ]
127                # Keep only inputs with at least two valid choices
128                valid_inputs = valid_stats.groupby(feature_columns).filter(
129                    lambda x: len(x) >= 2
130                )
131            else:
132                valid_inputs = stats
133
134            # Compute the winner and ratios for each input
135            def get_winner_and_speedups(group):
136                mean_time = group["median_execution_time"].mean()
137                winner = group.loc[group["median_execution_time"].idxmin(), CHOICE_COL]
138                min_time = group["median_execution_time"].min()
139                max_time = group["median_execution_time"].max()
140
141                group["winner"] = winner
142                group["speedup"] = max_time / min_time
143                group["target"] = mean_time / group["median_execution_time"]
144
145                return group[
146                    feature_columns + [CHOICE_COL, "winner", "speedup", "target"]
147                ]
148
149            results = (
150                valid_inputs.groupby(feature_columns)
151                .apply(get_winner_and_speedups)
152                .reset_index(drop=True)
153            )
154
155            return results
156
157        results = process_data(df, feature_columns, apply_filters)
158        (results, added_categorical_features) = self.add_new_features(results)
159        categorical_features += added_categorical_features
160        categorical_features += [CHOICE_COL]
161
162        (
163            results,
164            cat_feature2cats,
165            dummy_col_2_col_val,
166        ) = self.handle_categorical_features(
167            cat_feature2cats, categorical_features, results
168        )
169        return (results, choices, cat_feature2cats, dummy_col_2_col_val, metadata)
170
171    def custom_train_test_split(
172        self, df, test_size=0.2, val_size=0.25, random_state=42
173    ):
174        """
175        Splits the dataframe into train, val, and test sets.
176        Also adds other datasets, specified by the user, to the train set.
177        We need to be careful, because we want to make sure that rows with the same input but different choice are
178        kept in the same set, e.g.
179        Rows that looks like this
180        input_1,choice1,...
181        input_1,choice2,...
182        should be in the same set.
183        """
184        # We want to make sure that rows with the same input but different choice are kept in the same set
185        exclude_columns = ["speedup", "winner", "target"]
186        feature_columns = [
187            col
188            for col in df.columns
189            if col not in exclude_columns and not col.startswith(CHOICE_COL + "_")
190        ]
191        df["input_id"] = df.groupby(feature_columns).ngroup()
192
193        # Get unique input IDs
194        unique_inputs = df["input_id"].unique()
195
196        # Split unique inputs into train+val and test
197        train_val_inputs, test_inputs = train_test_split(
198            unique_inputs, test_size=test_size, random_state=random_state
199        )
200
201        # Split train+val inputs into train and val
202        train_inputs, val_inputs = train_test_split(
203            train_val_inputs, test_size=val_size, random_state=random_state
204        )
205
206        # Create masks for each set
207        train_mask = df["input_id"].isin(train_inputs)
208        val_mask = df["input_id"].isin(val_inputs)
209        test_mask = df["input_id"].isin(test_inputs)
210
211        # Split the dataframe
212        df_train = df[train_mask]
213        df_val = df[val_mask]
214        df_test = df[test_mask]
215
216        # Remove the temporary input_id column
217        df_train = df_train.drop("input_id", axis=1)
218        df_val = df_val.drop("input_id", axis=1)
219        df_test = df_test.drop("input_id", axis=1)
220
221        return df_train, df_val, df_test, feature_columns
222
223    def train_and_evaluate_models(
224        self,
225        datasets,
226        feature_columns,
227        choice_columns,
228        max_depths,
229        min_samples_leafs,
230        threshold=0.99,
231    ):
232        """
233        Does a grid search over max_depths, min_samples_leafs, and returns the best model.
234        """
235
236        results = []
237        df_train = datasets["train"]
238        df_val = datasets["val"]
239
240        best_model = None
241        best_model_threshold = 0
242        max_correct_predictions = -1
243        for max_depth in max_depths:
244            for min_samples_leaf in min_samples_leafs:
245                print(
246                    f"Evaluating max_depth={max_depth}, min_samples_leaf={min_samples_leaf}"
247                )
248                model = DecisionTreeRegressor(
249                    random_state=42,
250                    max_depth=max_depth,
251                    min_samples_leaf=min_samples_leaf,
252                )
253                model.fit(
254                    df_train[feature_columns + choice_columns], df_train["target"]
255                )
256
257                # we first compute a safe threshold: this threshold ensures that on the validation set,
258                # if the heuristic returns a choice, the choice will be correct, although a high threshold
259                # can lead to a lot of 'unsure' choices
260                eval_result = self.evaluate_model(
261                    model, df_val, feature_columns, choice_columns, threshold
262                )
263                safe_threshold = eval_result["wrong_max_ratio"]
264                for dataset_name, dataset in datasets.items():
265                    eval_result = self.evaluate_model(
266                        model, dataset, feature_columns, choice_columns, safe_threshold
267                    )
268                    print(eval_result)
269                    if dataset_name == "val":
270                        eval_correct = eval_result["correct"]
271                        if eval_correct > max_correct_predictions:
272                            best_model = model
273                            best_model_threshold = safe_threshold
274                            max_correct_predictions = eval_correct
275                    results.append(
276                        {
277                            "max_depth": max_depth,
278                            "min_samples_leaf": min_samples_leaf,
279                            "dataset": dataset_name,
280                            "correct": eval_result["correct"],
281                            "wrong": eval_result["wrong"],
282                            "unsure": eval_result["unsure"],
283                            "total": eval_result["total"],
284                            "max_wrong_speedup": eval_result["max_wrong_speedup"],
285                            "gman_wrong_speedup": eval_result["gman_wrong_speedup"],
286                            "threshold": safe_threshold,
287                        }
288                    )
289
290        return (pd.DataFrame(results), best_model, best_model_threshold)
291
292    def evaluate_model(self, model, df, feature_columns, choice_columns, threshold):
293        """
294        Custom evaluation function that evaluates a learned decision tree.
295        """
296
297        def predict_winner(group):
298            predictions = model.predict(group[feature_columns + choice_columns])
299
300            # Find the index of the maximum prediction (best choice)
301            best_choice_index = np.argmax(predictions)
302
303            # Get the corresponding choice
304            predicted_choice = (
305                group[choice_columns].iloc[best_choice_index].idxmax().split("_")[-1]
306            )
307
308            # Calculate the ratio between the best and second-best prediction
309            sorted_predictions = np.sort(predictions)[::-1]
310            top_pred_ratio = (
311                sorted_predictions[0] / sorted_predictions[1]
312                if len(sorted_predictions) > 1
313                else np.inf
314            )
315
316            # If the best choice is not "significantly" better than the second best choice,
317            # the learned heuristic will return "unsure"
318            if top_pred_ratio <= threshold:
319                predicted_winner = "unsure"
320            else:
321                predicted_winner = predicted_choice
322
323            actual_winner = group["winner"].iloc[0]
324            is_correct = (
325                predicted_winner == actual_winner
326                if predicted_winner != "unsure"
327                else "unsure"
328            )
329
330            return pd.Series(
331                {
332                    "predicted_winner": predicted_winner,
333                    "ratio": top_pred_ratio,
334                    "actual_winner": actual_winner,
335                    "is_correct": is_correct,
336                    "speedup": group["speedup"].iloc[
337                        0
338                    ],  # Speedup is the same for all rows in the group
339                }
340            )
341
342        results = df.groupby(feature_columns).apply(predict_winner).reset_index()
343        correct = (results["is_correct"].eq(True)).sum()
344        unsure = (results["is_correct"] == "unsure").sum()
345        wrong_results = results[results["is_correct"].eq(False)]
346        wrong = len(wrong_results)
347
348        # Calculate max and geometric mean of speedup for wrong predictions
349        # Used for debugging purposes
350        wrong_speedups = wrong_results["speedup"]
351        max_wrong_speedup = wrong_speedups.max() if not wrong_speedups.empty else np.nan
352        geo_mean_wrong_speedup = (
353            gmean(wrong_speedups) if not wrong_speedups.empty else np.nan
354        )
355        wrong_max_ratio = wrong_results["ratio"].max()
356
357        total = correct + wrong + unsure
358        return {
359            "correct": correct,
360            "wrong": wrong,
361            "unsure": unsure,
362            "total": total,
363            "max_wrong_speedup": max_wrong_speedup,
364            "gman_wrong_speedup": geo_mean_wrong_speedup,
365            "wrong_max_ratio": wrong_max_ratio,
366        }
367
368    def dt_to_python(
369        self,
370        dt,
371        metadata,
372        feature_names,
373        dummy_col_2_col_val,
374        heuristic_name,
375        threshold,
376        unsafe_leaves=None,
377    ):
378        tree_ = dt.tree_
379        feature_name = [
380            feature_names[i] if i != -1 else "undefined!" for i in tree_.feature
381        ]
382
383        lines = []
384        device_capa = metadata["device_capa"]
385        device_capa_str = f"({device_capa[0]}, {device_capa[1]})"
386        opt_name = metadata["name"]
387        lines.append(
388            self.codegen_boilerplate(
389                heuristic_name,
390                opt_name,
391                threshold,
392                metadata["shared_memory"],
393                device_capa_str,
394                dt,
395            )
396        )
397        fn_def = f"\n    {self.gen_predict_fn_def()}"
398        lines.append(fn_def)
399
400        def dt_to_python(node, depth):
401            indent = "    " * (depth + 1)
402            false_predicate = ""
403            if tree_.feature[node] != -2:
404                name = feature_name[node]
405                threshold = tree_.threshold[node]
406                if name in dummy_col_2_col_val:
407                    (orig_name, value) = dummy_col_2_col_val[name]
408                    predicate = f"{indent}if str(context.get_value('{orig_name}')) != '{value}':"
409                    assert (
410                        threshold == 0.5
411                    ), f"expected threshold to be 0.5 but is {threshold}"
412                else:
413                    predicate = (
414                        f"{indent}if context.get_value('{name}') <= {threshold}:"
415                    )
416                lines.append(predicate)
417                dt_to_python(tree_.children_left[node], depth + 1)
418                lines.append(f"{indent}else:")
419                dt_to_python(tree_.children_right[node], depth + 1)
420            else:
421                lines.append(self.handle_leaf(tree_, node, indent, unsafe_leaves))
422
423        dt_to_python(0, 1)
424
425        self.write_heuristic_to_file(lines, heuristic_name)
426
427    def handle_leaf(self, tree_, node, indent, unsafe_leaves):
428        """
429        Generates the code for a leaf node. This is just the value predicted by the regression tree.
430        """
431        value = tree_.value[node][0][0]
432        return f"{indent}return {str(value)}"
433
434    def gen_predict_fn_def(self):
435        return "def predict(self, context: AHContext) -> float:"
436
437    def codegen_boilerplate(
438        self, heuristic_name, opt_name, threshold, shared_memory, device_capa, classes
439    ):
440        """
441        Generates the boilerplate code for the generated heuristic. This includes things like imports, class definition,
442        etc.
443        """
444
445        boiler_plate = f"""# flake8: noqa: B950
446# fmt: off
447# This file was generated by AutoHeuristic. Do not modify it manually!
448# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/{opt_name}/
449from torch._inductor.autoheuristic.autoheuristic_utils import AHContext, AHMetadata, Choice, CHOICE_COL
450from torch._inductor.autoheuristic.learnedheuristic_interface import (
451    LearnedHeuristicRegression,
452)
453
454
455class {heuristic_name}(LearnedHeuristicRegression):
456
457    def __init__(self) -> None:
458        pass
459
460{self.gen_precondition(opt_name, shared_memory, device_capa)}
461
462    def get_feedback(self, context: AHContext, choice: Choice) -> float:
463        context.context_dict[CHOICE_COL] = choice
464        return self.predict(context)
465
466    def get_confidence_threshold(self) -> float:
467        return {threshold}
468
469    def get_name(self) -> str:
470        return '{opt_name}'"""
471        return boiler_plate
472
473
474if __name__ == "__main__":
475    train = AHTrain()
476    train.generate_heuristic()
477