1# Description: 2# Implementation of Keras benchmarks. 3 4load("//tensorflow:tensorflow.bzl", "cuda_py_test") 5 6package( 7 default_visibility = ["//visibility:public"], 8 licenses = ["notice"], 9) 10 11filegroup( 12 name = "all_py_srcs", 13 srcs = glob(["*.py"]), 14 visibility = ["//tensorflow/python/keras/google/private_tf_api_test:__pkg__"], 15) 16 17# To run CPU benchmarks: 18# bazel run -c opt benchmarks_test -- --benchmarks=. 19 20# To run GPU benchmarks: 21# bazel run --config=cuda -c opt --copt="-mavx" benchmarks_test -- \ 22# --benchmarks=. 23 24# To run a subset of benchmarks using --benchmarks flag. 25# --benchmarks: the list of benchmarks to run. The specified value is interpreted 26# as a regular expression and any benchmark whose name contains a partial match 27# to the regular expression is executed. 28# e.g. --benchmarks=".*lstm*." will run all lstm layer related benchmarks. 29 30COMMON_TAGS = [ 31 "no_pip", # b/161253163 32 "no_windows", # b/160628318 33] 34 35cuda_py_test( 36 name = "bidirectional_lstm_benchmark_test", 37 srcs = ["bidirectional_lstm_benchmark_test.py"], 38 python_version = "PY3", 39 tags = COMMON_TAGS, 40 deps = [ 41 "//tensorflow:tensorflow_py_no_contrib", 42 "//tensorflow/python/keras/benchmarks:benchmark_util", 43 "//tensorflow/python/keras/benchmarks:profiler_lib", 44 ], 45) 46 47cuda_py_test( 48 name = "text_classification_transformer_benchmark_test", 49 srcs = ["text_classification_transformer_benchmark_test.py"], 50 python_version = "PY3", 51 tags = COMMON_TAGS, 52 deps = [ 53 "//tensorflow:tensorflow_py_no_contrib", 54 "//tensorflow/python/keras/benchmarks:benchmark_util", 55 "//tensorflow/python/keras/benchmarks:profiler_lib", 56 ], 57) 58 59cuda_py_test( 60 name = "antirectifier_benchmark_test", 61 srcs = ["antirectifier_benchmark_test.py"], 62 python_version = "PY3", 63 tags = COMMON_TAGS, 64 deps = [ 65 "//tensorflow:tensorflow_py_no_contrib", 66 "//tensorflow/python/keras/benchmarks:benchmark_util", 67 "//tensorflow/python/keras/benchmarks:profiler_lib", 68 ], 69) 70 71cuda_py_test( 72 name = "mnist_conv_benchmark_test", 73 srcs = ["mnist_conv_benchmark_test.py"], 74 python_version = "PY3", 75 tags = COMMON_TAGS, 76 deps = [ 77 "//tensorflow:tensorflow_py_no_contrib", 78 "//tensorflow/python/keras/benchmarks:benchmark_util", 79 "//tensorflow/python/keras/benchmarks:profiler_lib", 80 "//third_party/py/numpy", 81 ], 82) 83 84cuda_py_test( 85 name = "mnist_hierarchical_rnn_benchmark_test", 86 srcs = ["mnist_hierarchical_rnn_benchmark_test.py"], 87 python_version = "PY3", 88 tags = COMMON_TAGS, 89 deps = [ 90 "//tensorflow:tensorflow_py_no_contrib", 91 "//tensorflow/python/keras/benchmarks:benchmark_util", 92 "//tensorflow/python/keras/benchmarks:profiler_lib", 93 ], 94) 95 96cuda_py_test( 97 name = "mnist_irnn_benchmark_test", 98 srcs = ["mnist_irnn_benchmark_test.py"], 99 python_version = "PY3", 100 tags = COMMON_TAGS, 101 deps = [ 102 "//tensorflow:tensorflow_py_no_contrib", 103 "//tensorflow/python/keras/benchmarks:benchmark_util", 104 "//tensorflow/python/keras/benchmarks:profiler_lib", 105 ], 106) 107 108cuda_py_test( 109 name = "reuters_mlp_benchmark_test", 110 srcs = ["reuters_mlp_benchmark_test.py"], 111 python_version = "PY3", 112 tags = COMMON_TAGS, 113 deps = [ 114 "//tensorflow:tensorflow_py_no_contrib", 115 "//tensorflow/python/keras/benchmarks:benchmark_util", 116 "//tensorflow/python/keras/benchmarks:profiler_lib", 117 "//third_party/py/numpy", 118 ], 119) 120 121cuda_py_test( 122 name = "cifar10_cnn_benchmark_test", 123 srcs = ["cifar10_cnn_benchmark_test.py"], 124 python_version = "PY3", 125 tags = COMMON_TAGS, 126 deps = [ 127 "//tensorflow:tensorflow_py_no_contrib", 128 "//tensorflow/python/keras/benchmarks:benchmark_util", 129 "//tensorflow/python/keras/benchmarks:profiler_lib", 130 ], 131) 132 133cuda_py_test( 134 name = "mnist_conv_custom_training_benchmark_test", 135 srcs = ["mnist_conv_custom_training_benchmark_test.py"], 136 python_version = "PY3", 137 tags = COMMON_TAGS, 138 deps = [ 139 "//tensorflow:tensorflow_py_no_contrib", 140 "//tensorflow/python/keras/benchmarks:benchmark_util", 141 "//tensorflow/python/keras/benchmarks:distribution_util", 142 "//tensorflow/python/keras/benchmarks:profiler_lib", 143 ], 144) 145