xref: /aosp_15_r20/external/executorch/backends/cadence/runtime/utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker
8*523fa7a6SAndroid Build Coastguard Workerimport logging
9*523fa7a6SAndroid Build Coastguard Workerimport typing
10*523fa7a6SAndroid Build Coastguard Workerfrom typing import Callable, Union
11*523fa7a6SAndroid Build Coastguard Worker
12*523fa7a6SAndroid Build Coastguard Workerimport numpy as np
13*523fa7a6SAndroid Build Coastguard Workerimport torch
14*523fa7a6SAndroid Build Coastguard Worker
15*523fa7a6SAndroid Build Coastguard Worker
16*523fa7a6SAndroid Build Coastguard Worker# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
17*523fa7a6SAndroid Build Coastguard Workerdef distance(fn: Callable[[np.ndarray, np.ndarray], float]) -> Callable[
18*523fa7a6SAndroid Build Coastguard Worker    [
19*523fa7a6SAndroid Build Coastguard Worker        # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
20*523fa7a6SAndroid Build Coastguard Worker        typing.Union[np.ndarray, torch._tensor.Tensor],
21*523fa7a6SAndroid Build Coastguard Worker        # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
22*523fa7a6SAndroid Build Coastguard Worker        typing.Union[np.ndarray, torch._tensor.Tensor],
23*523fa7a6SAndroid Build Coastguard Worker    ],
24*523fa7a6SAndroid Build Coastguard Worker    float,
25*523fa7a6SAndroid Build Coastguard Worker]:
26*523fa7a6SAndroid Build Coastguard Worker    # A distance decorator that performs all the necessary checkes before calculating
27*523fa7a6SAndroid Build Coastguard Worker    # the distance between two N-D tensors given a function. This can be a RMS
28*523fa7a6SAndroid Build Coastguard Worker    # function, maximum abs diff, or any kind of distance function.
29*523fa7a6SAndroid Build Coastguard Worker    def wrapper(
30*523fa7a6SAndroid Build Coastguard Worker        # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
31*523fa7a6SAndroid Build Coastguard Worker        a: Union[np.ndarray, torch.Tensor],
32*523fa7a6SAndroid Build Coastguard Worker        # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
33*523fa7a6SAndroid Build Coastguard Worker        b: Union[np.ndarray, torch.Tensor],
34*523fa7a6SAndroid Build Coastguard Worker    ) -> float:
35*523fa7a6SAndroid Build Coastguard Worker        # convert a and b to np.ndarray type fp64
36*523fa7a6SAndroid Build Coastguard Worker        a = to_np_arr_fp64(a)
37*523fa7a6SAndroid Build Coastguard Worker        b = to_np_arr_fp64(b)
38*523fa7a6SAndroid Build Coastguard Worker
39*523fa7a6SAndroid Build Coastguard Worker        # return NaN if shape mismatches
40*523fa7a6SAndroid Build Coastguard Worker        if a.shape != b.shape:
41*523fa7a6SAndroid Build Coastguard Worker            return np.nan
42*523fa7a6SAndroid Build Coastguard Worker
43*523fa7a6SAndroid Build Coastguard Worker        # After we make sure shape matches, check if it's empty. If yes, return 0
44*523fa7a6SAndroid Build Coastguard Worker        if a.size == 0:
45*523fa7a6SAndroid Build Coastguard Worker            return 0
46*523fa7a6SAndroid Build Coastguard Worker
47*523fa7a6SAndroid Build Coastguard Worker        # np.isinf and np.isnan returns a Boolean mask. Check if Inf or NaN occur at
48*523fa7a6SAndroid Build Coastguard Worker        # the same places in a and b. If not, return NaN
49*523fa7a6SAndroid Build Coastguard Worker        if np.any(np.isinf(a) != np.isinf(b)) or np.any(np.isnan(a) != np.isnan(b)):
50*523fa7a6SAndroid Build Coastguard Worker            return np.nan
51*523fa7a6SAndroid Build Coastguard Worker
52*523fa7a6SAndroid Build Coastguard Worker        # mask out all the values that are either Inf or NaN
53*523fa7a6SAndroid Build Coastguard Worker        mask = np.isinf(a) | np.isnan(a)
54*523fa7a6SAndroid Build Coastguard Worker        if np.any(mask):
55*523fa7a6SAndroid Build Coastguard Worker            logging.warning("Found inf/nan in tensor when calculating the distance")
56*523fa7a6SAndroid Build Coastguard Worker
57*523fa7a6SAndroid Build Coastguard Worker        a_masked = a[~mask]
58*523fa7a6SAndroid Build Coastguard Worker        b_masked = b[~mask]
59*523fa7a6SAndroid Build Coastguard Worker
60*523fa7a6SAndroid Build Coastguard Worker        # after masking, the resulting tensor might be empty. If yes, return 0
61*523fa7a6SAndroid Build Coastguard Worker        if a_masked.size == 0:
62*523fa7a6SAndroid Build Coastguard Worker            return 0
63*523fa7a6SAndroid Build Coastguard Worker
64*523fa7a6SAndroid Build Coastguard Worker        # only compare the rest (those that are actually numbers) using the metric
65*523fa7a6SAndroid Build Coastguard Worker        return fn(a_masked, b_masked)
66*523fa7a6SAndroid Build Coastguard Worker
67*523fa7a6SAndroid Build Coastguard Worker    return wrapper
68*523fa7a6SAndroid Build Coastguard Worker
69*523fa7a6SAndroid Build Coastguard Worker
70*523fa7a6SAndroid Build Coastguard Worker@distance
71*523fa7a6SAndroid Build Coastguard Worker# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
72*523fa7a6SAndroid Build Coastguard Workerdef rms(a: np.ndarray, b: np.ndarray) -> float:
73*523fa7a6SAndroid Build Coastguard Worker    return ((a - b) ** 2).mean() ** 0.5
74*523fa7a6SAndroid Build Coastguard Worker
75*523fa7a6SAndroid Build Coastguard Worker
76*523fa7a6SAndroid Build Coastguard Worker@distance
77*523fa7a6SAndroid Build Coastguard Worker# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
78*523fa7a6SAndroid Build Coastguard Workerdef max_abs_diff(a: np.ndarray, b: np.ndarray) -> float:
79*523fa7a6SAndroid Build Coastguard Worker    return np.abs(a - b).max()
80*523fa7a6SAndroid Build Coastguard Worker
81*523fa7a6SAndroid Build Coastguard Worker
82*523fa7a6SAndroid Build Coastguard Worker@distance
83*523fa7a6SAndroid Build Coastguard Worker# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
84*523fa7a6SAndroid Build Coastguard Workerdef max_rel_diff(x: np.ndarray, x_ref: np.ndarray) -> float:
85*523fa7a6SAndroid Build Coastguard Worker    return np.abs((x - x_ref) / x_ref).max()
86*523fa7a6SAndroid Build Coastguard Worker
87*523fa7a6SAndroid Build Coastguard Worker
88*523fa7a6SAndroid Build Coastguard Worker# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
89*523fa7a6SAndroid Build Coastguard Workerdef to_np_arr_fp64(x: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
90*523fa7a6SAndroid Build Coastguard Worker    if isinstance(x, torch.Tensor):
91*523fa7a6SAndroid Build Coastguard Worker        x = x.detach().cpu().numpy()
92*523fa7a6SAndroid Build Coastguard Worker    if isinstance(x, np.ndarray):
93*523fa7a6SAndroid Build Coastguard Worker        x = x.astype(np.float64)
94*523fa7a6SAndroid Build Coastguard Worker    return x
95*523fa7a6SAndroid Build Coastguard Worker
96*523fa7a6SAndroid Build Coastguard Worker
97*523fa7a6SAndroid Build Coastguard Worker# pyre-fixme[3]: Return type must be annotated.
98*523fa7a6SAndroid Build Coastguard Workerdef normalized_rms(
99*523fa7a6SAndroid Build Coastguard Worker    # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
100*523fa7a6SAndroid Build Coastguard Worker    predicted: Union[np.ndarray, torch.Tensor],
101*523fa7a6SAndroid Build Coastguard Worker    # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
102*523fa7a6SAndroid Build Coastguard Worker    ground_truth: Union[np.ndarray, torch.Tensor],
103*523fa7a6SAndroid Build Coastguard Worker):
104*523fa7a6SAndroid Build Coastguard Worker    num = rms(predicted, ground_truth)
105*523fa7a6SAndroid Build Coastguard Worker    if num == 0:
106*523fa7a6SAndroid Build Coastguard Worker        return 0
107*523fa7a6SAndroid Build Coastguard Worker    den = np.linalg.norm(to_np_arr_fp64(ground_truth))
108*523fa7a6SAndroid Build Coastguard Worker    return np.float64(num) / np.float64(den)
109