1#! /usr/bin/env python 2"""Generate C code from an ASDL description.""" 3 4import os 5import sys 6import textwrap 7import types 8 9from argparse import ArgumentParser 10from contextlib import contextmanager 11from pathlib import Path 12 13import asdl 14 15TABSIZE = 4 16MAX_COL = 80 17AUTOGEN_MESSAGE = "// File automatically generated by {}.\n\n" 18 19def get_c_type(name): 20 """Return a string for the C name of the type. 21 22 This function special cases the default types provided by asdl. 23 """ 24 if name in asdl.builtin_types: 25 return name 26 else: 27 return "%s_ty" % name 28 29def reflow_lines(s, depth): 30 """Reflow the line s indented depth tabs. 31 32 Return a sequence of lines where no line extends beyond MAX_COL 33 when properly indented. The first line is properly indented based 34 exclusively on depth * TABSIZE. All following lines -- these are 35 the reflowed lines generated by this function -- start at the same 36 column as the first character beyond the opening { in the first 37 line. 38 """ 39 size = MAX_COL - depth * TABSIZE 40 if len(s) < size: 41 return [s] 42 43 lines = [] 44 cur = s 45 padding = "" 46 while len(cur) > size: 47 i = cur.rfind(' ', 0, size) 48 # XXX this should be fixed for real 49 if i == -1 and 'GeneratorExp' in cur: 50 i = size + 3 51 assert i != -1, "Impossible line %d to reflow: %r" % (size, s) 52 lines.append(padding + cur[:i]) 53 if len(lines) == 1: 54 # find new size based on brace 55 j = cur.find('{', 0, i) 56 if j >= 0: 57 j += 2 # account for the brace and the space after it 58 size -= j 59 padding = " " * j 60 else: 61 j = cur.find('(', 0, i) 62 if j >= 0: 63 j += 1 # account for the paren (no space after it) 64 size -= j 65 padding = " " * j 66 cur = cur[i+1:] 67 else: 68 lines.append(padding + cur) 69 return lines 70 71def reflow_c_string(s, depth): 72 return '"%s"' % s.replace('\n', '\\n"\n%s"' % (' ' * depth * TABSIZE)) 73 74def is_simple(sum_type): 75 """Return True if a sum is a simple. 76 77 A sum is simple if it's types have no fields and itself 78 doesn't have any attributes. Instances of these types are 79 cached at C level, and they act like singletons when propagating 80 parser generated nodes into Python level, e.g. 81 unaryop = Invert | Not | UAdd | USub 82 """ 83 84 return not ( 85 sum_type.attributes or 86 any(constructor.fields for constructor in sum_type.types) 87 ) 88 89def asdl_of(name, obj): 90 if isinstance(obj, asdl.Product) or isinstance(obj, asdl.Constructor): 91 fields = ", ".join(map(str, obj.fields)) 92 if fields: 93 fields = "({})".format(fields) 94 return "{}{}".format(name, fields) 95 else: 96 if is_simple(obj): 97 types = " | ".join(type.name for type in obj.types) 98 else: 99 sep = "\n{}| ".format(" " * (len(name) + 1)) 100 types = sep.join( 101 asdl_of(type.name, type) for type in obj.types 102 ) 103 return "{} = {}".format(name, types) 104 105class EmitVisitor(asdl.VisitorBase): 106 """Visit that emits lines""" 107 108 def __init__(self, file, metadata = None): 109 self.file = file 110 self._metadata = metadata 111 super(EmitVisitor, self).__init__() 112 113 def emit(self, s, depth, reflow=True): 114 # XXX reflow long lines? 115 if reflow: 116 lines = reflow_lines(s, depth) 117 else: 118 lines = [s] 119 for line in lines: 120 if line: 121 line = (" " * TABSIZE * depth) + line 122 self.file.write(line + "\n") 123 124 @property 125 def metadata(self): 126 if self._metadata is None: 127 raise ValueError( 128 "%s was expecting to be annnotated with metadata" 129 % type(self).__name__ 130 ) 131 return self._metadata 132 133 @metadata.setter 134 def metadata(self, value): 135 self._metadata = value 136 137class MetadataVisitor(asdl.VisitorBase): 138 ROOT_TYPE = "AST" 139 140 def __init__(self, *args, **kwargs): 141 super().__init__(*args, **kwargs) 142 143 # Metadata: 144 # - simple_sums: Tracks the list of compound type 145 # names where all the constructors 146 # belonging to that type lack of any 147 # fields. 148 # - identifiers: All identifiers used in the AST declarations 149 # - singletons: List of all constructors that originates from 150 # simple sums. 151 # - types: List of all top level type names 152 # 153 self.metadata = types.SimpleNamespace( 154 simple_sums=set(), 155 identifiers=set(), 156 singletons=set(), 157 types={self.ROOT_TYPE}, 158 ) 159 160 def visitModule(self, mod): 161 for dfn in mod.dfns: 162 self.visit(dfn) 163 164 def visitType(self, type): 165 self.visit(type.value, type.name) 166 167 def visitSum(self, sum, name): 168 self.metadata.types.add(name) 169 170 simple_sum = is_simple(sum) 171 if simple_sum: 172 self.metadata.simple_sums.add(name) 173 174 for constructor in sum.types: 175 if simple_sum: 176 self.metadata.singletons.add(constructor.name) 177 self.visitConstructor(constructor) 178 self.visitFields(sum.attributes) 179 180 def visitConstructor(self, constructor): 181 self.metadata.types.add(constructor.name) 182 self.visitFields(constructor.fields) 183 184 def visitProduct(self, product, name): 185 self.metadata.types.add(name) 186 self.visitFields(product.attributes) 187 self.visitFields(product.fields) 188 189 def visitFields(self, fields): 190 for field in fields: 191 self.visitField(field) 192 193 def visitField(self, field): 194 self.metadata.identifiers.add(field.name) 195 196 197class TypeDefVisitor(EmitVisitor): 198 def visitModule(self, mod): 199 for dfn in mod.dfns: 200 self.visit(dfn) 201 202 def visitType(self, type, depth=0): 203 self.visit(type.value, type.name, depth) 204 205 def visitSum(self, sum, name, depth): 206 if is_simple(sum): 207 self.simple_sum(sum, name, depth) 208 else: 209 self.sum_with_constructors(sum, name, depth) 210 211 def simple_sum(self, sum, name, depth): 212 enum = [] 213 for i in range(len(sum.types)): 214 type = sum.types[i] 215 enum.append("%s=%d" % (type.name, i + 1)) 216 enums = ", ".join(enum) 217 ctype = get_c_type(name) 218 s = "typedef enum _%s { %s } %s;" % (name, enums, ctype) 219 self.emit(s, depth) 220 self.emit("", depth) 221 222 def sum_with_constructors(self, sum, name, depth): 223 ctype = get_c_type(name) 224 s = "typedef struct _%(name)s *%(ctype)s;" % locals() 225 self.emit(s, depth) 226 self.emit("", depth) 227 228 def visitProduct(self, product, name, depth): 229 ctype = get_c_type(name) 230 s = "typedef struct _%(name)s *%(ctype)s;" % locals() 231 self.emit(s, depth) 232 self.emit("", depth) 233 234class SequenceDefVisitor(EmitVisitor): 235 def visitModule(self, mod): 236 for dfn in mod.dfns: 237 self.visit(dfn) 238 239 def visitType(self, type, depth=0): 240 self.visit(type.value, type.name, depth) 241 242 def visitSum(self, sum, name, depth): 243 if is_simple(sum): 244 return 245 self.emit_sequence_constructor(name, depth) 246 247 def emit_sequence_constructor(self, name,depth): 248 ctype = get_c_type(name) 249 self.emit("""\ 250typedef struct { 251 _ASDL_SEQ_HEAD 252 %(ctype)s typed_elements[1]; 253} asdl_%(name)s_seq;""" % locals(), reflow=False, depth=depth) 254 self.emit("", depth) 255 self.emit("asdl_%(name)s_seq *_Py_asdl_%(name)s_seq_new(Py_ssize_t size, PyArena *arena);" % locals(), depth) 256 self.emit("", depth) 257 258 def visitProduct(self, product, name, depth): 259 self.emit_sequence_constructor(name, depth) 260 261class StructVisitor(EmitVisitor): 262 """Visitor to generate typedefs for AST.""" 263 264 def visitModule(self, mod): 265 for dfn in mod.dfns: 266 self.visit(dfn) 267 268 def visitType(self, type, depth=0): 269 self.visit(type.value, type.name, depth) 270 271 def visitSum(self, sum, name, depth): 272 if not is_simple(sum): 273 self.sum_with_constructors(sum, name, depth) 274 275 def sum_with_constructors(self, sum, name, depth): 276 def emit(s, depth=depth): 277 self.emit(s % sys._getframe(1).f_locals, depth) 278 enum = [] 279 for i in range(len(sum.types)): 280 type = sum.types[i] 281 enum.append("%s_kind=%d" % (type.name, i + 1)) 282 283 emit("enum _%(name)s_kind {" + ", ".join(enum) + "};") 284 285 emit("struct _%(name)s {") 286 emit("enum _%(name)s_kind kind;", depth + 1) 287 emit("union {", depth + 1) 288 for t in sum.types: 289 self.visit(t, depth + 2) 290 emit("} v;", depth + 1) 291 for field in sum.attributes: 292 # rudimentary attribute handling 293 type = str(field.type) 294 assert type in asdl.builtin_types, type 295 emit("%s %s;" % (type, field.name), depth + 1); 296 emit("};") 297 emit("") 298 299 def visitConstructor(self, cons, depth): 300 if cons.fields: 301 self.emit("struct {", depth) 302 for f in cons.fields: 303 self.visit(f, depth + 1) 304 self.emit("} %s;" % cons.name, depth) 305 self.emit("", depth) 306 307 def visitField(self, field, depth): 308 # XXX need to lookup field.type, because it might be something 309 # like a builtin... 310 ctype = get_c_type(field.type) 311 name = field.name 312 if field.seq: 313 if field.type in self.metadata.simple_sums: 314 self.emit("asdl_int_seq *%(name)s;" % locals(), depth) 315 else: 316 _type = field.type 317 self.emit("asdl_%(_type)s_seq *%(name)s;" % locals(), depth) 318 else: 319 self.emit("%(ctype)s %(name)s;" % locals(), depth) 320 321 def visitProduct(self, product, name, depth): 322 self.emit("struct _%(name)s {" % locals(), depth) 323 for f in product.fields: 324 self.visit(f, depth + 1) 325 for field in product.attributes: 326 # rudimentary attribute handling 327 type = str(field.type) 328 assert type in asdl.builtin_types, type 329 self.emit("%s %s;" % (type, field.name), depth + 1); 330 self.emit("};", depth) 331 self.emit("", depth) 332 333 334def ast_func_name(name): 335 return f"_PyAST_{name}" 336 337 338class PrototypeVisitor(EmitVisitor): 339 """Generate function prototypes for the .h file""" 340 341 def visitModule(self, mod): 342 for dfn in mod.dfns: 343 self.visit(dfn) 344 345 def visitType(self, type): 346 self.visit(type.value, type.name) 347 348 def visitSum(self, sum, name): 349 if is_simple(sum): 350 pass # XXX 351 else: 352 for t in sum.types: 353 self.visit(t, name, sum.attributes) 354 355 def get_args(self, fields): 356 """Return list of C argument into, one for each field. 357 358 Argument info is 3-tuple of a C type, variable name, and flag 359 that is true if type can be NULL. 360 """ 361 args = [] 362 unnamed = {} 363 for f in fields: 364 if f.name is None: 365 name = f.type 366 c = unnamed[name] = unnamed.get(name, 0) + 1 367 if c > 1: 368 name = "name%d" % (c - 1) 369 else: 370 name = f.name 371 # XXX should extend get_c_type() to handle this 372 if f.seq: 373 if f.type in self.metadata.simple_sums: 374 ctype = "asdl_int_seq *" 375 else: 376 ctype = f"asdl_{f.type}_seq *" 377 else: 378 ctype = get_c_type(f.type) 379 args.append((ctype, name, f.opt or f.seq)) 380 return args 381 382 def visitConstructor(self, cons, type, attrs): 383 args = self.get_args(cons.fields) 384 attrs = self.get_args(attrs) 385 ctype = get_c_type(type) 386 self.emit_function(cons.name, ctype, args, attrs) 387 388 def emit_function(self, name, ctype, args, attrs, union=True): 389 args = args + attrs 390 if args: 391 argstr = ", ".join(["%s %s" % (atype, aname) 392 for atype, aname, opt in args]) 393 argstr += ", PyArena *arena" 394 else: 395 argstr = "PyArena *arena" 396 self.emit("%s %s(%s);" % (ctype, ast_func_name(name), argstr), False) 397 398 def visitProduct(self, prod, name): 399 self.emit_function(name, get_c_type(name), 400 self.get_args(prod.fields), 401 self.get_args(prod.attributes), 402 union=False) 403 404 405class FunctionVisitor(PrototypeVisitor): 406 """Visitor to generate constructor functions for AST.""" 407 408 def emit_function(self, name, ctype, args, attrs, union=True): 409 def emit(s, depth=0, reflow=True): 410 self.emit(s, depth, reflow) 411 argstr = ", ".join(["%s %s" % (atype, aname) 412 for atype, aname, opt in args + attrs]) 413 if argstr: 414 argstr += ", PyArena *arena" 415 else: 416 argstr = "PyArena *arena" 417 self.emit("%s" % ctype, 0) 418 emit("%s(%s)" % (ast_func_name(name), argstr)) 419 emit("{") 420 emit("%s p;" % ctype, 1) 421 for argtype, argname, opt in args: 422 if not opt and argtype != "int": 423 emit("if (!%s) {" % argname, 1) 424 emit("PyErr_SetString(PyExc_ValueError,", 2) 425 msg = "field '%s' is required for %s" % (argname, name) 426 emit(' "%s");' % msg, 427 2, reflow=False) 428 emit('return NULL;', 2) 429 emit('}', 1) 430 431 emit("p = (%s)_PyArena_Malloc(arena, sizeof(*p));" % ctype, 1); 432 emit("if (!p)", 1) 433 emit("return NULL;", 2) 434 if union: 435 self.emit_body_union(name, args, attrs) 436 else: 437 self.emit_body_struct(name, args, attrs) 438 emit("return p;", 1) 439 emit("}") 440 emit("") 441 442 def emit_body_union(self, name, args, attrs): 443 def emit(s, depth=0, reflow=True): 444 self.emit(s, depth, reflow) 445 emit("p->kind = %s_kind;" % name, 1) 446 for argtype, argname, opt in args: 447 emit("p->v.%s.%s = %s;" % (name, argname, argname), 1) 448 for argtype, argname, opt in attrs: 449 emit("p->%s = %s;" % (argname, argname), 1) 450 451 def emit_body_struct(self, name, args, attrs): 452 def emit(s, depth=0, reflow=True): 453 self.emit(s, depth, reflow) 454 for argtype, argname, opt in args: 455 emit("p->%s = %s;" % (argname, argname), 1) 456 for argtype, argname, opt in attrs: 457 emit("p->%s = %s;" % (argname, argname), 1) 458 459 460class PickleVisitor(EmitVisitor): 461 462 def visitModule(self, mod): 463 for dfn in mod.dfns: 464 self.visit(dfn) 465 466 def visitType(self, type): 467 self.visit(type.value, type.name) 468 469 def visitSum(self, sum, name): 470 pass 471 472 def visitProduct(self, sum, name): 473 pass 474 475 def visitConstructor(self, cons, name): 476 pass 477 478 def visitField(self, sum): 479 pass 480 481 482class Obj2ModPrototypeVisitor(PickleVisitor): 483 def visitProduct(self, prod, name): 484 code = "static int obj2ast_%s(struct ast_state *state, PyObject* obj, %s* out, PyArena* arena);" 485 self.emit(code % (name, get_c_type(name)), 0) 486 487 visitSum = visitProduct 488 489 490class Obj2ModVisitor(PickleVisitor): 491 492 attribute_special_defaults = { 493 "end_lineno": "lineno", 494 "end_col_offset": "col_offset", 495 } 496 497 @contextmanager 498 def recursive_call(self, node, level): 499 self.emit('if (_Py_EnterRecursiveCall(" while traversing \'%s\' node")) {' % node, level, reflow=False) 500 self.emit('goto failed;', level + 1) 501 self.emit('}', level) 502 yield 503 self.emit('_Py_LeaveRecursiveCall();', level) 504 505 def funcHeader(self, name): 506 ctype = get_c_type(name) 507 self.emit("int", 0) 508 self.emit("obj2ast_%s(struct ast_state *state, PyObject* obj, %s* out, PyArena* arena)" % (name, ctype), 0) 509 self.emit("{", 0) 510 self.emit("int isinstance;", 1) 511 self.emit("", 0) 512 513 def sumTrailer(self, name, add_label=False): 514 self.emit("", 0) 515 # there's really nothing more we can do if this fails ... 516 error = "expected some sort of %s, but got %%R" % name 517 format = "PyErr_Format(PyExc_TypeError, \"%s\", obj);" 518 self.emit(format % error, 1, reflow=False) 519 if add_label: 520 self.emit("failed:", 1) 521 self.emit("Py_XDECREF(tmp);", 1) 522 self.emit("return 1;", 1) 523 self.emit("}", 0) 524 self.emit("", 0) 525 526 def simpleSum(self, sum, name): 527 self.funcHeader(name) 528 for t in sum.types: 529 line = ("isinstance = PyObject_IsInstance(obj, " 530 "state->%s_type);") 531 self.emit(line % (t.name,), 1) 532 self.emit("if (isinstance == -1) {", 1) 533 self.emit("return 1;", 2) 534 self.emit("}", 1) 535 self.emit("if (isinstance) {", 1) 536 self.emit("*out = %s;" % t.name, 2) 537 self.emit("return 0;", 2) 538 self.emit("}", 1) 539 self.sumTrailer(name) 540 541 def buildArgs(self, fields): 542 return ", ".join(fields + ["arena"]) 543 544 def complexSum(self, sum, name): 545 self.funcHeader(name) 546 self.emit("PyObject *tmp = NULL;", 1) 547 self.emit("PyObject *tp;", 1) 548 for a in sum.attributes: 549 self.visitAttributeDeclaration(a, name, sum=sum) 550 self.emit("", 0) 551 # XXX: should we only do this for 'expr'? 552 self.emit("if (obj == Py_None) {", 1) 553 self.emit("*out = NULL;", 2) 554 self.emit("return 0;", 2) 555 self.emit("}", 1) 556 for a in sum.attributes: 557 self.visitField(a, name, sum=sum, depth=1) 558 for t in sum.types: 559 self.emit("tp = state->%s_type;" % (t.name,), 1) 560 self.emit("isinstance = PyObject_IsInstance(obj, tp);", 1) 561 self.emit("if (isinstance == -1) {", 1) 562 self.emit("return 1;", 2) 563 self.emit("}", 1) 564 self.emit("if (isinstance) {", 1) 565 for f in t.fields: 566 self.visitFieldDeclaration(f, t.name, sum=sum, depth=2) 567 self.emit("", 0) 568 for f in t.fields: 569 self.visitField(f, t.name, sum=sum, depth=2) 570 args = [f.name for f in t.fields] + [a.name for a in sum.attributes] 571 self.emit("*out = %s(%s);" % (ast_func_name(t.name), self.buildArgs(args)), 2) 572 self.emit("if (*out == NULL) goto failed;", 2) 573 self.emit("return 0;", 2) 574 self.emit("}", 1) 575 self.sumTrailer(name, True) 576 577 def visitAttributeDeclaration(self, a, name, sum=sum): 578 ctype = get_c_type(a.type) 579 self.emit("%s %s;" % (ctype, a.name), 1) 580 581 def visitSum(self, sum, name): 582 if is_simple(sum): 583 self.simpleSum(sum, name) 584 else: 585 self.complexSum(sum, name) 586 587 def visitProduct(self, prod, name): 588 ctype = get_c_type(name) 589 self.emit("int", 0) 590 self.emit("obj2ast_%s(struct ast_state *state, PyObject* obj, %s* out, PyArena* arena)" % (name, ctype), 0) 591 self.emit("{", 0) 592 self.emit("PyObject* tmp = NULL;", 1) 593 for f in prod.fields: 594 self.visitFieldDeclaration(f, name, prod=prod, depth=1) 595 for a in prod.attributes: 596 self.visitFieldDeclaration(a, name, prod=prod, depth=1) 597 self.emit("", 0) 598 for f in prod.fields: 599 self.visitField(f, name, prod=prod, depth=1) 600 for a in prod.attributes: 601 self.visitField(a, name, prod=prod, depth=1) 602 args = [f.name for f in prod.fields] 603 args.extend([a.name for a in prod.attributes]) 604 self.emit("*out = %s(%s);" % (ast_func_name(name), self.buildArgs(args)), 1) 605 self.emit("return 0;", 1) 606 self.emit("failed:", 0) 607 self.emit("Py_XDECREF(tmp);", 1) 608 self.emit("return 1;", 1) 609 self.emit("}", 0) 610 self.emit("", 0) 611 612 def visitFieldDeclaration(self, field, name, sum=None, prod=None, depth=0): 613 ctype = get_c_type(field.type) 614 if field.seq: 615 if self.isSimpleType(field): 616 self.emit("asdl_int_seq* %s;" % field.name, depth) 617 else: 618 _type = field.type 619 self.emit(f"asdl_{field.type}_seq* {field.name};", depth) 620 else: 621 ctype = get_c_type(field.type) 622 self.emit("%s %s;" % (ctype, field.name), depth) 623 624 def isNumeric(self, field): 625 return get_c_type(field.type) in ("int", "bool") 626 627 def isSimpleType(self, field): 628 return field.type in self.metadata.simple_sums or self.isNumeric(field) 629 630 def visitField(self, field, name, sum=None, prod=None, depth=0): 631 ctype = get_c_type(field.type) 632 line = "if (_PyObject_LookupAttr(obj, state->%s, &tmp) < 0) {" 633 self.emit(line % field.name, depth) 634 self.emit("return 1;", depth+1) 635 self.emit("}", depth) 636 if not field.opt: 637 self.emit("if (tmp == NULL) {", depth) 638 message = "required field \\\"%s\\\" missing from %s" % (field.name, name) 639 format = "PyErr_SetString(PyExc_TypeError, \"%s\");" 640 self.emit(format % message, depth+1, reflow=False) 641 self.emit("return 1;", depth+1) 642 else: 643 self.emit("if (tmp == NULL || tmp == Py_None) {", depth) 644 self.emit("Py_CLEAR(tmp);", depth+1) 645 if self.isNumeric(field): 646 if field.name in self.attribute_special_defaults: 647 self.emit( 648 "%s = %s;" % (field.name, self.attribute_special_defaults[field.name]), 649 depth+1, 650 ) 651 else: 652 self.emit("%s = 0;" % field.name, depth+1) 653 elif not self.isSimpleType(field): 654 self.emit("%s = NULL;" % field.name, depth+1) 655 else: 656 raise TypeError("could not determine the default value for %s" % field.name) 657 self.emit("}", depth) 658 self.emit("else {", depth) 659 660 self.emit("int res;", depth+1) 661 if field.seq: 662 self.emit("Py_ssize_t len;", depth+1) 663 self.emit("Py_ssize_t i;", depth+1) 664 self.emit("if (!PyList_Check(tmp)) {", depth+1) 665 self.emit("PyErr_Format(PyExc_TypeError, \"%s field \\\"%s\\\" must " 666 "be a list, not a %%.200s\", _PyType_Name(Py_TYPE(tmp)));" % 667 (name, field.name), 668 depth+2, reflow=False) 669 self.emit("goto failed;", depth+2) 670 self.emit("}", depth+1) 671 self.emit("len = PyList_GET_SIZE(tmp);", depth+1) 672 if self.isSimpleType(field): 673 self.emit("%s = _Py_asdl_int_seq_new(len, arena);" % field.name, depth+1) 674 else: 675 self.emit("%s = _Py_asdl_%s_seq_new(len, arena);" % (field.name, field.type), depth+1) 676 self.emit("if (%s == NULL) goto failed;" % field.name, depth+1) 677 self.emit("for (i = 0; i < len; i++) {", depth+1) 678 self.emit("%s val;" % ctype, depth+2) 679 self.emit("PyObject *tmp2 = PyList_GET_ITEM(tmp, i);", depth+2) 680 self.emit("Py_INCREF(tmp2);", depth+2) 681 with self.recursive_call(name, depth+2): 682 self.emit("res = obj2ast_%s(state, tmp2, &val, arena);" % 683 field.type, depth+2, reflow=False) 684 self.emit("Py_DECREF(tmp2);", depth+2) 685 self.emit("if (res != 0) goto failed;", depth+2) 686 self.emit("if (len != PyList_GET_SIZE(tmp)) {", depth+2) 687 self.emit("PyErr_SetString(PyExc_RuntimeError, \"%s field \\\"%s\\\" " 688 "changed size during iteration\");" % 689 (name, field.name), 690 depth+3, reflow=False) 691 self.emit("goto failed;", depth+3) 692 self.emit("}", depth+2) 693 self.emit("asdl_seq_SET(%s, i, val);" % field.name, depth+2) 694 self.emit("}", depth+1) 695 else: 696 with self.recursive_call(name, depth+1): 697 self.emit("res = obj2ast_%s(state, tmp, &%s, arena);" % 698 (field.type, field.name), depth+1) 699 self.emit("if (res != 0) goto failed;", depth+1) 700 701 self.emit("Py_CLEAR(tmp);", depth+1) 702 self.emit("}", depth) 703 704 705class SequenceConstructorVisitor(EmitVisitor): 706 def visitModule(self, mod): 707 for dfn in mod.dfns: 708 self.visit(dfn) 709 710 def visitType(self, type): 711 self.visit(type.value, type.name) 712 713 def visitProduct(self, prod, name): 714 self.emit_sequence_constructor(name, get_c_type(name)) 715 716 def visitSum(self, sum, name): 717 if not is_simple(sum): 718 self.emit_sequence_constructor(name, get_c_type(name)) 719 720 def emit_sequence_constructor(self, name, type): 721 self.emit(f"GENERATE_ASDL_SEQ_CONSTRUCTOR({name}, {type})", depth=0) 722 723class PyTypesDeclareVisitor(PickleVisitor): 724 725 def visitProduct(self, prod, name): 726 self.emit("static PyObject* ast2obj_%s(struct ast_state *state, void*);" % name, 0) 727 if prod.attributes: 728 self.emit("static const char * const %s_attributes[] = {" % name, 0) 729 for a in prod.attributes: 730 self.emit('"%s",' % a.name, 1) 731 self.emit("};", 0) 732 if prod.fields: 733 self.emit("static const char * const %s_fields[]={" % name,0) 734 for f in prod.fields: 735 self.emit('"%s",' % f.name, 1) 736 self.emit("};", 0) 737 738 def visitSum(self, sum, name): 739 if sum.attributes: 740 self.emit("static const char * const %s_attributes[] = {" % name, 0) 741 for a in sum.attributes: 742 self.emit('"%s",' % a.name, 1) 743 self.emit("};", 0) 744 ptype = "void*" 745 if is_simple(sum): 746 ptype = get_c_type(name) 747 self.emit("static PyObject* ast2obj_%s(struct ast_state *state, %s);" % (name, ptype), 0) 748 for t in sum.types: 749 self.visitConstructor(t, name) 750 751 def visitConstructor(self, cons, name): 752 if cons.fields: 753 self.emit("static const char * const %s_fields[]={" % cons.name, 0) 754 for t in cons.fields: 755 self.emit('"%s",' % t.name, 1) 756 self.emit("};",0) 757 758 759class PyTypesVisitor(PickleVisitor): 760 761 def visitModule(self, mod): 762 self.emit(""" 763 764typedef struct { 765 PyObject_HEAD 766 PyObject *dict; 767} AST_object; 768 769static void 770ast_dealloc(AST_object *self) 771{ 772 /* bpo-31095: UnTrack is needed before calling any callbacks */ 773 PyTypeObject *tp = Py_TYPE(self); 774 PyObject_GC_UnTrack(self); 775 Py_CLEAR(self->dict); 776 freefunc free_func = PyType_GetSlot(tp, Py_tp_free); 777 assert(free_func != NULL); 778 free_func(self); 779 Py_DECREF(tp); 780} 781 782static int 783ast_traverse(AST_object *self, visitproc visit, void *arg) 784{ 785 Py_VISIT(Py_TYPE(self)); 786 Py_VISIT(self->dict); 787 return 0; 788} 789 790static int 791ast_clear(AST_object *self) 792{ 793 Py_CLEAR(self->dict); 794 return 0; 795} 796 797static int 798ast_type_init(PyObject *self, PyObject *args, PyObject *kw) 799{ 800 struct ast_state *state = get_ast_state(); 801 if (state == NULL) { 802 return -1; 803 } 804 805 Py_ssize_t i, numfields = 0; 806 int res = -1; 807 PyObject *key, *value, *fields; 808 if (_PyObject_LookupAttr((PyObject*)Py_TYPE(self), state->_fields, &fields) < 0) { 809 goto cleanup; 810 } 811 if (fields) { 812 numfields = PySequence_Size(fields); 813 if (numfields == -1) { 814 goto cleanup; 815 } 816 } 817 818 res = 0; /* if no error occurs, this stays 0 to the end */ 819 if (numfields < PyTuple_GET_SIZE(args)) { 820 PyErr_Format(PyExc_TypeError, "%.400s constructor takes at most " 821 "%zd positional argument%s", 822 _PyType_Name(Py_TYPE(self)), 823 numfields, numfields == 1 ? "" : "s"); 824 res = -1; 825 goto cleanup; 826 } 827 for (i = 0; i < PyTuple_GET_SIZE(args); i++) { 828 /* cannot be reached when fields is NULL */ 829 PyObject *name = PySequence_GetItem(fields, i); 830 if (!name) { 831 res = -1; 832 goto cleanup; 833 } 834 res = PyObject_SetAttr(self, name, PyTuple_GET_ITEM(args, i)); 835 Py_DECREF(name); 836 if (res < 0) { 837 goto cleanup; 838 } 839 } 840 if (kw) { 841 i = 0; /* needed by PyDict_Next */ 842 while (PyDict_Next(kw, &i, &key, &value)) { 843 int contains = PySequence_Contains(fields, key); 844 if (contains == -1) { 845 res = -1; 846 goto cleanup; 847 } else if (contains == 1) { 848 Py_ssize_t p = PySequence_Index(fields, key); 849 if (p == -1) { 850 res = -1; 851 goto cleanup; 852 } 853 if (p < PyTuple_GET_SIZE(args)) { 854 PyErr_Format(PyExc_TypeError, 855 "%.400s got multiple values for argument '%U'", 856 Py_TYPE(self)->tp_name, key); 857 res = -1; 858 goto cleanup; 859 } 860 } 861 res = PyObject_SetAttr(self, key, value); 862 if (res < 0) { 863 goto cleanup; 864 } 865 } 866 } 867 cleanup: 868 Py_XDECREF(fields); 869 return res; 870} 871 872/* Pickling support */ 873static PyObject * 874ast_type_reduce(PyObject *self, PyObject *unused) 875{ 876 struct ast_state *state = get_ast_state(); 877 if (state == NULL) { 878 return NULL; 879 } 880 881 PyObject *dict; 882 if (_PyObject_LookupAttr(self, state->__dict__, &dict) < 0) { 883 return NULL; 884 } 885 if (dict) { 886 return Py_BuildValue("O()N", Py_TYPE(self), dict); 887 } 888 return Py_BuildValue("O()", Py_TYPE(self)); 889} 890 891static PyMemberDef ast_type_members[] = { 892 {"__dictoffset__", T_PYSSIZET, offsetof(AST_object, dict), READONLY}, 893 {NULL} /* Sentinel */ 894}; 895 896static PyMethodDef ast_type_methods[] = { 897 {"__reduce__", ast_type_reduce, METH_NOARGS, NULL}, 898 {NULL} 899}; 900 901static PyGetSetDef ast_type_getsets[] = { 902 {"__dict__", PyObject_GenericGetDict, PyObject_GenericSetDict}, 903 {NULL} 904}; 905 906static PyType_Slot AST_type_slots[] = { 907 {Py_tp_dealloc, ast_dealloc}, 908 {Py_tp_getattro, PyObject_GenericGetAttr}, 909 {Py_tp_setattro, PyObject_GenericSetAttr}, 910 {Py_tp_traverse, ast_traverse}, 911 {Py_tp_clear, ast_clear}, 912 {Py_tp_members, ast_type_members}, 913 {Py_tp_methods, ast_type_methods}, 914 {Py_tp_getset, ast_type_getsets}, 915 {Py_tp_init, ast_type_init}, 916 {Py_tp_alloc, PyType_GenericAlloc}, 917 {Py_tp_new, PyType_GenericNew}, 918 {Py_tp_free, PyObject_GC_Del}, 919 {0, 0}, 920}; 921 922static PyType_Spec AST_type_spec = { 923 "ast.AST", 924 sizeof(AST_object), 925 0, 926 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, 927 AST_type_slots 928}; 929 930static PyObject * 931make_type(struct ast_state *state, const char *type, PyObject* base, 932 const char* const* fields, int num_fields, const char *doc) 933{ 934 PyObject *fnames, *result; 935 int i; 936 fnames = PyTuple_New(num_fields); 937 if (!fnames) return NULL; 938 for (i = 0; i < num_fields; i++) { 939 PyObject *field = PyUnicode_InternFromString(fields[i]); 940 if (!field) { 941 Py_DECREF(fnames); 942 return NULL; 943 } 944 PyTuple_SET_ITEM(fnames, i, field); 945 } 946 result = PyObject_CallFunction((PyObject*)&PyType_Type, "s(O){OOOOOOOs}", 947 type, base, 948 state->_fields, fnames, 949 state->__match_args__, fnames, 950 state->__module__, 951 state->ast, 952 state->__doc__, doc); 953 Py_DECREF(fnames); 954 return result; 955} 956 957static int 958add_attributes(struct ast_state *state, PyObject *type, const char * const *attrs, int num_fields) 959{ 960 int i, result; 961 PyObject *s, *l = PyTuple_New(num_fields); 962 if (!l) 963 return 0; 964 for (i = 0; i < num_fields; i++) { 965 s = PyUnicode_InternFromString(attrs[i]); 966 if (!s) { 967 Py_DECREF(l); 968 return 0; 969 } 970 PyTuple_SET_ITEM(l, i, s); 971 } 972 result = PyObject_SetAttr(type, state->_attributes, l) >= 0; 973 Py_DECREF(l); 974 return result; 975} 976 977/* Conversion AST -> Python */ 978 979static PyObject* ast2obj_list(struct ast_state *state, asdl_seq *seq, PyObject* (*func)(struct ast_state *state, void*)) 980{ 981 Py_ssize_t i, n = asdl_seq_LEN(seq); 982 PyObject *result = PyList_New(n); 983 PyObject *value; 984 if (!result) 985 return NULL; 986 for (i = 0; i < n; i++) { 987 value = func(state, asdl_seq_GET_UNTYPED(seq, i)); 988 if (!value) { 989 Py_DECREF(result); 990 return NULL; 991 } 992 PyList_SET_ITEM(result, i, value); 993 } 994 return result; 995} 996 997static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), void *o) 998{ 999 if (!o) 1000 o = Py_None; 1001 Py_INCREF((PyObject*)o); 1002 return (PyObject*)o; 1003} 1004#define ast2obj_constant ast2obj_object 1005#define ast2obj_identifier ast2obj_object 1006#define ast2obj_string ast2obj_object 1007 1008static PyObject* ast2obj_int(struct ast_state *Py_UNUSED(state), long b) 1009{ 1010 return PyLong_FromLong(b); 1011} 1012 1013/* Conversion Python -> AST */ 1014 1015static int obj2ast_object(struct ast_state *Py_UNUSED(state), PyObject* obj, PyObject** out, PyArena* arena) 1016{ 1017 if (obj == Py_None) 1018 obj = NULL; 1019 if (obj) { 1020 if (_PyArena_AddPyObject(arena, obj) < 0) { 1021 *out = NULL; 1022 return -1; 1023 } 1024 Py_INCREF(obj); 1025 } 1026 *out = obj; 1027 return 0; 1028} 1029 1030static int obj2ast_constant(struct ast_state *Py_UNUSED(state), PyObject* obj, PyObject** out, PyArena* arena) 1031{ 1032 if (_PyArena_AddPyObject(arena, obj) < 0) { 1033 *out = NULL; 1034 return -1; 1035 } 1036 Py_INCREF(obj); 1037 *out = obj; 1038 return 0; 1039} 1040 1041static int obj2ast_identifier(struct ast_state *state, PyObject* obj, PyObject** out, PyArena* arena) 1042{ 1043 if (!PyUnicode_CheckExact(obj) && obj != Py_None) { 1044 PyErr_SetString(PyExc_TypeError, "AST identifier must be of type str"); 1045 return 1; 1046 } 1047 return obj2ast_object(state, obj, out, arena); 1048} 1049 1050static int obj2ast_string(struct ast_state *state, PyObject* obj, PyObject** out, PyArena* arena) 1051{ 1052 if (!PyUnicode_CheckExact(obj) && !PyBytes_CheckExact(obj)) { 1053 PyErr_SetString(PyExc_TypeError, "AST string must be of type str"); 1054 return 1; 1055 } 1056 return obj2ast_object(state, obj, out, arena); 1057} 1058 1059static int obj2ast_int(struct ast_state* Py_UNUSED(state), PyObject* obj, int* out, PyArena* arena) 1060{ 1061 int i; 1062 if (!PyLong_Check(obj)) { 1063 PyErr_Format(PyExc_ValueError, "invalid integer value: %R", obj); 1064 return 1; 1065 } 1066 1067 i = _PyLong_AsInt(obj); 1068 if (i == -1 && PyErr_Occurred()) 1069 return 1; 1070 *out = i; 1071 return 0; 1072} 1073 1074static int add_ast_fields(struct ast_state *state) 1075{ 1076 PyObject *empty_tuple; 1077 empty_tuple = PyTuple_New(0); 1078 if (!empty_tuple || 1079 PyObject_SetAttrString(state->AST_type, "_fields", empty_tuple) < 0 || 1080 PyObject_SetAttrString(state->AST_type, "__match_args__", empty_tuple) < 0 || 1081 PyObject_SetAttrString(state->AST_type, "_attributes", empty_tuple) < 0) { 1082 Py_XDECREF(empty_tuple); 1083 return -1; 1084 } 1085 Py_DECREF(empty_tuple); 1086 return 0; 1087} 1088 1089""", 0, reflow=False) 1090 1091 self.file.write(textwrap.dedent(''' 1092 static int 1093 init_types(struct ast_state *state) 1094 { 1095 // init_types() must not be called after _PyAST_Fini() 1096 // has been called 1097 assert(state->initialized >= 0); 1098 1099 if (state->initialized) { 1100 return 1; 1101 } 1102 if (init_identifiers(state) < 0) { 1103 return 0; 1104 } 1105 state->AST_type = PyType_FromSpec(&AST_type_spec); 1106 if (!state->AST_type) { 1107 return 0; 1108 } 1109 if (add_ast_fields(state) < 0) { 1110 return 0; 1111 } 1112 ''')) 1113 for dfn in mod.dfns: 1114 self.visit(dfn) 1115 self.file.write(textwrap.dedent(''' 1116 state->recursion_depth = 0; 1117 state->recursion_limit = 0; 1118 state->initialized = 1; 1119 return 1; 1120 } 1121 ''')) 1122 1123 def visitProduct(self, prod, name): 1124 if prod.fields: 1125 fields = name+"_fields" 1126 else: 1127 fields = "NULL" 1128 self.emit('state->%s_type = make_type(state, "%s", state->AST_type, %s, %d,' % 1129 (name, name, fields, len(prod.fields)), 1) 1130 self.emit('%s);' % reflow_c_string(asdl_of(name, prod), 2), 2, reflow=False) 1131 self.emit("if (!state->%s_type) return 0;" % name, 1) 1132 if prod.attributes: 1133 self.emit("if (!add_attributes(state, state->%s_type, %s_attributes, %d)) return 0;" % 1134 (name, name, len(prod.attributes)), 1) 1135 else: 1136 self.emit("if (!add_attributes(state, state->%s_type, NULL, 0)) return 0;" % name, 1) 1137 self.emit_defaults(name, prod.fields, 1) 1138 self.emit_defaults(name, prod.attributes, 1) 1139 1140 def visitSum(self, sum, name): 1141 self.emit('state->%s_type = make_type(state, "%s", state->AST_type, NULL, 0,' % 1142 (name, name), 1) 1143 self.emit('%s);' % reflow_c_string(asdl_of(name, sum), 2), 2, reflow=False) 1144 self.emit("if (!state->%s_type) return 0;" % name, 1) 1145 if sum.attributes: 1146 self.emit("if (!add_attributes(state, state->%s_type, %s_attributes, %d)) return 0;" % 1147 (name, name, len(sum.attributes)), 1) 1148 else: 1149 self.emit("if (!add_attributes(state, state->%s_type, NULL, 0)) return 0;" % name, 1) 1150 self.emit_defaults(name, sum.attributes, 1) 1151 simple = is_simple(sum) 1152 for t in sum.types: 1153 self.visitConstructor(t, name, simple) 1154 1155 def visitConstructor(self, cons, name, simple): 1156 if cons.fields: 1157 fields = cons.name+"_fields" 1158 else: 1159 fields = "NULL" 1160 self.emit('state->%s_type = make_type(state, "%s", state->%s_type, %s, %d,' % 1161 (cons.name, cons.name, name, fields, len(cons.fields)), 1) 1162 self.emit('%s);' % reflow_c_string(asdl_of(cons.name, cons), 2), 2, reflow=False) 1163 self.emit("if (!state->%s_type) return 0;" % cons.name, 1) 1164 self.emit_defaults(cons.name, cons.fields, 1) 1165 if simple: 1166 self.emit("state->%s_singleton = PyType_GenericNew((PyTypeObject *)" 1167 "state->%s_type, NULL, NULL);" % 1168 (cons.name, cons.name), 1) 1169 self.emit("if (!state->%s_singleton) return 0;" % cons.name, 1) 1170 1171 def emit_defaults(self, name, fields, depth): 1172 for field in fields: 1173 if field.opt: 1174 self.emit('if (PyObject_SetAttr(state->%s_type, state->%s, Py_None) == -1)' % 1175 (name, field.name), depth) 1176 self.emit("return 0;", depth+1) 1177 1178 1179class ASTModuleVisitor(PickleVisitor): 1180 1181 def visitModule(self, mod): 1182 self.emit("static int", 0) 1183 self.emit("astmodule_exec(PyObject *m)", 0) 1184 self.emit("{", 0) 1185 self.emit('struct ast_state *state = get_ast_state();', 1) 1186 self.emit('if (state == NULL) {', 1) 1187 self.emit('return -1;', 2) 1188 self.emit('}', 1) 1189 self.emit('if (PyModule_AddObjectRef(m, "AST", state->AST_type) < 0) {', 1) 1190 self.emit('return -1;', 2) 1191 self.emit('}', 1) 1192 self.emit('if (PyModule_AddIntMacro(m, PyCF_ALLOW_TOP_LEVEL_AWAIT) < 0) {', 1) 1193 self.emit("return -1;", 2) 1194 self.emit('}', 1) 1195 self.emit('if (PyModule_AddIntMacro(m, PyCF_ONLY_AST) < 0) {', 1) 1196 self.emit("return -1;", 2) 1197 self.emit('}', 1) 1198 self.emit('if (PyModule_AddIntMacro(m, PyCF_TYPE_COMMENTS) < 0) {', 1) 1199 self.emit("return -1;", 2) 1200 self.emit('}', 1) 1201 for dfn in mod.dfns: 1202 self.visit(dfn) 1203 self.emit("return 0;", 1) 1204 self.emit("}", 0) 1205 self.emit("", 0) 1206 self.emit(""" 1207static PyModuleDef_Slot astmodule_slots[] = { 1208 {Py_mod_exec, astmodule_exec}, 1209 {0, NULL} 1210}; 1211 1212static struct PyModuleDef _astmodule = { 1213 PyModuleDef_HEAD_INIT, 1214 .m_name = "_ast", 1215 // The _ast module uses a per-interpreter state (PyInterpreterState.ast) 1216 .m_size = 0, 1217 .m_slots = astmodule_slots, 1218}; 1219 1220PyMODINIT_FUNC 1221PyInit__ast(void) 1222{ 1223 return PyModuleDef_Init(&_astmodule); 1224} 1225""".strip(), 0, reflow=False) 1226 1227 def visitProduct(self, prod, name): 1228 self.addObj(name) 1229 1230 def visitSum(self, sum, name): 1231 self.addObj(name) 1232 for t in sum.types: 1233 self.visitConstructor(t, name) 1234 1235 def visitConstructor(self, cons, name): 1236 self.addObj(cons.name) 1237 1238 def addObj(self, name): 1239 self.emit("if (PyModule_AddObjectRef(m, \"%s\", " 1240 "state->%s_type) < 0) {" % (name, name), 1) 1241 self.emit("return -1;", 2) 1242 self.emit('}', 1) 1243 1244 1245class StaticVisitor(PickleVisitor): 1246 CODE = '''Very simple, always emit this static code. Override CODE''' 1247 1248 def visit(self, object): 1249 self.emit(self.CODE, 0, reflow=False) 1250 1251 1252class ObjVisitor(PickleVisitor): 1253 1254 def func_begin(self, name): 1255 ctype = get_c_type(name) 1256 self.emit("PyObject*", 0) 1257 self.emit("ast2obj_%s(struct ast_state *state, void* _o)" % (name), 0) 1258 self.emit("{", 0) 1259 self.emit("%s o = (%s)_o;" % (ctype, ctype), 1) 1260 self.emit("PyObject *result = NULL, *value = NULL;", 1) 1261 self.emit("PyTypeObject *tp;", 1) 1262 self.emit('if (!o) {', 1) 1263 self.emit("Py_RETURN_NONE;", 2) 1264 self.emit("}", 1) 1265 self.emit("if (++state->recursion_depth > state->recursion_limit) {", 1) 1266 self.emit("PyErr_SetString(PyExc_RecursionError,", 2) 1267 self.emit('"maximum recursion depth exceeded during ast construction");', 3) 1268 self.emit("return 0;", 2) 1269 self.emit("}", 1) 1270 1271 def func_end(self): 1272 self.emit("state->recursion_depth--;", 1) 1273 self.emit("return result;", 1) 1274 self.emit("failed:", 0) 1275 self.emit("Py_XDECREF(value);", 1) 1276 self.emit("Py_XDECREF(result);", 1) 1277 self.emit("return NULL;", 1) 1278 self.emit("}", 0) 1279 self.emit("", 0) 1280 1281 def visitSum(self, sum, name): 1282 if is_simple(sum): 1283 self.simpleSum(sum, name) 1284 return 1285 self.func_begin(name) 1286 self.emit("switch (o->kind) {", 1) 1287 for i in range(len(sum.types)): 1288 t = sum.types[i] 1289 self.visitConstructor(t, i + 1, name) 1290 self.emit("}", 1) 1291 for a in sum.attributes: 1292 self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1) 1293 self.emit("if (!value) goto failed;", 1) 1294 self.emit('if (PyObject_SetAttr(result, state->%s, value) < 0)' % a.name, 1) 1295 self.emit('goto failed;', 2) 1296 self.emit('Py_DECREF(value);', 1) 1297 self.func_end() 1298 1299 def simpleSum(self, sum, name): 1300 self.emit("PyObject* ast2obj_%s(struct ast_state *state, %s_ty o)" % (name, name), 0) 1301 self.emit("{", 0) 1302 self.emit("switch(o) {", 1) 1303 for t in sum.types: 1304 self.emit("case %s:" % t.name, 2) 1305 self.emit("Py_INCREF(state->%s_singleton);" % t.name, 3) 1306 self.emit("return state->%s_singleton;" % t.name, 3) 1307 self.emit("}", 1) 1308 self.emit("Py_UNREACHABLE();", 1); 1309 self.emit("}", 0) 1310 1311 def visitProduct(self, prod, name): 1312 self.func_begin(name) 1313 self.emit("tp = (PyTypeObject *)state->%s_type;" % name, 1) 1314 self.emit("result = PyType_GenericNew(tp, NULL, NULL);", 1); 1315 self.emit("if (!result) return NULL;", 1) 1316 for field in prod.fields: 1317 self.visitField(field, name, 1, True) 1318 for a in prod.attributes: 1319 self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1) 1320 self.emit("if (!value) goto failed;", 1) 1321 self.emit("if (PyObject_SetAttr(result, state->%s, value) < 0)" % a.name, 1) 1322 self.emit('goto failed;', 2) 1323 self.emit('Py_DECREF(value);', 1) 1324 self.func_end() 1325 1326 def visitConstructor(self, cons, enum, name): 1327 self.emit("case %s_kind:" % cons.name, 1) 1328 self.emit("tp = (PyTypeObject *)state->%s_type;" % cons.name, 2) 1329 self.emit("result = PyType_GenericNew(tp, NULL, NULL);", 2); 1330 self.emit("if (!result) goto failed;", 2) 1331 for f in cons.fields: 1332 self.visitField(f, cons.name, 2, False) 1333 self.emit("break;", 2) 1334 1335 def visitField(self, field, name, depth, product): 1336 def emit(s, d): 1337 self.emit(s, depth + d) 1338 if product: 1339 value = "o->%s" % field.name 1340 else: 1341 value = "o->v.%s.%s" % (name, field.name) 1342 self.set(field, value, depth) 1343 emit("if (!value) goto failed;", 0) 1344 emit("if (PyObject_SetAttr(result, state->%s, value) == -1)" % field.name, 0) 1345 emit("goto failed;", 1) 1346 emit("Py_DECREF(value);", 0) 1347 1348 def set(self, field, value, depth): 1349 if field.seq: 1350 if field.type in self.metadata.simple_sums: 1351 # While the sequence elements are stored as void*, 1352 # simple sums expects an enum 1353 self.emit("{", depth) 1354 self.emit("Py_ssize_t i, n = asdl_seq_LEN(%s);" % value, depth+1) 1355 self.emit("value = PyList_New(n);", depth+1) 1356 self.emit("if (!value) goto failed;", depth+1) 1357 self.emit("for(i = 0; i < n; i++)", depth+1) 1358 # This cannot fail, so no need for error handling 1359 self.emit( 1360 "PyList_SET_ITEM(value, i, ast2obj_{0}(state, ({0}_ty)asdl_seq_GET({1}, i)));".format( 1361 field.type, 1362 value 1363 ), 1364 depth + 2, 1365 reflow=False, 1366 ) 1367 self.emit("}", depth) 1368 else: 1369 self.emit("value = ast2obj_list(state, (asdl_seq*)%s, ast2obj_%s);" % (value, field.type), depth) 1370 else: 1371 self.emit("value = ast2obj_%s(state, %s);" % (field.type, value), depth, reflow=False) 1372 1373 1374class PartingShots(StaticVisitor): 1375 1376 CODE = """ 1377PyObject* PyAST_mod2obj(mod_ty t) 1378{ 1379 struct ast_state *state = get_ast_state(); 1380 if (state == NULL) { 1381 return NULL; 1382 } 1383 1384 int recursion_limit = Py_GetRecursionLimit(); 1385 int starting_recursion_depth; 1386 /* Be careful here to prevent overflow. */ 1387 int COMPILER_STACK_FRAME_SCALE = 3; 1388 PyThreadState *tstate = _PyThreadState_GET(); 1389 if (!tstate) { 1390 return 0; 1391 } 1392 state->recursion_limit = (recursion_limit < INT_MAX / COMPILER_STACK_FRAME_SCALE) ? 1393 recursion_limit * COMPILER_STACK_FRAME_SCALE : recursion_limit; 1394 int recursion_depth = tstate->recursion_limit - tstate->recursion_remaining; 1395 starting_recursion_depth = (recursion_depth < INT_MAX / COMPILER_STACK_FRAME_SCALE) ? 1396 recursion_depth * COMPILER_STACK_FRAME_SCALE : recursion_depth; 1397 state->recursion_depth = starting_recursion_depth; 1398 1399 PyObject *result = ast2obj_mod(state, t); 1400 1401 /* Check that the recursion depth counting balanced correctly */ 1402 if (result && state->recursion_depth != starting_recursion_depth) { 1403 PyErr_Format(PyExc_SystemError, 1404 "AST constructor recursion depth mismatch (before=%d, after=%d)", 1405 starting_recursion_depth, state->recursion_depth); 1406 return 0; 1407 } 1408 return result; 1409} 1410 1411/* mode is 0 for "exec", 1 for "eval" and 2 for "single" input */ 1412mod_ty PyAST_obj2mod(PyObject* ast, PyArena* arena, int mode) 1413{ 1414 const char * const req_name[] = {"Module", "Expression", "Interactive"}; 1415 int isinstance; 1416 1417 if (PySys_Audit("compile", "OO", ast, Py_None) < 0) { 1418 return NULL; 1419 } 1420 1421 struct ast_state *state = get_ast_state(); 1422 if (state == NULL) { 1423 return NULL; 1424 } 1425 1426 PyObject *req_type[3]; 1427 req_type[0] = state->Module_type; 1428 req_type[1] = state->Expression_type; 1429 req_type[2] = state->Interactive_type; 1430 1431 assert(0 <= mode && mode <= 2); 1432 1433 isinstance = PyObject_IsInstance(ast, req_type[mode]); 1434 if (isinstance == -1) 1435 return NULL; 1436 if (!isinstance) { 1437 PyErr_Format(PyExc_TypeError, "expected %s node, got %.400s", 1438 req_name[mode], _PyType_Name(Py_TYPE(ast))); 1439 return NULL; 1440 } 1441 1442 mod_ty res = NULL; 1443 if (obj2ast_mod(state, ast, &res, arena) != 0) 1444 return NULL; 1445 else 1446 return res; 1447} 1448 1449int PyAST_Check(PyObject* obj) 1450{ 1451 struct ast_state *state = get_ast_state(); 1452 if (state == NULL) { 1453 return -1; 1454 } 1455 return PyObject_IsInstance(obj, state->AST_type); 1456} 1457""" 1458 1459class ChainOfVisitors: 1460 def __init__(self, *visitors, metadata = None): 1461 self.visitors = visitors 1462 self.metadata = metadata 1463 1464 def visit(self, object): 1465 for v in self.visitors: 1466 v.metadata = self.metadata 1467 v.visit(object) 1468 v.emit("", 0) 1469 1470 1471def generate_ast_state(module_state, f): 1472 f.write('struct ast_state {\n') 1473 f.write(' int initialized;\n') 1474 f.write(' int recursion_depth;\n') 1475 f.write(' int recursion_limit;\n') 1476 for s in module_state: 1477 f.write(' PyObject *' + s + ';\n') 1478 f.write('};') 1479 1480 1481def generate_ast_fini(module_state, f): 1482 f.write(textwrap.dedent(""" 1483 void _PyAST_Fini(PyInterpreterState *interp) 1484 { 1485 struct ast_state *state = &interp->ast; 1486 1487 """)) 1488 for s in module_state: 1489 f.write(" Py_CLEAR(state->" + s + ');\n') 1490 f.write(textwrap.dedent(""" 1491 #if !defined(NDEBUG) 1492 state->initialized = -1; 1493 #else 1494 state->initialized = 0; 1495 #endif 1496 } 1497 1498 """)) 1499 1500 1501def generate_module_def(mod, metadata, f, internal_h): 1502 # Gather all the data needed for ModuleSpec 1503 state_strings = { 1504 "ast", 1505 "_fields", 1506 "__match_args__", 1507 "__doc__", 1508 "__dict__", 1509 "__module__", 1510 "_attributes", 1511 *metadata.identifiers 1512 } 1513 1514 module_state = state_strings.copy() 1515 module_state.update( 1516 "%s_singleton" % singleton 1517 for singleton in metadata.singletons 1518 ) 1519 module_state.update( 1520 "%s_type" % type 1521 for type in metadata.types 1522 ) 1523 1524 state_strings = sorted(state_strings) 1525 module_state = sorted(module_state) 1526 1527 generate_ast_state(module_state, internal_h) 1528 1529 print(textwrap.dedent(""" 1530 #include "Python.h" 1531 #include "pycore_ast.h" 1532 #include "pycore_ast_state.h" // struct ast_state 1533 #include "pycore_ceval.h" // _Py_EnterRecursiveCall 1534 #include "pycore_interp.h" // _PyInterpreterState.ast 1535 #include "pycore_pystate.h" // _PyInterpreterState_GET() 1536 #include "structmember.h" 1537 #include <stddef.h> 1538 1539 // Forward declaration 1540 static int init_types(struct ast_state *state); 1541 1542 static struct ast_state* 1543 get_ast_state(void) 1544 { 1545 PyInterpreterState *interp = _PyInterpreterState_GET(); 1546 struct ast_state *state = &interp->ast; 1547 if (!init_types(state)) { 1548 return NULL; 1549 } 1550 return state; 1551 } 1552 """).strip(), file=f) 1553 1554 generate_ast_fini(module_state, f) 1555 1556 f.write('static int init_identifiers(struct ast_state *state)\n') 1557 f.write('{\n') 1558 for identifier in state_strings: 1559 f.write(' if ((state->' + identifier) 1560 f.write(' = PyUnicode_InternFromString("') 1561 f.write(identifier + '")) == NULL) return 0;\n') 1562 f.write(' return 1;\n') 1563 f.write('};\n\n') 1564 1565def write_header(mod, metadata, f): 1566 f.write(textwrap.dedent(""" 1567 #ifndef Py_INTERNAL_AST_H 1568 #define Py_INTERNAL_AST_H 1569 #ifdef __cplusplus 1570 extern "C" { 1571 #endif 1572 1573 #ifndef Py_BUILD_CORE 1574 # error "this header requires Py_BUILD_CORE define" 1575 #endif 1576 1577 #include "pycore_asdl.h" 1578 1579 """).lstrip()) 1580 1581 c = ChainOfVisitors( 1582 TypeDefVisitor(f), 1583 SequenceDefVisitor(f), 1584 StructVisitor(f), 1585 metadata=metadata 1586 ) 1587 c.visit(mod) 1588 1589 f.write("// Note: these macros affect function definitions, not only call sites.\n") 1590 prototype_visitor = PrototypeVisitor(f, metadata=metadata) 1591 prototype_visitor.visit(mod) 1592 1593 f.write(textwrap.dedent(""" 1594 1595 PyObject* PyAST_mod2obj(mod_ty t); 1596 mod_ty PyAST_obj2mod(PyObject* ast, PyArena* arena, int mode); 1597 int PyAST_Check(PyObject* obj); 1598 1599 extern int _PyAST_Validate(mod_ty); 1600 1601 /* _PyAST_ExprAsUnicode is defined in ast_unparse.c */ 1602 extern PyObject* _PyAST_ExprAsUnicode(expr_ty); 1603 1604 /* Return the borrowed reference to the first literal string in the 1605 sequence of statements or NULL if it doesn't start from a literal string. 1606 Doesn't set exception. */ 1607 extern PyObject* _PyAST_GetDocString(asdl_stmt_seq *); 1608 1609 #ifdef __cplusplus 1610 } 1611 #endif 1612 #endif /* !Py_INTERNAL_AST_H */ 1613 """)) 1614 1615 1616def write_internal_h_header(mod, f): 1617 print(textwrap.dedent(""" 1618 #ifndef Py_INTERNAL_AST_STATE_H 1619 #define Py_INTERNAL_AST_STATE_H 1620 #ifdef __cplusplus 1621 extern "C" { 1622 #endif 1623 1624 #ifndef Py_BUILD_CORE 1625 # error "this header requires Py_BUILD_CORE define" 1626 #endif 1627 """).lstrip(), file=f) 1628 1629 1630def write_internal_h_footer(mod, f): 1631 print(textwrap.dedent(""" 1632 1633 #ifdef __cplusplus 1634 } 1635 #endif 1636 #endif /* !Py_INTERNAL_AST_STATE_H */ 1637 """), file=f) 1638 1639def write_source(mod, metadata, f, internal_h_file): 1640 generate_module_def(mod, metadata, f, internal_h_file) 1641 1642 v = ChainOfVisitors( 1643 SequenceConstructorVisitor(f), 1644 PyTypesDeclareVisitor(f), 1645 PyTypesVisitor(f), 1646 Obj2ModPrototypeVisitor(f), 1647 FunctionVisitor(f), 1648 ObjVisitor(f), 1649 Obj2ModVisitor(f), 1650 ASTModuleVisitor(f), 1651 PartingShots(f), 1652 metadata=metadata 1653 ) 1654 v.visit(mod) 1655 1656def main(input_filename, c_filename, h_filename, internal_h_filename, dump_module=False): 1657 auto_gen_msg = AUTOGEN_MESSAGE.format("/".join(Path(__file__).parts[-2:])) 1658 mod = asdl.parse(input_filename) 1659 if dump_module: 1660 print('Parsed Module:') 1661 print(mod) 1662 if not asdl.check(mod): 1663 sys.exit(1) 1664 1665 metadata_visitor = MetadataVisitor() 1666 metadata_visitor.visit(mod) 1667 metadata = metadata_visitor.metadata 1668 1669 with c_filename.open("w") as c_file, \ 1670 h_filename.open("w") as h_file, \ 1671 internal_h_filename.open("w") as internal_h_file: 1672 c_file.write(auto_gen_msg) 1673 h_file.write(auto_gen_msg) 1674 internal_h_file.write(auto_gen_msg) 1675 1676 write_internal_h_header(mod, internal_h_file) 1677 write_source(mod, metadata, c_file, internal_h_file) 1678 write_header(mod, metadata, h_file) 1679 write_internal_h_footer(mod, internal_h_file) 1680 1681 print(f"{c_filename}, {h_filename}, {internal_h_filename} regenerated.") 1682 1683if __name__ == "__main__": 1684 parser = ArgumentParser() 1685 parser.add_argument("input_file", type=Path) 1686 parser.add_argument("-C", "--c-file", type=Path, required=True) 1687 parser.add_argument("-H", "--h-file", type=Path, required=True) 1688 parser.add_argument("-I", "--internal-h-file", type=Path, required=True) 1689 parser.add_argument("-d", "--dump-module", action="store_true") 1690 1691 args = parser.parse_args() 1692 main(args.input_file, args.c_file, args.h_file, 1693 args.internal_h_file, args.dump_module) 1694