1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Special Math Ops.""" 16 17import collections 18import importlib 19 20import numpy as np 21 22from tensorflow.python.eager import backprop as tfe_backprop 23from tensorflow.python.eager import context as tfe_context 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import test_util 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import gradient_checker 29from tensorflow.python.ops import gradients_impl 30from tensorflow.python.ops.distributions import special_math 31from tensorflow.python.platform import test 32from tensorflow.python.platform import tf_logging 33 34 35def try_import(name): # pylint: disable=invalid-name 36 module = None 37 try: 38 module = importlib.import_module(name) 39 except ImportError as e: 40 tf_logging.warning("Could not import %s: %s" % (name, str(e))) 41 return module 42 43 44special = try_import("scipy.special") 45stats = try_import("scipy.stats") 46sm = special_math 47 48 49def _check_strictly_increasing(array_1d): 50 diff = np.diff(array_1d) 51 np.testing.assert_array_less(0, diff) 52 53 54def _make_grid(dtype, grid_spec): 55 """Returns a uniform grid + noise, reshaped to shape argument.""" 56 rng = np.random.RandomState(0) 57 num_points = np.prod(grid_spec.shape) 58 grid = np.linspace(grid_spec.min, grid_spec.max, num=num_points).astype(dtype) 59 grid_spacing = (grid_spec.max - grid_spec.min) / num_points 60 grid += 0.1 * grid_spacing * rng.randn(*grid.shape) # pylint: disable=not-an-iterable 61 # More useful if it's sorted (e.g. for testing monotonicity, or debugging). 62 grid = np.sort(grid) 63 return np.reshape(grid, grid_spec.shape) 64 65 66def _value_and_gradient(fn, *args): 67 """Calls `fn` and computes the gradient of the result wrt `arg`.""" 68 if tfe_context.executing_eagerly(): 69 v, g = tfe_backprop.val_and_grad_function(fn)(args) 70 else: 71 v = fn(*args) 72 g = gradients_impl.gradients(v, args) 73 return v, g 74 75 76GridSpec = collections.namedtuple("GridSpec", ["min", "max", "shape"]) 77 78ErrorSpec = collections.namedtuple("ErrorSpec", ["rtol", "atol"]) 79 80 81class NdtriTest(test.TestCase): 82 83 def assertAllFinite(self, x): 84 is_finite = np.isfinite(x) 85 all_true = np.ones_like(is_finite, dtype=np.bool_) 86 self.assertAllEqual(all_true, is_finite) 87 88 @test_util.run_in_graph_and_eager_modes 89 def testNdtri(self): 90 """Verifies that ndtri computation is correct.""" 91 if not special: 92 return 93 94 p = np.linspace(0., 1.0, 50).astype(np.float64) 95 # Quantile performs piecewise rational approximation so adding some 96 # special input values to make sure we hit all the pieces. 97 p = np.hstack((p, np.exp(-32), 1. - np.exp(-32), np.exp(-2), 98 1. - np.exp(-2))) 99 expected_x = special.ndtri(p) 100 x = special_math.ndtri(p) 101 self.assertAllClose(expected_x, self.evaluate(x), atol=0.) 102 103 @test_util.run_deprecated_v1 104 def testNdtriDynamicShape(self): 105 """Verifies that ndtri computation is correct.""" 106 with self.cached_session() as sess: 107 if not special: 108 return 109 110 p = array_ops.placeholder(np.float32) 111 p_ = np.linspace(0., 1.0, 50).astype(np.float32) 112 113 x = special_math.ndtri(p) 114 x_ = sess.run(x, feed_dict={p: p_}) 115 116 expected_x_ = special.ndtri(p_) 117 self.assertAllClose(expected_x_, x_, atol=0.) 118 119 def _baseNdtriFiniteGradientTest(self, dtype): 120 """Verifies that ndtri has finite gradients at interesting points.""" 121 # Tests gradients at 0, 1, and piece-wise boundaries. 122 p = constant_op.constant( 123 np.array([ 124 0., 125 np.exp(-32.), 126 np.exp(-2.), 127 1. - np.exp(-2.), 128 1. - np.exp(-32.), 129 1., 130 ]).astype(dtype)) 131 # Not having the lambda sanitizer means we'd get an `IndexError` whenever 132 # the user supplied function has default args. 133 _, grads = _value_and_gradient( 134 lambda x: special_math.ndtri(x), p) # pylint: disable=unnecessary-lambda 135 self.assertAllFinite(self.evaluate(grads[0])) 136 137 @test_util.run_in_graph_and_eager_modes 138 def testNdtriFiniteGradientFloat32(self): 139 self._baseNdtriFiniteGradientTest(np.float32) 140 141 @test_util.run_in_graph_and_eager_modes 142 def testNdtriFiniteGradientFloat64(self): 143 self._baseNdtriFiniteGradientTest(np.float64) 144 145 146@test_util.run_all_in_graph_and_eager_modes 147class NdtrTest(test.TestCase): 148 _use_log = False 149 # Grid min/max chosen to ensure 0 < cdf(x) < 1. 150 _grid32 = GridSpec(min=-12.9, max=5., shape=[100]) 151 _grid64 = GridSpec(min=-37.5, max=8., shape=[100]) 152 _error32 = ErrorSpec(rtol=1e-4, atol=0.) 153 _error64 = ErrorSpec(rtol=1e-6, atol=0.) 154 155 def _test_grid(self, dtype, grid_spec, error_spec): 156 if self._use_log: 157 self._test_grid_log(dtype, grid_spec, error_spec) 158 else: 159 self._test_grid_no_log(dtype, grid_spec, error_spec) 160 161 def _test_grid_log(self, dtype, grid_spec, error_spec): 162 if not special: 163 return 164 165 grid = _make_grid(dtype, grid_spec) 166 actual = self.evaluate(sm.log_ndtr(grid)) 167 168 # Basic tests. 169 # isfinite checks for NaN and Inf. 170 self.assertTrue(np.isfinite(actual).all()) 171 # On the grid, -inf < log_cdf(x) < 0. In this case, we should be able 172 # to use a huge grid because we have used tricks to escape numerical 173 # difficulties. 174 self.assertTrue((actual < 0).all()) 175 _check_strictly_increasing(actual) 176 177 # Versus scipy. 178 expected = special.log_ndtr(grid) 179 # Scipy prematurely goes to zero at some places that we don't. So don't 180 # include these in the comparison. 181 self.assertAllClose( 182 expected.astype(np.float64)[expected < 0], 183 actual.astype(np.float64)[expected < 0], 184 rtol=error_spec.rtol, 185 atol=error_spec.atol) 186 187 def _test_grid_no_log(self, dtype, grid_spec, error_spec): 188 if not special: 189 return 190 191 grid = _make_grid(dtype, grid_spec) 192 actual = self.evaluate(sm.ndtr(grid)) 193 194 # Basic tests. 195 # isfinite checks for NaN and Inf. 196 self.assertTrue(np.isfinite(actual).all()) 197 # On the grid, 0 < cdf(x) < 1. The grid cannot contain everything due 198 # to numerical limitations of cdf. 199 self.assertTrue((actual > 0).all()) 200 self.assertTrue((actual < 1).all()) 201 _check_strictly_increasing(actual) 202 203 # Versus scipy. 204 expected = special.ndtr(grid) 205 # Scipy prematurely goes to zero at some places that we don't. So don't 206 # include these in the comparison. 207 self.assertAllClose( 208 expected.astype(np.float64)[expected < 0], 209 actual.astype(np.float64)[expected < 0], 210 rtol=error_spec.rtol, 211 atol=error_spec.atol) 212 213 @test_util.run_deprecated_v1 214 def test_float32(self): 215 self._test_grid(np.float32, self._grid32, self._error32) 216 217 @test_util.run_deprecated_v1 218 def test_float64(self): 219 self._test_grid(np.float64, self._grid64, self._error64) 220 221 222class LogNdtrTestLower(NdtrTest): 223 _use_log = True 224 _grid32 = GridSpec(min=-100., max=sm.LOGNDTR_FLOAT32_LOWER, shape=[100]) 225 _grid64 = GridSpec(min=-100., max=sm.LOGNDTR_FLOAT64_LOWER, shape=[100]) 226 _error32 = ErrorSpec(rtol=1e-4, atol=0.) 227 _error64 = ErrorSpec(rtol=1e-4, atol=0.) 228 229 230# The errors are quite large when the input is > 6 or so. Also, 231# scipy.special.log_ndtr becomes zero very early, before 10, 232# (due to ndtr becoming 1). We approximate Log[1 + epsilon] as epsilon, and 233# avoid this issue. 234class LogNdtrTestMid(NdtrTest): 235 _use_log = True 236 _grid32 = GridSpec( 237 min=sm.LOGNDTR_FLOAT32_LOWER, max=sm.LOGNDTR_FLOAT32_UPPER, shape=[100]) 238 _grid64 = GridSpec( 239 min=sm.LOGNDTR_FLOAT64_LOWER, max=sm.LOGNDTR_FLOAT64_UPPER, shape=[100]) 240 # Differences show up as soon as we're in the tail, so add some atol. 241 _error32 = ErrorSpec(rtol=0.1, atol=1e-7) 242 _error64 = ErrorSpec(rtol=0.1, atol=1e-7) 243 244 245class LogNdtrTestUpper(NdtrTest): 246 _use_log = True 247 _grid32 = GridSpec( 248 min=sm.LOGNDTR_FLOAT32_UPPER, 249 max=12., # Beyond this, log_cdf(x) may be zero. 250 shape=[100]) 251 _grid64 = GridSpec( 252 min=sm.LOGNDTR_FLOAT64_UPPER, 253 max=35., # Beyond this, log_cdf(x) may be zero. 254 shape=[100]) 255 _error32 = ErrorSpec(rtol=1e-6, atol=1e-14) 256 _error64 = ErrorSpec(rtol=1e-6, atol=1e-14) 257 258 259class NdtrGradientTest(test.TestCase): 260 _use_log = False 261 _grid = GridSpec(min=-100., max=100., shape=[1, 2, 3, 8]) 262 _error32 = ErrorSpec(rtol=1e-4, atol=0) 263 _error64 = ErrorSpec(rtol=1e-7, atol=0) 264 265 def assert_all_true(self, v): 266 self.assertAllEqual(np.ones_like(v, dtype=np.bool_), v) 267 268 def assert_all_false(self, v): 269 self.assertAllEqual(np.zeros_like(v, dtype=np.bool_), v) 270 271 def _test_grad_finite(self, dtype): 272 x = constant_op.constant([-100., 0., 100.], dtype=dtype) 273 output = (sm.log_ndtr(x) if self._use_log else sm.ndtr(x)) 274 fn = sm.log_ndtr if self._use_log else sm.ndtr 275 # Not having the lambda sanitizer means we'd get an `IndexError` whenever 276 # the user supplied function has default args. 277 output, grad_output = _value_and_gradient( 278 lambda x_: fn(x_), x) # pylint: disable=unnecessary-lambda 279 # isfinite checks for NaN and Inf. 280 output_, grad_output_ = self.evaluate([output, grad_output]) 281 self.assert_all_true(np.isfinite(output_)) 282 self.assert_all_true(np.isfinite(grad_output_[0])) 283 284 def _test_grad_accuracy(self, dtype, grid_spec, error_spec): 285 raw_grid = _make_grid(dtype, grid_spec) 286 grid = ops.convert_to_tensor(raw_grid) 287 with self.cached_session(): 288 fn = sm.log_ndtr if self._use_log else sm.ndtr 289 290 # If there are N points in the grid, 291 # grad_eval.shape = (N, N), with grad_eval[i, j] the partial derivative of 292 # the ith output point w.r.t. the jth grid point. We only expect the 293 # diagonal to be nonzero. 294 # TODO(b/31131137): Replace tf.compat.v1.test.compute_gradient with our 295 # own custom gradient evaluation to ensure we correctly handle small 296 # function delta. 297 grad_eval, _ = gradient_checker.compute_gradient(grid, grid_spec.shape, 298 fn(grid), 299 grid_spec.shape) 300 grad_eval = np.diag(grad_eval) 301 302 # Check for NaN separately in order to get informative failures. 303 self.assert_all_false(np.isnan(grad_eval)) 304 self.assert_all_true(grad_eval > 0.) 305 # isfinite checks for NaN and Inf. 306 self.assert_all_true(np.isfinite(grad_eval)) 307 308 # Do the same checks but explicitly compute the gradient. 309 # (We did this because we're not sure if we trust 310 # tf.test.compute_gradient.) 311 grad_eval = gradients_impl.gradients(fn(grid), grid)[0].eval() 312 self.assert_all_false(np.isnan(grad_eval)) 313 if self._use_log: 314 g = np.reshape(grad_eval, [-1]) 315 half = np.ceil(len(g) / 2) 316 self.assert_all_true(g[:int(half)] > 0.) 317 self.assert_all_true(g[int(half):] >= 0.) 318 else: 319 # The ndtr gradient will only be non-zero in the range [-14, 14] for 320 # float32 and [-38, 38] for float64. 321 self.assert_all_true(grad_eval >= 0.) 322 # isfinite checks for NaN and Inf. 323 self.assert_all_true(np.isfinite(grad_eval)) 324 325 # Versus scipy. 326 if not (special and stats): 327 return 328 329 expected = stats.norm.pdf(raw_grid) 330 if self._use_log: 331 expected /= special.ndtr(raw_grid) 332 expected[np.isnan(expected)] = 0. 333 # Scipy prematurely goes to zero at some places that we don't. So don't 334 # include these in the comparison. 335 self.assertAllClose( 336 expected.astype(np.float64)[expected < 0], 337 grad_eval.astype(np.float64)[expected < 0], 338 rtol=error_spec.rtol, 339 atol=error_spec.atol) 340 341 @test_util.run_deprecated_v1 342 def test_float32(self): 343 self._test_grad_accuracy(np.float32, self._grid, self._error32) 344 self._test_grad_finite(np.float32) 345 346 @test_util.run_deprecated_v1 347 def test_float64(self): 348 self._test_grad_accuracy(np.float64, self._grid, self._error64) 349 self._test_grad_finite(np.float64) 350 351 352class LogNdtrGradientTest(NdtrGradientTest): 353 _use_log = True 354 355 356class ErfInvTest(test.TestCase): 357 358 def testErfInvValues(self): 359 with self.cached_session(): 360 if not special: 361 return 362 363 x = np.linspace(0., 1.0, 50).astype(np.float64) 364 365 expected_x = special.erfinv(x) 366 x = special_math.erfinv(x) 367 self.assertAllClose(expected_x, self.evaluate(x), atol=0.) 368 369 def testErfInvIntegerInput(self): 370 with self.cached_session(): 371 372 with self.assertRaises(TypeError): 373 x = np.array([1, 2, 3]).astype(np.int32) 374 special_math.erfinv(x) 375 376 with self.assertRaises(TypeError): 377 x = np.array([1, 2, 3]).astype(np.int64) 378 special_math.erfinv(x) 379 380 381class LogCDFLaplaceTest(test.TestCase): 382 # Note that scipy.stats.laplace does not have a stable Log CDF, so we cannot 383 # rely on scipy to cross check the extreme values. 384 385 # Test will be done differently over different ranges. These are the values 386 # such that when exceeded by x, produce output that causes the naive (scipy) 387 # implementation to have numerical issues. 388 # 389 # If x = log(1 / (2 * eps)), then 0.5 * exp{-x} = eps. 390 # With inserting eps = np.finfo(dtype).eps, we see that log(1 / (2 * eps)) is 391 # the value of x such that any larger value will result in 392 # 1 - 0.5 * exp{-x} = 0, which will cause the log_cdf_laplace code to take a 393 # log # of zero. We therefore choose these as our cutoffs for testing. 394 CUTOFF_FLOAT64_UPPER = np.log(1. / (2. * np.finfo(np.float64).eps)) - 1. 395 CUTOFF_FLOAT32_UPPER = np.log(1. / (2. * np.finfo(np.float32).eps)) - 1. 396 397 def assertAllTrue(self, x): 398 self.assertAllEqual(np.ones_like(x, dtype=np.bool_), x) 399 400 def _test_grid_log(self, dtype, scipy_dtype, grid_spec, error_spec): 401 with self.cached_session(): 402 grid = _make_grid(dtype, grid_spec) 403 actual = sm.log_cdf_laplace(grid).eval() 404 405 # Basic tests. 406 # isfinite checks for NaN and Inf. 407 self.assertAllTrue(np.isfinite(actual)) 408 self.assertAllTrue((actual < 0)) 409 _check_strictly_increasing(actual) 410 411 # Versus scipy. 412 if not stats: 413 return 414 415 scipy_dist = stats.laplace(loc=0., scale=1.) 416 expected = scipy_dist.logcdf(grid.astype(scipy_dtype)) 417 self.assertAllClose( 418 expected.astype(np.float64), 419 actual.astype(np.float64), 420 rtol=error_spec.rtol, 421 atol=error_spec.atol) 422 423 @test_util.run_deprecated_v1 424 def test_float32_lower_and_mid_segment_scipy_float32_ok(self): 425 # Choose values mild enough that we can use scipy in float32, which will 426 # allow for a high accuracy match to scipy (since we both use float32). 427 self._test_grid_log( 428 np.float32, # dtype 429 np.float32, # scipy_dtype 430 GridSpec(min=-10, max=self.CUTOFF_FLOAT32_UPPER - 5, shape=[100]), 431 ErrorSpec(rtol=5e-4, atol=0)) 432 433 @test_util.run_deprecated_v1 434 def test_float32_all_segments_with_scipy_float64_ok(self): 435 # Choose values outside the range where scipy float32 works. 436 # Let scipy use float64. This means we 437 # won't be exactly the same since we are in float32. 438 self._test_grid_log( 439 np.float32, # dtype 440 np.float64, # scipy_dtype 441 GridSpec(min=-50, max=self.CUTOFF_FLOAT32_UPPER + 5, shape=[100]), 442 ErrorSpec(rtol=0.05, atol=0)) 443 444 @test_util.run_deprecated_v1 445 def test_float32_extreme_values_result_and_gradient_finite_and_nonzero(self): 446 with self.cached_session() as sess: 447 # On the lower branch, log_cdf_laplace(x) = x, so we know this will be 448 # fine, but test to -200 anyways. 449 grid = _make_grid( 450 np.float32, GridSpec(min=-200, max=80, shape=[20, 100])) 451 grid = ops.convert_to_tensor(grid) 452 453 actual = sm.log_cdf_laplace(grid) 454 grad = gradients_impl.gradients(actual, grid)[0] 455 456 actual_, grad_ = self.evaluate([actual, grad]) 457 458 # isfinite checks for NaN and Inf. 459 self.assertAllTrue(np.isfinite(actual_)) 460 self.assertAllTrue(np.isfinite(grad_)) 461 self.assertFalse(np.any(actual_ == 0)) 462 self.assertFalse(np.any(grad_ == 0)) 463 464 @test_util.run_deprecated_v1 465 def test_float64_extreme_values_result_and_gradient_finite_and_nonzero(self): 466 with self.cached_session() as sess: 467 # On the lower branch, log_cdf_laplace(x) = x, so we know this will be 468 # fine, but test to -200 anyways. 469 grid = _make_grid( 470 np.float64, GridSpec(min=-200, max=700, shape=[20, 100])) 471 grid = ops.convert_to_tensor(grid) 472 473 actual = sm.log_cdf_laplace(grid) 474 grad = gradients_impl.gradients(actual, grid)[0] 475 476 actual_, grad_ = self.evaluate([actual, grad]) 477 478 # isfinite checks for NaN and Inf. 479 self.assertAllTrue(np.isfinite(actual_)) 480 self.assertAllTrue(np.isfinite(grad_)) 481 self.assertFalse(np.any(actual_ == 0)) 482 self.assertFalse(np.any(grad_ == 0)) 483 484 485if __name__ == "__main__": 486 test.main() 487