1#!/usr/bin/env python3 2 3# Copyright 2023 Google LLC 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# https://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17import argparse 18from dataclasses import dataclass, field 19import json 20from pathlib import Path 21import sys 22from textwrap import dedent 23from typing import List, Tuple, Union, Optional 24 25from pdl import ast, core 26from pdl.utils import indent, to_pascal_case 27 28 29def get_cxx_scalar_type(width: int) -> str: 30 """Return the cxx scalar type to be used to back a PDL type.""" 31 for n in [8, 16, 32, 64]: 32 if width <= n: 33 return f'uint{n}_t' 34 # PDL type does not fit on non-extended scalar types. 35 assert False 36 37 38def generate_packet_parser_test(parser_test_suite: str, packet: ast.PacketDeclaration, tests: List[object]) -> str: 39 """Generate the implementation of unit tests for the selected packet.""" 40 41 def parse_packet(packet: ast.PacketDeclaration) -> str: 42 parent = parse_packet(packet.parent) if packet.parent else "input" 43 return f"{packet.id}View::Create({parent})" 44 45 def input_bytes(input: str) -> List[str]: 46 input = bytes.fromhex(input) 47 input_bytes = [] 48 for i in range(0, len(input), 16): 49 input_bytes.append(' '.join(f'0x{b:x},' for b in input[i:i + 16])) 50 return input_bytes 51 52 def get_field(decl: ast.Declaration, var: str, id: str) -> str: 53 if isinstance(decl, ast.StructDeclaration): 54 return f"{var}.{id}_" 55 else: 56 return f"{var}.Get{to_pascal_case(id)}()" 57 58 def check_members(decl: ast.Declaration, var: str, expected: object) -> List[str]: 59 checks = [] 60 for (id, value) in expected.items(): 61 field = core.get_packet_field(decl, id) 62 sanitized_var = var.replace('[', '_').replace(']', '') 63 field_var = f'{sanitized_var}_{id}' 64 65 if isinstance(field, ast.ScalarField) and field.cond: 66 value = f"std::make_optional({value})" if value is not None else "std::nullopt" 67 checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, {value});") 68 69 elif isinstance(field, ast.ScalarField): 70 checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, {value});") 71 72 elif (isinstance(field, ast.TypedefField) and 73 isinstance(field.type, ast.EnumDeclaration) and 74 field.cond): 75 value = f"std::make_optional({field.type_id}({value}))" if value is not None else "std::nullopt" 76 checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, {value});") 77 78 elif (isinstance(field, ast.TypedefField) and 79 isinstance(field.type, (ast.EnumDeclaration, ast.CustomFieldDeclaration, ast.ChecksumDeclaration))): 80 checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, {field.type_id}({value}));") 81 82 elif isinstance(field, ast.TypedefField) and field.cond and value is None: 83 checks.append(f"ASSERT_TRUE(!{get_field(decl, var, id)}.has_value());") 84 85 elif isinstance(field, ast.TypedefField) and field.cond and value is not None: 86 checks.append(f"{field.type_id} const& {field_var} = {get_field(decl, var, id)}.value();") 87 checks.extend(check_members(field.type, field_var, value)) 88 89 elif isinstance(field, ast.TypedefField): 90 checks.append(f"{field.type_id} const& {field_var} = {get_field(decl, var, id)};") 91 checks.extend(check_members(field.type, field_var, value)) 92 93 elif isinstance(field, (ast.PayloadField, ast.BodyField)): 94 checks.append(f"std::vector<uint8_t> expected_{field_var} {{") 95 for i in range(0, len(value), 16): 96 checks.append(' ' + ' '.join([f"0x{v:x}," for v in value[i:i + 16]])) 97 checks.append("};") 98 checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});") 99 100 elif isinstance(field, ast.ArrayField) and field.size and field.width: 101 checks.append(f"std::array<{get_cxx_scalar_type(field.width)}, {field.size}> expected_{field_var} {{") 102 step = int(16 * 8 / field.width) 103 for i in range(0, len(value), step): 104 checks.append(' ' + ' '.join([f"0x{v:x}," for v in value[i:i + step]])) 105 checks.append("};") 106 checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});") 107 108 elif isinstance(field, ast.ArrayField) and field.width: 109 checks.append(f"std::vector<{get_cxx_scalar_type(field.width)}> expected_{field_var} {{") 110 step = int(16 * 8 / field.width) 111 for i in range(0, len(value), step): 112 checks.append(' ' + ' '.join([f"0x{v:x}," for v in value[i:i + step]])) 113 checks.append("};") 114 checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});") 115 116 elif (isinstance(field, ast.ArrayField) and field.size and isinstance(field.type, ast.EnumDeclaration)): 117 checks.append(f"std::array<{field.type_id}, {field.size}> expected_{field_var} {{") 118 for v in value: 119 checks.append(f" {field.type_id}({v}),") 120 checks.append("};") 121 checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});") 122 123 elif (isinstance(field, ast.ArrayField) and isinstance(field.type, ast.EnumDeclaration)): 124 checks.append(f"std::vector<{field.type_id}> expected_{field_var} {{") 125 for v in value: 126 checks.append(f" {field.type_id}({v}),") 127 checks.append("};") 128 checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});") 129 130 elif isinstance(field, ast.ArrayField) and field.size: 131 checks.append(f"std::array<{field.type_id}, {field.size}> {field_var} = {get_field(decl, var, id)};") 132 checks.append(f"ASSERT_EQ({field_var}.size(), {len(value)});") 133 for (n, value) in enumerate(value): 134 checks.extend(check_members(field.type, f"{field_var}[{n}]", value)) 135 136 elif isinstance(field, ast.ArrayField): 137 checks.append(f"std::vector<{field.type_id}> {field_var} = {get_field(decl, var, id)};") 138 checks.append(f"ASSERT_EQ({field_var}.size(), {len(value)});") 139 for (n, value) in enumerate(value): 140 checks.extend(check_members(field.type, f"{field_var}[{n}]", value)) 141 142 else: 143 pass 144 145 return checks 146 147 generated_tests = [] 148 for (test_nr, test) in enumerate(tests): 149 child_packet_id = test.get('packet', packet.id) 150 child_packet = packet.file.packet_scope[child_packet_id] 151 152 generated_tests.append( 153 dedent("""\ 154 155 TEST_F({parser_test_suite}, {packet_id}_Case{test_nr}) {{ 156 pdl::packet::slice input(std::shared_ptr<std::vector<uint8_t>>(new std::vector<uint8_t> {{ 157 {input_bytes} 158 }})); 159 {child_packet_id}View packet = {parse_packet}; 160 ASSERT_TRUE(packet.IsValid()); 161 {checks} 162 }} 163 """).format(parser_test_suite=parser_test_suite, 164 packet_id=packet.id, 165 child_packet_id=child_packet_id, 166 test_nr=test_nr, 167 input_bytes=indent(input_bytes(test['packed']), 2), 168 parse_packet=parse_packet(child_packet), 169 checks=indent(check_members(packet, 'packet', test['unpacked']), 1))) 170 171 return ''.join(generated_tests) 172 173 174def generate_packet_serializer_test(serializer_test_suite: str, packet: ast.PacketDeclaration, 175 tests: List[object]) -> str: 176 """Generate the implementation of unit tests for the selected packet.""" 177 178 def build_packet(decl: ast.Declaration, var: str, initializer: object) -> (str, List[str]): 179 fields = core.get_unconstrained_parent_fields(decl) + decl.fields 180 declarations = [] 181 parameters = [] 182 for field in fields: 183 sanitized_var = var.replace('[', '_').replace(']', '') 184 field_id = getattr(field, 'id', None) 185 field_var = f'{sanitized_var}_{field_id}' 186 value = initializer['payload'] if isinstance(field, (ast.PayloadField, 187 ast.BodyField)) else initializer.get(field_id, None) 188 189 if field.cond_for: 190 pass 191 192 elif field.cond and value is None: 193 parameters.append("std::nullopt") 194 195 elif isinstance(field, ast.ScalarField) and field.cond: 196 parameters.append(f"std::make_optional({value})") 197 198 elif isinstance(field, ast.ScalarField): 199 parameters.append(f"{value}") 200 201 elif isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration) and field.cond: 202 parameters.append(f"std::make_optional({field.type_id}({value}))") 203 204 elif isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration): 205 parameters.append(f"{field.type_id}({value})") 206 207 elif isinstance(field, ast.TypedefField): 208 (element, intermediate_declarations) = build_packet(field.type, field_var, value) 209 declarations.extend(intermediate_declarations) 210 parameters.append(element) 211 212 elif isinstance(field, (ast.PayloadField, ast.BodyField)): 213 declarations.append(f"std::vector<uint8_t> {field_var} {{") 214 for i in range(0, len(value), 16): 215 declarations.append(' ' + ' '.join([f"0x{v:x}," for v in value[i:i + 16]])) 216 declarations.append("};") 217 parameters.append(f"std::move({field_var})") 218 219 elif isinstance(field, ast.ArrayField) and field.size and field.width: 220 declarations.append(f"std::array<{get_cxx_scalar_type(field.width)}, {field.size}> {field_var} {{") 221 step = int(16 * 8 / field.width) 222 for i in range(0, len(value), step): 223 declarations.append(' ' + ' '.join([f"0x{v:x}," for v in value[i:i + step]])) 224 declarations.append("};") 225 parameters.append(f"std::move({field_var})") 226 227 elif isinstance(field, ast.ArrayField) and field.width: 228 declarations.append(f"std::vector<{get_cxx_scalar_type(field.width)}> {field_var} {{") 229 step = int(16 * 8 / field.width) 230 for i in range(0, len(value), step): 231 declarations.append(' ' + ' '.join([f"0x{v:x}," for v in value[i:i + step]])) 232 declarations.append("};") 233 parameters.append(f"std::move({field_var})") 234 235 elif isinstance(field, ast.ArrayField) and field.size and isinstance(field.type, ast.EnumDeclaration): 236 declarations.append(f"std::array<{field.type_id}, {field.size}> {field_var} {{") 237 for v in value: 238 declarations.append(f" {field.type_id}({v}),") 239 declarations.append("};") 240 parameters.append(f"std::move({field_var})") 241 242 elif isinstance(field, ast.ArrayField) and isinstance(field.type, ast.EnumDeclaration): 243 declarations.append(f"std::vector<{field.type_id}> {field_var} {{") 244 for v in value: 245 declarations.append(f" {field.type_id}({v}),") 246 declarations.append("};") 247 parameters.append(f"std::move({field_var})") 248 249 elif isinstance(field, ast.ArrayField) and field.size: 250 elements = [] 251 for (n, value) in enumerate(value): 252 (element, intermediate_declarations) = build_packet(field.type, f'{field_var}_{n}', value) 253 elements.append(element) 254 declarations.extend(intermediate_declarations) 255 declarations.append(f"std::array<{field.type_id}, {field.size}> {field_var} {{") 256 for element in elements: 257 declarations.append(f" {element},") 258 declarations.append("};") 259 parameters.append(f"std::move({field_var})") 260 261 elif isinstance(field, ast.ArrayField): 262 elements = [] 263 for (n, value) in enumerate(value): 264 (element, intermediate_declarations) = build_packet(field.type, f'{field_var}_{n}', value) 265 elements.append(element) 266 declarations.extend(intermediate_declarations) 267 declarations.append(f"std::vector<{field.type_id}> {field_var} {{") 268 for element in elements: 269 declarations.append(f" {element},") 270 declarations.append("};") 271 parameters.append(f"std::move({field_var})") 272 273 else: 274 pass 275 276 constructor_name = f'{decl.id}Builder' if isinstance(decl, ast.PacketDeclaration) else decl.id 277 return (f"{constructor_name}({', '.join(parameters)})", declarations) 278 279 def output_bytes(output: str) -> List[str]: 280 output = bytes.fromhex(output) 281 output_bytes = [] 282 for i in range(0, len(output), 16): 283 output_bytes.append(' '.join(f'0x{b:x},' for b in output[i:i + 16])) 284 return output_bytes 285 286 generated_tests = [] 287 for (test_nr, test) in enumerate(tests): 288 child_packet_id = test.get('packet', packet.id) 289 child_packet = packet.file.packet_scope[child_packet_id] 290 291 (built_packet, intermediate_declarations) = build_packet(child_packet, 'packet', test['unpacked']) 292 generated_tests.append( 293 dedent("""\ 294 295 TEST_F({serializer_test_suite}, {packet_id}_Case{test_nr}) {{ 296 std::vector<uint8_t> expected_output {{ 297 {output_bytes} 298 }}; 299 {intermediate_declarations} 300 {child_packet_id}Builder packet = {built_packet}; 301 ASSERT_EQ(packet.SerializeToBytes(), expected_output); 302 }} 303 """).format(serializer_test_suite=serializer_test_suite, 304 packet_id=packet.id, 305 child_packet_id=child_packet_id, 306 test_nr=test_nr, 307 output_bytes=indent(output_bytes(test['packed']), 2), 308 built_packet=built_packet, 309 intermediate_declarations=indent(intermediate_declarations, 1))) 310 311 return ''.join(generated_tests) 312 313 314def run(input: argparse.FileType, output: argparse.FileType, test_vectors: argparse.FileType, include_header: List[str], 315 using_namespace: List[str], namespace: str, parser_test_suite: str, serializer_test_suite: str): 316 317 file = ast.File.from_json(json.load(input)) 318 tests = json.load(test_vectors) 319 core.desugar(file) 320 321 include_header = '\n'.join([f'#include <{header}>' for header in include_header]) 322 using_namespace = '\n'.join([f'using namespace {namespace};' for namespace in using_namespace]) 323 324 skipped_tests = [ 325 'Packet_Checksum_Field_FromStart', 326 'Packet_Checksum_Field_FromEnd', 327 'Struct_Checksum_Field_FromStart', 328 'Struct_Checksum_Field_FromEnd', 329 'PartialParent5', 330 'PartialParent12', 331 'Packet_Array_Field_VariableElementSize_ConstantSize', 332 'Packet_Array_Field_VariableElementSize_VariableSize', 333 'Packet_Array_Field_VariableElementSize_VariableCount', 334 'Packet_Array_Field_VariableElementSize_UnknownSize', 335 ] 336 337 output.write( 338 dedent("""\ 339 // File generated from {input_name} and {test_vectors_name}, with the command: 340 // {input_command} 341 // /!\\ Do not edit by hand 342 343 #include <cstdint> 344 #include <string> 345 #include <gtest/gtest.h> 346 #include <packet_runtime.h> 347 348 {include_header} 349 {using_namespace} 350 351 namespace {namespace} {{ 352 353 class {parser_test_suite} : public testing::Test {{}}; 354 class {serializer_test_suite} : public testing::Test {{}}; 355 """).format(parser_test_suite=parser_test_suite, 356 serializer_test_suite=serializer_test_suite, 357 input_name=input.name, 358 input_command=' '.join(sys.argv), 359 test_vectors_name=test_vectors.name, 360 include_header=include_header, 361 using_namespace=using_namespace, 362 namespace=namespace)) 363 364 for decl in file.declarations: 365 if decl.id in skipped_tests: 366 continue 367 368 if isinstance(decl, ast.PacketDeclaration): 369 matching_tests = [test['tests'] for test in tests if test['packet'] == decl.id] 370 matching_tests = [test for test_list in matching_tests for test in test_list] 371 if matching_tests: 372 output.write(generate_packet_parser_test(parser_test_suite, decl, matching_tests)) 373 output.write(generate_packet_serializer_test(serializer_test_suite, decl, matching_tests)) 374 375 output.write(f"}} // namespace {namespace}\n") 376 377 378def main() -> int: 379 """Generate cxx PDL backend.""" 380 parser = argparse.ArgumentParser(description=__doc__) 381 parser.add_argument('--input', type=argparse.FileType('r'), default=sys.stdin, help='Input PDL-JSON source') 382 parser.add_argument('--output', type=argparse.FileType('w'), default=sys.stdout, help='Output C++ file') 383 parser.add_argument('--test-vectors', type=argparse.FileType('r'), required=True, help='Input PDL test file') 384 parser.add_argument('--namespace', type=str, default='pdl', help='Namespace of the generated file') 385 parser.add_argument('--parser-test-suite', type=str, default='ParserTest', help='Name of the parser test suite') 386 parser.add_argument('--serializer-test-suite', 387 type=str, 388 default='SerializerTest', 389 help='Name of the serializer test suite') 390 parser.add_argument('--include-header', type=str, default=[], action='append', help='Added include directives') 391 parser.add_argument('--using-namespace', 392 type=str, 393 default=[], 394 action='append', 395 help='Added using namespace statements') 396 return run(**vars(parser.parse_args())) 397 398 399if __name__ == '__main__': 400 sys.exit(main()) 401