xref: /aosp_15_r20/external/tensorflow/tensorflow/python/compiler/xla/jit_compile_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16from tensorflow.python.client import session
17from tensorflow.python.eager import backprop
18from tensorflow.python.eager import def_function
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import errors
21from tensorflow.python.framework import ops
22from tensorflow.python.ops import array_ops
23from tensorflow.python.ops import string_ops
24from tensorflow.python.platform import test
25
26
27class JitCompileTest(test.TestCase):
28
29  def testBasic(self):
30    with ops.Graph().as_default() as g:
31
32      def fn(x, a):
33        return x + a
34
35      xla_func = def_function.function(fn, jit_compile=True)
36      inputs = array_ops.placeholder(dtypes.float32, [5])
37      x = xla_func(inputs, 1)
38      with session.Session(graph=g) as sess:
39        y = sess.run(x, feed_dict={inputs: [1, 2, 2, 3, 3]})
40        self.assertTrue(x.graph.as_graph_def().library.function[0]
41                        .attr["_XlaMustCompile"].b)
42        self.assertAllClose([2, 3, 3, 4, 4], y)
43
44  def testDerivative(self):
45    def fn(x, a):
46      return 2 * x + a
47
48    with ops.Graph().as_default() as g:
49      xla_func = def_function.function(fn, jit_compile=True)
50      with backprop.GradientTape() as tape:
51        inputs = array_ops.placeholder(dtypes.float32, [5])
52        tape.watch(inputs)
53        outputs = xla_func(inputs, 1)
54      grads = tape.gradient(outputs, inputs)
55
56    with session.Session(graph=g) as sess:
57      grads_tensor = sess.run(grads, feed_dict={inputs: [1, 2, 2, 3, 3]})
58      self.assertAllClose([2, 2, 2, 2, 2], grads_tensor)
59      (forward, backward) = xla_func.get_concrete_function(
60          inputs, 1)._delayed_rewrite_functions.forward_backward()
61
62      # Check that the must-compile attribute gets correctly propagated to the
63      # created derivatives.
64      self.assertTrue(forward.definition.attr["_XlaMustCompile"])
65      self.assertTrue(backward.function_def.attr["_XlaMustCompile"])
66
67  def testBasicInt32(self):
68    with ops.Graph().as_default() as g:
69
70      def fn(x, a):
71        return x + a
72
73      xla_func = def_function.function(fn, jit_compile=True)
74      inputs = array_ops.placeholder(dtypes.int32, [5])
75      x = xla_func(inputs, 1)
76      with session.Session(graph=g) as sess:
77        y = sess.run(x, feed_dict={inputs: [1, 2, 2, 3, 3]})
78        self.assertTrue(x.graph.as_graph_def().library.function[0]
79                        .attr["_XlaMustCompile"].b)
80        self.assertAllClose([2, 3, 3, 4, 4], y)
81
82  # Checking that we crash on an unsupported operation lets us test that the XLA
83  # compiler was actually invoked.
84  def testUnsupportedOps(self):
85    with ops.Graph().as_default() as g:
86
87      def fn(x):
88        return string_ops.string_length(
89            string_ops.string_format('{}', x))
90
91      xla_func = def_function.function(fn, jit_compile=True)
92      inputs = array_ops.placeholder(dtypes.float32, [5])
93      x = xla_func(inputs)
94      with self.assertRaisesRegex(errors.InvalidArgumentError,
95                                  "Detected unsupported operations"):
96        with session.Session(graph=g) as sess:
97          sess.run(x, feed_dict={inputs: [1, 2, 2, 3, 3]})
98
99
100if __name__ == "__main__":
101  test.main()
102