xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/benchmarks/keras_examples_benchmarks/BUILD (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Description:
2#   Implementation of Keras benchmarks.
3
4load("//tensorflow:tensorflow.bzl", "cuda_py_test")
5
6package(
7    default_visibility = ["//visibility:public"],
8    licenses = ["notice"],
9)
10
11filegroup(
12    name = "all_py_srcs",
13    srcs = glob(["*.py"]),
14    visibility = ["//tensorflow/python/keras/google/private_tf_api_test:__pkg__"],
15)
16
17# To run CPU benchmarks:
18#   bazel run -c opt benchmarks_test -- --benchmarks=.
19
20# To run GPU benchmarks:
21#   bazel run --config=cuda -c opt --copt="-mavx" benchmarks_test -- \
22#     --benchmarks=.
23
24# To run a subset of benchmarks using --benchmarks flag.
25# --benchmarks: the list of benchmarks to run. The specified value is interpreted
26# as a regular expression and any benchmark whose name contains a partial match
27# to the regular expression is executed.
28# e.g. --benchmarks=".*lstm*." will run all lstm layer related benchmarks.
29
30COMMON_TAGS = [
31    "no_pip",  # b/161253163
32    "no_windows",  # b/160628318
33]
34
35cuda_py_test(
36    name = "bidirectional_lstm_benchmark_test",
37    srcs = ["bidirectional_lstm_benchmark_test.py"],
38    python_version = "PY3",
39    tags = COMMON_TAGS,
40    deps = [
41        "//tensorflow:tensorflow_py_no_contrib",
42        "//tensorflow/python/keras/benchmarks:benchmark_util",
43        "//tensorflow/python/keras/benchmarks:profiler_lib",
44    ],
45)
46
47cuda_py_test(
48    name = "text_classification_transformer_benchmark_test",
49    srcs = ["text_classification_transformer_benchmark_test.py"],
50    python_version = "PY3",
51    tags = COMMON_TAGS,
52    deps = [
53        "//tensorflow:tensorflow_py_no_contrib",
54        "//tensorflow/python/keras/benchmarks:benchmark_util",
55        "//tensorflow/python/keras/benchmarks:profiler_lib",
56    ],
57)
58
59cuda_py_test(
60    name = "antirectifier_benchmark_test",
61    srcs = ["antirectifier_benchmark_test.py"],
62    python_version = "PY3",
63    tags = COMMON_TAGS,
64    deps = [
65        "//tensorflow:tensorflow_py_no_contrib",
66        "//tensorflow/python/keras/benchmarks:benchmark_util",
67        "//tensorflow/python/keras/benchmarks:profiler_lib",
68    ],
69)
70
71cuda_py_test(
72    name = "mnist_conv_benchmark_test",
73    srcs = ["mnist_conv_benchmark_test.py"],
74    python_version = "PY3",
75    tags = COMMON_TAGS,
76    deps = [
77        "//tensorflow:tensorflow_py_no_contrib",
78        "//tensorflow/python/keras/benchmarks:benchmark_util",
79        "//tensorflow/python/keras/benchmarks:profiler_lib",
80        "//third_party/py/numpy",
81    ],
82)
83
84cuda_py_test(
85    name = "mnist_hierarchical_rnn_benchmark_test",
86    srcs = ["mnist_hierarchical_rnn_benchmark_test.py"],
87    python_version = "PY3",
88    tags = COMMON_TAGS,
89    deps = [
90        "//tensorflow:tensorflow_py_no_contrib",
91        "//tensorflow/python/keras/benchmarks:benchmark_util",
92        "//tensorflow/python/keras/benchmarks:profiler_lib",
93    ],
94)
95
96cuda_py_test(
97    name = "mnist_irnn_benchmark_test",
98    srcs = ["mnist_irnn_benchmark_test.py"],
99    python_version = "PY3",
100    tags = COMMON_TAGS,
101    deps = [
102        "//tensorflow:tensorflow_py_no_contrib",
103        "//tensorflow/python/keras/benchmarks:benchmark_util",
104        "//tensorflow/python/keras/benchmarks:profiler_lib",
105    ],
106)
107
108cuda_py_test(
109    name = "reuters_mlp_benchmark_test",
110    srcs = ["reuters_mlp_benchmark_test.py"],
111    python_version = "PY3",
112    tags = COMMON_TAGS,
113    deps = [
114        "//tensorflow:tensorflow_py_no_contrib",
115        "//tensorflow/python/keras/benchmarks:benchmark_util",
116        "//tensorflow/python/keras/benchmarks:profiler_lib",
117        "//third_party/py/numpy",
118    ],
119)
120
121cuda_py_test(
122    name = "cifar10_cnn_benchmark_test",
123    srcs = ["cifar10_cnn_benchmark_test.py"],
124    python_version = "PY3",
125    tags = COMMON_TAGS,
126    deps = [
127        "//tensorflow:tensorflow_py_no_contrib",
128        "//tensorflow/python/keras/benchmarks:benchmark_util",
129        "//tensorflow/python/keras/benchmarks:profiler_lib",
130    ],
131)
132
133cuda_py_test(
134    name = "mnist_conv_custom_training_benchmark_test",
135    srcs = ["mnist_conv_custom_training_benchmark_test.py"],
136    python_version = "PY3",
137    tags = COMMON_TAGS,
138    deps = [
139        "//tensorflow:tensorflow_py_no_contrib",
140        "//tensorflow/python/keras/benchmarks:benchmark_util",
141        "//tensorflow/python/keras/benchmarks:distribution_util",
142        "//tensorflow/python/keras/benchmarks:profiler_lib",
143    ],
144)
145