xref: /aosp_15_r20/external/tensorflow/tensorflow/python/eager/benchmarks_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15r"""Benchmarks for low-level eager execution primitives.
16
17To run CPU benchmarks:
18  bazel run -c opt benchmarks_test -- --benchmarks=.
19
20To run GPU benchmarks:
21  bazel run --config=cuda -c opt --copt="-mavx" benchmarks_test -- \
22    --benchmarks=.
23
24To run a subset of benchmarks using --benchmarks flag.
25--benchmarks: the list of benchmarks to run. The specified value is interpreted
26as a regular expression and any benchmark whose name contains a partial match
27to the regular expression is executed.
28e.g. --benchmarks=".*matmul*." will run all matmul related benchmarks.
29
30"""
31import time
32
33import numpy as np
34
35from tensorflow.python import pywrap_tfe
36from tensorflow.python.eager import backprop  # pylint: disable=unused-import
37from tensorflow.python.eager import benchmarks_test_base
38from tensorflow.python.eager import context
39from tensorflow.python.eager import core
40from tensorflow.python.eager import def_function
41from tensorflow.python.eager import forwardprop
42from tensorflow.python.eager import function
43from tensorflow.python.eager import test
44from tensorflow.python.framework import constant_op
45from tensorflow.python.framework import dtypes
46from tensorflow.python.framework import ops
47from tensorflow.python.framework import tensor_shape
48from tensorflow.python.framework import tensor_spec
49from tensorflow.python.framework import test_util
50from tensorflow.python.ops import array_ops
51from tensorflow.python.ops import control_flow_ops
52from tensorflow.python.ops import functional_ops
53from tensorflow.python.ops import gen_array_ops
54from tensorflow.python.ops import gen_math_ops
55from tensorflow.python.ops import math_ops
56from tensorflow.python.ops import nn_ops
57from tensorflow.python.ops import random_ops
58from tensorflow.python.ops import resource_variable_ops
59from tensorflow.python.util import nest
60from tensorflow.python.util import tf_inspect
61
62CPU = "/device:CPU:0"
63GPU = "/device:GPU:0"
64GLOBAL_TEST_VALUE = None
65
66
67def c_tfe_py_fastpath_execute(a,
68                              b,
69                              transpose_a=False,
70                              transpose_b=False,
71                              name=None):
72  ctx = context.context()
73  assert ctx.executing_eagerly(
74  ), "The prototype doesn't contain C code for graph construction"
75  try:
76    return pywrap_tfe.TFE_Py_FastPathExecute(ctx, "MatMul", name, a, b,
77                                             "transpose_a", transpose_a,
78                                             "transpose_b", transpose_b)
79  except core._NotOkStatusException as e:
80    if name is not None:
81      e.message += " name: " + name
82    raise core._status_to_exception(e) from None
83
84
85def run_benchmark(func, num_iters, execution_mode=None):
86  ctx = context.context()
87  with context.execution_mode(execution_mode):
88    # call func to warm up
89    func()
90    if execution_mode == context.ASYNC:
91      ctx.executor.wait()
92    start = time.time()
93    for _ in range(num_iters):
94      func()
95    if execution_mode == context.ASYNC:
96      ctx.executor.wait()
97    end = time.time()
98
99    return end - start
100
101
102class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
103
104  def __init__(self):
105    # used for multiply benchmarks
106    self._m_2 = random_ops.random_uniform([2])
107
108    # used for matmul benchmarks
109    self._m_2_by_2 = random_ops.random_uniform((2, 2))
110    self._m_100_by_784 = random_ops.random_uniform((100, 784))
111
112    self._num_iters_2_by_2 = 30000
113    self._num_iters_100_by_784 = 30000
114
115    # used for conv2d benchmarks
116    self._m_8_28_28_3 = random_ops.random_uniform((8, 28, 28, 3))
117    self._m_1_3_3_1 = random_ops.random_uniform((1, 3, 3, 1))
118
119  def _get_benchmark_name(self):
120    """Mostly copied from benchmark.py _get_name()."""
121    stack = tf_inspect.stack()
122    name = None
123    for frame in stack[::-1]:
124      f_locals = frame[0].f_locals
125      f_self = f_locals.get("self", None)
126      if isinstance(f_self, test.Benchmark):
127        name = frame[3]  # Get the method name
128        # This is a hack to get around the fact that some methods might have a
129        # disable_tfrt decorator around them. In that case a function called
130        # 'decorated' wraps the real called function underneath and so we
131        # peek one deeper into the stack to get the real name.
132        if name == "decorated":
133          continue
134        else:
135          break
136    if name is None:
137      raise ValueError("Unable to determine calling Benchmark function.")
138    if context.is_tfrt_enabled():
139      name = name + "_tfrt"
140    return name
141
142  def _run(self, func, num_iters, execution_mode=None):
143    self.run_report(run_benchmark, func, num_iters, execution_mode)
144
145  def benchmark_create_np_array(self):
146    func = lambda: np.array([3.0])
147    self._run(func, 30000)
148
149  def _benchmark_create_tensor(self, value, dtype, device):
150    """Benchmark overheads of creating a Tensor object."""
151    if device == GPU:
152      # Warmup the GPU
153      ops.EagerTensor(value, device=device)
154
155    def func():
156      ops.EagerTensor(value, device=device, dtype=dtype)
157
158    self._run(func, 30000)
159
160  def _benchmark_create_constant(self, value, dtype, cached=True):
161    global GLOBAL_TEST_VALUE
162    GLOBAL_TEST_VALUE = value
163
164    def cached_func():
165      constant_op.constant(value, dtype=dtype)
166
167    def uncached_func():
168      global GLOBAL_TEST_VALUE
169      GLOBAL_TEST_VALUE += 1
170      constant_op.constant(GLOBAL_TEST_VALUE, dtype=dtype)
171
172    func = cached_func if cached else uncached_func
173
174    with ops.device("GPU:0" if context.num_gpus() else "CPU:0"):
175      for _ in range(1000):
176        func()  # Warmup.
177      self._run(func, 3000)
178
179  def benchmark_create_float_constant(self):
180    self._benchmark_create_constant(42.0, dtype=None)
181
182  def benchmark_create_float_constant_uncached(self):
183    self._benchmark_create_constant(42.0, dtype=None, cached=False)
184
185  def benchmark_create_int32_constant(self):
186    if context.num_gpus():
187      return  # int32 constants are always allocated on CPU.
188
189    self._benchmark_create_constant(42, dtype=dtypes.int32)
190
191  def benchmark_create_int32_constant_uncached(self):
192    if context.num_gpus():
193      return  # int32 constants are always allocated on CPU.
194
195    self._benchmark_create_constant(42, dtype=dtypes.int32, cached=False)
196
197  def _benchmark_add(self, a, b):
198
199    def func():
200      return memoryview(math_ops.add_v2(a, b))
201
202    with ops.device("GPU:0" if context.num_gpus() else "CPU:0"):
203      for _ in range(1000):
204        func()  # Warmup.
205      self._run(func, 30000)
206
207  def _benchmark_add_operator_overload(self, a, b):
208
209    def func():
210      return memoryview(a + b)
211
212    with ops.device("GPU:0" if context.num_gpus() else "CPU:0"):
213      for _ in range(1000):
214        func()  # Warmup.
215      self._run(func, 30000)
216
217  def benchmark_add_float_scalars(self):
218    self._benchmark_add(42.0, 24.0)
219
220  def benchmark_add_int32_scalars(self):
221    self._benchmark_add(42, 24)
222
223  def benchmark_add_float_scalar_tensor(self):
224    tensor_a = constant_op.constant(42.0)
225    tensor_b = constant_op.constant(24.0)
226    self._benchmark_add(tensor_a, tensor_b)
227
228  def benchmark_add_float_scalar_tensor_overloaded_operator(self):
229    tensor_a = constant_op.constant(42.0)
230    tensor_b = constant_op.constant(24.0)
231    self._benchmark_add_operator_overload(tensor_a, tensor_b)
232
233  def benchmark_add_int32_scalar_tensor(self):
234    tensor_a = constant_op.constant(42)
235    tensor_b = constant_op.constant(24)
236    self._benchmark_add(tensor_a, tensor_b)
237
238  def benchmark_add_float_dense_tensor(self):
239    tensor_a = constant_op.constant([[42.0, 42.0], [42.0, 42.0]])
240    tensor_b = constant_op.constant([[24.0, 24.0], [24.0, 24.0]])
241    self._benchmark_add(tensor_a, tensor_b)
242
243  def benchmark_add_int32_dense_tensor(self):
244    tensor_a = constant_op.constant([[42, 42], [42, 42]])
245    tensor_b = constant_op.constant([[24, 24], [24, 24]])
246    self._benchmark_add(tensor_a, tensor_b)
247
248  def benchmark_create_float_tensor_from_list_CPU(self):
249    self._benchmark_create_tensor([[3.0]], dtypes.float32.as_datatype_enum, CPU)
250
251  def benchmark_create_float_tensor_from_np_array_CPU(self):
252    self._benchmark_create_tensor(
253        np.array([[3.0]], dtype=np.float32), dtypes.float32.as_datatype_enum,
254        CPU)
255
256  def benchmark_create_int32_tensor_from_list_CPU(self):
257    self._benchmark_create_tensor([[3]], dtypes.int32.as_datatype_enum, CPU)
258
259  def benchmark_create_int32_tensor_from_np_array_CPU(self):
260    self._benchmark_create_tensor(
261        np.array([[3]], dtype=np.int32), dtypes.int32.as_datatype_enum, CPU)
262
263  def benchmark_create_float_tensor_from_list_GPU(self):
264    if not context.num_gpus():
265      return
266    self._benchmark_create_tensor([[3.0]], dtypes.float32.as_datatype_enum, GPU)
267
268  def benchmark_create_float_tensor_from_np_array_GPU(self):
269    if not context.num_gpus():
270      return
271    self._benchmark_create_tensor(
272        np.array([[3.0]], dtype=np.float32), dtypes.float32.as_datatype_enum,
273        GPU)
274
275  def benchmark_create_int32_tensor_from_list_GPU(self):
276    # int32's are kept on host memory even when executing on GPU.
277    if not context.num_gpus():
278      return
279    self._benchmark_create_tensor([[3]], dtypes.int32.as_datatype_enum, GPU)
280
281  def benchmark_create_int32_tensor_from_np_array_GPU(self):
282    # int32's are kept on host memory even when executing on GPU.
283    if not context.num_gpus():
284      return
285    self._benchmark_create_tensor(
286        np.array([[3]], dtype=np.int32), dtypes.int32.as_datatype_enum, GPU)
287
288  def benchmark_index_tensor_with_literal(self):
289    func = lambda: constant_op.constant([3.0])[0]
290    self._run(func, 30000)
291
292  def benchmark_index_tensor_with_tensor(self):
293    func = lambda idx=constant_op.constant(0): constant_op.constant([3.0])[idx]
294    self._run(func, 30000)
295
296  def benchmark_index_tensor_with_np_array(self):
297    func = lambda idx=np.array(0): constant_op.constant([3.0])[idx]
298    self._run(func, 30000)
299
300  def _benchmark_np_multiply(self, m, num_iters):
301    a = m.cpu().numpy()
302    func = lambda: a * a
303    self._run(func, num_iters)
304
305  def _benchmark_tf_multiply(self, m, num_iters):
306    func = lambda: m * m
307    self._run(func, num_iters)
308
309  def _benchmark_tf_conv2d(self, m1, m2, num_iters):
310    func = lambda: nn_ops.conv2d(m1, m2, strides=[1, 1, 1, 1], padding="VALID")
311    self._run(func, num_iters)
312
313  def _benchmark_tf_multiply_op(self, m, num_iters):
314    func = lambda: math_ops.multiply(m, m)
315    self._run(func, num_iters)
316
317  def benchmark_np_multiply(self):
318    self._benchmark_np_multiply(self._m_2, 30000)
319
320  def benchmark_tf_multiply_CPU(self):
321    with context.device(CPU):
322      m = self._m_2.cpu()
323      self._benchmark_tf_multiply(m, 30000)
324
325  def benchmark_tf_multiply_GPU(self):
326    if not context.num_gpus():
327      return
328    with context.device(GPU):
329      m = self._m_2.gpu()
330      self._benchmark_tf_multiply(m, 30000)
331
332  def benchmark_tf_multiply_op_CPU(self):
333    with context.device(CPU):
334      m = self._m_2.cpu()
335      self._benchmark_tf_multiply_op(m, 30000)
336
337  def benchmark_tf_multiply_op_GPU(self):
338    if not context.num_gpus():
339      return
340    with context.device(GPU):
341      m = self._m_2.gpu()
342      self._benchmark_tf_multiply_op(m, 30000)
343
344  def benchmark_tf_conv2d_CPU(self):
345    with context.device(CPU):
346      m1 = self._m_8_28_28_3.cpu()
347      m2 = self._m_1_3_3_1.cpu()
348      self._benchmark_tf_conv2d(m1, m2, 30000)
349
350  def benchmark_tf_conv2d_GPU(self):
351    if not context.num_gpus():
352      return
353    with context.device(GPU):
354      m1 = self._m_8_28_28_3.gpu()
355      m2 = self._m_1_3_3_1.gpu()
356      self._benchmark_tf_conv2d(m1, m2, 30000)
357
358  def benchmark_tf_identity(self):
359    m = self._m_2
360    self._run(lambda: gen_array_ops.identity(m), 30000)
361
362  def benchmark_slowpath_tf_identity(self):
363    self._run(lambda: gen_array_ops.identity(1), 30000)
364
365  def benchmark_tfe_py_execute_identity(self):
366    m = self._m_2
367    ctx_handle = context.context()._handle
368    attrs = ("T", self._m_2.dtype.as_datatype_enum)
369    inputs = [m]
370
371    def f():
372      pywrap_tfe.TFE_Py_Execute(ctx_handle, None, "Identity", inputs, attrs, 1)
373
374    self._run(f, 30000)
375
376  def benchmark_tf_gradient_function_identity(self):
377    with context.device(CPU):
378      m = gen_array_ops.identity(self._m_2)
379      self._run(
380          lambda: backprop.gradients_function(gen_array_ops.identity, [0])(m),
381          30000)
382
383  def benchmark_tf_gradient_forward_identity(self):
384    with backprop.GradientTape() as tape:
385      m = self._m_2
386      tape.watch(m)
387      self._run(lambda: gen_array_ops.identity(m), 30000)
388
389  def benchmark_tf_gradient_tape_push_pop(self):
390
391    def f():
392      with backprop.GradientTape():
393        pass
394
395    self._run(f, 30000)
396
397  def benchmark_tf_gradient_function_no_op(self):
398    with context.device(CPU):
399      m = gen_array_ops.identity(self._m_2)
400      self._run(lambda: backprop.gradients_function(lambda x: x, [0])(m), 30000)
401
402  def _benchmark_np_matmul(self, m, transpose_b, num_iters):
403    a = m.cpu().numpy()
404    b = a.T if transpose_b else a
405    func = lambda: np.dot(a, b)
406    self._run(func, num_iters)
407
408  def _benchmark_tf_matmul(self,
409                           m,
410                           transpose_b,
411                           num_iters,
412                           execution_mode=None):
413    func = lambda: math_ops.matmul(m, m, transpose_b=transpose_b)
414    self._run(func, num_iters, execution_mode=execution_mode)
415
416  def _benchmark_gen_math_ops_matmul(self, m, transpose_b, num_iters):
417
418    def func():
419      gen_math_ops.mat_mul(m, m, transpose_b=transpose_b)
420
421    self._run(func, num_iters)
422
423  def _benchmark_tfe_py_fastpath_execute_matmul(self, m, transpose_b,
424                                                num_iters):
425
426    def func():
427      c_tfe_py_fastpath_execute(m, m, transpose_b=transpose_b)
428
429    self._run(func, num_iters)
430
431  def _benchmark_tfe_py_execute_matmul(self, m, transpose_b, num_iters):
432    inputs = [m, m]
433    # pylint: disable=protected-access
434    ctx_handle = context.context()._handle
435    # pylint: enable=protected-access
436    device = context.context().device_name
437    attrs = ("transpose_a", False, "transpose_b", transpose_b, "T",
438             m.dtype.as_datatype_enum)
439
440    def func():
441      pywrap_tfe.TFE_Py_Execute(ctx_handle, device, "MatMul", inputs, attrs, 1)
442
443    self._run(func, num_iters)
444
445  def _benchmark_defun_matmul(self,
446                              m,
447                              transpose_b,
448                              num_iters,
449                              execution_mode=None):
450    f = function.defun(math_ops.matmul)
451    func = lambda: f(m, m, transpose_b=transpose_b)
452    self._run(func, num_iters, execution_mode=execution_mode)
453
454  def _benchmark_defun_matmul_with_signature(self,
455                                             m,
456                                             num_iters,
457                                             execution_mode=None):
458
459    @def_function.function(
460        input_signature=[tensor_spec.TensorSpec([2, 2], dtypes.float32)])
461    def defun_matmul(m):
462      return math_ops.matmul(m, m)
463
464    func = lambda: defun_matmul(m)
465    self._run(func, num_iters, execution_mode=execution_mode)
466
467  def _benchmark_defun_matmul_relaxed_shape(self,
468                                            m,
469                                            num_iters,
470                                            execution_mode=None):
471
472    @def_function.function(reduce_retracing=True)
473    def defun_matmul(m):
474      return math_ops.matmul(m, m)
475
476    m_3_by_3 = random_ops.random_uniform((3, 3))
477    defun_matmul(m_3_by_3)
478    func = lambda: defun_matmul(m)
479    self._run(func, num_iters, execution_mode=execution_mode)
480
481  def _benchmark_defun_args_matmul(self, m, num_iters, execution_mode=None):
482
483    @def_function.function
484    def defun_matmul(m):
485      return math_ops.matmul(m, m)
486
487    func = lambda: defun_matmul(m)
488    self._run(func, num_iters, execution_mode=execution_mode)
489
490  def _benchmark_nested_defun_matmul(self, m, transpose_b, num_iters):
491    inner = function.defun(math_ops.matmul)
492
493    @function.defun
494    def outer(a, b, c, transpose_b):
495      return math_ops.matmul(inner(a, b, transpose_b=transpose_b), c)
496
497    func = lambda: outer(m, m, m, transpose_b=transpose_b)
498    # Warmup before benchmark
499    for _ in range(1000):
500      func()
501    self._run(func, num_iters)
502
503  def _benchmark_defun_matmul_forward_backward(self,
504                                               m,
505                                               transpose_b,
506                                               num_iters,
507                                               execution_mode=None):
508    f = def_function.function(math_ops.matmul)
509
510    def func():
511      with backprop.GradientTape() as gt:
512        gt.watch(m)
513        y = f(m, m, transpose_b=transpose_b)
514      _ = gt.gradient(y, m)
515
516    self._run(func, num_iters, execution_mode=execution_mode)
517
518  def _benchmark_read_variable(self, m, num_iters):
519    self._run(m.value, num_iters)
520
521  def _benchmark_matmul_read_variable(self, m, num_iters):
522    self._benchmark_gen_math_ops_matmul(
523        m, transpose_b=False, num_iters=num_iters)
524
525  def _benchmark_matmul_read_variable_with_tape(self, m, num_iters):
526    with backprop.GradientTape() as tape:
527      tape.watch(m)
528      self._benchmark_gen_math_ops_matmul(
529          m, transpose_b=False, num_iters=num_iters)
530
531  def _benchmark_read_variable_with_tape(self, m, num_iters):
532    with backprop.GradientTape() as tape:
533      tape.watch(m)
534      self._run(m.value, num_iters)
535
536  # Benchmarks for A^2, A of dimension 2 by 2.
537  def benchmark_np_matmul_2_by_2(self):
538    self._benchmark_np_matmul(
539        self._m_2_by_2, transpose_b=False, num_iters=self._num_iters_2_by_2)
540
541  def benchmark_tf_matmul_2_by_2_CPU(self):
542    with context.device(CPU):
543      m = self._m_2_by_2.cpu()
544      self._benchmark_tf_matmul(
545          m, transpose_b=False, num_iters=self._num_iters_2_by_2)
546
547  def benchmark_tf_matmul_2_by_2_CPU_async(self):
548    with context.device(CPU):
549      m = self._m_2_by_2.cpu()
550      self._benchmark_tf_matmul(
551          m,
552          transpose_b=False,
553          num_iters=self._num_iters_2_by_2,
554          execution_mode=context.ASYNC)
555
556  def benchmark_gen_math_ops_matmul_2_by_2_CPU(self):
557    with context.device(CPU):
558      m = self._m_2_by_2.cpu()
559      self._benchmark_gen_math_ops_matmul(
560          m, transpose_b=False, num_iters=self._num_iters_2_by_2)
561
562  def benchmark_tfe_py_fastpath_execute_matmul_2_by_2_CPU(self):
563    with context.device(CPU):
564      m = self._m_2_by_2.cpu()
565      self._benchmark_tfe_py_fastpath_execute_matmul(
566          m, transpose_b=False, num_iters=self._num_iters_2_by_2)
567
568  def benchmark_tfe_py_execute_matmul_2_by_2_CPU(self):
569    with context.device(CPU):
570      m = self._m_2_by_2.cpu()
571      self._benchmark_tfe_py_execute_matmul(
572          m, transpose_b=False, num_iters=self._num_iters_2_by_2)
573
574  def benchmark_defun_matmul_2_by_2_CPU(self):
575    with context.device(CPU):
576      m = self._m_2_by_2.cpu()
577      self._benchmark_defun_matmul(
578          m, transpose_b=False, num_iters=self._num_iters_2_by_2)
579
580  def benchmark_defun_matmul_2_by_2_with_signature_CPU(self):
581    with context.device(CPU):
582      m = self._m_2_by_2.cpu()
583      self._benchmark_defun_matmul_with_signature(
584          m, num_iters=self._num_iters_2_by_2)
585
586  def benchmark_defun_matmul_2_by_2_relaxed_shape_CPU(self):
587    with context.device(CPU):
588      m = self._m_2_by_2.cpu()
589      self._benchmark_defun_matmul_relaxed_shape(
590          m, num_iters=self._num_iters_2_by_2)
591
592  def benchmark_defun_args_matmul_2_by_2_CPU(self):
593    with context.device(CPU):
594      m = self._m_2_by_2.cpu()
595      self._benchmark_defun_args_matmul(m, num_iters=self._num_iters_2_by_2)
596
597  def benchmark_defun_matmul_2_by_2_CPU_async(self):
598    with context.device(CPU):
599      m = self._m_2_by_2.cpu()
600      self._benchmark_defun_matmul(
601          m,
602          transpose_b=False,
603          num_iters=self._num_iters_2_by_2,
604          execution_mode=context.ASYNC)
605
606  def _benchmark_matmul_forward_backward_2_by_2_CPU(self, run_eager=False):
607    def_function.run_functions_eagerly(run_eager)
608    with context.device(CPU):
609      m = self._m_2_by_2.cpu()
610      self._benchmark_defun_matmul_forward_backward(
611          m, transpose_b=False, num_iters=self._num_iters_2_by_2)
612    def_function.run_functions_eagerly(False)
613
614  def _benchmark_matmul_forward_backward_2_by_2_CPU_async(
615      self, run_eager=False):
616    def_function.run_functions_eagerly(run_eager)
617    with context.device(CPU):
618      m = self._m_2_by_2.cpu()
619      self._benchmark_defun_matmul_forward_backward(
620          m,
621          transpose_b=False,
622          num_iters=self._num_iters_2_by_2,
623          execution_mode=context.ASYNC)
624
625  def benchmark_defun_matmul_forward_backward_2_by_2_CPU(self):
626    self._benchmark_matmul_forward_backward_2_by_2_CPU(False)
627
628  def benchmark_defun_matmul_forward_backward_2_by_2_CPU_async(self):
629    self._benchmark_matmul_forward_backward_2_by_2_CPU_async(False)
630
631  def benchmark_defun_eager_matmul_forward_backward_2_by_2_CPU(self):
632    self._benchmark_matmul_forward_backward_2_by_2_CPU(True)
633
634  def benchmark_defun_eager_matmul_forward_backward_2_by_2_CPU_async(self):
635    self._benchmark_matmul_forward_backward_2_by_2_CPU_async(True)
636
637  def benchmark_tf_matmul_2_by_2_GPU(self):
638    if not context.num_gpus():
639      return
640    with context.device(GPU):
641      m = self._m_2_by_2.gpu()
642      self._benchmark_tf_matmul(
643          m, transpose_b=False, num_iters=self._num_iters_2_by_2)
644
645  def benchmark_tf_matmul_2_by_2_GPU_async(self):
646    if not context.num_gpus():
647      return
648    with context.device(GPU):
649      m = self._m_2_by_2.gpu()
650      self._benchmark_tf_matmul(
651          m,
652          transpose_b=False,
653          num_iters=self._num_iters_2_by_2,
654          execution_mode=context.ASYNC)
655
656  def benchmark_gen_math_ops_matmul_2_by_2_GPU(self):
657    if not context.num_gpus():
658      return
659    with context.device(GPU):
660      m = self._m_2_by_2.gpu()
661      self._benchmark_gen_math_ops_matmul(
662          m, transpose_b=False, num_iters=self._num_iters_2_by_2)
663
664  def benchmark_tfe_py_execute_matmul_2_by_2_GPU(self):
665    if not context.num_gpus():
666      return
667    with context.device(GPU):
668      m = self._m_2_by_2.gpu()
669      self._benchmark_tfe_py_execute_matmul(
670          m, transpose_b=False, num_iters=self._num_iters_2_by_2)
671
672  def benchmark_defun_matmul_2_by_2_GPU(self):
673    if not context.num_gpus():
674      return
675    with context.device(GPU):
676      m = self._m_2_by_2.gpu()
677      self._benchmark_defun_matmul(
678          m, transpose_b=False, num_iters=self._num_iters_2_by_2)
679
680  def benchmark_defun_matmul_2_by_2_with_signature_GPU(self):
681    if not context.num_gpus():
682      return
683    with context.device(GPU):
684      m = self._m_2_by_2.gpu()
685      self._benchmark_defun_matmul_with_signature(
686          m, num_iters=self._num_iters_2_by_2)
687
688  def benchmark_defun_matmul_2_by_2_relaxed_shape_GPU(self):
689    if not context.num_gpus():
690      return
691    with context.device(GPU):
692      m = self._m_2_by_2.gpu()
693      self._benchmark_defun_matmul_relaxed_shape(
694          m, num_iters=self._num_iters_2_by_2)
695
696  def benchmark_defun_args_matmul_2_by_2_GPU(self):
697    if not context.num_gpus():
698      return
699    with context.device(GPU):
700      m = self._m_2_by_2.gpu()
701      self._benchmark_defun_args_matmul(m, num_iters=self._num_iters_2_by_2)
702
703  def benchmark_defun_matmul_2_by_2_GPU_async(self):
704    if not context.num_gpus():
705      return
706    with context.device(GPU):
707      m = self._m_2_by_2.gpu()
708      self._benchmark_defun_matmul(
709          m,
710          transpose_b=False,
711          num_iters=self._num_iters_2_by_2,
712          execution_mode=context.ASYNC)
713
714  def benchmark_nested_defun_matmul_2_by_2(self):
715    m = self._m_2_by_2.cpu()
716    self._benchmark_nested_defun_matmul(
717        m, transpose_b=False, num_iters=self._num_iters_2_by_2)
718
719  # Benchmarks for AA.T, A of dimension 100 by 784.
720  def benchmark_np_matmul_100_by_784(self):
721    self._benchmark_np_matmul(
722        self._m_100_by_784,
723        transpose_b=True,
724        num_iters=self._num_iters_100_by_784)
725
726  def benchmark_tf_matmul_100_by_784_CPU(self):
727    with context.device(CPU):
728      m = self._m_100_by_784.cpu()
729      self._benchmark_tf_matmul(
730          m, transpose_b=True, num_iters=self._num_iters_100_by_784)
731
732  def benchmark_tf_matmul_100_by_784_CPU_async(self):
733    with context.device(CPU):
734      m = self._m_100_by_784.cpu()
735      self._benchmark_tf_matmul(
736          m,
737          transpose_b=True,
738          num_iters=self._num_iters_100_by_784,
739          execution_mode=context.ASYNC)
740
741  def benchmark_gen_math_ops_matmul_100_by_784_CPU(self):
742    with context.device(CPU):
743      m = self._m_100_by_784.cpu()
744      self._benchmark_gen_math_ops_matmul(
745          m, transpose_b=True, num_iters=self._num_iters_100_by_784)
746
747  def benchmark_tfe_py_fastpath_execute_matmul_100_by_784_CPU(self):
748    with context.device(CPU):
749      m = self._m_100_by_784.cpu()
750      self._benchmark_tfe_py_fastpath_execute_matmul(
751          m, transpose_b=True, num_iters=self._num_iters_100_by_784)
752
753  def benchmark_tfe_py_execute_matmul_100_by_784_CPU(self):
754    with context.device(CPU):
755      m = self._m_100_by_784.cpu()
756      self._benchmark_tfe_py_execute_matmul(
757          m, transpose_b=True, num_iters=self._num_iters_100_by_784)
758
759  def benchmark_defun_matmul_100_by_784_CPU(self):
760    with context.device(CPU):
761      m = self._m_100_by_784.cpu()
762      self._benchmark_defun_matmul(
763          m, transpose_b=True, num_iters=self._num_iters_100_by_784)
764
765  def benchmark_tf_matmul_100_by_784_GPU(self):
766    if not context.num_gpus():
767      return
768    with context.device(GPU):
769      m = self._m_100_by_784.gpu()
770      self._benchmark_tf_matmul(
771          m, transpose_b=True, num_iters=self._num_iters_100_by_784)
772
773  def benchmark_tf_matmul_100_by_784_GPU_async(self):
774    if not context.num_gpus():
775      return
776    with context.device(GPU):
777      m = self._m_100_by_784.gpu()
778      self._benchmark_tf_matmul(
779          m,
780          transpose_b=True,
781          num_iters=self._num_iters_100_by_784,
782          execution_mode=context.ASYNC)
783
784  def benchmark_gen_math_ops_matmul_100_by_784_GPU(self):
785    if not context.num_gpus():
786      return
787    with context.device(GPU):
788      m = self._m_100_by_784.gpu()
789      self._benchmark_gen_math_ops_matmul(
790          m, transpose_b=True, num_iters=self._num_iters_100_by_784)
791
792  def benchmark_tfe_py_execute_matmul_100_by_784_GPU(self):
793    if not context.num_gpus():
794      return
795    with context.device(GPU):
796      m = self._m_100_by_784.gpu()
797      self._benchmark_tfe_py_execute_matmul(
798          m, transpose_b=True, num_iters=self._num_iters_100_by_784)
799
800  def benchmark_defun_matmul_100_by_784_GPU(self):
801    if not context.num_gpus():
802      return
803    with context.device(GPU):
804      m = self._m_100_by_784.gpu()
805      self._benchmark_defun_matmul(
806          m, transpose_b=True, num_iters=self._num_iters_100_by_784)
807
808  @test_util.disable_tfrt(
809      "b/169371527: Support inserting transfer op in lowering.")
810  def benchmark_nested_defun_matmul_100_by_784_GPU(self):
811    m = self._m_100_by_784.gpu()
812    self._benchmark_nested_defun_matmul(
813        m, transpose_b=True, num_iters=self._num_iters_100_by_784)
814
815  def _benchmark_forwardprop_matmul_CPU(self, shape):
816    with ops.device(CPU):
817      m = random_ops.random_uniform(shape).cpu()
818      tangent = random_ops.random_uniform(shape).cpu()
819
820      def func():
821        with forwardprop.ForwardAccumulator(m, tangent) as acc:
822          result = math_ops.matmul(m, m, transpose_b=True)
823        return result, acc.jvp(result)
824
825      # Warmup before benchmark
826      for _ in range(100):
827        func()
828      self._run(func, 3000)
829
830  def _benchmark_forwardprop_in_defun_matmul_CPU(self, shape):
831    with ops.device(CPU):
832
833      @def_function.function
834      def compiled_function(x, tangent):
835        with forwardprop.ForwardAccumulator(x, tangent) as acc:
836          result = math_ops.matmul(x, x, transpose_b=True)
837        return result, acc.jvp(result)
838
839      m = random_ops.random_uniform(shape).cpu()
840      tangent = random_ops.random_uniform(shape).cpu()
841      func = lambda: compiled_function(m, tangent)
842
843      # Warmup before benchmark
844      for _ in range(100):
845        func()
846      self._run(func, 3000)
847
848  def _benchmark_forwardprop_in_defun_of_defun_matmul_CPU(self, shape):
849    with ops.device(CPU):
850      matmul = def_function.function(math_ops.matmul)
851
852      @def_function.function()
853      def compiled_function(x, tangent):
854        with forwardprop.ForwardAccumulator(x, tangent) as acc:
855          result = matmul(x, x, transpose_b=True)
856        return result, acc.jvp(result)
857
858      m = random_ops.random_uniform(shape).cpu()
859      tangent = random_ops.random_uniform(shape).cpu()
860      func = lambda: compiled_function(m, tangent)
861
862      # Warmup before benchmark
863      for _ in range(100):
864        func()
865      self._run(func, 3000)
866
867  def _benchmark_forwardprop_of_defun_matmul_CPU(self, shape):
868    with ops.device(CPU):
869      m = random_ops.random_uniform(shape).cpu()
870      tangent = random_ops.random_uniform(shape).cpu()
871      matmul = def_function.function(math_ops.matmul)
872
873      def func():
874        with forwardprop.ForwardAccumulator(m, tangent) as acc:
875          result = matmul(m, m, transpose_b=True)
876        return result, acc.jvp(result)
877
878      # Warmup before benchmark
879      for _ in range(100):
880        func()
881      self._run(func, 3000)
882
883  def benchmark_forwardprop_matmul_256_by_2096_CPU(self):
884    self._benchmark_forwardprop_matmul_CPU(shape=(256, 2096))
885
886  def benchmark_forwardprop_in_defun_matmul_256_by_2096_CPU(self):
887    self._benchmark_forwardprop_in_defun_matmul_CPU(shape=(256, 2096))
888
889  def benchmark_forwardprop_in_defun_of_defun_matmul_256_by_2096_CPU(self):
890    self._benchmark_forwardprop_in_defun_of_defun_matmul_CPU(shape=(256, 2096))
891
892  def benchmark_forwardprop_of_defun_matmul_256_by_2096_CPU(self):
893    self._benchmark_forwardprop_of_defun_matmul_CPU(shape=(256, 2096))
894
895  def benchmark_forwardprop_matmul_100_by_784_CPU(self):
896    self._benchmark_forwardprop_matmul_CPU(shape=(100, 784))
897
898  def benchmark_forwardprop_in_defun_matmul_100_by_784_CPU(self):
899    self._benchmark_forwardprop_in_defun_matmul_CPU(shape=(100, 784))
900
901  def benchmark_forwardprop_in_defun_of_defun_matmul_100_by_784_CPU(self):
902    self._benchmark_forwardprop_in_defun_of_defun_matmul_CPU(shape=(100, 784))
903
904  def benchmark_forwardprop_of_defun_matmul_100_by_784_CPU(self):
905    self._benchmark_forwardprop_of_defun_matmul_CPU(shape=(100, 784))
906
907  def _benchmark_tf_reduce_logsumexp(self,
908                                     device=CPU,
909                                     execution_mode=None,
910                                     defunc=False,
911                                     xla_compile=False):
912    with context.device(device):
913      x = constant_op.constant([[1, 0.], [0., 0.]])
914      if defunc:
915        reduce_func = def_function.function(
916            math_ops.reduce_logsumexp, jit_compile=xla_compile)
917        func = lambda: reduce_func(x)
918      else:
919        func = lambda: math_ops.reduce_logsumexp(x)
920      self._run(func, 3000, execution_mode=execution_mode)
921
922  def benchmark_tf_reduce_logsumexp_CPU(self):
923    self._benchmark_tf_reduce_logsumexp()
924
925  def benchmark_tf_reduce_logsumexp_CPU_async(self):
926    self._benchmark_tf_reduce_logsumexp(execution_mode=context.ASYNC)
927
928  def benchmark_tf_reduce_logsumexp_GPU(self):
929    self._benchmark_tf_reduce_logsumexp(device=GPU)
930
931  def benchmark_tf_reduce_logsumexp_GPU_async(self):
932    self._benchmark_tf_reduce_logsumexp(
933        device=GPU, execution_mode=context.ASYNC)
934
935  @test_util.disable_tfrt(
936      "b/169371527: Support inserting transfer op in lowering.")
937  def benchmark_tf_reduce_logsumexp_CPU_defunc(self):
938    self._benchmark_tf_reduce_logsumexp(defunc=True)
939
940  @test_util.disable_tfrt(
941      "b/169371527: Support inserting transfer op in lowering.")
942  def benchmark_tf_reduce_logsumexp_CPU_async_defun(self):
943    self._benchmark_tf_reduce_logsumexp(
944        execution_mode=context.ASYNC, defunc=True)
945
946  def benchmark_tf_reduce_logsumexp_GPU_defun(self):
947    self._benchmark_tf_reduce_logsumexp(device=GPU, defunc=True)
948
949  def benchmark_tf_reduce_logsumexp_GPU_async_defun(self):
950    self._benchmark_tf_reduce_logsumexp(
951        device=GPU, execution_mode=context.ASYNC, defunc=True)
952
953  def benchmark_tf_reduce_logsumexp_GPU_defun_compile(self):
954    self._benchmark_tf_reduce_logsumexp(
955        device=GPU, defunc=True, xla_compile=True)
956
957  def benchmark_tf_reduce_logsumexp_GPU_async_defun_compile(self):
958    self._benchmark_tf_reduce_logsumexp(
959        device=GPU, execution_mode=context.ASYNC, defunc=True, xla_compile=True)
960
961  def _benchmark_tf_tensordot(self, device=CPU, execution_mode=None):
962    with context.device(device):
963      a = array_ops.ones((2, 2))
964      b = array_ops.ones((2, 2))
965      func = lambda: math_ops.tensordot(a, b, [[1], [0]])
966      self._run(func, 30000, execution_mode=execution_mode)
967
968  def benchmark_tf_tensordot_CPU(self):
969    self._benchmark_tf_tensordot()
970
971  def benchmark_tf_tensordot_CPU_async(self):
972    self._benchmark_tf_tensordot(execution_mode=context.ASYNC)
973
974  def benchmark_tf_tensordot_GPU(self):
975    self._benchmark_tf_tensordot(device=GPU)
976
977  def benchmark_tf_tensordot_GPU_async(self):
978    self._benchmark_tf_tensordot(device=GPU, execution_mode=context.ASYNC)
979
980  def _benchmark_tf_zeros(self, shape, dtype, device=CPU):
981    with context.device(device):
982      func = lambda: array_ops.zeros(shape, dtype)
983      self._run(func, 3000)
984
985  def benchmark_tf_zeros_2_by_2_float32_CPU(self):
986    self._benchmark_tf_zeros((2, 2), dtypes.float32)
987
988  def benchmark_tf_zeros_2_by_2_bool_CPU(self):
989    self._benchmark_tf_zeros((2, 2), dtypes.bool)
990
991  def benchmark_tf_zeros_2_by_2_string_CPU(self):
992    self._benchmark_tf_zeros((2, 2), dtypes.string)
993
994  def benchmark_tf_zeros_2_by_2_float32_GPU(self):
995    self._benchmark_tf_zeros((2, 2), dtypes.float32, device=GPU)
996
997  def benchmark_tf_zeros_2_by_2_bool_GPU(self):
998    self._benchmark_tf_zeros((2, 2), dtypes.bool, device=GPU)
999
1000  def benchmark_tf_zeros_30_by_30_float32_CPU(self):
1001    self._benchmark_tf_zeros((30, 30), dtypes.float32)
1002
1003  def benchmark_tf_zeros_30_by_30_bool_CPU(self):
1004    self._benchmark_tf_zeros((30, 30), dtypes.bool)
1005
1006  def benchmark_tf_zeros_30_by_30_string_CPU(self):
1007    self._benchmark_tf_zeros((30, 30), dtypes.string)
1008
1009  def benchmark_tf_zeros_30_by_30_float32_GPU(self):
1010    self._benchmark_tf_zeros((30, 30), dtypes.float32, device=GPU)
1011
1012  def benchmark_tf_zeros_30_by_30_bool_GPU(self):
1013    self._benchmark_tf_zeros((30, 30), dtypes.bool, device=GPU)
1014
1015  def benchmark_tf_zeros_100_by_100_float32_CPU(self):
1016    self._benchmark_tf_zeros((100, 100), dtypes.float32)
1017
1018  def benchmark_tf_zeros_100_by_100_bool_CPU(self):
1019    self._benchmark_tf_zeros((100, 100), dtypes.bool)
1020
1021  def benchmark_tf_zeros_100_by_100_string_CPU(self):
1022    self._benchmark_tf_zeros((100, 100), dtypes.string)
1023
1024  def benchmark_tf_zeros_100_by_100_float32_GPU(self):
1025    self._benchmark_tf_zeros((100, 100), dtypes.float32, device=GPU)
1026
1027  def benchmark_tf_zeros_100_by_100_bool_GPU(self):
1028    self._benchmark_tf_zeros((100, 100), dtypes.bool, device=GPU)
1029
1030  def _benchmark_tf_zeros_like(self, m, device=CPU):
1031    with context.device(device):
1032      func = lambda: array_ops.zeros_like(m)
1033      self._run(func, 3000)
1034
1035  def benchmark_tf_zeros_like_CPU(self):
1036    self._benchmark_tf_zeros_like(self._m_2_by_2)
1037
1038  def benchmark_tf_zeros_like_GPU(self):
1039    self._benchmark_tf_zeros_like(self._m_2_by_2, device=GPU)
1040
1041  def benchmark_tf_zeros_like_variable_CPU(self):
1042    m = resource_variable_ops.ResourceVariable(self._m_2_by_2)
1043    self._benchmark_tf_zeros_like(m)
1044
1045  def benchmark_tf_zeros_like_variable_GPU(self):
1046    m = resource_variable_ops.ResourceVariable(self._m_2_by_2)
1047    self._benchmark_tf_zeros_like(m, device=GPU)
1048
1049  def _benchmark_tf_random_uniform_2_by_2(self,
1050                                          shape=(2, 2),
1051                                          dtype=dtypes.int32,
1052                                          device=CPU):
1053    with context.device(device):
1054
1055      def func():
1056        return random_ops.random_uniform(shape, maxval=3, dtype=dtype)
1057
1058      self._run(func, num_iters=self._num_iters_2_by_2)
1059
1060  def benchmark_tf_random_uniform_2_by_2_integer_CPU(self):
1061    self._benchmark_tf_random_uniform_2_by_2()
1062
1063  def benchmark_tf_random_uniform_2_by_2_integer_GPU(self):
1064    self._benchmark_tf_random_uniform_2_by_2(device=GPU)
1065
1066  def benchmark_tf_random_uniform_2_by_2_float_CPU(self):
1067    self._benchmark_tf_random_uniform_2_by_2(dtype=dtypes.float32)
1068
1069  def benchmark_tf_random_uniform_2_by_2_float_GPU(self):
1070    self._benchmark_tf_random_uniform_2_by_2(dtype=dtypes.float32, device=GPU)
1071
1072  def benchmark_tf_random_uniform_2_by_2_default_setting_CPU(self):
1073    with context.device(CPU):
1074      func = lambda: random_ops.random_uniform((2, 2))
1075      self._run(func, num_iters=self._num_iters_2_by_2)
1076
1077  def benchmark_tf_random_uniform_2_by_2_default_setting_GPU(self):
1078    with context.device(GPU):
1079      func = lambda: random_ops.random_uniform((2, 2))
1080      self._run(func, num_iters=self._num_iters_2_by_2)
1081
1082  def _benchmark_tf_dropout_2_by_2(self,
1083                                   rate=0.5,
1084                                   is_rate_tensor=True,
1085                                   noise_shape=None,
1086                                   device=CPU):
1087    if is_rate_tensor:
1088      rate = constant_op.constant(rate, dtype=dtypes.float32)
1089    with context.device(device):
1090
1091      def func():
1092        return nn_ops.dropout(
1093            self._m_2_by_2, rate=rate, noise_shape=noise_shape)
1094
1095      self._run(func, num_iters=self._num_iters_2_by_2)
1096
1097  def benchmark_tf_dropout_scalar_rate_2_by_2_CPU(self):
1098    self._benchmark_tf_dropout_2_by_2(is_rate_tensor=False)
1099
1100  def benchmark_tf_dropout_scalar_rate_2_by_2_GPU(self):
1101    self._benchmark_tf_dropout_2_by_2(is_rate_tensor=False, device=GPU)
1102
1103  def benchmark_tf_dropout_2_by_2_CPU(self):
1104    self._benchmark_tf_dropout_2_by_2()
1105
1106  def benchmark_tf_dropout_2_by_2_GPU(self):
1107    self._benchmark_tf_dropout_2_by_2(device=GPU)
1108
1109  def benchmark_tf_dropout_scalar_rate_2_by_2_CPU_rate_0(self):
1110    self._benchmark_tf_dropout_2_by_2(rate=0, is_rate_tensor=False)
1111
1112  def benchmark_tf_dropout_scalar_rate_2_by_2_GPU_rate_0(self):
1113    self._benchmark_tf_dropout_2_by_2(
1114        rate=0.0, is_rate_tensor=False, device=GPU)
1115
1116  def benchmark_tf_dropout_2_by_2_CPU_rate_0(self):
1117    self._benchmark_tf_dropout_2_by_2(rate=0.0)
1118
1119  def benchmark_tf_dropout_2_by_2_GPU_rate_0(self):
1120    self._benchmark_tf_dropout_2_by_2(rate=0, device=GPU)
1121
1122  def _benchmark_transpose(self,
1123                           m,
1124                           num_iters,
1125                           perm=None,
1126                           conjugate=False,
1127                           execution_mode=None):
1128    func = lambda: array_ops.transpose(m, perm, conjugate)
1129    self._run(func, num_iters, execution_mode=execution_mode)
1130
1131  def benchmark_tf_transpose_2_by_2_CPU(self):
1132    with context.device(CPU):
1133      m = self._m_2_by_2.cpu()
1134      self._benchmark_transpose(m, num_iters=self._num_iters_2_by_2)
1135
1136  def benchmark_tf_transpose_2_by_2_GPU(self):
1137    with context.device(GPU):
1138      m = self._m_2_by_2.gpu()
1139      self._benchmark_transpose(m, num_iters=self._num_iters_2_by_2)
1140
1141  def benchmark_tf_transpose_variable_2_by_2_CPU(self):
1142    with context.device(CPU):
1143      m = resource_variable_ops.ResourceVariable(self._m_2_by_2)
1144      self._benchmark_transpose(m, num_iters=self._num_iters_2_by_2)
1145
1146  def benchmark_tf_transpose_variable_2_by_2_GPU(self):
1147    with context.device(GPU):
1148      m = resource_variable_ops.ResourceVariable(self._m_2_by_2)
1149      self._benchmark_transpose(m, num_iters=self._num_iters_2_by_2)
1150
1151  def benchmark_defun_without_signature(self):
1152
1153    def func(t1, t2, t3, t4, t5, t6, t7, t8):
1154      del t1, t2, t3, t4, t5, t6, t7, t8
1155      return None
1156
1157    defined = function.defun(func)
1158    t = constant_op.constant(0.0)
1159    cache_computation = lambda: defined(t, t, t, t, t, t, t, t)
1160    self._run(cache_computation, 30000)
1161
1162  def benchmark_defun_without_signature_and_with_kwargs(self):
1163
1164    def func(t1, t2, t3, t4, t5, t6, t7, t8):
1165      del t1, t2, t3, t4, t5, t6, t7, t8
1166      return None
1167
1168    defined = function.defun(func)
1169    t = constant_op.constant(0.0)
1170
1171    def cache_computation():
1172      return defined(t1=t, t2=t, t3=t, t4=t, t5=t, t6=t, t7=t, t8=t)
1173
1174    self._run(cache_computation, 30000)
1175
1176  def benchmark_defun_with_signature(self):
1177
1178    def func(t1, t2, t3, t4, t5, t6, t7, t8):
1179      del t1, t2, t3, t4, t5, t6, t7, t8
1180      return None
1181
1182    defined = function.defun(
1183        func, input_signature=[tensor_spec.TensorSpec([], dtypes.float32)] * 8)
1184    t = constant_op.constant(0.0)
1185    signature_computation = lambda: defined(t, t, t, t, t, t, t, t)
1186    self._run(signature_computation, 30000)
1187
1188  def benchmark_defun_with_signature_and_kwargs(self):
1189
1190    def func(t1, t2, t3, t4, t5, t6, t7, t8):
1191      del t1, t2, t3, t4, t5, t6, t7, t8
1192      return None
1193
1194    defined = function.defun(
1195        func, input_signature=[tensor_spec.TensorSpec([], dtypes.float32)] * 8)
1196    t = constant_op.constant(0.0)
1197
1198    def signature_computation():
1199      return defined(t1=t, t2=t, t3=t, t4=t, t5=t, t6=t, t7=t, t8=t)
1200
1201    self._run(signature_computation, 30000)
1202
1203  def benchmark_matmul_read_variable_op_2_by_2_CPU(self):
1204    with context.device(CPU):
1205      m = resource_variable_ops.ResourceVariable(self._m_2_by_2)
1206      self._benchmark_matmul_read_variable(m, num_iters=self._num_iters_2_by_2)
1207
1208  def benchmark_matmul_read_variable_op_with_tape_2_by_2_CPU(self):
1209    with context.device(CPU):
1210      m = resource_variable_ops.ResourceVariable(self._m_2_by_2)
1211      self._benchmark_matmul_read_variable_with_tape(
1212          m, num_iters=self._num_iters_2_by_2)
1213
1214  def benchmark_read_variable_op_2_by_2_CPU(self):
1215    with context.device(CPU):
1216      m = resource_variable_ops.ResourceVariable(self._m_2_by_2)
1217      self._benchmark_read_variable(m, num_iters=self._num_iters_2_by_2)
1218
1219  def benchmark_read_variable_op_2_by_2_GPU(self):
1220    if not context.num_gpus():
1221      return
1222    with context.device(GPU):
1223      m = resource_variable_ops.ResourceVariable(self._m_2_by_2.gpu())
1224      self._benchmark_read_variable(m, num_iters=self._num_iters_2_by_2)
1225
1226  def benchmark_read_variable_op_with_tape_2_by_2_CPU(self):
1227    with context.device(CPU):
1228      m = resource_variable_ops.ResourceVariable(self._m_2_by_2)
1229      self._benchmark_read_variable_with_tape(
1230          m, num_iters=self._num_iters_2_by_2)
1231
1232  def benchmark_read_variable_op_with_tape_2_by_2_GPU(self):
1233    if not context.num_gpus():
1234      return
1235    with context.device(GPU):
1236      m = resource_variable_ops.ResourceVariable(self._m_2_by_2.gpu())
1237      self._benchmark_read_variable_with_tape(
1238          m, num_iters=self._num_iters_2_by_2)
1239
1240  def benchmarkScan(self):
1241    elems = math_ops.range(1600)
1242
1243    def scan():
1244      return functional_ops.scan(
1245          lambda a, x: a + x, elems, parallel_iterations=1)
1246
1247    self._run(scan, 100)
1248
1249  @test_util.disable_tfrt("tf.While not supported RTFB tensor. b/169374895")
1250  def benchmarkScanDefun(self):
1251    elems = math_ops.range(1600)
1252
1253    @function.defun
1254    def scan():
1255      return functional_ops.scan(
1256          lambda a, x: a + x, elems, parallel_iterations=1)
1257
1258    self._run(scan, 100)
1259
1260  def benchmark_fastpath_conversion_type_inference(self):
1261    c = constant_op.constant(1., dtype=dtypes.float32)
1262
1263    def fn():
1264      return gen_math_ops.add(c, 1)
1265
1266    self._run(fn, 10000)
1267
1268  def benchmark_convert_tensor(self):
1269    value = ops.convert_to_tensor(42)
1270
1271    def fn():
1272      return ops.convert_to_tensor(value)
1273
1274    self._run(fn, 10000)
1275
1276  def _benchmark_convert_constant(self, value, cached):
1277    global GLOBAL_TEST_VALUE
1278    GLOBAL_TEST_VALUE = value
1279
1280    def cached_func():
1281      ops.convert_to_tensor(value)
1282
1283    def uncached_func():
1284      global GLOBAL_TEST_VALUE
1285      GLOBAL_TEST_VALUE += 1
1286      ops.convert_to_tensor(GLOBAL_TEST_VALUE)
1287
1288    func = cached_func if cached else uncached_func
1289
1290    self._run(func, 10000)
1291
1292  def benchmark_convert_python_int(self):
1293    self._benchmark_convert_constant(42, cached=True)
1294
1295  def benchmark_convert_python_int_uncached(self):
1296    self._benchmark_convert_constant(42, cached=False)
1297
1298  def benchmark_convert_python_float(self):
1299    self._benchmark_convert_constant(42.0, cached=True)
1300
1301  def benchmark_convert_python_float_uncached(self):
1302    self._benchmark_convert_constant(42.0, cached=False)
1303
1304  def benchmark_convert_numpy_int(self):
1305    self._benchmark_convert_constant(np.array(42), cached=True)
1306
1307  def benchmark_convert_numpy_int_uncached(self):
1308    self._benchmark_convert_constant(np.array(42), cached=False)
1309
1310  def benchmark_convert_numpy_float(self):
1311    self._benchmark_convert_constant(np.array(42.0), cached=True)
1312
1313  def benchmark_convert_numpy_float_uncached(self):
1314    self._benchmark_convert_constant(np.array(42.0), cached=False)
1315
1316  def benchmark_convert_3x_list_to_tensor(self):
1317    xs = [1, 2, 3]
1318    self._run(lambda: ops.convert_to_tensor(xs), 1000)
1319
1320  def benchmark_convert_3x_array_to_tensor(self):
1321    xs = np.array([1, 2, 3], dtype=np.int32)
1322    self._run(lambda: ops.convert_to_tensor(xs), 1000)
1323
1324  def benchmark_constant_40x2_list_to_tensor(self):
1325    xs = [[0] * 2] * 40
1326    self._run(lambda: constant_op.constant(xs), 1000)
1327
1328  def benchmark_constant_40x2_array_to_tensor(self):
1329    xs = np.array([[0] * 2] * 40, dtype=np.int32)
1330    self._run(lambda: constant_op.constant(xs), 1000)
1331
1332  def benchmark_constant_40x_list_of_2x_arrays_to_tensor(self):
1333    xs = [np.array([0] * 2, dtype=np.int32)] * 40
1334    self._run(lambda: constant_op.constant(xs), 1000)
1335
1336  def benchmark_constant_20x20x20_double_list_to_float32_tensor(self):
1337    xs = [[[np.linspace(0, 1, 21).tolist()] * 20] * 20]
1338    self._run(lambda: constant_op.constant(xs, dtype=dtypes.float32), 10000)
1339
1340  def benchmark_constant_20x20x20_double_list_to_float64_tensor(self):
1341    xs = [[[np.linspace(0, 1, 21).tolist()] * 20] * 20]
1342    self._run(lambda: constant_op.constant(xs, dtype=dtypes.float64), 10000)
1343
1344  def benchmark_list_of_zeros_to_np_array(self):
1345    values = []
1346    for _ in range(1000):
1347      values.append(array_ops.zeros(shape=(1000,)))
1348    self._run(lambda: np.array([x.numpy() for x in values]), 1000)
1349
1350  def benchmark_function_trace(self):
1351
1352    def func(x):
1353      return x
1354
1355    self._run(lambda: (def_function.function(func)(x) for x in range(1000)),
1356              30000)
1357
1358  def _benchmarkFunctionWithResourceInputs(self, num_resources, num_iters):
1359
1360    @def_function.function
1361    def add_all(*args):
1362      return math_ops.add_n(*args)
1363
1364    with context.device(CPU):
1365      resources = []
1366      for _ in range(num_resources):
1367        resources.append(resource_variable_ops.ResourceVariable(self._m_2))
1368      self._run(lambda: add_all(resources), num_iters)
1369
1370  def benchmarkFunctionWithFiveResourceInputs(self):
1371    self._benchmarkFunctionWithResourceInputs(5, 1000)
1372
1373  def benchmarkFunctionWithFiveHundredResourceInputs(self):
1374    self._benchmarkFunctionWithResourceInputs(500, 100)
1375
1376  def _benchmarkResourceReadsInCondInInnerFunc(self, var_count):
1377    rvars = []
1378    for _ in range(var_count):
1379      rvars.append(resource_variable_ops.ResourceVariable(1.0))
1380
1381    # Note: We want to benchmark the graph building time so we intentionally
1382    # add this outer function so that the tf.function gets retraced every time.
1383    def benchmark_fn():
1384
1385      @def_function.function
1386      def fn_with_many_reads():
1387
1388        @def_function.function
1389        def fn_with_many_reads_inner():
1390
1391          def then_branch():
1392            return math_ops.add_n(rvars)
1393
1394          def else_branch():
1395            return 0.
1396
1397          return control_flow_ops.cond(
1398              constant_op.constant(True), then_branch, else_branch)
1399
1400        return fn_with_many_reads_inner()
1401
1402      return fn_with_many_reads()
1403
1404    with context.device(CPU):
1405      self._run(benchmark_fn, 10)
1406
1407  def benchmarkTenThousandResourceReadsInCondInInnerFunc(self):
1408    self._benchmarkResourceReadsInCondInInnerFunc(10000)
1409
1410  def benchmarkHundredResourceReadsInCondInInnerFunc(self):
1411    self._benchmarkResourceReadsInCondInInnerFunc(100)
1412
1413  def benchmarkTenResourceReadsInCondInInnerFunc(self):
1414    self._benchmarkResourceReadsInCondInInnerFunc(10)
1415
1416  def benchmark_tf_name_scope(self):
1417
1418    def fn():
1419      with ops.name_scope_v2("name"):
1420        pass
1421
1422    self._run(fn, 10000)
1423
1424  def benchmark_tf_nest_map_structure(self):
1425    nested = {"a": [1, 2, 3], "b": (4, 5, 6)}
1426
1427    def fn():
1428      nest.map_structure(lambda x: x, nested)
1429
1430    self._run(fn, 10000)
1431
1432  def benchmark_tf_nest_pack_sequence_as(self):
1433    nested = {"a": [1, 2, 3], "b": (4, 5, 6)}
1434    flat = nest.flatten(nested)
1435
1436    def fn():
1437      nest.pack_sequence_as(nested, flat)
1438
1439    self._run(fn, 10000)
1440
1441  def benchmark_tf_nest_flatten_none(self):
1442
1443    def fn():
1444      nest.flatten(None)
1445
1446    self._run(fn, 100000)
1447
1448  def benchmark_tf_nest_flatten(self):
1449    nested = {"a": [1, 2, 3], "b": (4, 5, 6)}
1450
1451    def fn():
1452      nest.flatten(nested)
1453
1454    self._run(fn, 100000)
1455
1456  def benchmark_tf_flatten_dict_items(self):
1457    nested = {(4, 5, (6, 8)): ("a", "b", ("c", "d"))}
1458
1459    def fn():
1460      nest.flatten_dict_items(nested)
1461
1462    self._run(fn, 100000)
1463
1464  def benchmark_tf_nn_convolution_overhead(self):
1465    inputs = array_ops.ones((1, 1, 1, 1))
1466    filters = array_ops.ones((1, 1, 1, 1))
1467
1468    def fn():
1469      nn_ops.convolution_v2(inputs, filters)
1470
1471    self._run(fn, 10000)
1472
1473  def benchmark_tf_tensor_shape_creation_overhead(self):
1474    # A `TensorShape` is created the first time `EagerTensor.shape` is
1475    # called, which puts `TensorShape.__init__` on the hotpath. The
1476    # `TensorShape` is created from `EagerTensor._shape_tuple`.
1477
1478    x = array_ops.ones((1, 1))
1479    shape_tuple = x._shape_tuple()
1480
1481    def fn():
1482      tensor_shape.TensorShape(shape_tuple)
1483
1484    self._run(fn, 100000)
1485
1486  def _boolean_mask_input(self):
1487    n = 3000
1488    return (array_ops.ones([n, n]), array_ops.fill([n, n], True))
1489
1490  def _boolean_mask_fn(self, input_tensor, mask):
1491    return array_ops.boolean_mask(input_tensor, mask)
1492
1493  def benchmark_tf_boolean_mask_eager(self):
1494    input_tensor, mask = self._boolean_mask_input()
1495
1496    self._run(lambda: self._boolean_mask_fn(input_tensor, mask), 10000)
1497
1498  def benchmark_tf_boolean_mask_graph(self):
1499    input_tensor, mask = self._boolean_mask_input()
1500    compiled_fn = def_function.function(self._boolean_mask_fn)
1501
1502    self._run(lambda: compiled_fn(input_tensor, mask), 10000)
1503
1504  def _benchmark_tf_range_var(self,
1505                              limit=100,
1506                              dtype=dtypes.int32,
1507                              range_dtype=dtypes.int32,
1508                              device=CPU,
1509                              num_iters=1000):
1510
1511    def func(v, lim):
1512      for _ in math_ops.range(lim, dtype=range_dtype):
1513        v.assign_add(constant_op.constant(1, dtype=dtype))
1514      return v
1515
1516    compiled_func = def_function.function(func)
1517
1518    with context.device(CPU):
1519      m = resource_variable_ops.ResourceVariable(
1520          constant_op.constant(1, dtype=dtype), dtype=dtype)
1521      limit_t = constant_op.constant(limit, dtype=dtype)
1522
1523    with context.device(device):
1524      compiled_func(m, limit_t)
1525      self._run(lambda: compiled_func(m, limit_t), num_iters=num_iters)
1526
1527  def benchmark_tf_range_var_int32_CPU(self):
1528    self._benchmark_tf_range_var()
1529
1530  def benchmark_tf_range_var_int64_CPU(self):
1531    self._benchmark_tf_range_var(dtype=dtypes.int64, range_dtype=dtypes.int64)
1532
1533  def benchmark_tf_range_var_int32_GPU(self):
1534    self._benchmark_tf_range_var(device=GPU)
1535
1536  def benchmark_tf_range_var_int64_GPU(self):
1537    self._benchmark_tf_range_var(
1538        dtype=dtypes.int64, range_dtype=dtypes.int64, device=GPU)
1539
1540  def _benchmark_tf_range_const(self,
1541                                limit=100,
1542                                dtype=dtypes.int32,
1543                                range_dtype=dtypes.int32,
1544                                device=CPU,
1545                                num_iters=1000):
1546
1547    def func(c, lim):
1548      for _ in math_ops.range(lim, dtype=range_dtype):
1549        c += 1
1550      return c
1551
1552    compiled_func = def_function.function(func)
1553
1554    with context.device(CPU):
1555      input_c = constant_op.constant(1, dtype=dtype)
1556      limit_t = constant_op.constant(limit, dtype=dtype)
1557
1558    with context.device(device):
1559      compiled_func(input_c, limit_t)
1560      self._run(lambda: compiled_func(input_c, limit_t), num_iters=num_iters)
1561
1562  # int32 constant, int32 range, CPU
1563  def benchmark_tf_range_const_int32_int32_CPU(self):
1564    self._benchmark_tf_range_const()
1565
1566  # int32 constant, int64 range, CPU
1567  def benchmark_tf_range_const_int32_int64_CPU(self):
1568    self._benchmark_tf_range_const(range_dtype=dtypes.int64)
1569
1570  # int64 constant, int32 range, CPU
1571  def benchmark_tf_range_const_int64_int32_CPU(self):
1572    self._benchmark_tf_range_const(dtype=dtypes.int64)
1573
1574  # int64 constant, int64 range, CPU
1575  def benchmark_tf_range_const_int64_int64_CPU(self):
1576    self._benchmark_tf_range_const(dtype=dtypes.int64, range_dtype=dtypes.int64)
1577
1578  # int32 constant, int32 range, GPU
1579  def benchmark_tf_range_const_int32_int32_GPU(self):
1580    self._benchmark_tf_range_const(device=GPU)
1581
1582  # int32 constant, int64 range, GPU
1583  def benchmark_tf_range_const_int32_int64_GPU(self):
1584    self._benchmark_tf_range_const(range_dtype=dtypes.int64, device=GPU)
1585
1586  # int64 constant, int32 range, GPU
1587  def benchmark_tf_range_const_int64_int32_GPU(self):
1588    self._benchmark_tf_range_const(dtype=dtypes.int64, device=GPU)
1589
1590  # int64 constant, int64 range, GPU
1591  def benchmark_tf_range_const_int64_int64_GPU(self):
1592    self._benchmark_tf_range_const(
1593        dtype=dtypes.int64, range_dtype=dtypes.int64, device=GPU)
1594
1595  def _benchmark_tf_range_return(self,
1596                                 limit=100000,
1597                                 dtype=dtypes.int32,
1598                                 device=CPU,
1599                                 num_iters=100000):
1600
1601    def func(lim):
1602      return math_ops.range(lim, dtype=dtype)
1603
1604    compiled_func = def_function.function(func)
1605
1606    with context.device(device):
1607      limit_t = constant_op.constant(limit, dtype=dtype)
1608      compiled_func(limit_t)
1609      self._run(lambda: compiled_func(limit_t), num_iters=num_iters)
1610
1611  def benchmark_tf_range_return_int32_CPU(self):
1612    self._benchmark_tf_range_return()
1613
1614  def benchmark_tf_range_return_int64_CPU(self):
1615    self._benchmark_tf_range_return(dtype=dtypes.int64)
1616
1617  def benchmark_tf_range_return_int32_GPU(self):
1618    self._benchmark_tf_range_return(device=GPU)
1619
1620  def benchmark_tf_range_return_int64_GPU(self):
1621    self._benchmark_tf_range_return(dtype=dtypes.int64, device=GPU)
1622
1623if __name__ == "__main__":
1624  test.main()
1625