xref: /aosp_15_r20/external/tensorflow/tensorflow/python/kernel_tests/distributions/special_math_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Special Math Ops."""
16
17import collections
18import importlib
19
20import numpy as np
21
22from tensorflow.python.eager import backprop as tfe_backprop
23from tensorflow.python.eager import context as tfe_context
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import test_util
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import gradient_checker
29from tensorflow.python.ops import gradients_impl
30from tensorflow.python.ops.distributions import special_math
31from tensorflow.python.platform import test
32from tensorflow.python.platform import tf_logging
33
34
35def try_import(name):  # pylint: disable=invalid-name
36  module = None
37  try:
38    module = importlib.import_module(name)
39  except ImportError as e:
40    tf_logging.warning("Could not import %s: %s" % (name, str(e)))
41  return module
42
43
44special = try_import("scipy.special")
45stats = try_import("scipy.stats")
46sm = special_math
47
48
49def _check_strictly_increasing(array_1d):
50  diff = np.diff(array_1d)
51  np.testing.assert_array_less(0, diff)
52
53
54def _make_grid(dtype, grid_spec):
55  """Returns a uniform grid + noise, reshaped to shape argument."""
56  rng = np.random.RandomState(0)
57  num_points = np.prod(grid_spec.shape)
58  grid = np.linspace(grid_spec.min, grid_spec.max, num=num_points).astype(dtype)
59  grid_spacing = (grid_spec.max - grid_spec.min) / num_points
60  grid += 0.1 * grid_spacing * rng.randn(*grid.shape)  # pylint: disable=not-an-iterable
61  # More useful if it's sorted (e.g. for testing monotonicity, or debugging).
62  grid = np.sort(grid)
63  return np.reshape(grid, grid_spec.shape)
64
65
66def _value_and_gradient(fn, *args):
67  """Calls `fn` and computes the gradient of the result wrt `arg`."""
68  if tfe_context.executing_eagerly():
69    v, g = tfe_backprop.val_and_grad_function(fn)(args)
70  else:
71    v = fn(*args)
72    g = gradients_impl.gradients(v, args)
73  return v, g
74
75
76GridSpec = collections.namedtuple("GridSpec", ["min", "max", "shape"])
77
78ErrorSpec = collections.namedtuple("ErrorSpec", ["rtol", "atol"])
79
80
81class NdtriTest(test.TestCase):
82
83  def assertAllFinite(self, x):
84    is_finite = np.isfinite(x)
85    all_true = np.ones_like(is_finite, dtype=np.bool_)
86    self.assertAllEqual(all_true, is_finite)
87
88  @test_util.run_in_graph_and_eager_modes
89  def testNdtri(self):
90    """Verifies that ndtri computation is correct."""
91    if not special:
92      return
93
94    p = np.linspace(0., 1.0, 50).astype(np.float64)
95    # Quantile performs piecewise rational approximation so adding some
96    # special input values to make sure we hit all the pieces.
97    p = np.hstack((p, np.exp(-32), 1. - np.exp(-32), np.exp(-2),
98                   1. - np.exp(-2)))
99    expected_x = special.ndtri(p)
100    x = special_math.ndtri(p)
101    self.assertAllClose(expected_x, self.evaluate(x), atol=0.)
102
103  @test_util.run_deprecated_v1
104  def testNdtriDynamicShape(self):
105    """Verifies that ndtri computation is correct."""
106    with self.cached_session() as sess:
107      if not special:
108        return
109
110      p = array_ops.placeholder(np.float32)
111      p_ = np.linspace(0., 1.0, 50).astype(np.float32)
112
113      x = special_math.ndtri(p)
114      x_ = sess.run(x, feed_dict={p: p_})
115
116      expected_x_ = special.ndtri(p_)
117      self.assertAllClose(expected_x_, x_, atol=0.)
118
119  def _baseNdtriFiniteGradientTest(self, dtype):
120    """Verifies that ndtri has finite gradients at interesting points."""
121    # Tests gradients at 0, 1, and piece-wise boundaries.
122    p = constant_op.constant(
123        np.array([
124            0.,
125            np.exp(-32.),
126            np.exp(-2.),
127            1. - np.exp(-2.),
128            1. - np.exp(-32.),
129            1.,
130        ]).astype(dtype))
131    # Not having the lambda sanitizer means we'd get an `IndexError` whenever
132    # the user supplied function has default args.
133    _, grads = _value_and_gradient(
134        lambda x: special_math.ndtri(x), p)  # pylint: disable=unnecessary-lambda
135    self.assertAllFinite(self.evaluate(grads[0]))
136
137  @test_util.run_in_graph_and_eager_modes
138  def testNdtriFiniteGradientFloat32(self):
139    self._baseNdtriFiniteGradientTest(np.float32)
140
141  @test_util.run_in_graph_and_eager_modes
142  def testNdtriFiniteGradientFloat64(self):
143    self._baseNdtriFiniteGradientTest(np.float64)
144
145
146@test_util.run_all_in_graph_and_eager_modes
147class NdtrTest(test.TestCase):
148  _use_log = False
149  # Grid min/max chosen to ensure 0 < cdf(x) < 1.
150  _grid32 = GridSpec(min=-12.9, max=5., shape=[100])
151  _grid64 = GridSpec(min=-37.5, max=8., shape=[100])
152  _error32 = ErrorSpec(rtol=1e-4, atol=0.)
153  _error64 = ErrorSpec(rtol=1e-6, atol=0.)
154
155  def _test_grid(self, dtype, grid_spec, error_spec):
156    if self._use_log:
157      self._test_grid_log(dtype, grid_spec, error_spec)
158    else:
159      self._test_grid_no_log(dtype, grid_spec, error_spec)
160
161  def _test_grid_log(self, dtype, grid_spec, error_spec):
162    if not special:
163      return
164
165    grid = _make_grid(dtype, grid_spec)
166    actual = self.evaluate(sm.log_ndtr(grid))
167
168    # Basic tests.
169    # isfinite checks for NaN and Inf.
170    self.assertTrue(np.isfinite(actual).all())
171    # On the grid, -inf < log_cdf(x) < 0.  In this case, we should be able
172    # to use a huge grid because we have used tricks to escape numerical
173    # difficulties.
174    self.assertTrue((actual < 0).all())
175    _check_strictly_increasing(actual)
176
177    # Versus scipy.
178    expected = special.log_ndtr(grid)
179    # Scipy prematurely goes to zero at some places that we don't.  So don't
180    # include these in the comparison.
181    self.assertAllClose(
182        expected.astype(np.float64)[expected < 0],
183        actual.astype(np.float64)[expected < 0],
184        rtol=error_spec.rtol,
185        atol=error_spec.atol)
186
187  def _test_grid_no_log(self, dtype, grid_spec, error_spec):
188    if not special:
189      return
190
191    grid = _make_grid(dtype, grid_spec)
192    actual = self.evaluate(sm.ndtr(grid))
193
194    # Basic tests.
195    # isfinite checks for NaN and Inf.
196    self.assertTrue(np.isfinite(actual).all())
197    # On the grid, 0 < cdf(x) < 1.  The grid cannot contain everything due
198    # to numerical limitations of cdf.
199    self.assertTrue((actual > 0).all())
200    self.assertTrue((actual < 1).all())
201    _check_strictly_increasing(actual)
202
203    # Versus scipy.
204    expected = special.ndtr(grid)
205    # Scipy prematurely goes to zero at some places that we don't.  So don't
206    # include these in the comparison.
207    self.assertAllClose(
208        expected.astype(np.float64)[expected < 0],
209        actual.astype(np.float64)[expected < 0],
210        rtol=error_spec.rtol,
211        atol=error_spec.atol)
212
213  @test_util.run_deprecated_v1
214  def test_float32(self):
215    self._test_grid(np.float32, self._grid32, self._error32)
216
217  @test_util.run_deprecated_v1
218  def test_float64(self):
219    self._test_grid(np.float64, self._grid64, self._error64)
220
221
222class LogNdtrTestLower(NdtrTest):
223  _use_log = True
224  _grid32 = GridSpec(min=-100., max=sm.LOGNDTR_FLOAT32_LOWER, shape=[100])
225  _grid64 = GridSpec(min=-100., max=sm.LOGNDTR_FLOAT64_LOWER, shape=[100])
226  _error32 = ErrorSpec(rtol=1e-4, atol=0.)
227  _error64 = ErrorSpec(rtol=1e-4, atol=0.)
228
229
230# The errors are quite large when the input is > 6 or so.  Also,
231# scipy.special.log_ndtr becomes zero very early, before 10,
232# (due to ndtr becoming 1).  We approximate Log[1 + epsilon] as epsilon, and
233# avoid this issue.
234class LogNdtrTestMid(NdtrTest):
235  _use_log = True
236  _grid32 = GridSpec(
237      min=sm.LOGNDTR_FLOAT32_LOWER, max=sm.LOGNDTR_FLOAT32_UPPER, shape=[100])
238  _grid64 = GridSpec(
239      min=sm.LOGNDTR_FLOAT64_LOWER, max=sm.LOGNDTR_FLOAT64_UPPER, shape=[100])
240  # Differences show up as soon as we're in the tail, so add some atol.
241  _error32 = ErrorSpec(rtol=0.1, atol=1e-7)
242  _error64 = ErrorSpec(rtol=0.1, atol=1e-7)
243
244
245class LogNdtrTestUpper(NdtrTest):
246  _use_log = True
247  _grid32 = GridSpec(
248      min=sm.LOGNDTR_FLOAT32_UPPER,
249      max=12.,  # Beyond this, log_cdf(x) may be zero.
250      shape=[100])
251  _grid64 = GridSpec(
252      min=sm.LOGNDTR_FLOAT64_UPPER,
253      max=35.,  # Beyond this, log_cdf(x) may be zero.
254      shape=[100])
255  _error32 = ErrorSpec(rtol=1e-6, atol=1e-14)
256  _error64 = ErrorSpec(rtol=1e-6, atol=1e-14)
257
258
259class NdtrGradientTest(test.TestCase):
260  _use_log = False
261  _grid = GridSpec(min=-100., max=100., shape=[1, 2, 3, 8])
262  _error32 = ErrorSpec(rtol=1e-4, atol=0)
263  _error64 = ErrorSpec(rtol=1e-7, atol=0)
264
265  def assert_all_true(self, v):
266    self.assertAllEqual(np.ones_like(v, dtype=np.bool_), v)
267
268  def assert_all_false(self, v):
269    self.assertAllEqual(np.zeros_like(v, dtype=np.bool_), v)
270
271  def _test_grad_finite(self, dtype):
272    x = constant_op.constant([-100., 0., 100.], dtype=dtype)
273    output = (sm.log_ndtr(x) if self._use_log else sm.ndtr(x))
274    fn = sm.log_ndtr if self._use_log else sm.ndtr
275    # Not having the lambda sanitizer means we'd get an `IndexError` whenever
276    # the user supplied function has default args.
277    output, grad_output = _value_and_gradient(
278        lambda x_: fn(x_), x)  # pylint: disable=unnecessary-lambda
279    # isfinite checks for NaN and Inf.
280    output_, grad_output_ = self.evaluate([output, grad_output])
281    self.assert_all_true(np.isfinite(output_))
282    self.assert_all_true(np.isfinite(grad_output_[0]))
283
284  def _test_grad_accuracy(self, dtype, grid_spec, error_spec):
285    raw_grid = _make_grid(dtype, grid_spec)
286    grid = ops.convert_to_tensor(raw_grid)
287    with self.cached_session():
288      fn = sm.log_ndtr if self._use_log else sm.ndtr
289
290      # If there are N points in the grid,
291      # grad_eval.shape = (N, N), with grad_eval[i, j] the partial derivative of
292      # the ith output point w.r.t. the jth grid point.  We only expect the
293      # diagonal to be nonzero.
294      # TODO(b/31131137): Replace tf.compat.v1.test.compute_gradient with our
295      # own custom gradient evaluation to ensure we correctly handle small
296      # function delta.
297      grad_eval, _ = gradient_checker.compute_gradient(grid, grid_spec.shape,
298                                                       fn(grid),
299                                                       grid_spec.shape)
300      grad_eval = np.diag(grad_eval)
301
302      # Check for NaN separately in order to get informative failures.
303      self.assert_all_false(np.isnan(grad_eval))
304      self.assert_all_true(grad_eval > 0.)
305      # isfinite checks for NaN and Inf.
306      self.assert_all_true(np.isfinite(grad_eval))
307
308      # Do the same checks but explicitly compute the gradient.
309      # (We did this because we're not sure if we trust
310      # tf.test.compute_gradient.)
311      grad_eval = gradients_impl.gradients(fn(grid), grid)[0].eval()
312      self.assert_all_false(np.isnan(grad_eval))
313      if self._use_log:
314        g = np.reshape(grad_eval, [-1])
315        half = np.ceil(len(g) / 2)
316        self.assert_all_true(g[:int(half)] > 0.)
317        self.assert_all_true(g[int(half):] >= 0.)
318      else:
319        # The ndtr gradient will only be non-zero in the range [-14, 14] for
320        # float32 and [-38, 38] for float64.
321        self.assert_all_true(grad_eval >= 0.)
322      # isfinite checks for NaN and Inf.
323      self.assert_all_true(np.isfinite(grad_eval))
324
325      # Versus scipy.
326      if not (special and stats):
327        return
328
329      expected = stats.norm.pdf(raw_grid)
330      if self._use_log:
331        expected /= special.ndtr(raw_grid)
332        expected[np.isnan(expected)] = 0.
333      # Scipy prematurely goes to zero at some places that we don't.  So don't
334      # include these in the comparison.
335      self.assertAllClose(
336          expected.astype(np.float64)[expected < 0],
337          grad_eval.astype(np.float64)[expected < 0],
338          rtol=error_spec.rtol,
339          atol=error_spec.atol)
340
341  @test_util.run_deprecated_v1
342  def test_float32(self):
343    self._test_grad_accuracy(np.float32, self._grid, self._error32)
344    self._test_grad_finite(np.float32)
345
346  @test_util.run_deprecated_v1
347  def test_float64(self):
348    self._test_grad_accuracy(np.float64, self._grid, self._error64)
349    self._test_grad_finite(np.float64)
350
351
352class LogNdtrGradientTest(NdtrGradientTest):
353  _use_log = True
354
355
356class ErfInvTest(test.TestCase):
357
358  def testErfInvValues(self):
359    with self.cached_session():
360      if not special:
361        return
362
363      x = np.linspace(0., 1.0, 50).astype(np.float64)
364
365      expected_x = special.erfinv(x)
366      x = special_math.erfinv(x)
367      self.assertAllClose(expected_x, self.evaluate(x), atol=0.)
368
369  def testErfInvIntegerInput(self):
370    with self.cached_session():
371
372      with self.assertRaises(TypeError):
373        x = np.array([1, 2, 3]).astype(np.int32)
374        special_math.erfinv(x)
375
376      with self.assertRaises(TypeError):
377        x = np.array([1, 2, 3]).astype(np.int64)
378        special_math.erfinv(x)
379
380
381class LogCDFLaplaceTest(test.TestCase):
382  # Note that scipy.stats.laplace does not have a stable Log CDF, so we cannot
383  # rely on scipy to cross check the extreme values.
384
385  # Test will be done differently over different ranges.  These are the values
386  # such that when exceeded by x, produce output that causes the naive (scipy)
387  # implementation to have numerical issues.
388  #
389  # If x = log(1 / (2 * eps)), then 0.5 * exp{-x} = eps.
390  # With inserting eps = np.finfo(dtype).eps, we see that log(1 / (2 * eps)) is
391  # the value of x such that any larger value will result in
392  # 1 - 0.5 * exp{-x} = 0, which will cause the log_cdf_laplace code to take a
393  # log # of zero.  We therefore choose these as our cutoffs for testing.
394  CUTOFF_FLOAT64_UPPER = np.log(1. / (2. * np.finfo(np.float64).eps)) - 1.
395  CUTOFF_FLOAT32_UPPER = np.log(1. / (2. * np.finfo(np.float32).eps)) - 1.
396
397  def assertAllTrue(self, x):
398    self.assertAllEqual(np.ones_like(x, dtype=np.bool_), x)
399
400  def _test_grid_log(self, dtype, scipy_dtype, grid_spec, error_spec):
401    with self.cached_session():
402      grid = _make_grid(dtype, grid_spec)
403      actual = sm.log_cdf_laplace(grid).eval()
404
405      # Basic tests.
406      # isfinite checks for NaN and Inf.
407      self.assertAllTrue(np.isfinite(actual))
408      self.assertAllTrue((actual < 0))
409      _check_strictly_increasing(actual)
410
411      # Versus scipy.
412      if not stats:
413        return
414
415      scipy_dist = stats.laplace(loc=0., scale=1.)
416      expected = scipy_dist.logcdf(grid.astype(scipy_dtype))
417      self.assertAllClose(
418          expected.astype(np.float64),
419          actual.astype(np.float64),
420          rtol=error_spec.rtol,
421          atol=error_spec.atol)
422
423  @test_util.run_deprecated_v1
424  def test_float32_lower_and_mid_segment_scipy_float32_ok(self):
425    # Choose values mild enough that we can use scipy in float32, which will
426    # allow for a high accuracy match to scipy (since we both use float32).
427    self._test_grid_log(
428        np.float32,  # dtype
429        np.float32,  # scipy_dtype
430        GridSpec(min=-10, max=self.CUTOFF_FLOAT32_UPPER - 5, shape=[100]),
431        ErrorSpec(rtol=5e-4, atol=0))
432
433  @test_util.run_deprecated_v1
434  def test_float32_all_segments_with_scipy_float64_ok(self):
435    # Choose values outside the range where scipy float32 works.
436    # Let scipy use float64.  This means we
437    # won't be exactly the same since we are in float32.
438    self._test_grid_log(
439        np.float32,  # dtype
440        np.float64,  # scipy_dtype
441        GridSpec(min=-50, max=self.CUTOFF_FLOAT32_UPPER + 5, shape=[100]),
442        ErrorSpec(rtol=0.05, atol=0))
443
444  @test_util.run_deprecated_v1
445  def test_float32_extreme_values_result_and_gradient_finite_and_nonzero(self):
446    with self.cached_session() as sess:
447      # On the lower branch, log_cdf_laplace(x) = x, so we know this will be
448      # fine, but test to -200 anyways.
449      grid = _make_grid(
450          np.float32, GridSpec(min=-200, max=80, shape=[20, 100]))
451      grid = ops.convert_to_tensor(grid)
452
453      actual = sm.log_cdf_laplace(grid)
454      grad = gradients_impl.gradients(actual, grid)[0]
455
456      actual_, grad_ = self.evaluate([actual, grad])
457
458      # isfinite checks for NaN and Inf.
459      self.assertAllTrue(np.isfinite(actual_))
460      self.assertAllTrue(np.isfinite(grad_))
461      self.assertFalse(np.any(actual_ == 0))
462      self.assertFalse(np.any(grad_ == 0))
463
464  @test_util.run_deprecated_v1
465  def test_float64_extreme_values_result_and_gradient_finite_and_nonzero(self):
466    with self.cached_session() as sess:
467      # On the lower branch, log_cdf_laplace(x) = x, so we know this will be
468      # fine, but test to -200 anyways.
469      grid = _make_grid(
470          np.float64, GridSpec(min=-200, max=700, shape=[20, 100]))
471      grid = ops.convert_to_tensor(grid)
472
473      actual = sm.log_cdf_laplace(grid)
474      grad = gradients_impl.gradients(actual, grid)[0]
475
476      actual_, grad_ = self.evaluate([actual, grad])
477
478      # isfinite checks for NaN and Inf.
479      self.assertAllTrue(np.isfinite(actual_))
480      self.assertAllTrue(np.isfinite(grad_))
481      self.assertFalse(np.any(actual_ == 0))
482      self.assertFalse(np.any(grad_ == 0))
483
484
485if __name__ == "__main__":
486  test.main()
487