xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/signal/util_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility ops shared across tf.contrib.signal."""
16
17import fractions  # gcd is here for Python versions < 3
18import math  # Get gcd here for Python versions >= 3
19import sys
20
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import tensor_util
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import control_flow_ops
25from tensorflow.python.ops import math_ops
26
27
28def gcd(a, b, name=None):
29  """Returns the greatest common divisor via Euclid's algorithm.
30
31  Args:
32    a: The dividend. A scalar integer `Tensor`.
33    b: The divisor. A scalar integer `Tensor`.
34    name: An optional name for the operation.
35
36  Returns:
37    A scalar `Tensor` representing the greatest common divisor between `a` and
38    `b`.
39
40  Raises:
41    ValueError: If `a` or `b` are not scalar integers.
42  """
43  with ops.name_scope(name, 'gcd', [a, b]):
44    a = ops.convert_to_tensor(a)
45    b = ops.convert_to_tensor(b)
46
47    a.shape.assert_has_rank(0)
48    b.shape.assert_has_rank(0)
49
50    if not a.dtype.is_integer:
51      raise ValueError('a must be an integer type. Got: %s' % a.dtype)
52    if not b.dtype.is_integer:
53      raise ValueError('b must be an integer type. Got: %s' % b.dtype)
54
55    # TPU requires static shape inference. GCD is used for subframe size
56    # computation, so we should prefer static computation where possible.
57    const_a = tensor_util.constant_value(a)
58    const_b = tensor_util.constant_value(b)
59    if const_a is not None and const_b is not None:
60      if sys.version_info.major < 3:
61        math_gcd = fractions.gcd
62      else:
63        math_gcd = math.gcd
64      return ops.convert_to_tensor(math_gcd(const_a, const_b))
65
66    cond = lambda _, b: math_ops.greater(b, array_ops.zeros_like(b))
67    body = lambda a, b: [b, math_ops.mod(a, b)]
68    a, b = control_flow_ops.while_loop(cond, body, [a, b], back_prop=False)
69    return a
70