xref: /aosp_15_r20/external/emboss/compiler/back_end/util/code_template_test.py (revision 99e0aae7469b87d12f0ad23e61142c2d74c1ef70)
1# Copyright 2019 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Tests for code_template."""
16
17import string
18import unittest
19from compiler.back_end.util import code_template
20
21def _format_template_str(template: str, **kwargs) -> str:
22  return code_template.format_template(string.Template(template), **kwargs)
23
24class FormatTest(unittest.TestCase):
25  """Tests for code_template.format."""
26
27  def test_no_replacement_fields(self):
28    self.assertEqual("foo", _format_template_str("foo"))
29    self.assertEqual("{foo}", _format_template_str("{foo}"))
30    self.assertEqual("${foo}", _format_template_str("$${foo}"))
31
32  def test_one_replacement_field(self):
33    self.assertEqual("foo", _format_template_str("${bar}", bar="foo"))
34    self.assertEqual("bazfoo",
35                     _format_template_str("baz${bar}", bar="foo"))
36    self.assertEqual("foobaz",
37                     _format_template_str("${bar}baz", bar="foo"))
38    self.assertEqual("bazfooqux",
39                     _format_template_str("baz${bar}qux", bar="foo"))
40
41  def test_one_replacement_field_with_formatting(self):
42    # Basic string.Templates don't support formatting values.
43    self.assertRaises(ValueError,
44                     _format_template_str, "${bar:.6f}", bar=1)
45
46  def test_one_replacement_field_value_missing(self):
47    self.assertRaises(KeyError, _format_template_str, "${bar}")
48
49  def test_multiple_replacement_fields(self):
50    self.assertEqual(" aaa  bbb   ",
51                     _format_template_str(" ${bar}  ${baz}   ",
52                                                   bar="aaa",
53                                                   baz="bbb"))
54
55
56class ParseTemplatesTest(unittest.TestCase):
57  """Tests for code_template.parse_templates."""
58
59  def assertTemplatesEqual(self, expected, actual): # pylint:disable=invalid-name
60    """Compares the results of a parse_templates"""
61    # Extract the name and template from the result tuple
62    actual = {
63        k: v.template for k, v in actual._asdict().items()
64    }
65    self.assertEqual(expected, actual)
66
67  def test_handles_no_template_case(self):
68    self.assertTemplatesEqual({}, code_template.parse_templates(""))
69    self.assertTemplatesEqual({}, code_template.parse_templates(
70        "this is not a template"))
71
72  def test_handles_one_template_at_start(self):
73    self.assertTemplatesEqual({"foo": "bar"},
74                     code_template.parse_templates("** foo **\nbar"))
75
76  def test_handles_one_template_after_start(self):
77    self.assertTemplatesEqual(
78        {"foo": "bar"},
79        code_template.parse_templates("text\n** foo **\nbar"))
80
81  def test_handles_delimiter_with_other_text(self):
82    self.assertTemplatesEqual(
83        {"foo": "bar"},
84        code_template.parse_templates("text\n// ** foo ** ////\nbar"))
85    self.assertTemplatesEqual(
86        {"foo": "bar"},
87        code_template.parse_templates("text\n# ** foo ** #####\nbar"))
88
89  def test_handles_multiple_delimiters(self):
90    self.assertTemplatesEqual({"foo": "bar",
91                      "baz": "qux"}, code_template.parse_templates(
92                          "** foo **\nbar\n** baz **\nqux"))
93
94  def test_returns_object_with_attributes(self):
95    self.assertEqual("bar", code_template.parse_templates(
96        "** foo **\nbar\n** baz **\nqux").foo.template)
97
98if __name__ == "__main__":
99  unittest.main()
100