xref: /aosp_15_r20/external/tensorflow/tensorflow/python/grappler/auto_mixed_precision_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Grappler AutoMixedPrecision."""
16
17import os
18
19from absl.testing import parameterized
20import numpy as np
21
22from tensorflow.core.framework import types_pb2
23from tensorflow.core.protobuf import config_pb2
24from tensorflow.core.protobuf import rewriter_config_pb2
25from tensorflow.python import tf2
26from tensorflow.python.client import session
27from tensorflow.python.data.ops import dataset_ops
28from tensorflow.python.framework import constant_op
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import function
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import random_seed
33from tensorflow.python.framework import test_util
34from tensorflow.python.layers import layers
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import control_flow_ops
37from tensorflow.python.ops import init_ops
38from tensorflow.python.ops import math_ops
39from tensorflow.python.ops import nn
40from tensorflow.python.ops import nn_impl
41from tensorflow.python.ops import random_ops
42from tensorflow.python.ops import tensor_array_ops
43from tensorflow.python.ops import variables
44from tensorflow.python.ops.losses import losses
45from tensorflow.python.platform import sysconfig
46from tensorflow.python.platform import test
47from tensorflow.python.training import adam
48from tensorflow.python.training import gradient_descent
49from tensorflow.python.util import _pywrap_utils
50
51
52def _input(shape):
53  """Generates an input of a given shape."""
54  return variables.Variable(random_ops.truncated_normal(shape, seed=0))
55
56
57def _weight(shape):
58  """Generates a weight of a given shape."""
59  # Note that the lambda is needed to allow construction inside loops.
60  return variables.Variable(lambda: init_ops.glorot_uniform_initializer(seed=0)
61                            (shape))
62
63
64def _bias(shape):
65  """Generates a bias of a given shape."""
66  return constant_op.constant(0.1, shape=shape)
67
68
69def _conv2d(x, w):
70  """Returns a 2d convolution layer with full stride."""
71  return nn.conv2d(x, w, strides=[1, 1, 1, 1], padding='SAME')
72
73
74def _conv3d(x, w):
75  """Returns a 3d convolution layer with full stride."""
76  return nn.conv3d(x, w, strides=[1, 1, 1, 1, 1], padding='SAME')
77
78
79def _max_pool_2x2(x):
80  """Downsamples a feature map by 2X."""
81  return nn.max_pool(
82      x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
83
84
85def _fused_batchnorm(x, scale, offset):
86  """Batchnorm."""
87  return nn_impl.fused_batch_norm(
88      x, scale=scale, offset=offset, is_training=True)
89
90
91def _conv_bn(x):
92  """Conv followed by batchnorm."""
93  i = array_ops.reshape(x, [-1, 8, 8, 1])
94  f = _weight([3, 3, 1, 6])
95  x = _conv2d(i, f)
96  s = _weight([6])
97  o = _weight([6])
98  y, _, _ = _fused_batchnorm(x, s, o)
99  y = array_ops.identity(y)
100  return y
101
102
103def _conv3d_bn(x):
104  """Conv3D followed by batchnorm."""
105  i = array_ops.reshape(x, [-1, 8, 8, 8, 1])
106  f = _weight([3, 3, 3, 1, 6])
107  x = _conv3d(i, f)
108  s = _weight([6])
109  o = _weight([6])
110  x = array_ops.reshape(x, [-1, 8, 8, 6])
111  y, _, _ = _fused_batchnorm(x, s, o)
112  y = array_ops.identity(y)
113  return y
114
115
116def _matmul_act(x):
117  """Matmul followed by activation."""
118  i = array_ops.reshape(x, [8, 8])
119  f = _weight([8, 8])
120  x = math_ops.matmul(i, f)
121  y = nn.relu(x)
122  return y
123
124
125def _conv_pool(x):
126  """(Conv -> bias -> relu -> max_pool) x2."""
127  x_image = array_ops.reshape(x, [-1, 8, 8, 1])
128  w_conv1 = _weight([3, 3, 1, 6])
129  b_conv1 = _bias([6])
130  h_conv1 = nn.relu(nn.bias_add(_conv2d(x_image, w_conv1), b_conv1))
131  h_pool1 = _max_pool_2x2(h_conv1)
132  w_conv2 = _weight([3, 3, 6, 4])
133  b_conv2 = _bias([4])
134  h_conv2 = nn.relu(nn.bias_add(_conv2d(h_pool1, w_conv2), b_conv2))
135  h_pool2 = _max_pool_2x2(h_conv2)
136  return h_pool2
137
138
139def _depthwise_conv2d(x, w):
140  """Returns a 2d depthwise convolution layer with full stride."""
141  return nn.depthwise_conv2d(x, w, strides=[1, 1, 1, 1], padding='SAME')
142
143
144def _simple_loop(x, functor):
145  """Simple loop whose body is provided by the functor."""
146  init = (constant_op.constant(0), x)
147  c = lambda i, j: i < 4
148  b = lambda i, j: (i + 1, functor(j))
149  ij = control_flow_ops.while_loop(c, b, init)
150  return ij
151
152
153def _loop_vars_intertwined(x0, y0, functor_x, functor_y):
154  """Loop whose loop variables are intertwined."""
155  c = lambda i, j, x, y: j < 4
156  b = lambda i, j, x, y: (j + 1, i + 1, functor_y(y), functor_x(x))
157  init = (constant_op.constant(0), constant_op.constant(0), x0, y0)
158  ijzw = control_flow_ops.while_loop(c, b, init)
159  return ijzw
160
161
162def _lstm_cell(prev_c, prev_h, x):
163  """Create an LSTM cell."""
164  # i: input gate
165  # f: forget gate
166  # o: output gate
167  # c: cell state
168  # x: input
169  # h: embedding
170  bias = _bias([4])
171  w = _weight([8, 16])
172  ifoc = math_ops.matmul(array_ops.concat([x, prev_h], axis=1), w)
173  i, f, o, c = array_ops.split(ifoc, 4, axis=1)
174  i = math_ops.sigmoid(nn.bias_add(i, bias))
175  f = math_ops.sigmoid(nn.bias_add(f, bias))
176  o = math_ops.sigmoid(nn.bias_add(o, bias))
177  c = math_ops.tanh(nn.bias_add(c, bias))
178  next_c = f * prev_c + i * c
179  next_h = o * math_ops.tanh(next_c)
180  return next_c, next_h
181
182
183def _recurrent_lstm(c, h):
184  """Dynamic single-layer LSTM with TensorArray."""
185
186  def cond(i, c, h, ta_x):
187    del c, h, ta_x
188    return i < 4
189
190  def body(i, c, h, ta_x):
191    x = ta_x.read(i)
192    next_c, next_h = _lstm_cell(c, h, x)
193    return (i + 1, next_c, next_h, ta_x)
194
195  ta_x = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=4)
196  for i in range(0, 4):
197    ta_x = ta_x.write(
198        i, constant_op.constant(0.1, shape=[8, 4], dtype=dtypes.float32))
199  init = (constant_op.constant(0), c, h, ta_x)
200  r = control_flow_ops.while_loop(cond, body, init)
201  return r
202
203
204def _make_node_with_color(color, input_tensor, name=None):
205  """Returns a node representative of the specified list type."""
206  color = color.lower()
207  if color == 'w':  # Allow node
208    weights = _weight(input_tensor.get_shape().as_list())
209    return math_ops.matmul(input_tensor, weights, name=name)
210  if color == 'g':  # Infer node
211    return math_ops.add(input_tensor, 0.1, name=name)
212  if color == 'c':  # Clear node
213    return nn.relu(input_tensor, name=name)
214  if color == 'b':  # Deny node
215    return math_ops.pow(math_ops.pow(input_tensor, 2.), 0.5, name=name)
216  raise ValueError('Invalid node color: ' + str(color))
217
218
219def _build_simple_loop_graph(inp_colors, body_colors, out_colors):
220  """Builds a test graph with a simple loop."""
221  a = _input([8, 8])
222  for i, color in enumerate(inp_colors):
223    a = _make_node_with_color(color, a, 'input_%i' % i)
224
225  def body(x):
226    for i, color in enumerate(body_colors):
227      x = _make_node_with_color(color, x, 'body_%i' % i)
228    return x
229
230  _, a = _simple_loop(a, body)
231  for i, color in enumerate(out_colors):
232    a = _make_node_with_color(color, a, 'output_%i' % i)
233  a = array_ops.identity(a)
234  return a
235
236
237def _get_config(auto_mixed_precision_mode):
238  """Returns a ConfigProto with auto mixed precision enabled if appropriate."""
239  rewrite_config = rewriter_config_pb2.RewriterConfig(
240      # do not remove duplicated nodes
241      arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
242      # do not turn Conv2D and other nodes into _FusedConv2D
243      remapping=rewriter_config_pb2.RewriterConfig.OFF,
244  )
245  if auto_mixed_precision_mode == 'cuda':
246    rewrite_config.auto_mixed_precision = rewriter_config_pb2.RewriterConfig.ON
247  elif auto_mixed_precision_mode == 'mkl':
248    rewrite_config.auto_mixed_precision_onednn_bfloat16 = (
249        rewriter_config_pb2.RewriterConfig.ON)
250  else:
251    assert auto_mixed_precision_mode is None
252  rewrite_config.min_graph_nodes = -1
253  graph_options = config_pb2.GraphOptions(
254      rewrite_options=rewrite_config, build_cost_model=1)
255  config = config_pb2.ConfigProto(graph_options=graph_options)
256  config.graph_options.optimizer_options.opt_level = -1
257  return config
258
259
260def _get_device(auto_mixed_precision_mode):
261  """Returns the device to run on. If mode is mkl, run on CPU"""
262  if auto_mixed_precision_mode == 'mkl':
263    return '/cpu:0'
264  else:
265    return ''
266
267
268def _is_cast_to_fp16(node_name):
269  return node_name.endswith('-CastToFp16-AutoMixedPrecision')
270
271
272def _is_cast_to_bf16(node_name):
273  return node_name.endswith('-CastToBf16-AutoMixedPrecision')
274
275
276def _is_cast_to_fp32(node_name):
277  return node_name.endswith('-CastToFp32-AutoMixedPrecision')
278
279
280def _count_casts(mode, nodes):
281  """Counts the number of casts to f16 and fp32."""
282  num_to_fp16 = 0
283  num_to_bf16 = 0
284  num_to_fp32 = 0
285  for node in nodes:
286    if _is_cast_to_fp16(node.name):
287      num_to_fp16 += 1
288    if _is_cast_to_bf16(node.name):
289      num_to_bf16 += 1
290    elif _is_cast_to_fp32(node.name):
291      num_to_fp32 += 1
292  if mode == 'cuda':
293    assert num_to_bf16 == 0
294    return num_to_fp16, num_to_fp32
295  else:
296    assert mode == 'mkl'
297    assert num_to_fp16 == 0
298    return num_to_bf16, num_to_fp32
299
300
301def _build_node_map(nodes):
302  node_map = {}
303  for node in nodes:
304    node_map[node.name] = node
305  return node_map
306
307
308def _example_noninlined_funcdef_shape(op):
309  return [op.inputs[0].shape]
310
311
312@function.Defun(
313    shape_func=_example_noninlined_funcdef_shape,
314    func_name='example_noninlined_funcdef_grad',
315    noinline=True)
316def _example_noninlined_funcdef_grad(features, grad):
317  """Gradient of Swish function defined below."""
318  sigmoid_features = math_ops.sigmoid(features)
319  activation_grad = (
320      sigmoid_features * (1.0 + features * (1.0 - sigmoid_features)))
321  return grad * activation_grad
322
323
324@function.Defun(
325    grad_func=_example_noninlined_funcdef_grad,
326    shape_func=_example_noninlined_funcdef_shape,
327    func_name='example_noninlined_funcdef',
328    noinline=True)
329def _example_noninlined_funcdef(features):
330  """Computes the Swish activation function: `x * sigmoid(x)`."""
331  return features * math_ops.sigmoid(features)
332
333
334class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase):
335  """Tests the Grappler auto mixed precision optimizer."""
336  IGNORE_PERF_VAR = 'TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_IGNORE_PERFORMANCE'
337
338  # TODO(benbarsdell): Add tests for eager mode with a tf.function.
339
340  def setUp(self):
341    super(AutoMixedPrecisionTest, self).setUp()
342    # Enable the CUDA tests to be run on pre-Volta GPUs by telling the grappler
343    # pass to ignore performance and always transform the graph.
344    self._original_ignore_perf_value = os.getenv(self.IGNORE_PERF_VAR)
345    os.environ[self.IGNORE_PERF_VAR] = '1'
346
347  def tearDown(self):
348    if self._original_ignore_perf_value is not None:
349      os.environ[self.IGNORE_PERF_VAR] = self._original_ignore_perf_value
350    else:
351      del os.environ[self.IGNORE_PERF_VAR]
352    super(AutoMixedPrecisionTest, self).tearDown()
353
354  def _lower_precision_dtype(self, mode):
355    return dtypes.float16 if mode == 'cuda' else dtypes.bfloat16
356
357  def _assert_output_f16(self, mode, node_map, node_name, output_port=0):
358    self.assertEqual(node_map[node_name].output_info[output_port].dtype,
359                     self._lower_precision_dtype(mode).as_datatype_enum)
360
361  def _run(self, mode, fetches):
362    """Runs the graph and returns the evaluation of the fetches."""
363    with session.Session(config=_get_config(None)) as sess:
364      sess.run(variables.global_variables_initializer())
365      output_val_ref = self.evaluate(fetches)
366
367    with session.Session(config=_get_config(mode)) as sess:
368      sess.run(variables.global_variables_initializer())
369      metadata = config_pb2.RunMetadata()
370      output_val = sess.run(fetches, run_metadata=metadata)
371
372    return output_val_ref, output_val, metadata.cost_graph
373
374  def _maybe_skip(self, mode):
375    if mode == 'cuda' and not test.is_gpu_available(cuda_only=True):
376      self.skipTest('No GPU is available')
377    if mode == 'mkl' and not test_util.IsMklEnabled():
378      self.skipTest('MKL is not enabled')
379    # Test will fail on machines without AVX512f, e.g., Broadwell
380    isAVX512f = _pywrap_utils.IsBF16SupportedByOneDNNOnThisCPU()
381    if mode == 'mkl' and not isAVX512f:
382      self.skipTest('Skipping test due to non-AVX512f machine')
383
384  def _run_simple_loop_test(self, mode, inp, body, out):
385    """Runs a test of a simple loop.
386
387    The loop has different node colors in different sections of the graph. The
388    arguments must be strings where each character represents the color of a
389    node in that section of the graph: w = allow, g = infer, c = clear,
390    b = deny. CAPITALIZED characters indicate that the node is expected to be
391    changed to DT_HALF during graph optimization.
392
393    inp -> loop [ body ] -> out.
394
395    Args:
396      mode: Either 'cuda' or 'mkl'.
397      inp: A string of letters indicating the colors and expected dtypes of the
398        input nodes.
399      body: A string of letters indicating the colors and expected dtypes of the
400        body nodes.
401      out: A string of letters indicating the colors and expected dtypes of the
402        output nodes.
403    """
404    self._maybe_skip(mode)
405    with ops.device(_get_device(mode)):
406      random_seed.set_random_seed(0)
407      expected_types = []
408      for section in [inp, body, out]:
409        section_expected_types = []
410        for color in section:
411          if color.isupper():
412            expected_type = self._lower_precision_dtype(mode).as_datatype_enum
413          else:
414            expected_type = types_pb2.DT_FLOAT
415          section_expected_types.append(expected_type)
416        expected_types.append(section_expected_types)
417      a = _build_simple_loop_graph(inp, body, out)
418    output_val_ref, output_val, cost_graph = self._run(mode, a)
419    node_map = _build_node_map(cost_graph.node)
420
421    section_names = ['input', 'while/body', 'output']
422    all_types_correct = True
423    for section_name, expected_types in zip(section_names, expected_types):
424      for i, expected_type in enumerate(expected_types):
425        node_name = section_name + '_%i' % i
426        output_port = 0
427        optimized_type = node_map[node_name].output_info[output_port].dtype
428        if optimized_type != expected_type:
429          print('Expected node %s to have type %s but got type %s' %
430                (node_name, expected_type, optimized_type))
431          all_types_correct = False
432    self.assertTrue(all_types_correct)
433    if mode == 'mkl':
434      self.assertAllClose(output_val_ref, output_val, atol=2e-2, rtol=2e-2)
435    else:
436      self.assertAllClose(output_val_ref, output_val, atol=2e-3, rtol=1e-3)
437
438  @parameterized.parameters(['cuda', 'mkl'])
439  @test_util.run_deprecated_v1
440  @test_util.disable_xla('This test does not pass with XLA')
441  def test_conv_bn(self, mode):
442    """Test graph with convolution followed by batch norm."""
443    self._maybe_skip(mode)
444    with ops.device(_get_device(mode)):
445      random_seed.set_random_seed(0)
446      x = _input([2, 8, 8, 1])
447      x = _conv_bn(x)
448      output = _conv_bn(x)
449
450    output_val_ref, output_val, cost_graph = self._run(mode, output)
451    node_map = _build_node_map(cost_graph.node)
452    num_to_f16, num_to_fp32 = _count_casts(mode, cost_graph.node)
453
454    self._assert_output_f16(mode, node_map, 'Conv2D')
455    self._assert_output_f16(mode, node_map, 'FusedBatchNormV3')
456    self._assert_output_f16(mode, node_map, 'Conv2D_1')
457    self.assertEqual(num_to_f16, 3)  # Before Conv2D:0, Conv2D:1, Conv2D_1:1
458    self.assertEqual(num_to_fp32, 1)  # After FusedBatchNormV3:0
459    if mode == 'mkl':
460      tol = 1e-2
461    elif test.is_built_with_rocm():
462      # Bump up the tolerance for the ROCm platform
463      # The default tolerance (1e-3) results in a tiny fraction (<1%) of
464      # miscompares on ROCm platform, and hence the tolerance bump
465      tol = 2e-3
466    else:
467      tol = 1e-3
468    self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
469
470  @parameterized.parameters(['cuda', 'mkl'])
471  @test_util.run_deprecated_v1
472  @test_util.disable_xla('This test does not pass with XLA')
473  def test_conv3d_bn(self, mode):
474    """Test graph with convolution followed by batch norm."""
475    self._maybe_skip(mode)
476    if mode == 'cuda':
477      # TODO(reedwm): enable these tests when cuDNN is upgraded to >= 7.6.2.
478      self.skipTest('Test case should be skipped when cuDNN < 7.6.2')
479    with ops.device(_get_device(mode)):
480      random_seed.set_random_seed(0)
481      x = _input([2, 8, 8, 8, 1])
482      x = _conv3d_bn(x)
483      output = _conv3d_bn(x)
484
485    output_val_ref, output_val, cost_graph = self._run(mode, output)
486    node_map = _build_node_map(cost_graph.node)
487    num_to_fp16, num_to_fp32 = _count_casts(mode, cost_graph.node)
488
489    self._assert_output_f16(mode, node_map, 'Conv3D')
490    self._assert_output_f16(mode, node_map, 'FusedBatchNormV3')
491    self._assert_output_f16(mode, node_map, 'Conv3D_1')
492    self.assertEqual(num_to_fp16, 3)  # Before Conv3D:0, Conv3D:1, Conv3D_1:1
493    self.assertEqual(num_to_fp32, 1)  # After FusedBatchNormV3:0
494    self.assertAllClose(output_val_ref, output_val, atol=1e-2, rtol=1e-2)
495
496  @parameterized.parameters(['cuda', 'mkl'])
497  @test_util.run_deprecated_v1
498  @test_util.disable_xla('This test does not pass with XLA')
499  def test_conv3d(self, mode):
500    """Test grad ops with convolution3d graph."""
501    self._maybe_skip(mode)
502    if mode == 'cuda':
503      # TODO(reedwm): enable these tests when cuDNN is upgraded to >= 7.6.2.
504      self.skipTest('Test case should be skipped when cuDNN < 7.6.2')
505    with ops.device(_get_device(mode)):
506      random_seed.set_random_seed(0)
507      x = _input([2, 8, 8, 8, 1])
508      f = _weight([3, 3, 3, 1, 6])
509      y = _conv3d(x, f)
510      y = array_ops.identity(y)
511      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
512      g = optimizer.compute_gradients(y, [x, f])
513      output = (y, g)
514
515    output_val_ref, output_val, cost_graph = self._run(mode, output)
516    node_map = _build_node_map(cost_graph.node)
517    self._assert_output_f16(mode, node_map, 'Conv3D')
518    self._assert_output_f16(mode, node_map,
519                            'gradients/Conv3D_grad/Conv3DBackpropInputV2')
520    self._assert_output_f16(mode, node_map,
521                            'gradients/Conv3D_grad/Conv3DBackpropFilterV2')
522
523    output_val_ref, output_val, cost_graph = self._run(mode, output)
524    tol = 5e-2 if mode == 'mkl' else 1e-3
525    self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
526
527  @parameterized.parameters(['cuda', 'mkl'])
528  @test_util.run_deprecated_v1
529  @test_util.disable_xla('This test does not pass with XLA')
530  def test_conv_bn_dropout(self, mode):
531    """Test dropout precision of convolution batch norm graph."""
532    self._maybe_skip(mode)
533    with ops.device(_get_device(mode)):
534      random_seed.set_random_seed(0)
535      x = _input([2, 8, 8, 1])
536      y = _conv_bn(x)
537      y = nn.dropout(y, rate=0.5)
538      y = math_ops.add(y, 1, name='addition')
539      y = _conv_bn(y)
540      y = array_ops.identity(y)
541      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
542      g = optimizer.compute_gradients(y, [x])
543      output = (y, g)
544
545    output_val_ref, output_val, cost_graph = self._run(mode, output)
546    node_map = _build_node_map(cost_graph.node)
547    self._assert_output_f16(mode, node_map, 'Conv2D')
548    self._assert_output_f16(mode, node_map, 'FusedBatchNormV3')
549    # We do not assert dropout's dtype because we do not want to rely on the
550    # node names of dropout's internal implementation.
551    self._assert_output_f16(mode, node_map, 'addition')
552    self._assert_output_f16(mode, node_map, 'Conv2D_1')
553
554    output_val_ref, output_val, cost_graph = self._run(mode, output)
555    # Bump up the tolerance for the ROCm platform
556    # The default tolerance (1e-3) results in a tiny fraction (<1%) of
557    # miscompares on ROCm platform, and hence the tolerance bump
558    tol = 2e-3 if test.is_built_with_rocm else 1e-3
559    tol = 5e-2 if mode == 'mkl' else tol
560    self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
561
562  # TODO(reedwm): Fix and enable this test with MKL. Currently this crashes with
563  # MKL
564  @parameterized.parameters(['cuda'])
565  @test_util.run_deprecated_v1
566  @test_util.disable_xla('This test does not pass with XLA')
567  def test_conv_pool(self, mode):
568    """Test graph with convolution followed by pooling."""
569    self._maybe_skip(mode)
570    with ops.device(_get_device(mode)):
571      random_seed.set_random_seed(0)
572      x = _input([2, 8, 8, 1])
573      output = _conv_pool(x)
574
575    output_val_ref, output_val, cost_graph = self._run(mode, output)
576    node_map = _build_node_map(cost_graph.node)
577    num_to_f16, num_to_fp32 = _count_casts(mode, cost_graph.node)
578
579    self._assert_output_f16(mode, node_map, 'Conv2D')
580    self._assert_output_f16(mode, node_map, 'Relu')
581    self._assert_output_f16(mode, node_map, 'MaxPool')
582    self._assert_output_f16(mode, node_map, 'Conv2D_1')
583    self.assertEqual(num_to_f16, 4)
584    self.assertEqual(num_to_fp32, 1)
585    tol = 5e-3 if mode == 'mkl' else 1e-3
586    self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
587
588  # TODO(benbarsdell): This test has not been tried with MKL.
589  @parameterized.parameters(['cuda'])
590  @test_util.run_deprecated_v1
591  @test_util.disable_xla('This test does not pass with XLA')
592  def test_depthwise_conv2d(self, mode):
593    """Test grad ops with depthwise convolution2d graph."""
594    self._maybe_skip(mode)
595    cudnn_version_str = sysconfig.get_build_info().get('cudnn_version', '0.0')
596    cudnn_version = tuple([int(x) for x in cudnn_version_str.split('.')])
597    if cudnn_version < (8,):
598      # Depthwise conv2d ops are only enabled in auto_mixed_precision as of
599      # cuDNN v8.
600      self.skipTest('cuDNN version >= 8 required')
601    with ops.device(_get_device(mode)):
602      random_seed.set_random_seed(0)
603      x = _input([2, 8, 8, 1])
604      f = _weight([3, 3, 1, 4])
605      y = _depthwise_conv2d(x, f)
606      y = array_ops.identity(y)
607      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
608      g = optimizer.compute_gradients(y, [x, f])
609      output = (y, g)
610
611    output_val_ref, output_val, cost_graph = self._run(mode, output)
612    node_map = _build_node_map(cost_graph.node)
613    self._assert_output_f16(mode, node_map, 'depthwise')
614    self._assert_output_f16(
615        mode, node_map,
616        'gradients/depthwise_grad/DepthwiseConv2dNativeBackpropInput')
617    self._assert_output_f16(
618        mode, node_map,
619        'gradients/depthwise_grad/DepthwiseConv2dNativeBackpropFilter')
620
621    output_val_ref, output_val, cost_graph = self._run(mode, output)
622    tol = 2e-3
623    self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
624
625  @parameterized.parameters(['cuda', 'mkl'])
626  @test_util.run_v1_only('b/138749235')
627  @test_util.disable_xla('This test does not pass with XLA')
628  def test_simple_loop(self, mode):
629    """Test graph with while loop."""
630    self._maybe_skip(mode)
631    with ops.device(_get_device(mode)):
632      random_seed.set_random_seed(0)
633      x = _input([8, 8])
634      y = _simple_loop(x, _matmul_act)[1]
635      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
636      g = optimizer.compute_gradients(y, [x])
637      output = (y, g)
638
639    output_val_ref, output_val, cost_graph = self._run(mode, output)
640    node_map = _build_node_map(cost_graph.node)
641
642    self._assert_output_f16(mode, node_map, 'while/MatMul')
643    self._assert_output_f16(mode, node_map, 'while/Relu')
644    tol = 1e-2 if mode == 'mkl' else 1e-3
645    self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
646
647  @parameterized.parameters(['cuda', 'mkl'])
648  @test_util.run_v1_only('b/138749235')
649  @test_util.disable_xla('This test does not pass with XLA')
650  def test_loop_with_vars_intertwined(self, mode):
651    """Test graph with intertwined while loops."""
652    self._maybe_skip(mode)
653    with ops.device(_get_device(mode)):
654      random_seed.set_random_seed(0)
655      x = _input([8, 8])
656      _, _, k, l = _loop_vars_intertwined(
657          array_ops.ones(array_ops.shape(x)), x, _matmul_act, _matmul_act)
658      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
659      g = optimizer.compute_gradients(k, [x])
660      output = (k, l, g)
661
662    output_val_ref, output_val, cost_graph = self._run(mode, output)
663    node_map = _build_node_map(cost_graph.node)
664
665    self._assert_output_f16(mode, node_map, 'while/MatMul')
666    self._assert_output_f16(mode, node_map, 'while/Relu')
667    self._assert_output_f16(mode, node_map, 'while/MatMul_1')
668    self._assert_output_f16(mode, node_map, 'while/Relu_1')
669    tol = 5e-3 if mode == 'mkl' else 1e-3
670    self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
671
672  @parameterized.parameters(['cuda'])
673  @test_util.run_deprecated_v1
674  @test_util.disable_xla('This test does not pass with XLA')
675  def test_multi_paths(self, mode):
676    """Test graph with multiple paths."""
677    self._maybe_skip(mode)
678    with ops.device(_get_device(mode)):
679      random_seed.set_random_seed(0)
680      x = _input([2, 8, 8, 3])
681      x1, x2, x3 = array_ops.split(x, num_or_size_splits=3, axis=3)
682      y1 = _conv_pool(x1)
683      y2 = _conv_pool(x2)
684      y3 = _conv_pool(x3)
685      y = array_ops.concat([y1, y2, y3], axis=3)
686      y = array_ops.identity(y)
687      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
688      g = optimizer.compute_gradients(y, [x])
689      output = (y, g)
690
691    output_val_ref, output_val, cost_graph = self._run(mode, output)
692    node_map = _build_node_map(cost_graph.node)
693
694    self._assert_output_f16(mode, node_map, 'split')
695    for suffix in [''] + ['_%i' % i for i in range(1, 6)]:
696      self._assert_output_f16(mode, node_map, 'Conv2D' + suffix)
697      self._assert_output_f16(mode, node_map, 'Relu' + suffix)
698      self._assert_output_f16(mode, node_map, 'MaxPool' + suffix)
699    self._assert_output_f16(mode, node_map, 'concat')
700    atol = 1e-2 if test.is_built_with_rocm() else 1e-3
701    self.assertAllClose(output_val_ref, output_val, atol=atol, rtol=1e-3)
702
703  @parameterized.parameters(['cuda', 'mkl'])
704  @test_util.run_deprecated_v1
705  @test_util.disable_xla('This test does not pass with XLA')
706  def test_multi_paths_2(self, mode):
707    """Test graph with multiple paths."""
708    self._maybe_skip(mode)
709    with ops.device(_get_device(mode)):
710      random_seed.set_random_seed(0)
711      x = _input([8, 8])
712      y1 = _matmul_act(x)
713      y2 = _matmul_act(x)
714      y = y1 + y2 + x
715      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
716      g = optimizer.compute_gradients(y, [x])
717      output = (g, y)
718
719    output_val_ref, output_val, cost_graph = self._run(mode, output)
720    node_map = _build_node_map(cost_graph.node)
721
722    self._assert_output_f16(mode, node_map, 'MatMul')
723    self._assert_output_f16(mode, node_map, 'Relu')
724    self._assert_output_f16(mode, node_map, 'MatMul_1')
725    self._assert_output_f16(mode, node_map, 'Relu_1')
726    if mode == 'mkl':
727      tol = 2e-2
728    elif test.is_built_with_rocm():
729      # Bump up the tolerance for the ROCm platform
730      # The default tolerance (1e-3) results in a tiny fraction (<1%) of
731      # miscompares on ROCm platform, and hence the tolerance bump
732      tol = 1e-2
733    else:
734      tol = 1e-3
735    self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
736
737  @parameterized.parameters(['cuda'])  # MKL doesn't support bf16 Sigmoid
738  @test_util.run_v1_only('b/138749235')
739  @test_util.disable_xla('This test does not pass with XLA')
740  def test_recurrent_lstm(self, mode):
741    """Test graph with recurrent lstm."""
742    self._maybe_skip(mode)
743    with ops.device(_get_device(mode)):
744      random_seed.set_random_seed(0)
745      init_c = _input([8, 4])
746      init_h = _input([8, 4])
747      _, _, h, _ = _recurrent_lstm(init_c, init_h)
748      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
749      g = optimizer.compute_gradients(h, [init_c, init_h])
750      output = (h, g)
751
752    output_val_ref, output_val, cost_graph = self._run(mode, output)
753    node_map = _build_node_map(cost_graph.node)
754
755    self._assert_output_f16(mode, node_map, 'while/concat')
756    self._assert_output_f16(mode, node_map, 'while/MatMul')
757    self._assert_output_f16(mode, node_map, 'while/split')
758    self._assert_output_f16(mode, node_map, 'while/Sigmoid')
759    self._assert_output_f16(mode, node_map, 'while/Sigmoid_1')
760    self._assert_output_f16(mode, node_map, 'while/Sigmoid_2')
761    self._assert_output_f16(mode, node_map, 'while/Tanh')
762    self._assert_output_f16(mode, node_map, 'while/Tanh_1')
763    self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
764
765  @parameterized.parameters(['cuda', 'mkl'])
766  @test_util.run_v1_only('v1 loop test')
767  @test_util.disable_xla('This test does not pass with XLA')
768  def test_propagation_through_simple_loop_1(self, mode):
769    self._run_simple_loop_test(mode, 'W', 'C', 'C')
770
771  @parameterized.parameters(['cuda', 'mkl'])
772  @test_util.run_v1_only('v1 loop test')
773  @test_util.disable_xla('This test does not pass with XLA')
774  def test_propagation_through_simple_loop_2(self, mode):
775    self._run_simple_loop_test(mode, 'C', 'C', 'W')
776
777  @parameterized.parameters(['cuda', 'mkl'])
778  @test_util.run_v1_only('v1 loop test')
779  @test_util.disable_xla('This test does not pass with XLA')
780  def test_propagation_through_simple_loop_3(self, mode):
781    self._run_simple_loop_test(mode, 'W', 'G', 'W')
782
783  @parameterized.parameters(['cuda', 'mkl'])
784  @test_util.run_v1_only('v1 loop test')
785  @test_util.disable_xla('This test does not pass with XLA')
786  def test_propagation_through_simple_loop_4(self, mode):
787    self._run_simple_loop_test(mode, 'W', 'gbg', 'W')
788
789  @parameterized.parameters(['cuda', 'mkl'])
790  @test_util.run_v1_only('b/138749235')
791  @test_util.disable_xla('This test does not pass with XLA')
792  def test_propagation_through_simple_loop_5(self, mode):
793    self._run_simple_loop_test(mode, 'b', 'gWC', 'c')
794
795  @parameterized.parameters(['cuda', 'mkl'])
796  @test_util.run_v1_only('b/138749235')
797  @test_util.disable_xla('This test does not pass with XLA')
798  def test_propagation_through_simple_loop_6(self, mode):
799    self._run_simple_loop_test(mode, 'b', 'CWCG', 'C')
800
801  @parameterized.parameters(['cuda', 'mkl'])
802  @test_util.run_v1_only('b/138749235')
803  @test_util.disable_xla('This test does not pass with XLA')
804  def test_propagation_through_simple_loop_7(self, mode):
805    self._run_simple_loop_test(mode, 'C', 'GWCG', 'C')
806
807  @parameterized.parameters(['cuda', 'mkl'])
808  @test_util.run_v1_only('b/138749235')
809  @test_util.disable_xla('This test does not pass with XLA')
810  def test_propagation_through_simple_loop_8(self, mode):
811    self._run_simple_loop_test(mode, 'C', 'CgbgWC', 'g')
812
813  @parameterized.parameters(['cuda', 'mkl'])
814  @test_util.run_deprecated_v1
815  @test_util.disable_xla('This test does not pass with XLA')
816  def test_noninlined_funcdef(self, mode):
817    """Test graph with non-inlined function subgraph.
818
819    This requires the grappler pass to handle an OpDef that only appears in the
820    graph's function registry instead of the global op registry.
821
822    Args:
823      mode: Either 'cuda' or 'mkl'.
824    """
825    self._maybe_skip(mode)
826    with ops.device(_get_device(mode)):
827      random_seed.set_random_seed(0)
828      x = _input([8, 8])
829      y = _matmul_act(x)
830      y = _example_noninlined_funcdef(y)
831      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
832      g = optimizer.compute_gradients(y, [x])
833      output = (g, y)
834
835    output_val_ref, output_val, cost_graph = self._run(mode, output)
836    node_map = _build_node_map(cost_graph.node)
837
838    self._assert_output_f16(mode, node_map, 'MatMul')
839    tol = 1e-2 if mode == 'mkl' else 1e-3
840    atol = 1e-2 if test.is_built_with_rocm() else tol
841    self.assertAllClose(output_val_ref, output_val, atol=atol, rtol=tol)
842
843  @parameterized.parameters(['cuda', 'mkl'])
844  @test_util.run_deprecated_v1
845  @test_util.disable_xla('This test does not pass with XLA')
846  def test_ingraph_train_loop(self, mode):
847    """Tests a graph containing a while loop around a training update.
848
849    This requires the grappler pass to take special care with its handling of
850    Enter ops that appear in front of reads from non-resource variables. See
851    the use of NodeImplicitlyReadsVariable in auto_mixed_precision.cc.
852
853    Args:
854      mode: Either 'cuda' or 'mkl'.
855    """
856    self._maybe_skip(mode)
857    if tf2.enabled():
858      # This test tests non-resource variables, which are only used in TF1.
859      self.skipTest('TensorFlow 1 required')
860    with ops.device(_get_device(mode)):
861      random_seed.set_random_seed(1234)
862      np.random.seed(1234)
863      num_iter, bs, nchan, nclass = 100, 64, 32, 100
864
865      data = np.random.normal(size=(bs * num_iter, nchan)).astype(np.float32)
866      labels = np.random.randint(nclass, size=(bs * num_iter,))
867      ds = dataset_ops.Dataset.from_tensor_slices((data, labels))
868      ds = ds.batch(bs).prefetch(3)
869      it = ds.make_one_shot_iterator()
870
871      def body(_, i):
872        i += 1
873        x, yt = it.get_next()
874        dense = layers.Dense(nclass)
875        y = dense(x)
876        loss = losses.sparse_softmax_cross_entropy(yt, y)
877        opt = adam.AdamOptimizer()
878        train_op = opt.minimize(loss, var_list=dense.trainable_weights)
879        with ops.control_dependencies([train_op]):
880          loss = array_ops.identity(loss)
881        return loss, i
882
883      begin, end = constant_op.constant(0), constant_op.constant(num_iter)
884      loss, _ = control_flow_ops.while_loop(
885          lambda loss, i: math_ops.less(i, end), body, [0.0, begin])
886
887    output_val_ref, output_val, cost_graph = self._run(mode, loss)
888    node_map = _build_node_map(cost_graph.node)
889
890    self._assert_output_f16(mode, node_map, 'while/dense/MatMul')
891    self._assert_output_f16(mode, node_map,
892                            'while/gradients/while/dense/MatMul_grad/MatMul_1')
893    self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
894
895  # TODO(benbarsdell): Add tests for list ops (TensorList*) that pass through
896  # graph source/sink nodes, similar to the TensorListThroughFunction C++ test.
897  # Tests here will have the advantage of catching changes in the types of ops
898  # that are added to the graph.
899
900
901if __name__ == '__main__':
902  test.main()
903