1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16from tensorflow.python.client import session 17from tensorflow.python.eager import backprop 18from tensorflow.python.eager import def_function 19from tensorflow.python.framework import dtypes 20from tensorflow.python.framework import errors 21from tensorflow.python.framework import ops 22from tensorflow.python.ops import array_ops 23from tensorflow.python.ops import string_ops 24from tensorflow.python.platform import test 25 26 27class JitCompileTest(test.TestCase): 28 29 def testBasic(self): 30 with ops.Graph().as_default() as g: 31 32 def fn(x, a): 33 return x + a 34 35 xla_func = def_function.function(fn, jit_compile=True) 36 inputs = array_ops.placeholder(dtypes.float32, [5]) 37 x = xla_func(inputs, 1) 38 with session.Session(graph=g) as sess: 39 y = sess.run(x, feed_dict={inputs: [1, 2, 2, 3, 3]}) 40 self.assertTrue(x.graph.as_graph_def().library.function[0] 41 .attr["_XlaMustCompile"].b) 42 self.assertAllClose([2, 3, 3, 4, 4], y) 43 44 def testDerivative(self): 45 def fn(x, a): 46 return 2 * x + a 47 48 with ops.Graph().as_default() as g: 49 xla_func = def_function.function(fn, jit_compile=True) 50 with backprop.GradientTape() as tape: 51 inputs = array_ops.placeholder(dtypes.float32, [5]) 52 tape.watch(inputs) 53 outputs = xla_func(inputs, 1) 54 grads = tape.gradient(outputs, inputs) 55 56 with session.Session(graph=g) as sess: 57 grads_tensor = sess.run(grads, feed_dict={inputs: [1, 2, 2, 3, 3]}) 58 self.assertAllClose([2, 2, 2, 2, 2], grads_tensor) 59 (forward, backward) = xla_func.get_concrete_function( 60 inputs, 1)._delayed_rewrite_functions.forward_backward() 61 62 # Check that the must-compile attribute gets correctly propagated to the 63 # created derivatives. 64 self.assertTrue(forward.definition.attr["_XlaMustCompile"]) 65 self.assertTrue(backward.function_def.attr["_XlaMustCompile"]) 66 67 def testBasicInt32(self): 68 with ops.Graph().as_default() as g: 69 70 def fn(x, a): 71 return x + a 72 73 xla_func = def_function.function(fn, jit_compile=True) 74 inputs = array_ops.placeholder(dtypes.int32, [5]) 75 x = xla_func(inputs, 1) 76 with session.Session(graph=g) as sess: 77 y = sess.run(x, feed_dict={inputs: [1, 2, 2, 3, 3]}) 78 self.assertTrue(x.graph.as_graph_def().library.function[0] 79 .attr["_XlaMustCompile"].b) 80 self.assertAllClose([2, 3, 3, 4, 4], y) 81 82 # Checking that we crash on an unsupported operation lets us test that the XLA 83 # compiler was actually invoked. 84 def testUnsupportedOps(self): 85 with ops.Graph().as_default() as g: 86 87 def fn(x): 88 return string_ops.string_length( 89 string_ops.string_format('{}', x)) 90 91 xla_func = def_function.function(fn, jit_compile=True) 92 inputs = array_ops.placeholder(dtypes.float32, [5]) 93 x = xla_func(inputs) 94 with self.assertRaisesRegex(errors.InvalidArgumentError, 95 "Detected unsupported operations"): 96 with session.Session(graph=g) as sess: 97 sess.run(x, feed_dict={inputs: [1, 2, 2, 3, 3]}) 98 99 100if __name__ == "__main__": 101 test.main() 102