1# Copyright 2019 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Tests for code_template.""" 16 17import string 18import unittest 19from compiler.back_end.util import code_template 20 21def _format_template_str(template: str, **kwargs) -> str: 22 return code_template.format_template(string.Template(template), **kwargs) 23 24class FormatTest(unittest.TestCase): 25 """Tests for code_template.format.""" 26 27 def test_no_replacement_fields(self): 28 self.assertEqual("foo", _format_template_str("foo")) 29 self.assertEqual("{foo}", _format_template_str("{foo}")) 30 self.assertEqual("${foo}", _format_template_str("$${foo}")) 31 32 def test_one_replacement_field(self): 33 self.assertEqual("foo", _format_template_str("${bar}", bar="foo")) 34 self.assertEqual("bazfoo", 35 _format_template_str("baz${bar}", bar="foo")) 36 self.assertEqual("foobaz", 37 _format_template_str("${bar}baz", bar="foo")) 38 self.assertEqual("bazfooqux", 39 _format_template_str("baz${bar}qux", bar="foo")) 40 41 def test_one_replacement_field_with_formatting(self): 42 # Basic string.Templates don't support formatting values. 43 self.assertRaises(ValueError, 44 _format_template_str, "${bar:.6f}", bar=1) 45 46 def test_one_replacement_field_value_missing(self): 47 self.assertRaises(KeyError, _format_template_str, "${bar}") 48 49 def test_multiple_replacement_fields(self): 50 self.assertEqual(" aaa bbb ", 51 _format_template_str(" ${bar} ${baz} ", 52 bar="aaa", 53 baz="bbb")) 54 55 56class ParseTemplatesTest(unittest.TestCase): 57 """Tests for code_template.parse_templates.""" 58 59 def assertTemplatesEqual(self, expected, actual): # pylint:disable=invalid-name 60 """Compares the results of a parse_templates""" 61 # Extract the name and template from the result tuple 62 actual = { 63 k: v.template for k, v in actual._asdict().items() 64 } 65 self.assertEqual(expected, actual) 66 67 def test_handles_no_template_case(self): 68 self.assertTemplatesEqual({}, code_template.parse_templates("")) 69 self.assertTemplatesEqual({}, code_template.parse_templates( 70 "this is not a template")) 71 72 def test_handles_one_template_at_start(self): 73 self.assertTemplatesEqual({"foo": "bar"}, 74 code_template.parse_templates("** foo **\nbar")) 75 76 def test_handles_one_template_after_start(self): 77 self.assertTemplatesEqual( 78 {"foo": "bar"}, 79 code_template.parse_templates("text\n** foo **\nbar")) 80 81 def test_handles_delimiter_with_other_text(self): 82 self.assertTemplatesEqual( 83 {"foo": "bar"}, 84 code_template.parse_templates("text\n// ** foo ** ////\nbar")) 85 self.assertTemplatesEqual( 86 {"foo": "bar"}, 87 code_template.parse_templates("text\n# ** foo ** #####\nbar")) 88 89 def test_handles_multiple_delimiters(self): 90 self.assertTemplatesEqual({"foo": "bar", 91 "baz": "qux"}, code_template.parse_templates( 92 "** foo **\nbar\n** baz **\nqux")) 93 94 def test_returns_object_with_attributes(self): 95 self.assertEqual("bar", code_template.parse_templates( 96 "** foo **\nbar\n** baz **\nqux").foo.template) 97 98if __name__ == "__main__": 99 unittest.main() 100