xref: /aosp_15_r20/external/pytorch/torch/_inductor/fx_passes/numeric_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import gc
3import logging
4import os
5import random
6import traceback
7
8import numpy
9
10import torch
11import torch.optim as optim
12
13from .. import config
14
15
16logger: logging.Logger = logging.getLogger(__name__)
17
18MAIN_RANDOM_SEED = 1337
19
20# Set the CUBLAS_WORKSPACE_CONFIG environment variable
21os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
22
23
24# If the two forward functions involve any non-deterministic operations,
25# such as certain types of parallelism or asynchronous execution,
26# this can also lead to different outputs.
27def set_deterministic() -> None:
28    """Make torch manual seed deterministic."""
29
30    torch.manual_seed(MAIN_RANDOM_SEED)
31    random.seed(MAIN_RANDOM_SEED)
32    numpy.random.seed(MAIN_RANDOM_SEED)
33    torch.use_deterministic_algorithms(True)
34
35
36def clean_memory() -> None:
37    """Clean memory to avoid OOM."""
38    gc.collect()
39    torch.cuda.empty_cache()
40
41
42# We compare the numerical results before and after pre/post grad fx passes
43# transformation to make sure the numerical results are the same.
44def compare_dict_tensors(dict_base, dict_control, precision):
45    if len(set(dict_base.keys())) != len(set(dict_control.keys())):
46        logger.warning("Mismatch keys found before and after pre/post grad fx passes.")
47        logger.debug("keys before pre/post grad fx passes %s", dict_base.keys())
48        logger.debug("keys after pre/post grad fx passes %s", dict_control.keys())
49        return False
50    is_allclose = True
51    for key in dict_base.keys():
52        if key not in dict_control:
53            logger.warning(
54                "Mismatch parameter name %s does not exist after pre/post grad fx passes",
55                key,
56            )
57        # Some parameters have `None`, and not every param has a valid .grad field, we skip them
58        if dict_base[key] is None or dict_control[key] is None:
59            continue
60        if not torch.allclose(
61            dict_base[key],
62            dict_control[key],
63            rtol=precision,
64            atol=precision,
65            equal_nan=True,
66        ):
67            logger.warning(
68                "Mismatch parameter values found before and after pre/post grad fx passes."
69            )
70            logger.debug("value before pre/post grad fx passes %s", dict_base[key])
71            logger.debug("value after pre/post grad fx passes %s", dict_control[key])
72            is_allclose = False
73    return is_allclose
74
75
76def compare_tuple_tensors(tuple_base, tuple_control, precision):
77    if len(tuple_base) != len(tuple_control):
78        logger.warning(
79            "Mismatch fw output length. before transformation: %s, after transformation: %s",
80            len(tuple_base),
81            len(tuple_control),
82        )
83        return False
84    is_allclose = True
85    for i in range(len(tuple_base)):
86        # Some parameters have `None`, we skip them
87        if tuple_base[i] is None or tuple_control[i] is None:
88            continue
89        if not torch.allclose(
90            tuple_base[i],
91            tuple_control[i],
92            rtol=precision,
93            atol=precision,
94            equal_nan=True,
95        ):
96            logger.debug(
97                "forward output before pre/post grad fx passes %s", tuple_base[i]
98            )
99            logger.debug(
100                "forward output after pre/post grad fx passes %s", tuple_control[i]
101            )
102            is_allclose = False
103    return is_allclose
104
105
106def compare_parameters(model_base, model_control, precision):
107    return compare_dict_tensors(
108        dict(model_base.named_parameters()),
109        dict(model_control.named_parameters()),
110        precision,
111    )
112
113
114def compare_forward_output(pred_base, pred_control, precision):
115    return compare_tuple_tensors(
116        pred_base,
117        pred_control,
118        precision,
119    )
120
121
122def compare_gradients(model_base, model_control, precision):
123    grad_base = {key: param.grad for key, param in model_base.named_parameters()}
124    grad_pt2 = {key: param.grad for key, param in model_control.named_parameters()}
125    return compare_dict_tensors(
126        grad_base,
127        grad_pt2,
128        precision,
129    )
130
131
132def run_model(
133    model_base, model_control, model_input, num_iterations=10, precision=1e-4
134):
135    clean_memory()
136    for i in range(num_iterations):
137        logger.info("start %s iteration", i)
138        set_deterministic()
139        pred_base = model_base(*model_input)
140        set_deterministic()
141        pred_control = model_control(*model_input)
142
143        res = compare_parameters(model_base, model_control, precision)
144        logger.info("compare parameters. Numerical result : %s", res)
145
146        res = compare_forward_output(pred_base, pred_control, precision)
147        logger.info("compare loss/predict. Numerical result : %s", res)
148        # tensor may not have a grad_fn
149        try:
150            _ = pred_base[0].sum().backward(retain_graph=True)
151            _ = pred_control[0].sum().backward(retain_graph=True)
152            res = compare_gradients(model_base, model_control, precision)
153            logger.info("compare param grad. Numerical result : %s", res)
154        except Exception:
155            logger.exception("Exception when comparing gradients")
156            traceback.print_exc()
157
158        if config.fx_passes_numeric_check["requires_optimizer"]:
159            try:
160                optimizer_base = optim.SGD(
161                    [param for name, param in model_base.named_parameters()], lr=0.01
162                )
163                optimizer_base.step()
164
165                optimizer_control = optim.SGD(
166                    [param for name, param in model_control.named_parameters()], lr=0.01
167                )
168                optimizer_control.step()
169
170                res = compare_parameters(model_base, model_control, precision)
171                logger.info(
172                    "compare parameters with optimizer added. Numerical result : %s",
173                    res,
174                )
175            except Exception as e:
176                logger.exception(
177                    "Exception when optimizer is added to check parameter names"
178                )
179                traceback.print_exc()
180        else:
181            logger.warning(
182                "no parameter with optimizer to compare with length %s before transformation"
183                " and the length %s after transformation",
184                len(dict(model_base.named_parameters())),
185                len(dict(model_control.named_parameters())),
186            )
187
188
189def numeric_check_if_enabled(
190    gm_before_fx_passes,
191    gm_after_fx_passes,
192    example_inputs,
193    num_iterations,
194    precision,
195):
196    # need to topo-sort graphmodule before we run the model,
197    # otherwise it may fail as refer before def
198    # fail silently in order not to block the model run
199    try:
200        with torch.autograd.set_detect_anomaly(True):
201            run_model(
202                gm_before_fx_passes,
203                gm_after_fx_passes,
204                example_inputs,
205                num_iterations=num_iterations,
206                precision=precision,
207            )
208    except Exception as e:
209        logger.warning(
210            "Runtime numeric check failed in pre grad fx passes with error: %s", e
211        )
212        traceback.print_exc()
213