1#!/usr/bin/env python3
2
3# Copyright 2023 Google LLC
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     https://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17import argparse
18from dataclasses import dataclass, field
19import json
20from pathlib import Path
21import sys
22from textwrap import dedent
23from typing import List, Tuple, Union, Optional
24
25from pdl import ast, core
26from pdl.utils import indent, to_pascal_case
27
28
29def get_cxx_scalar_type(width: int) -> str:
30    """Return the cxx scalar type to be used to back a PDL type."""
31    for n in [8, 16, 32, 64]:
32        if width <= n:
33            return f'uint{n}_t'
34    # PDL type does not fit on non-extended scalar types.
35    assert False
36
37
38def generate_packet_parser_test(parser_test_suite: str, packet: ast.PacketDeclaration, tests: List[object]) -> str:
39    """Generate the implementation of unit tests for the selected packet."""
40
41    def parse_packet(packet: ast.PacketDeclaration) -> str:
42        parent = parse_packet(packet.parent) if packet.parent else "input"
43        return f"{packet.id}View::Create({parent})"
44
45    def input_bytes(input: str) -> List[str]:
46        input = bytes.fromhex(input)
47        input_bytes = []
48        for i in range(0, len(input), 16):
49            input_bytes.append(' '.join(f'0x{b:x},' for b in input[i:i + 16]))
50        return input_bytes
51
52    def get_field(decl: ast.Declaration, var: str, id: str) -> str:
53        if isinstance(decl, ast.StructDeclaration):
54            return f"{var}.{id}_"
55        else:
56            return f"{var}.Get{to_pascal_case(id)}()"
57
58    def check_members(decl: ast.Declaration, var: str, expected: object) -> List[str]:
59        checks = []
60        for (id, value) in expected.items():
61            field = core.get_packet_field(decl, id)
62            sanitized_var = var.replace('[', '_').replace(']', '')
63            field_var = f'{sanitized_var}_{id}'
64
65            if isinstance(field, ast.ScalarField) and field.cond:
66                value = f"std::make_optional({value})" if value is not None else "std::nullopt"
67                checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, {value});")
68
69            elif isinstance(field, ast.ScalarField):
70                checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, {value});")
71
72            elif (isinstance(field, ast.TypedefField) and
73                  isinstance(field.type, ast.EnumDeclaration) and
74                  field.cond):
75                value = f"std::make_optional({field.type_id}({value}))" if value is not None else "std::nullopt"
76                checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, {value});")
77
78            elif (isinstance(field, ast.TypedefField) and
79                  isinstance(field.type, (ast.EnumDeclaration, ast.CustomFieldDeclaration, ast.ChecksumDeclaration))):
80                checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, {field.type_id}({value}));")
81
82            elif isinstance(field, ast.TypedefField) and field.cond and value is None:
83                checks.append(f"ASSERT_TRUE(!{get_field(decl, var, id)}.has_value());")
84
85            elif isinstance(field, ast.TypedefField) and field.cond and value is not None:
86                checks.append(f"{field.type_id} const& {field_var} = {get_field(decl, var, id)}.value();")
87                checks.extend(check_members(field.type, field_var, value))
88
89            elif isinstance(field, ast.TypedefField):
90                checks.append(f"{field.type_id} const& {field_var} = {get_field(decl, var, id)};")
91                checks.extend(check_members(field.type, field_var, value))
92
93            elif isinstance(field, (ast.PayloadField, ast.BodyField)):
94                checks.append(f"std::vector<uint8_t> expected_{field_var} {{")
95                for i in range(0, len(value), 16):
96                    checks.append('    ' + ' '.join([f"0x{v:x}," for v in value[i:i + 16]]))
97                checks.append("};")
98                checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});")
99
100            elif isinstance(field, ast.ArrayField) and field.size and field.width:
101                checks.append(f"std::array<{get_cxx_scalar_type(field.width)}, {field.size}> expected_{field_var} {{")
102                step = int(16 * 8 / field.width)
103                for i in range(0, len(value), step):
104                    checks.append('    ' + ' '.join([f"0x{v:x}," for v in value[i:i + step]]))
105                checks.append("};")
106                checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});")
107
108            elif isinstance(field, ast.ArrayField) and field.width:
109                checks.append(f"std::vector<{get_cxx_scalar_type(field.width)}> expected_{field_var} {{")
110                step = int(16 * 8 / field.width)
111                for i in range(0, len(value), step):
112                    checks.append('    ' + ' '.join([f"0x{v:x}," for v in value[i:i + step]]))
113                checks.append("};")
114                checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});")
115
116            elif (isinstance(field, ast.ArrayField) and field.size and isinstance(field.type, ast.EnumDeclaration)):
117                checks.append(f"std::array<{field.type_id}, {field.size}> expected_{field_var} {{")
118                for v in value:
119                    checks.append(f"    {field.type_id}({v}),")
120                checks.append("};")
121                checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});")
122
123            elif (isinstance(field, ast.ArrayField) and isinstance(field.type, ast.EnumDeclaration)):
124                checks.append(f"std::vector<{field.type_id}> expected_{field_var} {{")
125                for v in value:
126                    checks.append(f"    {field.type_id}({v}),")
127                checks.append("};")
128                checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});")
129
130            elif isinstance(field, ast.ArrayField) and field.size:
131                checks.append(f"std::array<{field.type_id}, {field.size}> {field_var} = {get_field(decl, var, id)};")
132                checks.append(f"ASSERT_EQ({field_var}.size(), {len(value)});")
133                for (n, value) in enumerate(value):
134                    checks.extend(check_members(field.type, f"{field_var}[{n}]", value))
135
136            elif isinstance(field, ast.ArrayField):
137                checks.append(f"std::vector<{field.type_id}> {field_var} = {get_field(decl, var, id)};")
138                checks.append(f"ASSERT_EQ({field_var}.size(), {len(value)});")
139                for (n, value) in enumerate(value):
140                    checks.extend(check_members(field.type, f"{field_var}[{n}]", value))
141
142            else:
143                pass
144
145        return checks
146
147    generated_tests = []
148    for (test_nr, test) in enumerate(tests):
149        child_packet_id = test.get('packet', packet.id)
150        child_packet = packet.file.packet_scope[child_packet_id]
151
152        generated_tests.append(
153            dedent("""\
154
155            TEST_F({parser_test_suite}, {packet_id}_Case{test_nr}) {{
156                pdl::packet::slice input(std::shared_ptr<std::vector<uint8_t>>(new std::vector<uint8_t> {{
157                    {input_bytes}
158                }}));
159                {child_packet_id}View packet = {parse_packet};
160                ASSERT_TRUE(packet.IsValid());
161                {checks}
162            }}
163            """).format(parser_test_suite=parser_test_suite,
164                        packet_id=packet.id,
165                        child_packet_id=child_packet_id,
166                        test_nr=test_nr,
167                        input_bytes=indent(input_bytes(test['packed']), 2),
168                        parse_packet=parse_packet(child_packet),
169                        checks=indent(check_members(packet, 'packet', test['unpacked']), 1)))
170
171    return ''.join(generated_tests)
172
173
174def generate_packet_serializer_test(serializer_test_suite: str, packet: ast.PacketDeclaration,
175                                    tests: List[object]) -> str:
176    """Generate the implementation of unit tests for the selected packet."""
177
178    def build_packet(decl: ast.Declaration, var: str, initializer: object) -> (str, List[str]):
179        fields = core.get_unconstrained_parent_fields(decl) + decl.fields
180        declarations = []
181        parameters = []
182        for field in fields:
183            sanitized_var = var.replace('[', '_').replace(']', '')
184            field_id = getattr(field, 'id', None)
185            field_var = f'{sanitized_var}_{field_id}'
186            value = initializer['payload'] if isinstance(field, (ast.PayloadField,
187                                                                 ast.BodyField)) else initializer.get(field_id, None)
188
189            if field.cond_for:
190                pass
191
192            elif field.cond and value is None:
193                parameters.append("std::nullopt")
194
195            elif isinstance(field, ast.ScalarField) and field.cond:
196                parameters.append(f"std::make_optional({value})")
197
198            elif isinstance(field, ast.ScalarField):
199                parameters.append(f"{value}")
200
201            elif isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration) and field.cond:
202                parameters.append(f"std::make_optional({field.type_id}({value}))")
203
204            elif isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration):
205                parameters.append(f"{field.type_id}({value})")
206
207            elif isinstance(field, ast.TypedefField):
208                (element, intermediate_declarations) = build_packet(field.type, field_var, value)
209                declarations.extend(intermediate_declarations)
210                parameters.append(element)
211
212            elif isinstance(field, (ast.PayloadField, ast.BodyField)):
213                declarations.append(f"std::vector<uint8_t> {field_var} {{")
214                for i in range(0, len(value), 16):
215                    declarations.append('    ' + ' '.join([f"0x{v:x}," for v in value[i:i + 16]]))
216                declarations.append("};")
217                parameters.append(f"std::move({field_var})")
218
219            elif isinstance(field, ast.ArrayField) and field.size and field.width:
220                declarations.append(f"std::array<{get_cxx_scalar_type(field.width)}, {field.size}> {field_var} {{")
221                step = int(16 * 8 / field.width)
222                for i in range(0, len(value), step):
223                    declarations.append('    ' + ' '.join([f"0x{v:x}," for v in value[i:i + step]]))
224                declarations.append("};")
225                parameters.append(f"std::move({field_var})")
226
227            elif isinstance(field, ast.ArrayField) and field.width:
228                declarations.append(f"std::vector<{get_cxx_scalar_type(field.width)}> {field_var} {{")
229                step = int(16 * 8 / field.width)
230                for i in range(0, len(value), step):
231                    declarations.append('    ' + ' '.join([f"0x{v:x}," for v in value[i:i + step]]))
232                declarations.append("};")
233                parameters.append(f"std::move({field_var})")
234
235            elif isinstance(field, ast.ArrayField) and field.size and isinstance(field.type, ast.EnumDeclaration):
236                declarations.append(f"std::array<{field.type_id}, {field.size}> {field_var} {{")
237                for v in value:
238                    declarations.append(f"    {field.type_id}({v}),")
239                declarations.append("};")
240                parameters.append(f"std::move({field_var})")
241
242            elif isinstance(field, ast.ArrayField) and isinstance(field.type, ast.EnumDeclaration):
243                declarations.append(f"std::vector<{field.type_id}> {field_var} {{")
244                for v in value:
245                    declarations.append(f"    {field.type_id}({v}),")
246                declarations.append("};")
247                parameters.append(f"std::move({field_var})")
248
249            elif isinstance(field, ast.ArrayField) and field.size:
250                elements = []
251                for (n, value) in enumerate(value):
252                    (element, intermediate_declarations) = build_packet(field.type, f'{field_var}_{n}', value)
253                    elements.append(element)
254                    declarations.extend(intermediate_declarations)
255                declarations.append(f"std::array<{field.type_id}, {field.size}> {field_var} {{")
256                for element in elements:
257                    declarations.append(f"    {element},")
258                declarations.append("};")
259                parameters.append(f"std::move({field_var})")
260
261            elif isinstance(field, ast.ArrayField):
262                elements = []
263                for (n, value) in enumerate(value):
264                    (element, intermediate_declarations) = build_packet(field.type, f'{field_var}_{n}', value)
265                    elements.append(element)
266                    declarations.extend(intermediate_declarations)
267                declarations.append(f"std::vector<{field.type_id}> {field_var} {{")
268                for element in elements:
269                    declarations.append(f"    {element},")
270                declarations.append("};")
271                parameters.append(f"std::move({field_var})")
272
273            else:
274                pass
275
276        constructor_name = f'{decl.id}Builder' if isinstance(decl, ast.PacketDeclaration) else decl.id
277        return (f"{constructor_name}({', '.join(parameters)})", declarations)
278
279    def output_bytes(output: str) -> List[str]:
280        output = bytes.fromhex(output)
281        output_bytes = []
282        for i in range(0, len(output), 16):
283            output_bytes.append(' '.join(f'0x{b:x},' for b in output[i:i + 16]))
284        return output_bytes
285
286    generated_tests = []
287    for (test_nr, test) in enumerate(tests):
288        child_packet_id = test.get('packet', packet.id)
289        child_packet = packet.file.packet_scope[child_packet_id]
290
291        (built_packet, intermediate_declarations) = build_packet(child_packet, 'packet', test['unpacked'])
292        generated_tests.append(
293            dedent("""\
294
295            TEST_F({serializer_test_suite}, {packet_id}_Case{test_nr}) {{
296                std::vector<uint8_t> expected_output {{
297                    {output_bytes}
298                }};
299                {intermediate_declarations}
300                {child_packet_id}Builder packet = {built_packet};
301                ASSERT_EQ(packet.SerializeToBytes(), expected_output);
302            }}
303            """).format(serializer_test_suite=serializer_test_suite,
304                        packet_id=packet.id,
305                        child_packet_id=child_packet_id,
306                        test_nr=test_nr,
307                        output_bytes=indent(output_bytes(test['packed']), 2),
308                        built_packet=built_packet,
309                        intermediate_declarations=indent(intermediate_declarations, 1)))
310
311    return ''.join(generated_tests)
312
313
314def run(input: argparse.FileType, output: argparse.FileType, test_vectors: argparse.FileType, include_header: List[str],
315        using_namespace: List[str], namespace: str, parser_test_suite: str, serializer_test_suite: str):
316
317    file = ast.File.from_json(json.load(input))
318    tests = json.load(test_vectors)
319    core.desugar(file)
320
321    include_header = '\n'.join([f'#include <{header}>' for header in include_header])
322    using_namespace = '\n'.join([f'using namespace {namespace};' for namespace in using_namespace])
323
324    skipped_tests = [
325        'Packet_Checksum_Field_FromStart',
326        'Packet_Checksum_Field_FromEnd',
327        'Struct_Checksum_Field_FromStart',
328        'Struct_Checksum_Field_FromEnd',
329        'PartialParent5',
330        'PartialParent12',
331        'Packet_Array_Field_VariableElementSize_ConstantSize',
332        'Packet_Array_Field_VariableElementSize_VariableSize',
333        'Packet_Array_Field_VariableElementSize_VariableCount',
334        'Packet_Array_Field_VariableElementSize_UnknownSize',
335    ]
336
337    output.write(
338        dedent("""\
339        // File generated from {input_name} and {test_vectors_name}, with the command:
340        //  {input_command}
341        // /!\\ Do not edit by hand
342
343        #include <cstdint>
344        #include <string>
345        #include <gtest/gtest.h>
346        #include <packet_runtime.h>
347
348        {include_header}
349        {using_namespace}
350
351        namespace {namespace} {{
352
353        class {parser_test_suite} : public testing::Test {{}};
354        class {serializer_test_suite} : public testing::Test {{}};
355        """).format(parser_test_suite=parser_test_suite,
356                    serializer_test_suite=serializer_test_suite,
357                    input_name=input.name,
358                    input_command=' '.join(sys.argv),
359                    test_vectors_name=test_vectors.name,
360                    include_header=include_header,
361                    using_namespace=using_namespace,
362                    namespace=namespace))
363
364    for decl in file.declarations:
365        if decl.id in skipped_tests:
366            continue
367
368        if isinstance(decl, ast.PacketDeclaration):
369            matching_tests = [test['tests'] for test in tests if test['packet'] == decl.id]
370            matching_tests = [test for test_list in matching_tests for test in test_list]
371            if matching_tests:
372                output.write(generate_packet_parser_test(parser_test_suite, decl, matching_tests))
373                output.write(generate_packet_serializer_test(serializer_test_suite, decl, matching_tests))
374
375    output.write(f"}}  // namespace {namespace}\n")
376
377
378def main() -> int:
379    """Generate cxx PDL backend."""
380    parser = argparse.ArgumentParser(description=__doc__)
381    parser.add_argument('--input', type=argparse.FileType('r'), default=sys.stdin, help='Input PDL-JSON source')
382    parser.add_argument('--output', type=argparse.FileType('w'), default=sys.stdout, help='Output C++ file')
383    parser.add_argument('--test-vectors', type=argparse.FileType('r'), required=True, help='Input PDL test file')
384    parser.add_argument('--namespace', type=str, default='pdl', help='Namespace of the generated file')
385    parser.add_argument('--parser-test-suite', type=str, default='ParserTest', help='Name of the parser test suite')
386    parser.add_argument('--serializer-test-suite',
387                        type=str,
388                        default='SerializerTest',
389                        help='Name of the serializer test suite')
390    parser.add_argument('--include-header', type=str, default=[], action='append', help='Added include directives')
391    parser.add_argument('--using-namespace',
392                        type=str,
393                        default=[],
394                        action='append',
395                        help='Added using namespace statements')
396    return run(**vars(parser.parse_args()))
397
398
399if __name__ == '__main__':
400    sys.exit(main())
401