1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Grappler AutoMixedPrecision.""" 16 17import os 18 19from absl.testing import parameterized 20import numpy as np 21 22from tensorflow.core.framework import types_pb2 23from tensorflow.core.protobuf import config_pb2 24from tensorflow.core.protobuf import rewriter_config_pb2 25from tensorflow.python import tf2 26from tensorflow.python.client import session 27from tensorflow.python.data.ops import dataset_ops 28from tensorflow.python.framework import constant_op 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import function 31from tensorflow.python.framework import ops 32from tensorflow.python.framework import random_seed 33from tensorflow.python.framework import test_util 34from tensorflow.python.layers import layers 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import control_flow_ops 37from tensorflow.python.ops import init_ops 38from tensorflow.python.ops import math_ops 39from tensorflow.python.ops import nn 40from tensorflow.python.ops import nn_impl 41from tensorflow.python.ops import random_ops 42from tensorflow.python.ops import tensor_array_ops 43from tensorflow.python.ops import variables 44from tensorflow.python.ops.losses import losses 45from tensorflow.python.platform import sysconfig 46from tensorflow.python.platform import test 47from tensorflow.python.training import adam 48from tensorflow.python.training import gradient_descent 49from tensorflow.python.util import _pywrap_utils 50 51 52def _input(shape): 53 """Generates an input of a given shape.""" 54 return variables.Variable(random_ops.truncated_normal(shape, seed=0)) 55 56 57def _weight(shape): 58 """Generates a weight of a given shape.""" 59 # Note that the lambda is needed to allow construction inside loops. 60 return variables.Variable(lambda: init_ops.glorot_uniform_initializer(seed=0) 61 (shape)) 62 63 64def _bias(shape): 65 """Generates a bias of a given shape.""" 66 return constant_op.constant(0.1, shape=shape) 67 68 69def _conv2d(x, w): 70 """Returns a 2d convolution layer with full stride.""" 71 return nn.conv2d(x, w, strides=[1, 1, 1, 1], padding='SAME') 72 73 74def _conv3d(x, w): 75 """Returns a 3d convolution layer with full stride.""" 76 return nn.conv3d(x, w, strides=[1, 1, 1, 1, 1], padding='SAME') 77 78 79def _max_pool_2x2(x): 80 """Downsamples a feature map by 2X.""" 81 return nn.max_pool( 82 x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 83 84 85def _fused_batchnorm(x, scale, offset): 86 """Batchnorm.""" 87 return nn_impl.fused_batch_norm( 88 x, scale=scale, offset=offset, is_training=True) 89 90 91def _conv_bn(x): 92 """Conv followed by batchnorm.""" 93 i = array_ops.reshape(x, [-1, 8, 8, 1]) 94 f = _weight([3, 3, 1, 6]) 95 x = _conv2d(i, f) 96 s = _weight([6]) 97 o = _weight([6]) 98 y, _, _ = _fused_batchnorm(x, s, o) 99 y = array_ops.identity(y) 100 return y 101 102 103def _conv3d_bn(x): 104 """Conv3D followed by batchnorm.""" 105 i = array_ops.reshape(x, [-1, 8, 8, 8, 1]) 106 f = _weight([3, 3, 3, 1, 6]) 107 x = _conv3d(i, f) 108 s = _weight([6]) 109 o = _weight([6]) 110 x = array_ops.reshape(x, [-1, 8, 8, 6]) 111 y, _, _ = _fused_batchnorm(x, s, o) 112 y = array_ops.identity(y) 113 return y 114 115 116def _matmul_act(x): 117 """Matmul followed by activation.""" 118 i = array_ops.reshape(x, [8, 8]) 119 f = _weight([8, 8]) 120 x = math_ops.matmul(i, f) 121 y = nn.relu(x) 122 return y 123 124 125def _conv_pool(x): 126 """(Conv -> bias -> relu -> max_pool) x2.""" 127 x_image = array_ops.reshape(x, [-1, 8, 8, 1]) 128 w_conv1 = _weight([3, 3, 1, 6]) 129 b_conv1 = _bias([6]) 130 h_conv1 = nn.relu(nn.bias_add(_conv2d(x_image, w_conv1), b_conv1)) 131 h_pool1 = _max_pool_2x2(h_conv1) 132 w_conv2 = _weight([3, 3, 6, 4]) 133 b_conv2 = _bias([4]) 134 h_conv2 = nn.relu(nn.bias_add(_conv2d(h_pool1, w_conv2), b_conv2)) 135 h_pool2 = _max_pool_2x2(h_conv2) 136 return h_pool2 137 138 139def _depthwise_conv2d(x, w): 140 """Returns a 2d depthwise convolution layer with full stride.""" 141 return nn.depthwise_conv2d(x, w, strides=[1, 1, 1, 1], padding='SAME') 142 143 144def _simple_loop(x, functor): 145 """Simple loop whose body is provided by the functor.""" 146 init = (constant_op.constant(0), x) 147 c = lambda i, j: i < 4 148 b = lambda i, j: (i + 1, functor(j)) 149 ij = control_flow_ops.while_loop(c, b, init) 150 return ij 151 152 153def _loop_vars_intertwined(x0, y0, functor_x, functor_y): 154 """Loop whose loop variables are intertwined.""" 155 c = lambda i, j, x, y: j < 4 156 b = lambda i, j, x, y: (j + 1, i + 1, functor_y(y), functor_x(x)) 157 init = (constant_op.constant(0), constant_op.constant(0), x0, y0) 158 ijzw = control_flow_ops.while_loop(c, b, init) 159 return ijzw 160 161 162def _lstm_cell(prev_c, prev_h, x): 163 """Create an LSTM cell.""" 164 # i: input gate 165 # f: forget gate 166 # o: output gate 167 # c: cell state 168 # x: input 169 # h: embedding 170 bias = _bias([4]) 171 w = _weight([8, 16]) 172 ifoc = math_ops.matmul(array_ops.concat([x, prev_h], axis=1), w) 173 i, f, o, c = array_ops.split(ifoc, 4, axis=1) 174 i = math_ops.sigmoid(nn.bias_add(i, bias)) 175 f = math_ops.sigmoid(nn.bias_add(f, bias)) 176 o = math_ops.sigmoid(nn.bias_add(o, bias)) 177 c = math_ops.tanh(nn.bias_add(c, bias)) 178 next_c = f * prev_c + i * c 179 next_h = o * math_ops.tanh(next_c) 180 return next_c, next_h 181 182 183def _recurrent_lstm(c, h): 184 """Dynamic single-layer LSTM with TensorArray.""" 185 186 def cond(i, c, h, ta_x): 187 del c, h, ta_x 188 return i < 4 189 190 def body(i, c, h, ta_x): 191 x = ta_x.read(i) 192 next_c, next_h = _lstm_cell(c, h, x) 193 return (i + 1, next_c, next_h, ta_x) 194 195 ta_x = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=4) 196 for i in range(0, 4): 197 ta_x = ta_x.write( 198 i, constant_op.constant(0.1, shape=[8, 4], dtype=dtypes.float32)) 199 init = (constant_op.constant(0), c, h, ta_x) 200 r = control_flow_ops.while_loop(cond, body, init) 201 return r 202 203 204def _make_node_with_color(color, input_tensor, name=None): 205 """Returns a node representative of the specified list type.""" 206 color = color.lower() 207 if color == 'w': # Allow node 208 weights = _weight(input_tensor.get_shape().as_list()) 209 return math_ops.matmul(input_tensor, weights, name=name) 210 if color == 'g': # Infer node 211 return math_ops.add(input_tensor, 0.1, name=name) 212 if color == 'c': # Clear node 213 return nn.relu(input_tensor, name=name) 214 if color == 'b': # Deny node 215 return math_ops.pow(math_ops.pow(input_tensor, 2.), 0.5, name=name) 216 raise ValueError('Invalid node color: ' + str(color)) 217 218 219def _build_simple_loop_graph(inp_colors, body_colors, out_colors): 220 """Builds a test graph with a simple loop.""" 221 a = _input([8, 8]) 222 for i, color in enumerate(inp_colors): 223 a = _make_node_with_color(color, a, 'input_%i' % i) 224 225 def body(x): 226 for i, color in enumerate(body_colors): 227 x = _make_node_with_color(color, x, 'body_%i' % i) 228 return x 229 230 _, a = _simple_loop(a, body) 231 for i, color in enumerate(out_colors): 232 a = _make_node_with_color(color, a, 'output_%i' % i) 233 a = array_ops.identity(a) 234 return a 235 236 237def _get_config(auto_mixed_precision_mode): 238 """Returns a ConfigProto with auto mixed precision enabled if appropriate.""" 239 rewrite_config = rewriter_config_pb2.RewriterConfig( 240 # do not remove duplicated nodes 241 arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, 242 # do not turn Conv2D and other nodes into _FusedConv2D 243 remapping=rewriter_config_pb2.RewriterConfig.OFF, 244 ) 245 if auto_mixed_precision_mode == 'cuda': 246 rewrite_config.auto_mixed_precision = rewriter_config_pb2.RewriterConfig.ON 247 elif auto_mixed_precision_mode == 'mkl': 248 rewrite_config.auto_mixed_precision_onednn_bfloat16 = ( 249 rewriter_config_pb2.RewriterConfig.ON) 250 else: 251 assert auto_mixed_precision_mode is None 252 rewrite_config.min_graph_nodes = -1 253 graph_options = config_pb2.GraphOptions( 254 rewrite_options=rewrite_config, build_cost_model=1) 255 config = config_pb2.ConfigProto(graph_options=graph_options) 256 config.graph_options.optimizer_options.opt_level = -1 257 return config 258 259 260def _get_device(auto_mixed_precision_mode): 261 """Returns the device to run on. If mode is mkl, run on CPU""" 262 if auto_mixed_precision_mode == 'mkl': 263 return '/cpu:0' 264 else: 265 return '' 266 267 268def _is_cast_to_fp16(node_name): 269 return node_name.endswith('-CastToFp16-AutoMixedPrecision') 270 271 272def _is_cast_to_bf16(node_name): 273 return node_name.endswith('-CastToBf16-AutoMixedPrecision') 274 275 276def _is_cast_to_fp32(node_name): 277 return node_name.endswith('-CastToFp32-AutoMixedPrecision') 278 279 280def _count_casts(mode, nodes): 281 """Counts the number of casts to f16 and fp32.""" 282 num_to_fp16 = 0 283 num_to_bf16 = 0 284 num_to_fp32 = 0 285 for node in nodes: 286 if _is_cast_to_fp16(node.name): 287 num_to_fp16 += 1 288 if _is_cast_to_bf16(node.name): 289 num_to_bf16 += 1 290 elif _is_cast_to_fp32(node.name): 291 num_to_fp32 += 1 292 if mode == 'cuda': 293 assert num_to_bf16 == 0 294 return num_to_fp16, num_to_fp32 295 else: 296 assert mode == 'mkl' 297 assert num_to_fp16 == 0 298 return num_to_bf16, num_to_fp32 299 300 301def _build_node_map(nodes): 302 node_map = {} 303 for node in nodes: 304 node_map[node.name] = node 305 return node_map 306 307 308def _example_noninlined_funcdef_shape(op): 309 return [op.inputs[0].shape] 310 311 312@function.Defun( 313 shape_func=_example_noninlined_funcdef_shape, 314 func_name='example_noninlined_funcdef_grad', 315 noinline=True) 316def _example_noninlined_funcdef_grad(features, grad): 317 """Gradient of Swish function defined below.""" 318 sigmoid_features = math_ops.sigmoid(features) 319 activation_grad = ( 320 sigmoid_features * (1.0 + features * (1.0 - sigmoid_features))) 321 return grad * activation_grad 322 323 324@function.Defun( 325 grad_func=_example_noninlined_funcdef_grad, 326 shape_func=_example_noninlined_funcdef_shape, 327 func_name='example_noninlined_funcdef', 328 noinline=True) 329def _example_noninlined_funcdef(features): 330 """Computes the Swish activation function: `x * sigmoid(x)`.""" 331 return features * math_ops.sigmoid(features) 332 333 334class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase): 335 """Tests the Grappler auto mixed precision optimizer.""" 336 IGNORE_PERF_VAR = 'TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_IGNORE_PERFORMANCE' 337 338 # TODO(benbarsdell): Add tests for eager mode with a tf.function. 339 340 def setUp(self): 341 super(AutoMixedPrecisionTest, self).setUp() 342 # Enable the CUDA tests to be run on pre-Volta GPUs by telling the grappler 343 # pass to ignore performance and always transform the graph. 344 self._original_ignore_perf_value = os.getenv(self.IGNORE_PERF_VAR) 345 os.environ[self.IGNORE_PERF_VAR] = '1' 346 347 def tearDown(self): 348 if self._original_ignore_perf_value is not None: 349 os.environ[self.IGNORE_PERF_VAR] = self._original_ignore_perf_value 350 else: 351 del os.environ[self.IGNORE_PERF_VAR] 352 super(AutoMixedPrecisionTest, self).tearDown() 353 354 def _lower_precision_dtype(self, mode): 355 return dtypes.float16 if mode == 'cuda' else dtypes.bfloat16 356 357 def _assert_output_f16(self, mode, node_map, node_name, output_port=0): 358 self.assertEqual(node_map[node_name].output_info[output_port].dtype, 359 self._lower_precision_dtype(mode).as_datatype_enum) 360 361 def _run(self, mode, fetches): 362 """Runs the graph and returns the evaluation of the fetches.""" 363 with session.Session(config=_get_config(None)) as sess: 364 sess.run(variables.global_variables_initializer()) 365 output_val_ref = self.evaluate(fetches) 366 367 with session.Session(config=_get_config(mode)) as sess: 368 sess.run(variables.global_variables_initializer()) 369 metadata = config_pb2.RunMetadata() 370 output_val = sess.run(fetches, run_metadata=metadata) 371 372 return output_val_ref, output_val, metadata.cost_graph 373 374 def _maybe_skip(self, mode): 375 if mode == 'cuda' and not test.is_gpu_available(cuda_only=True): 376 self.skipTest('No GPU is available') 377 if mode == 'mkl' and not test_util.IsMklEnabled(): 378 self.skipTest('MKL is not enabled') 379 # Test will fail on machines without AVX512f, e.g., Broadwell 380 isAVX512f = _pywrap_utils.IsBF16SupportedByOneDNNOnThisCPU() 381 if mode == 'mkl' and not isAVX512f: 382 self.skipTest('Skipping test due to non-AVX512f machine') 383 384 def _run_simple_loop_test(self, mode, inp, body, out): 385 """Runs a test of a simple loop. 386 387 The loop has different node colors in different sections of the graph. The 388 arguments must be strings where each character represents the color of a 389 node in that section of the graph: w = allow, g = infer, c = clear, 390 b = deny. CAPITALIZED characters indicate that the node is expected to be 391 changed to DT_HALF during graph optimization. 392 393 inp -> loop [ body ] -> out. 394 395 Args: 396 mode: Either 'cuda' or 'mkl'. 397 inp: A string of letters indicating the colors and expected dtypes of the 398 input nodes. 399 body: A string of letters indicating the colors and expected dtypes of the 400 body nodes. 401 out: A string of letters indicating the colors and expected dtypes of the 402 output nodes. 403 """ 404 self._maybe_skip(mode) 405 with ops.device(_get_device(mode)): 406 random_seed.set_random_seed(0) 407 expected_types = [] 408 for section in [inp, body, out]: 409 section_expected_types = [] 410 for color in section: 411 if color.isupper(): 412 expected_type = self._lower_precision_dtype(mode).as_datatype_enum 413 else: 414 expected_type = types_pb2.DT_FLOAT 415 section_expected_types.append(expected_type) 416 expected_types.append(section_expected_types) 417 a = _build_simple_loop_graph(inp, body, out) 418 output_val_ref, output_val, cost_graph = self._run(mode, a) 419 node_map = _build_node_map(cost_graph.node) 420 421 section_names = ['input', 'while/body', 'output'] 422 all_types_correct = True 423 for section_name, expected_types in zip(section_names, expected_types): 424 for i, expected_type in enumerate(expected_types): 425 node_name = section_name + '_%i' % i 426 output_port = 0 427 optimized_type = node_map[node_name].output_info[output_port].dtype 428 if optimized_type != expected_type: 429 print('Expected node %s to have type %s but got type %s' % 430 (node_name, expected_type, optimized_type)) 431 all_types_correct = False 432 self.assertTrue(all_types_correct) 433 if mode == 'mkl': 434 self.assertAllClose(output_val_ref, output_val, atol=2e-2, rtol=2e-2) 435 else: 436 self.assertAllClose(output_val_ref, output_val, atol=2e-3, rtol=1e-3) 437 438 @parameterized.parameters(['cuda', 'mkl']) 439 @test_util.run_deprecated_v1 440 @test_util.disable_xla('This test does not pass with XLA') 441 def test_conv_bn(self, mode): 442 """Test graph with convolution followed by batch norm.""" 443 self._maybe_skip(mode) 444 with ops.device(_get_device(mode)): 445 random_seed.set_random_seed(0) 446 x = _input([2, 8, 8, 1]) 447 x = _conv_bn(x) 448 output = _conv_bn(x) 449 450 output_val_ref, output_val, cost_graph = self._run(mode, output) 451 node_map = _build_node_map(cost_graph.node) 452 num_to_f16, num_to_fp32 = _count_casts(mode, cost_graph.node) 453 454 self._assert_output_f16(mode, node_map, 'Conv2D') 455 self._assert_output_f16(mode, node_map, 'FusedBatchNormV3') 456 self._assert_output_f16(mode, node_map, 'Conv2D_1') 457 self.assertEqual(num_to_f16, 3) # Before Conv2D:0, Conv2D:1, Conv2D_1:1 458 self.assertEqual(num_to_fp32, 1) # After FusedBatchNormV3:0 459 if mode == 'mkl': 460 tol = 1e-2 461 elif test.is_built_with_rocm(): 462 # Bump up the tolerance for the ROCm platform 463 # The default tolerance (1e-3) results in a tiny fraction (<1%) of 464 # miscompares on ROCm platform, and hence the tolerance bump 465 tol = 2e-3 466 else: 467 tol = 1e-3 468 self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol) 469 470 @parameterized.parameters(['cuda', 'mkl']) 471 @test_util.run_deprecated_v1 472 @test_util.disable_xla('This test does not pass with XLA') 473 def test_conv3d_bn(self, mode): 474 """Test graph with convolution followed by batch norm.""" 475 self._maybe_skip(mode) 476 if mode == 'cuda': 477 # TODO(reedwm): enable these tests when cuDNN is upgraded to >= 7.6.2. 478 self.skipTest('Test case should be skipped when cuDNN < 7.6.2') 479 with ops.device(_get_device(mode)): 480 random_seed.set_random_seed(0) 481 x = _input([2, 8, 8, 8, 1]) 482 x = _conv3d_bn(x) 483 output = _conv3d_bn(x) 484 485 output_val_ref, output_val, cost_graph = self._run(mode, output) 486 node_map = _build_node_map(cost_graph.node) 487 num_to_fp16, num_to_fp32 = _count_casts(mode, cost_graph.node) 488 489 self._assert_output_f16(mode, node_map, 'Conv3D') 490 self._assert_output_f16(mode, node_map, 'FusedBatchNormV3') 491 self._assert_output_f16(mode, node_map, 'Conv3D_1') 492 self.assertEqual(num_to_fp16, 3) # Before Conv3D:0, Conv3D:1, Conv3D_1:1 493 self.assertEqual(num_to_fp32, 1) # After FusedBatchNormV3:0 494 self.assertAllClose(output_val_ref, output_val, atol=1e-2, rtol=1e-2) 495 496 @parameterized.parameters(['cuda', 'mkl']) 497 @test_util.run_deprecated_v1 498 @test_util.disable_xla('This test does not pass with XLA') 499 def test_conv3d(self, mode): 500 """Test grad ops with convolution3d graph.""" 501 self._maybe_skip(mode) 502 if mode == 'cuda': 503 # TODO(reedwm): enable these tests when cuDNN is upgraded to >= 7.6.2. 504 self.skipTest('Test case should be skipped when cuDNN < 7.6.2') 505 with ops.device(_get_device(mode)): 506 random_seed.set_random_seed(0) 507 x = _input([2, 8, 8, 8, 1]) 508 f = _weight([3, 3, 3, 1, 6]) 509 y = _conv3d(x, f) 510 y = array_ops.identity(y) 511 optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01) 512 g = optimizer.compute_gradients(y, [x, f]) 513 output = (y, g) 514 515 output_val_ref, output_val, cost_graph = self._run(mode, output) 516 node_map = _build_node_map(cost_graph.node) 517 self._assert_output_f16(mode, node_map, 'Conv3D') 518 self._assert_output_f16(mode, node_map, 519 'gradients/Conv3D_grad/Conv3DBackpropInputV2') 520 self._assert_output_f16(mode, node_map, 521 'gradients/Conv3D_grad/Conv3DBackpropFilterV2') 522 523 output_val_ref, output_val, cost_graph = self._run(mode, output) 524 tol = 5e-2 if mode == 'mkl' else 1e-3 525 self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol) 526 527 @parameterized.parameters(['cuda', 'mkl']) 528 @test_util.run_deprecated_v1 529 @test_util.disable_xla('This test does not pass with XLA') 530 def test_conv_bn_dropout(self, mode): 531 """Test dropout precision of convolution batch norm graph.""" 532 self._maybe_skip(mode) 533 with ops.device(_get_device(mode)): 534 random_seed.set_random_seed(0) 535 x = _input([2, 8, 8, 1]) 536 y = _conv_bn(x) 537 y = nn.dropout(y, rate=0.5) 538 y = math_ops.add(y, 1, name='addition') 539 y = _conv_bn(y) 540 y = array_ops.identity(y) 541 optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01) 542 g = optimizer.compute_gradients(y, [x]) 543 output = (y, g) 544 545 output_val_ref, output_val, cost_graph = self._run(mode, output) 546 node_map = _build_node_map(cost_graph.node) 547 self._assert_output_f16(mode, node_map, 'Conv2D') 548 self._assert_output_f16(mode, node_map, 'FusedBatchNormV3') 549 # We do not assert dropout's dtype because we do not want to rely on the 550 # node names of dropout's internal implementation. 551 self._assert_output_f16(mode, node_map, 'addition') 552 self._assert_output_f16(mode, node_map, 'Conv2D_1') 553 554 output_val_ref, output_val, cost_graph = self._run(mode, output) 555 # Bump up the tolerance for the ROCm platform 556 # The default tolerance (1e-3) results in a tiny fraction (<1%) of 557 # miscompares on ROCm platform, and hence the tolerance bump 558 tol = 2e-3 if test.is_built_with_rocm else 1e-3 559 tol = 5e-2 if mode == 'mkl' else tol 560 self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol) 561 562 # TODO(reedwm): Fix and enable this test with MKL. Currently this crashes with 563 # MKL 564 @parameterized.parameters(['cuda']) 565 @test_util.run_deprecated_v1 566 @test_util.disable_xla('This test does not pass with XLA') 567 def test_conv_pool(self, mode): 568 """Test graph with convolution followed by pooling.""" 569 self._maybe_skip(mode) 570 with ops.device(_get_device(mode)): 571 random_seed.set_random_seed(0) 572 x = _input([2, 8, 8, 1]) 573 output = _conv_pool(x) 574 575 output_val_ref, output_val, cost_graph = self._run(mode, output) 576 node_map = _build_node_map(cost_graph.node) 577 num_to_f16, num_to_fp32 = _count_casts(mode, cost_graph.node) 578 579 self._assert_output_f16(mode, node_map, 'Conv2D') 580 self._assert_output_f16(mode, node_map, 'Relu') 581 self._assert_output_f16(mode, node_map, 'MaxPool') 582 self._assert_output_f16(mode, node_map, 'Conv2D_1') 583 self.assertEqual(num_to_f16, 4) 584 self.assertEqual(num_to_fp32, 1) 585 tol = 5e-3 if mode == 'mkl' else 1e-3 586 self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol) 587 588 # TODO(benbarsdell): This test has not been tried with MKL. 589 @parameterized.parameters(['cuda']) 590 @test_util.run_deprecated_v1 591 @test_util.disable_xla('This test does not pass with XLA') 592 def test_depthwise_conv2d(self, mode): 593 """Test grad ops with depthwise convolution2d graph.""" 594 self._maybe_skip(mode) 595 cudnn_version_str = sysconfig.get_build_info().get('cudnn_version', '0.0') 596 cudnn_version = tuple([int(x) for x in cudnn_version_str.split('.')]) 597 if cudnn_version < (8,): 598 # Depthwise conv2d ops are only enabled in auto_mixed_precision as of 599 # cuDNN v8. 600 self.skipTest('cuDNN version >= 8 required') 601 with ops.device(_get_device(mode)): 602 random_seed.set_random_seed(0) 603 x = _input([2, 8, 8, 1]) 604 f = _weight([3, 3, 1, 4]) 605 y = _depthwise_conv2d(x, f) 606 y = array_ops.identity(y) 607 optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01) 608 g = optimizer.compute_gradients(y, [x, f]) 609 output = (y, g) 610 611 output_val_ref, output_val, cost_graph = self._run(mode, output) 612 node_map = _build_node_map(cost_graph.node) 613 self._assert_output_f16(mode, node_map, 'depthwise') 614 self._assert_output_f16( 615 mode, node_map, 616 'gradients/depthwise_grad/DepthwiseConv2dNativeBackpropInput') 617 self._assert_output_f16( 618 mode, node_map, 619 'gradients/depthwise_grad/DepthwiseConv2dNativeBackpropFilter') 620 621 output_val_ref, output_val, cost_graph = self._run(mode, output) 622 tol = 2e-3 623 self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol) 624 625 @parameterized.parameters(['cuda', 'mkl']) 626 @test_util.run_v1_only('b/138749235') 627 @test_util.disable_xla('This test does not pass with XLA') 628 def test_simple_loop(self, mode): 629 """Test graph with while loop.""" 630 self._maybe_skip(mode) 631 with ops.device(_get_device(mode)): 632 random_seed.set_random_seed(0) 633 x = _input([8, 8]) 634 y = _simple_loop(x, _matmul_act)[1] 635 optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01) 636 g = optimizer.compute_gradients(y, [x]) 637 output = (y, g) 638 639 output_val_ref, output_val, cost_graph = self._run(mode, output) 640 node_map = _build_node_map(cost_graph.node) 641 642 self._assert_output_f16(mode, node_map, 'while/MatMul') 643 self._assert_output_f16(mode, node_map, 'while/Relu') 644 tol = 1e-2 if mode == 'mkl' else 1e-3 645 self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol) 646 647 @parameterized.parameters(['cuda', 'mkl']) 648 @test_util.run_v1_only('b/138749235') 649 @test_util.disable_xla('This test does not pass with XLA') 650 def test_loop_with_vars_intertwined(self, mode): 651 """Test graph with intertwined while loops.""" 652 self._maybe_skip(mode) 653 with ops.device(_get_device(mode)): 654 random_seed.set_random_seed(0) 655 x = _input([8, 8]) 656 _, _, k, l = _loop_vars_intertwined( 657 array_ops.ones(array_ops.shape(x)), x, _matmul_act, _matmul_act) 658 optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01) 659 g = optimizer.compute_gradients(k, [x]) 660 output = (k, l, g) 661 662 output_val_ref, output_val, cost_graph = self._run(mode, output) 663 node_map = _build_node_map(cost_graph.node) 664 665 self._assert_output_f16(mode, node_map, 'while/MatMul') 666 self._assert_output_f16(mode, node_map, 'while/Relu') 667 self._assert_output_f16(mode, node_map, 'while/MatMul_1') 668 self._assert_output_f16(mode, node_map, 'while/Relu_1') 669 tol = 5e-3 if mode == 'mkl' else 1e-3 670 self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol) 671 672 @parameterized.parameters(['cuda']) 673 @test_util.run_deprecated_v1 674 @test_util.disable_xla('This test does not pass with XLA') 675 def test_multi_paths(self, mode): 676 """Test graph with multiple paths.""" 677 self._maybe_skip(mode) 678 with ops.device(_get_device(mode)): 679 random_seed.set_random_seed(0) 680 x = _input([2, 8, 8, 3]) 681 x1, x2, x3 = array_ops.split(x, num_or_size_splits=3, axis=3) 682 y1 = _conv_pool(x1) 683 y2 = _conv_pool(x2) 684 y3 = _conv_pool(x3) 685 y = array_ops.concat([y1, y2, y3], axis=3) 686 y = array_ops.identity(y) 687 optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01) 688 g = optimizer.compute_gradients(y, [x]) 689 output = (y, g) 690 691 output_val_ref, output_val, cost_graph = self._run(mode, output) 692 node_map = _build_node_map(cost_graph.node) 693 694 self._assert_output_f16(mode, node_map, 'split') 695 for suffix in [''] + ['_%i' % i for i in range(1, 6)]: 696 self._assert_output_f16(mode, node_map, 'Conv2D' + suffix) 697 self._assert_output_f16(mode, node_map, 'Relu' + suffix) 698 self._assert_output_f16(mode, node_map, 'MaxPool' + suffix) 699 self._assert_output_f16(mode, node_map, 'concat') 700 atol = 1e-2 if test.is_built_with_rocm() else 1e-3 701 self.assertAllClose(output_val_ref, output_val, atol=atol, rtol=1e-3) 702 703 @parameterized.parameters(['cuda', 'mkl']) 704 @test_util.run_deprecated_v1 705 @test_util.disable_xla('This test does not pass with XLA') 706 def test_multi_paths_2(self, mode): 707 """Test graph with multiple paths.""" 708 self._maybe_skip(mode) 709 with ops.device(_get_device(mode)): 710 random_seed.set_random_seed(0) 711 x = _input([8, 8]) 712 y1 = _matmul_act(x) 713 y2 = _matmul_act(x) 714 y = y1 + y2 + x 715 optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01) 716 g = optimizer.compute_gradients(y, [x]) 717 output = (g, y) 718 719 output_val_ref, output_val, cost_graph = self._run(mode, output) 720 node_map = _build_node_map(cost_graph.node) 721 722 self._assert_output_f16(mode, node_map, 'MatMul') 723 self._assert_output_f16(mode, node_map, 'Relu') 724 self._assert_output_f16(mode, node_map, 'MatMul_1') 725 self._assert_output_f16(mode, node_map, 'Relu_1') 726 if mode == 'mkl': 727 tol = 2e-2 728 elif test.is_built_with_rocm(): 729 # Bump up the tolerance for the ROCm platform 730 # The default tolerance (1e-3) results in a tiny fraction (<1%) of 731 # miscompares on ROCm platform, and hence the tolerance bump 732 tol = 1e-2 733 else: 734 tol = 1e-3 735 self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol) 736 737 @parameterized.parameters(['cuda']) # MKL doesn't support bf16 Sigmoid 738 @test_util.run_v1_only('b/138749235') 739 @test_util.disable_xla('This test does not pass with XLA') 740 def test_recurrent_lstm(self, mode): 741 """Test graph with recurrent lstm.""" 742 self._maybe_skip(mode) 743 with ops.device(_get_device(mode)): 744 random_seed.set_random_seed(0) 745 init_c = _input([8, 4]) 746 init_h = _input([8, 4]) 747 _, _, h, _ = _recurrent_lstm(init_c, init_h) 748 optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01) 749 g = optimizer.compute_gradients(h, [init_c, init_h]) 750 output = (h, g) 751 752 output_val_ref, output_val, cost_graph = self._run(mode, output) 753 node_map = _build_node_map(cost_graph.node) 754 755 self._assert_output_f16(mode, node_map, 'while/concat') 756 self._assert_output_f16(mode, node_map, 'while/MatMul') 757 self._assert_output_f16(mode, node_map, 'while/split') 758 self._assert_output_f16(mode, node_map, 'while/Sigmoid') 759 self._assert_output_f16(mode, node_map, 'while/Sigmoid_1') 760 self._assert_output_f16(mode, node_map, 'while/Sigmoid_2') 761 self._assert_output_f16(mode, node_map, 'while/Tanh') 762 self._assert_output_f16(mode, node_map, 'while/Tanh_1') 763 self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3) 764 765 @parameterized.parameters(['cuda', 'mkl']) 766 @test_util.run_v1_only('v1 loop test') 767 @test_util.disable_xla('This test does not pass with XLA') 768 def test_propagation_through_simple_loop_1(self, mode): 769 self._run_simple_loop_test(mode, 'W', 'C', 'C') 770 771 @parameterized.parameters(['cuda', 'mkl']) 772 @test_util.run_v1_only('v1 loop test') 773 @test_util.disable_xla('This test does not pass with XLA') 774 def test_propagation_through_simple_loop_2(self, mode): 775 self._run_simple_loop_test(mode, 'C', 'C', 'W') 776 777 @parameterized.parameters(['cuda', 'mkl']) 778 @test_util.run_v1_only('v1 loop test') 779 @test_util.disable_xla('This test does not pass with XLA') 780 def test_propagation_through_simple_loop_3(self, mode): 781 self._run_simple_loop_test(mode, 'W', 'G', 'W') 782 783 @parameterized.parameters(['cuda', 'mkl']) 784 @test_util.run_v1_only('v1 loop test') 785 @test_util.disable_xla('This test does not pass with XLA') 786 def test_propagation_through_simple_loop_4(self, mode): 787 self._run_simple_loop_test(mode, 'W', 'gbg', 'W') 788 789 @parameterized.parameters(['cuda', 'mkl']) 790 @test_util.run_v1_only('b/138749235') 791 @test_util.disable_xla('This test does not pass with XLA') 792 def test_propagation_through_simple_loop_5(self, mode): 793 self._run_simple_loop_test(mode, 'b', 'gWC', 'c') 794 795 @parameterized.parameters(['cuda', 'mkl']) 796 @test_util.run_v1_only('b/138749235') 797 @test_util.disable_xla('This test does not pass with XLA') 798 def test_propagation_through_simple_loop_6(self, mode): 799 self._run_simple_loop_test(mode, 'b', 'CWCG', 'C') 800 801 @parameterized.parameters(['cuda', 'mkl']) 802 @test_util.run_v1_only('b/138749235') 803 @test_util.disable_xla('This test does not pass with XLA') 804 def test_propagation_through_simple_loop_7(self, mode): 805 self._run_simple_loop_test(mode, 'C', 'GWCG', 'C') 806 807 @parameterized.parameters(['cuda', 'mkl']) 808 @test_util.run_v1_only('b/138749235') 809 @test_util.disable_xla('This test does not pass with XLA') 810 def test_propagation_through_simple_loop_8(self, mode): 811 self._run_simple_loop_test(mode, 'C', 'CgbgWC', 'g') 812 813 @parameterized.parameters(['cuda', 'mkl']) 814 @test_util.run_deprecated_v1 815 @test_util.disable_xla('This test does not pass with XLA') 816 def test_noninlined_funcdef(self, mode): 817 """Test graph with non-inlined function subgraph. 818 819 This requires the grappler pass to handle an OpDef that only appears in the 820 graph's function registry instead of the global op registry. 821 822 Args: 823 mode: Either 'cuda' or 'mkl'. 824 """ 825 self._maybe_skip(mode) 826 with ops.device(_get_device(mode)): 827 random_seed.set_random_seed(0) 828 x = _input([8, 8]) 829 y = _matmul_act(x) 830 y = _example_noninlined_funcdef(y) 831 optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01) 832 g = optimizer.compute_gradients(y, [x]) 833 output = (g, y) 834 835 output_val_ref, output_val, cost_graph = self._run(mode, output) 836 node_map = _build_node_map(cost_graph.node) 837 838 self._assert_output_f16(mode, node_map, 'MatMul') 839 tol = 1e-2 if mode == 'mkl' else 1e-3 840 atol = 1e-2 if test.is_built_with_rocm() else tol 841 self.assertAllClose(output_val_ref, output_val, atol=atol, rtol=tol) 842 843 @parameterized.parameters(['cuda', 'mkl']) 844 @test_util.run_deprecated_v1 845 @test_util.disable_xla('This test does not pass with XLA') 846 def test_ingraph_train_loop(self, mode): 847 """Tests a graph containing a while loop around a training update. 848 849 This requires the grappler pass to take special care with its handling of 850 Enter ops that appear in front of reads from non-resource variables. See 851 the use of NodeImplicitlyReadsVariable in auto_mixed_precision.cc. 852 853 Args: 854 mode: Either 'cuda' or 'mkl'. 855 """ 856 self._maybe_skip(mode) 857 if tf2.enabled(): 858 # This test tests non-resource variables, which are only used in TF1. 859 self.skipTest('TensorFlow 1 required') 860 with ops.device(_get_device(mode)): 861 random_seed.set_random_seed(1234) 862 np.random.seed(1234) 863 num_iter, bs, nchan, nclass = 100, 64, 32, 100 864 865 data = np.random.normal(size=(bs * num_iter, nchan)).astype(np.float32) 866 labels = np.random.randint(nclass, size=(bs * num_iter,)) 867 ds = dataset_ops.Dataset.from_tensor_slices((data, labels)) 868 ds = ds.batch(bs).prefetch(3) 869 it = ds.make_one_shot_iterator() 870 871 def body(_, i): 872 i += 1 873 x, yt = it.get_next() 874 dense = layers.Dense(nclass) 875 y = dense(x) 876 loss = losses.sparse_softmax_cross_entropy(yt, y) 877 opt = adam.AdamOptimizer() 878 train_op = opt.minimize(loss, var_list=dense.trainable_weights) 879 with ops.control_dependencies([train_op]): 880 loss = array_ops.identity(loss) 881 return loss, i 882 883 begin, end = constant_op.constant(0), constant_op.constant(num_iter) 884 loss, _ = control_flow_ops.while_loop( 885 lambda loss, i: math_ops.less(i, end), body, [0.0, begin]) 886 887 output_val_ref, output_val, cost_graph = self._run(mode, loss) 888 node_map = _build_node_map(cost_graph.node) 889 890 self._assert_output_f16(mode, node_map, 'while/dense/MatMul') 891 self._assert_output_f16(mode, node_map, 892 'while/gradients/while/dense/MatMul_grad/MatMul_1') 893 self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3) 894 895 # TODO(benbarsdell): Add tests for list ops (TensorList*) that pass through 896 # graph source/sink nodes, similar to the TensorListThroughFunction C++ test. 897 # Tests here will have the advantage of catching changes in the types of ops 898 # that are added to the graph. 899 900 901if __name__ == '__main__': 902 test.main() 903