1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker 8*523fa7a6SAndroid Build Coastguard Workerimport logging 9*523fa7a6SAndroid Build Coastguard Workerimport typing 10*523fa7a6SAndroid Build Coastguard Workerfrom typing import Callable, Union 11*523fa7a6SAndroid Build Coastguard Worker 12*523fa7a6SAndroid Build Coastguard Workerimport numpy as np 13*523fa7a6SAndroid Build Coastguard Workerimport torch 14*523fa7a6SAndroid Build Coastguard Worker 15*523fa7a6SAndroid Build Coastguard Worker 16*523fa7a6SAndroid Build Coastguard Worker# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. 17*523fa7a6SAndroid Build Coastguard Workerdef distance(fn: Callable[[np.ndarray, np.ndarray], float]) -> Callable[ 18*523fa7a6SAndroid Build Coastguard Worker [ 19*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. 20*523fa7a6SAndroid Build Coastguard Worker typing.Union[np.ndarray, torch._tensor.Tensor], 21*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. 22*523fa7a6SAndroid Build Coastguard Worker typing.Union[np.ndarray, torch._tensor.Tensor], 23*523fa7a6SAndroid Build Coastguard Worker ], 24*523fa7a6SAndroid Build Coastguard Worker float, 25*523fa7a6SAndroid Build Coastguard Worker]: 26*523fa7a6SAndroid Build Coastguard Worker # A distance decorator that performs all the necessary checkes before calculating 27*523fa7a6SAndroid Build Coastguard Worker # the distance between two N-D tensors given a function. This can be a RMS 28*523fa7a6SAndroid Build Coastguard Worker # function, maximum abs diff, or any kind of distance function. 29*523fa7a6SAndroid Build Coastguard Worker def wrapper( 30*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. 31*523fa7a6SAndroid Build Coastguard Worker a: Union[np.ndarray, torch.Tensor], 32*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. 33*523fa7a6SAndroid Build Coastguard Worker b: Union[np.ndarray, torch.Tensor], 34*523fa7a6SAndroid Build Coastguard Worker ) -> float: 35*523fa7a6SAndroid Build Coastguard Worker # convert a and b to np.ndarray type fp64 36*523fa7a6SAndroid Build Coastguard Worker a = to_np_arr_fp64(a) 37*523fa7a6SAndroid Build Coastguard Worker b = to_np_arr_fp64(b) 38*523fa7a6SAndroid Build Coastguard Worker 39*523fa7a6SAndroid Build Coastguard Worker # return NaN if shape mismatches 40*523fa7a6SAndroid Build Coastguard Worker if a.shape != b.shape: 41*523fa7a6SAndroid Build Coastguard Worker return np.nan 42*523fa7a6SAndroid Build Coastguard Worker 43*523fa7a6SAndroid Build Coastguard Worker # After we make sure shape matches, check if it's empty. If yes, return 0 44*523fa7a6SAndroid Build Coastguard Worker if a.size == 0: 45*523fa7a6SAndroid Build Coastguard Worker return 0 46*523fa7a6SAndroid Build Coastguard Worker 47*523fa7a6SAndroid Build Coastguard Worker # np.isinf and np.isnan returns a Boolean mask. Check if Inf or NaN occur at 48*523fa7a6SAndroid Build Coastguard Worker # the same places in a and b. If not, return NaN 49*523fa7a6SAndroid Build Coastguard Worker if np.any(np.isinf(a) != np.isinf(b)) or np.any(np.isnan(a) != np.isnan(b)): 50*523fa7a6SAndroid Build Coastguard Worker return np.nan 51*523fa7a6SAndroid Build Coastguard Worker 52*523fa7a6SAndroid Build Coastguard Worker # mask out all the values that are either Inf or NaN 53*523fa7a6SAndroid Build Coastguard Worker mask = np.isinf(a) | np.isnan(a) 54*523fa7a6SAndroid Build Coastguard Worker if np.any(mask): 55*523fa7a6SAndroid Build Coastguard Worker logging.warning("Found inf/nan in tensor when calculating the distance") 56*523fa7a6SAndroid Build Coastguard Worker 57*523fa7a6SAndroid Build Coastguard Worker a_masked = a[~mask] 58*523fa7a6SAndroid Build Coastguard Worker b_masked = b[~mask] 59*523fa7a6SAndroid Build Coastguard Worker 60*523fa7a6SAndroid Build Coastguard Worker # after masking, the resulting tensor might be empty. If yes, return 0 61*523fa7a6SAndroid Build Coastguard Worker if a_masked.size == 0: 62*523fa7a6SAndroid Build Coastguard Worker return 0 63*523fa7a6SAndroid Build Coastguard Worker 64*523fa7a6SAndroid Build Coastguard Worker # only compare the rest (those that are actually numbers) using the metric 65*523fa7a6SAndroid Build Coastguard Worker return fn(a_masked, b_masked) 66*523fa7a6SAndroid Build Coastguard Worker 67*523fa7a6SAndroid Build Coastguard Worker return wrapper 68*523fa7a6SAndroid Build Coastguard Worker 69*523fa7a6SAndroid Build Coastguard Worker 70*523fa7a6SAndroid Build Coastguard Worker@distance 71*523fa7a6SAndroid Build Coastguard Worker# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. 72*523fa7a6SAndroid Build Coastguard Workerdef rms(a: np.ndarray, b: np.ndarray) -> float: 73*523fa7a6SAndroid Build Coastguard Worker return ((a - b) ** 2).mean() ** 0.5 74*523fa7a6SAndroid Build Coastguard Worker 75*523fa7a6SAndroid Build Coastguard Worker 76*523fa7a6SAndroid Build Coastguard Worker@distance 77*523fa7a6SAndroid Build Coastguard Worker# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. 78*523fa7a6SAndroid Build Coastguard Workerdef max_abs_diff(a: np.ndarray, b: np.ndarray) -> float: 79*523fa7a6SAndroid Build Coastguard Worker return np.abs(a - b).max() 80*523fa7a6SAndroid Build Coastguard Worker 81*523fa7a6SAndroid Build Coastguard Worker 82*523fa7a6SAndroid Build Coastguard Worker@distance 83*523fa7a6SAndroid Build Coastguard Worker# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. 84*523fa7a6SAndroid Build Coastguard Workerdef max_rel_diff(x: np.ndarray, x_ref: np.ndarray) -> float: 85*523fa7a6SAndroid Build Coastguard Worker return np.abs((x - x_ref) / x_ref).max() 86*523fa7a6SAndroid Build Coastguard Worker 87*523fa7a6SAndroid Build Coastguard Worker 88*523fa7a6SAndroid Build Coastguard Worker# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. 89*523fa7a6SAndroid Build Coastguard Workerdef to_np_arr_fp64(x: Union[np.ndarray, torch.Tensor]) -> np.ndarray: 90*523fa7a6SAndroid Build Coastguard Worker if isinstance(x, torch.Tensor): 91*523fa7a6SAndroid Build Coastguard Worker x = x.detach().cpu().numpy() 92*523fa7a6SAndroid Build Coastguard Worker if isinstance(x, np.ndarray): 93*523fa7a6SAndroid Build Coastguard Worker x = x.astype(np.float64) 94*523fa7a6SAndroid Build Coastguard Worker return x 95*523fa7a6SAndroid Build Coastguard Worker 96*523fa7a6SAndroid Build Coastguard Worker 97*523fa7a6SAndroid Build Coastguard Worker# pyre-fixme[3]: Return type must be annotated. 98*523fa7a6SAndroid Build Coastguard Workerdef normalized_rms( 99*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. 100*523fa7a6SAndroid Build Coastguard Worker predicted: Union[np.ndarray, torch.Tensor], 101*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. 102*523fa7a6SAndroid Build Coastguard Worker ground_truth: Union[np.ndarray, torch.Tensor], 103*523fa7a6SAndroid Build Coastguard Worker): 104*523fa7a6SAndroid Build Coastguard Worker num = rms(predicted, ground_truth) 105*523fa7a6SAndroid Build Coastguard Worker if num == 0: 106*523fa7a6SAndroid Build Coastguard Worker return 0 107*523fa7a6SAndroid Build Coastguard Worker den = np.linalg.norm(to_np_arr_fp64(ground_truth)) 108*523fa7a6SAndroid Build Coastguard Worker return np.float64(num) / np.float64(den) 109