xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/jit/tests/BUILD (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
2load("//tensorflow:tensorflow.bzl", "tf_cc_test")
3
4licenses(["notice"])
5
6package(default_visibility = ["//visibility:private"])
7
8cc_library(
9    name = "auto_clustering_test_helper",
10    testonly = True,
11    srcs = ["auto_clustering_test_helper.cc"],
12    hdrs = ["auto_clustering_test_helper.h"],
13    visibility = ["//visibility:public"],
14    deps = [
15        "//tensorflow/compiler/jit:compilation_passes",
16        "//tensorflow/compiler/jit:jit_compilation_passes",
17        "//tensorflow/compiler/jit:xla_cluster_util",
18        "//tensorflow/compiler/jit:xla_cpu_jit",
19        "//tensorflow/compiler/jit:xla_gpu_jit",
20        "//tensorflow/compiler/xla:status_macros",
21        "//tensorflow/compiler/xla:statusor",
22        "//tensorflow/core:core_cpu",
23        "//tensorflow/core:framework",
24        "//tensorflow/core:graph",
25        "//tensorflow/core:lib",
26        "//tensorflow/core:lib_internal",
27        "//tensorflow/core:test",
28        "//tensorflow/core:test_main",
29        "//tensorflow/tools/optimization:optimization_pass_runner_lib",
30        "@com_google_absl//absl/strings",
31    ],
32)
33
34tf_cc_test(
35    name = "auto_clustering_test",
36    srcs = ["auto_clustering_test.cc"],
37    data = [
38        "keras_imagenet_main.golden_summary",
39        "keras_imagenet_main.pbtxt",
40        "keras_imagenet_main_graph_mode.golden_summary",
41        "keras_imagenet_main_graph_mode.pbtxt",
42        "opens2s_gnmt_mixed_precision.golden_summary",
43        "opens2s_gnmt_mixed_precision.pbtxt.gz",
44    ],
45    tags = ["config-cuda-only"],
46    deps = [
47        ":auto_clustering_test_helper",
48        "//tensorflow/core:test",
49        "@com_google_absl//absl/strings",
50    ],
51)
52
53cc_library(
54    name = "xla_compilation_cache_test_helper",
55    testonly = True,
56    srcs = ["xla_compilation_cache_test_helper.cc"],
57    hdrs = ["xla_compilation_cache_test_helper.h"],
58    deps = [
59        "//tensorflow/compiler/jit:xla_activity_listener",
60        "//tensorflow/compiler/jit:xla_compilation_cache_proto_cc",
61        "//tensorflow/compiler/jit:xla_cpu_jit",
62        "//tensorflow/compiler/jit:xla_gpu_device",
63        "//tensorflow/compiler/jit:xla_gpu_jit",
64        "//tensorflow/compiler/xla/service:hlo_proto_cc",
65        "//tensorflow/core:all_kernels",
66        "//tensorflow/core:core_cpu",
67        "//tensorflow/core:core_cpu_internal",
68        "//tensorflow/core:direct_session",
69        "//tensorflow/core:framework_internal",
70        "//tensorflow/core:lib",
71        "//tensorflow/core:lib_internal",
72        "//tensorflow/core:ops",
73        "//tensorflow/core:test",
74        "//tensorflow/core:test_main",
75        "//tensorflow/core/platform:path",
76        "@com_google_absl//absl/strings",
77    ],
78)
79
80tf_cc_test(
81    name = "xla_compilation_cache_serialize_test",
82    srcs = [
83        "xla_compilation_cache_serialize_test.cc",
84    ],
85    env = {
86        "XLA_FLAGS": "--xla_gpu_jitrt_executable",
87    },
88    tags = [
89        "config-cuda-only",
90        "no_oss",  # This test only runs with GPU.
91        "requires-gpu-nvidia",
92        "xla",
93    ],
94    deps = [
95        ":xla_compilation_cache_test_helper",
96        "//tensorflow/compiler/jit:compilation_passes",
97        "//tensorflow/compiler/jit:flags",
98        "//tensorflow/core:test",
99    ],
100)
101
102tf_cc_test(
103    name = "xla_compilation_cache_serialize_options_test",
104    srcs = [
105        "xla_compilation_cache_serialize_options_test.cc",
106    ],
107    env = {
108        "XLA_FLAGS": "--xla_gpu_jitrt_executable",
109    },
110    tags = [
111        "config-cuda-only",
112        "no_oss",  # This test only runs with GPU.
113        "requires-gpu-nvidia",
114        "xla",
115    ],
116    deps = [
117        ":xla_compilation_cache_test_helper",
118        "//tensorflow/compiler/jit:compilation_passes",
119        "//tensorflow/compiler/jit:flags",
120        "//tensorflow/core:test",
121    ],
122)
123