1# mypy: allow-untyped-defs 2import gc 3import logging 4import os 5import random 6import traceback 7 8import numpy 9 10import torch 11import torch.optim as optim 12 13from .. import config 14 15 16logger: logging.Logger = logging.getLogger(__name__) 17 18MAIN_RANDOM_SEED = 1337 19 20# Set the CUBLAS_WORKSPACE_CONFIG environment variable 21os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" 22 23 24# If the two forward functions involve any non-deterministic operations, 25# such as certain types of parallelism or asynchronous execution, 26# this can also lead to different outputs. 27def set_deterministic() -> None: 28 """Make torch manual seed deterministic.""" 29 30 torch.manual_seed(MAIN_RANDOM_SEED) 31 random.seed(MAIN_RANDOM_SEED) 32 numpy.random.seed(MAIN_RANDOM_SEED) 33 torch.use_deterministic_algorithms(True) 34 35 36def clean_memory() -> None: 37 """Clean memory to avoid OOM.""" 38 gc.collect() 39 torch.cuda.empty_cache() 40 41 42# We compare the numerical results before and after pre/post grad fx passes 43# transformation to make sure the numerical results are the same. 44def compare_dict_tensors(dict_base, dict_control, precision): 45 if len(set(dict_base.keys())) != len(set(dict_control.keys())): 46 logger.warning("Mismatch keys found before and after pre/post grad fx passes.") 47 logger.debug("keys before pre/post grad fx passes %s", dict_base.keys()) 48 logger.debug("keys after pre/post grad fx passes %s", dict_control.keys()) 49 return False 50 is_allclose = True 51 for key in dict_base.keys(): 52 if key not in dict_control: 53 logger.warning( 54 "Mismatch parameter name %s does not exist after pre/post grad fx passes", 55 key, 56 ) 57 # Some parameters have `None`, and not every param has a valid .grad field, we skip them 58 if dict_base[key] is None or dict_control[key] is None: 59 continue 60 if not torch.allclose( 61 dict_base[key], 62 dict_control[key], 63 rtol=precision, 64 atol=precision, 65 equal_nan=True, 66 ): 67 logger.warning( 68 "Mismatch parameter values found before and after pre/post grad fx passes." 69 ) 70 logger.debug("value before pre/post grad fx passes %s", dict_base[key]) 71 logger.debug("value after pre/post grad fx passes %s", dict_control[key]) 72 is_allclose = False 73 return is_allclose 74 75 76def compare_tuple_tensors(tuple_base, tuple_control, precision): 77 if len(tuple_base) != len(tuple_control): 78 logger.warning( 79 "Mismatch fw output length. before transformation: %s, after transformation: %s", 80 len(tuple_base), 81 len(tuple_control), 82 ) 83 return False 84 is_allclose = True 85 for i in range(len(tuple_base)): 86 # Some parameters have `None`, we skip them 87 if tuple_base[i] is None or tuple_control[i] is None: 88 continue 89 if not torch.allclose( 90 tuple_base[i], 91 tuple_control[i], 92 rtol=precision, 93 atol=precision, 94 equal_nan=True, 95 ): 96 logger.debug( 97 "forward output before pre/post grad fx passes %s", tuple_base[i] 98 ) 99 logger.debug( 100 "forward output after pre/post grad fx passes %s", tuple_control[i] 101 ) 102 is_allclose = False 103 return is_allclose 104 105 106def compare_parameters(model_base, model_control, precision): 107 return compare_dict_tensors( 108 dict(model_base.named_parameters()), 109 dict(model_control.named_parameters()), 110 precision, 111 ) 112 113 114def compare_forward_output(pred_base, pred_control, precision): 115 return compare_tuple_tensors( 116 pred_base, 117 pred_control, 118 precision, 119 ) 120 121 122def compare_gradients(model_base, model_control, precision): 123 grad_base = {key: param.grad for key, param in model_base.named_parameters()} 124 grad_pt2 = {key: param.grad for key, param in model_control.named_parameters()} 125 return compare_dict_tensors( 126 grad_base, 127 grad_pt2, 128 precision, 129 ) 130 131 132def run_model( 133 model_base, model_control, model_input, num_iterations=10, precision=1e-4 134): 135 clean_memory() 136 for i in range(num_iterations): 137 logger.info("start %s iteration", i) 138 set_deterministic() 139 pred_base = model_base(*model_input) 140 set_deterministic() 141 pred_control = model_control(*model_input) 142 143 res = compare_parameters(model_base, model_control, precision) 144 logger.info("compare parameters. Numerical result : %s", res) 145 146 res = compare_forward_output(pred_base, pred_control, precision) 147 logger.info("compare loss/predict. Numerical result : %s", res) 148 # tensor may not have a grad_fn 149 try: 150 _ = pred_base[0].sum().backward(retain_graph=True) 151 _ = pred_control[0].sum().backward(retain_graph=True) 152 res = compare_gradients(model_base, model_control, precision) 153 logger.info("compare param grad. Numerical result : %s", res) 154 except Exception: 155 logger.exception("Exception when comparing gradients") 156 traceback.print_exc() 157 158 if config.fx_passes_numeric_check["requires_optimizer"]: 159 try: 160 optimizer_base = optim.SGD( 161 [param for name, param in model_base.named_parameters()], lr=0.01 162 ) 163 optimizer_base.step() 164 165 optimizer_control = optim.SGD( 166 [param for name, param in model_control.named_parameters()], lr=0.01 167 ) 168 optimizer_control.step() 169 170 res = compare_parameters(model_base, model_control, precision) 171 logger.info( 172 "compare parameters with optimizer added. Numerical result : %s", 173 res, 174 ) 175 except Exception as e: 176 logger.exception( 177 "Exception when optimizer is added to check parameter names" 178 ) 179 traceback.print_exc() 180 else: 181 logger.warning( 182 "no parameter with optimizer to compare with length %s before transformation" 183 " and the length %s after transformation", 184 len(dict(model_base.named_parameters())), 185 len(dict(model_control.named_parameters())), 186 ) 187 188 189def numeric_check_if_enabled( 190 gm_before_fx_passes, 191 gm_after_fx_passes, 192 example_inputs, 193 num_iterations, 194 precision, 195): 196 # need to topo-sort graphmodule before we run the model, 197 # otherwise it may fail as refer before def 198 # fail silently in order not to block the model run 199 try: 200 with torch.autograd.set_detect_anomaly(True): 201 run_model( 202 gm_before_fx_passes, 203 gm_after_fx_passes, 204 example_inputs, 205 num_iterations=num_iterations, 206 precision=precision, 207 ) 208 except Exception as e: 209 logger.warning( 210 "Runtime numeric check failed in pre grad fx passes with error: %s", e 211 ) 212 traceback.print_exc() 213