xref: /aosp_15_r20/external/pytorch/test/torch_np/test_reductions.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2
3from unittest import skipIf, SkipTest
4
5import numpy
6import pytest
7from pytest import raises as assert_raises
8
9from torch.testing._internal.common_utils import (
10    instantiate_parametrized_tests,
11    parametrize,
12    run_tests,
13    TEST_WITH_TORCHDYNAMO,
14    TestCase,
15    xpassIfTorchDynamo,
16)
17
18
19# If we are going to trace through these, we should use NumPy
20# If testing on eager mode, we use torch._numpy
21if TEST_WITH_TORCHDYNAMO:
22    import numpy as np
23    import numpy.core.numeric as _util  # for normalize_axis_tuple
24    from numpy.testing import (
25        assert_allclose,
26        assert_almost_equal,
27        assert_array_equal,
28        assert_equal,
29    )
30else:
31    import torch._numpy as np
32    from torch._numpy import _util
33    from torch._numpy.testing import (
34        assert_allclose,
35        assert_almost_equal,
36        assert_array_equal,
37        assert_equal,
38    )
39
40
41class TestFlatnonzero(TestCase):
42    def test_basic(self):
43        x = np.arange(-2, 3)
44        assert_equal(np.flatnonzero(x), [0, 1, 3, 4])
45
46
47class TestAny(TestCase):
48    def test_basic(self):
49        y1 = [0, 0, 1, 0]
50        y2 = [0, 0, 0, 0]
51        y3 = [1, 0, 1, 0]
52        assert np.any(y1)
53        assert np.any(y3)
54        assert not np.any(y2)
55
56    def test_nd(self):
57        y1 = [[0, 0, 0], [0, 1, 0], [1, 1, 0]]
58        assert np.any(y1)
59        assert_equal(np.any(y1, axis=0), [1, 1, 0])
60        assert_equal(np.any(y1, axis=1), [0, 1, 1])
61        assert_equal(np.any(y1), True)
62        assert isinstance(np.any(y1, axis=1), np.ndarray)
63
64    # YYY: deduplicate
65    def test_method_vs_function(self):
66        y = np.array([[0, 1, 0, 3], [1, 0, 2, 0]])
67        assert_equal(np.any(y), y.any())
68
69
70class TestAll(TestCase):
71    def test_basic(self):
72        y1 = [0, 1, 1, 0]
73        y2 = [0, 0, 0, 0]
74        y3 = [1, 1, 1, 1]
75        assert not np.all(y1)
76        assert np.all(y3)
77        assert not np.all(y2)
78        assert np.all(~np.array(y2))
79
80    def test_nd(self):
81        y1 = [[0, 0, 1], [0, 1, 1], [1, 1, 1]]
82        assert not np.all(y1)
83        assert_equal(np.all(y1, axis=0), [0, 0, 1])
84        assert_equal(np.all(y1, axis=1), [0, 0, 1])
85        assert_equal(np.all(y1), False)
86
87    def test_method_vs_function(self):
88        y = np.array([[0, 1, 0, 3], [1, 0, 2, 0]])
89        assert_equal(np.all(y), y.all())
90
91
92class TestMean(TestCase):
93    def test_mean(self):
94        A = [[1, 2, 3], [4, 5, 6]]
95        assert np.mean(A) == 3.5
96        assert np.all(np.mean(A, 0) == np.array([2.5, 3.5, 4.5]))
97        assert np.all(np.mean(A, 1) == np.array([2.0, 5.0]))
98
99        # XXX: numpy emits a warning on empty slice
100        assert np.isnan(np.mean([]))
101
102        m = np.asarray(A)
103        assert np.mean(A) == m.mean()
104
105    def test_mean_values(self):
106        # rmat = np.random.random((4, 5))
107        rmat = np.arange(20, dtype=float).reshape((4, 5))
108        cmat = rmat + 1j * rmat
109
110        import warnings
111
112        with warnings.catch_warnings():
113            warnings.simplefilter("error")
114            for mat in [rmat, cmat]:
115                for axis in [0, 1]:
116                    tgt = mat.sum(axis=axis)
117                    res = np.mean(mat, axis=axis) * mat.shape[axis]
118                    assert_allclose(res, tgt)
119
120                for axis in [None]:
121                    tgt = mat.sum(axis=axis)
122                    res = np.mean(mat, axis=axis) * mat.size
123                    assert_allclose(res, tgt)
124
125    def test_mean_float16(self):
126        # This fail if the sum inside mean is done in float16 instead
127        # of float32.
128        assert np.mean(np.ones(100000, dtype="float16")) == 1
129
130    @xpassIfTorchDynamo  # (reason="XXX: mean(..., where=...) not implemented")
131    def test_mean_where(self):
132        a = np.arange(16).reshape((4, 4))
133        wh_full = np.array(
134            [
135                [False, True, False, True],
136                [True, False, True, False],
137                [True, True, False, False],
138                [False, False, True, True],
139            ]
140        )
141        wh_partial = np.array([[False], [True], [True], [False]])
142        _cases = [
143            (1, True, [1.5, 5.5, 9.5, 13.5]),
144            (0, wh_full, [6.0, 5.0, 10.0, 9.0]),
145            (1, wh_full, [2.0, 5.0, 8.5, 14.5]),
146            (0, wh_partial, [6.0, 7.0, 8.0, 9.0]),
147        ]
148        for _ax, _wh, _res in _cases:
149            assert_allclose(a.mean(axis=_ax, where=_wh), np.array(_res))
150            assert_allclose(np.mean(a, axis=_ax, where=_wh), np.array(_res))
151
152        a3d = np.arange(16).reshape((2, 2, 4))
153        _wh_partial = np.array([False, True, True, False])
154        _res = [[1.5, 5.5], [9.5, 13.5]]
155        assert_allclose(a3d.mean(axis=2, where=_wh_partial), np.array(_res))
156        assert_allclose(np.mean(a3d, axis=2, where=_wh_partial), np.array(_res))
157
158        with pytest.warns(RuntimeWarning) as w:
159            assert_allclose(
160                a.mean(axis=1, where=wh_partial), np.array([np.nan, 5.5, 9.5, np.nan])
161            )
162        with pytest.warns(RuntimeWarning) as w:
163            assert_equal(a.mean(where=False), np.nan)
164        with pytest.warns(RuntimeWarning) as w:
165            assert_equal(np.mean(a, where=False), np.nan)
166
167
168@instantiate_parametrized_tests
169class TestSum(TestCase):
170    def test_sum(self):
171        m = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
172        tgt = [[6], [15], [24]]
173        out = np.sum(m, axis=1, keepdims=True)
174        assert_equal(tgt, out)
175
176        am = np.asarray(m)
177        assert_equal(np.sum(m), am.sum())
178
179    def test_sum_stability(self):
180        a = np.ones(500, dtype=np.float32)
181        zero = np.zeros(1, dtype="float32")[0]
182        assert_allclose((a / 10.0).sum() - a.size / 10.0, zero, atol=1.5e-4)
183
184        a = np.ones(500, dtype=np.float64)
185        assert_allclose((a / 10.0).sum() - a.size / 10.0, 0.0, atol=1.5e-13)
186
187    def test_sum_boolean(self):
188        a = np.arange(7) % 2 == 0
189        res = a.sum()
190        assert_equal(res, 4)
191
192        res_float = a.sum(dtype=np.float64)
193        assert_allclose(res_float, 4.0, atol=1e-15)
194        assert res_float.dtype == "float64"
195
196    @skipIf(numpy.__version__ < "1.24", reason="NP_VER: fails on NumPy 1.23.x")
197    @xpassIfTorchDynamo  # (reason="sum: does not warn on overflow")
198    def test_sum_dtypes_warnings(self):
199        for dt in (int, np.float16, np.float32, np.float64):
200            for v in (0, 1, 2, 7, 8, 9, 15, 16, 19, 127, 128, 1024, 1235):
201                # warning if sum overflows, which it does in float16
202                import warnings
203
204                with warnings.catch_warnings(record=True) as w:
205                    warnings.simplefilter("always", RuntimeWarning)
206
207                    tgt = dt(v * (v + 1) / 2)
208                    overflow = not np.isfinite(tgt)
209                    assert_equal(len(w), 1 * overflow)
210
211                    d = np.arange(1, v + 1, dtype=dt)
212
213                    assert_almost_equal(np.sum(d), tgt)
214                    assert_equal(len(w), 2 * overflow)
215
216                    assert_almost_equal(np.sum(np.flip(d)), tgt)
217                    assert_equal(len(w), 3 * overflow)
218
219    def test_sum_dtypes_2(self):
220        for dt in (int, np.float16, np.float32, np.float64):
221            d = np.ones(500, dtype=dt)
222            assert_almost_equal(np.sum(d[::2]), 250.0)
223            assert_almost_equal(np.sum(d[1::2]), 250.0)
224            assert_almost_equal(np.sum(d[::3]), 167.0)
225            assert_almost_equal(np.sum(d[1::3]), 167.0)
226            assert_almost_equal(np.sum(np.flip(d)[::2]), 250.0)
227
228            assert_almost_equal(np.sum(np.flip(d)[1::2]), 250.0)
229
230            assert_almost_equal(np.sum(np.flip(d)[::3]), 167.0)
231            assert_almost_equal(np.sum(np.flip(d)[1::3]), 167.0)
232
233            # sum with first reduction entry != 0
234            d = np.ones((1,), dtype=dt)
235            d += d
236            assert_almost_equal(d, 2.0)
237
238    @parametrize("dt", [np.complex64, np.complex128])
239    def test_sum_complex_1(self, dt):
240        for v in (0, 1, 2, 7, 8, 9, 15, 16, 19, 127, 128, 1024, 1235):
241            tgt = dt(v * (v + 1) / 2) - dt((v * (v + 1) / 2) * 1j)
242            d = np.empty(v, dtype=dt)
243            d.real = np.arange(1, v + 1)
244            d.imag = -np.arange(1, v + 1)
245            assert_allclose(np.sum(d), tgt, atol=1.5e-5)
246            assert_allclose(np.sum(np.flip(d)), tgt, atol=1.5e-7)
247
248    @parametrize("dt", [np.complex64, np.complex128])
249    def test_sum_complex_2(self, dt):
250        d = np.ones(500, dtype=dt) + 1j
251        assert_allclose(np.sum(d[::2]), 250.0 + 250j, atol=1.5e-7)
252        assert_allclose(np.sum(d[1::2]), 250.0 + 250j, atol=1.5e-7)
253        assert_allclose(np.sum(d[::3]), 167.0 + 167j, atol=1.5e-7)
254        assert_allclose(np.sum(d[1::3]), 167.0 + 167j, atol=1.5e-7)
255        assert_allclose(np.sum(np.flip(d)[::2]), 250.0 + 250j, atol=1.5e-7)
256        assert_allclose(np.sum(np.flip(d)[1::2]), 250.0 + 250j, atol=1.5e-7)
257        assert_allclose(np.sum(np.flip(d)[::3]), 167.0 + 167j, atol=1.5e-7)
258        assert_allclose(np.sum(np.flip(d)[1::3]), 167.0 + 167j, atol=1.5e-7)
259        # sum with first reduction entry != 0
260        d = np.ones((1,), dtype=dt) + 1j
261        d += d
262        assert_allclose(d, 2.0 + 2j, atol=1.5e-7)
263
264    @xpassIfTorchDynamo  # (reason="initial=... need implementing")
265    def test_sum_initial(self):
266        # Integer, single axis
267        assert_equal(np.sum([3], initial=2), 5)
268
269        # Floating point
270        assert_almost_equal(np.sum([0.2], initial=0.1), 0.3)
271
272        # Multiple non-adjacent axes
273        assert_equal(
274            np.sum(np.ones((2, 3, 5), dtype=np.int64), axis=(0, 2), initial=2),
275            [12, 12, 12],
276        )
277
278    @xpassIfTorchDynamo  # (reason="where=... need implementing")
279    def test_sum_where(self):
280        # More extensive tests done in test_reduction_with_where.
281        assert_equal(np.sum([[1.0, 2.0], [3.0, 4.0]], where=[True, False]), 4.0)
282        assert_equal(
283            np.sum([[1.0, 2.0], [3.0, 4.0]], axis=0, initial=5.0, where=[True, False]),
284            [9.0, 5.0],
285        )
286
287
288parametrize_axis = parametrize(
289    "axis", [0, 1, 2, -1, -2, (0, 1), (1, 0), (0, 1, 2), (1, -1, 0)]
290)
291parametrize_func = parametrize(
292    "func",
293    [
294        np.any,
295        np.all,
296        np.argmin,
297        np.argmax,
298        np.min,
299        np.max,
300        np.mean,
301        np.sum,
302        np.prod,
303        np.std,
304        np.var,
305        np.count_nonzero,
306    ],
307)
308
309fails_axes_tuples = {
310    np.any,
311    np.all,
312    np.argmin,
313    np.argmax,
314    np.prod,
315}
316
317fails_out_arg = {
318    np.count_nonzero,
319}
320
321restricts_dtype_casts = {np.var, np.std}
322
323fails_empty_tuple = {np.argmin, np.argmax}
324
325
326@instantiate_parametrized_tests
327class TestGenericReductions(TestCase):
328    """Run a set of generic tests to verify that self.func acts like a
329    reduction operation.
330
331    Specifically, this class checks axis=... and keepdims=... parameters.
332    To check the out=... parameter, see the _GenericHasOutTestMixin class below.
333
334    To use: subclass, define self.func and self.allowed_axes.
335    """
336
337    @parametrize_func
338    def test_bad_axis(self, func):
339        # Basic check of functionality
340        m = np.array([[0, 1, 7, 0, 0], [3, 0, 0, 2, 19]])
341
342        assert_raises(TypeError, func, m, axis="foo")
343        assert_raises(np.AxisError, func, m, axis=3)
344        assert_raises(TypeError, func, m, axis=np.array([[1], [2]]))
345        assert_raises(TypeError, func, m, axis=1.5)
346
347        # TODO: add tests with np.int32(3) etc, when implemented
348
349    @parametrize_func
350    def test_array_axis(self, func):
351        a = np.array([[0, 1, 7, 0, 0], [3, 0, 0, 2, 19]])
352        assert_equal(func(a, axis=np.array(-1)), func(a, axis=-1))
353
354        with assert_raises(TypeError):
355            func(a, axis=np.array([1, 2]))
356
357    @parametrize_func
358    def test_axis_empty_generic(self, func):
359        if func in fails_empty_tuple:
360            raise SkipTest("func(..., axis=()) is not valid")
361
362        a = np.array([[0, 0, 1], [1, 0, 1]])
363        assert_array_equal(func(a, axis=()), func(np.expand_dims(a, axis=0), axis=0))
364
365    @parametrize_func
366    def test_axis_bad_tuple(self, func):
367        # Basic check of functionality
368        m = np.array([[0, 1, 7, 0, 0], [3, 0, 0, 2, 19]])
369
370        if func in fails_axes_tuples:
371            raise SkipTest(f"{func.__name__} does not allow tuple axis.")
372
373        with assert_raises(ValueError):
374            func(m, axis=(1, 1))
375
376    @parametrize_axis
377    @parametrize_func
378    def test_keepdims_generic(self, axis, func):
379        if func in fails_axes_tuples:
380            raise SkipTest(f"{func.__name__} does not allow tuple axis.")
381
382        a = np.arange(2 * 3 * 4).reshape((2, 3, 4))
383        with_keepdims = func(a, axis, keepdims=True)
384        expanded = np.expand_dims(func(a, axis=axis), axis=axis)
385        assert_array_equal(with_keepdims, expanded)
386
387    @skipIf(numpy.__version__ < "1.24", reason="NP_VER: fails on CI w/old numpy")
388    @parametrize_func
389    def test_keepdims_generic_axis_none(self, func):
390        a = np.arange(2 * 3 * 4).reshape((2, 3, 4))
391        with_keepdims = func(a, axis=None, keepdims=True)
392        scalar = func(a, axis=None)
393        expanded = np.full((1,) * a.ndim, fill_value=scalar)
394        assert_array_equal(with_keepdims, expanded)
395
396    @parametrize_func
397    def test_out_scalar(self, func):
398        # out no axis: scalar
399        if func in fails_out_arg:
400            raise SkipTest(f"{func.__name__} does not have out= arg.")
401
402        a = np.arange(2 * 3 * 4).reshape((2, 3, 4))
403
404        result = func(a)
405        out = np.empty_like(result)
406        result_with_out = func(a, out=out)
407
408        assert result_with_out is out
409        assert_array_equal(result, result_with_out)
410
411    def _check_out_axis(self, axis, dtype, keepdims):
412        # out with axis
413        a = np.arange(2 * 3 * 4).reshape((2, 3, 4))
414        result = self.func(a, axis=axis, keepdims=keepdims).astype(dtype)
415
416        out = np.empty_like(result, dtype=dtype)
417        result_with_out = self.func(a, axis=axis, keepdims=keepdims, out=out)
418
419        assert result_with_out is out
420        assert result_with_out.dtype == dtype
421        assert_array_equal(result, result_with_out)
422
423        # TODO: what if result.dtype != out.dtype; does out typecast the result?
424
425        # out of wrong shape (any/out does not broadcast)
426        # np.any(m, out=np.empty_like(m)) raises a ValueError (wrong number
427        # of dimensions.)
428        # pytorch.any emits a warning and resizes the out array.
429        # Here we follow pytorch, since the result is a superset
430        # of the numpy functionality
431
432    @parametrize("keepdims", [True, False])
433    @parametrize("dtype", [bool, "int32", "float64"])
434    @parametrize_func
435    @parametrize_axis
436    def test_out_axis(self, func, axis, dtype, keepdims):
437        # out with axis
438        if func in fails_out_arg:
439            raise SkipTest(f"{func.__name__} does not have out= arg.")
440        if func in fails_axes_tuples:
441            raise SkipTest(f"{func.__name__} does not hangle tuple axis.")
442        if func in restricts_dtype_casts:
443            raise SkipTest(f"{func.__name__}: test implies float->int casts")
444
445        a = np.arange(2 * 3 * 4).reshape((2, 3, 4))
446        result = func(a, axis=axis, keepdims=keepdims).astype(dtype)
447
448        out = np.empty_like(result, dtype=dtype)
449        result_with_out = func(a, axis=axis, keepdims=keepdims, out=out)
450
451        assert result_with_out is out
452        assert result_with_out.dtype == dtype
453        assert_array_equal(result, result_with_out)
454
455        # TODO: what if result.dtype != out.dtype; does out typecast the result?
456
457        # out of wrong shape (any/out does not broadcast)
458        # np.any(m, out=np.empty_like(m)) raises a ValueError (wrong number
459        # of dimensions.)
460        # pytorch.any emits a warning and resizes the out array.
461        # Here we follow pytorch, since the result is a superset
462        # of the numpy functionality
463
464    @parametrize_func
465    @parametrize_axis
466    def test_keepdims_out(self, func, axis):
467        if func in fails_out_arg:
468            raise SkipTest(f"{func.__name__} does not have out= arg.")
469        if func in fails_axes_tuples:
470            raise SkipTest(f"{func.__name__} does not hangle tuple axis.")
471
472        d = np.ones((3, 5, 7, 11))
473        if axis is None:
474            shape_out = (1,) * d.ndim
475        else:
476            axis_norm = _util.normalize_axis_tuple(axis, d.ndim)
477            shape_out = tuple(
478                1 if i in axis_norm else d.shape[i] for i in range(d.ndim)
479            )
480        out = np.empty(shape_out)
481
482        result = func(d, axis=axis, keepdims=True, out=out)
483        assert result is out
484        assert_equal(result.shape, shape_out)
485
486
487@instantiate_parametrized_tests
488class TestGenericCumSumProd(TestCase):
489    """Run a set of generic tests to verify that cumsum/cumprod are sane."""
490
491    @parametrize("func", [np.cumsum, np.cumprod])
492    def test_bad_axis(self, func):
493        # Basic check of functionality
494        m = np.array([[0, 1, 7, 0, 0], [3, 0, 0, 2, 19]])
495
496        assert_raises(TypeError, func, m, axis="foo")
497        assert_raises(np.AxisError, func, m, axis=3)
498        assert_raises(TypeError, func, m, axis=np.array([[1], [2]]))
499        assert_raises(TypeError, func, m, axis=1.5)
500
501        # TODO: add tests with np.int32(3) etc, when implemented
502
503    @parametrize("func", [np.cumsum, np.cumprod])
504    def test_array_axis(self, func):
505        a = np.array([[0, 1, 7, 0, 0], [3, 0, 0, 2, 19]])
506        assert_equal(func(a, axis=np.array(-1)), func(a, axis=-1))
507
508        with assert_raises(TypeError):
509            func(a, axis=np.array([1, 2]))
510
511    @parametrize("func", [np.cumsum, np.cumprod])
512    def test_axis_empty_generic(self, func):
513        a = np.array([[0, 0, 1], [1, 0, 1]])
514        assert_array_equal(func(a, axis=None), func(a.ravel(), axis=0))
515
516    @parametrize("func", [np.cumsum, np.cumprod])
517    def test_axis_bad_tuple(self, func):
518        # Basic check of functionality
519        m = np.array([[0, 1, 7, 0, 0], [3, 0, 0, 2, 19]])
520        with assert_raises(TypeError):
521            func(m, axis=(1, 1))
522
523
524if __name__ == "__main__":
525    run_tests()
526