1#! /usr/bin/env python
2"""Generate C code from an ASDL description."""
3
4import os
5import sys
6import textwrap
7import types
8
9from argparse import ArgumentParser
10from contextlib import contextmanager
11from pathlib import Path
12
13import asdl
14
15TABSIZE = 4
16MAX_COL = 80
17AUTOGEN_MESSAGE = "// File automatically generated by {}.\n\n"
18
19def get_c_type(name):
20    """Return a string for the C name of the type.
21
22    This function special cases the default types provided by asdl.
23    """
24    if name in asdl.builtin_types:
25        return name
26    else:
27        return "%s_ty" % name
28
29def reflow_lines(s, depth):
30    """Reflow the line s indented depth tabs.
31
32    Return a sequence of lines where no line extends beyond MAX_COL
33    when properly indented.  The first line is properly indented based
34    exclusively on depth * TABSIZE.  All following lines -- these are
35    the reflowed lines generated by this function -- start at the same
36    column as the first character beyond the opening { in the first
37    line.
38    """
39    size = MAX_COL - depth * TABSIZE
40    if len(s) < size:
41        return [s]
42
43    lines = []
44    cur = s
45    padding = ""
46    while len(cur) > size:
47        i = cur.rfind(' ', 0, size)
48        # XXX this should be fixed for real
49        if i == -1 and 'GeneratorExp' in cur:
50            i = size + 3
51        assert i != -1, "Impossible line %d to reflow: %r" % (size, s)
52        lines.append(padding + cur[:i])
53        if len(lines) == 1:
54            # find new size based on brace
55            j = cur.find('{', 0, i)
56            if j >= 0:
57                j += 2 # account for the brace and the space after it
58                size -= j
59                padding = " " * j
60            else:
61                j = cur.find('(', 0, i)
62                if j >= 0:
63                    j += 1 # account for the paren (no space after it)
64                    size -= j
65                    padding = " " * j
66        cur = cur[i+1:]
67    else:
68        lines.append(padding + cur)
69    return lines
70
71def reflow_c_string(s, depth):
72    return '"%s"' % s.replace('\n', '\\n"\n%s"' % (' ' * depth * TABSIZE))
73
74def is_simple(sum_type):
75    """Return True if a sum is a simple.
76
77    A sum is simple if it's types have no fields and itself
78    doesn't have any attributes. Instances of these types are
79    cached at C level, and they act like singletons when propagating
80    parser generated nodes into Python level, e.g.
81    unaryop = Invert | Not | UAdd | USub
82    """
83
84    return not (
85        sum_type.attributes or
86        any(constructor.fields for constructor in sum_type.types)
87    )
88
89def asdl_of(name, obj):
90    if isinstance(obj, asdl.Product) or isinstance(obj, asdl.Constructor):
91        fields = ", ".join(map(str, obj.fields))
92        if fields:
93            fields = "({})".format(fields)
94        return "{}{}".format(name, fields)
95    else:
96        if is_simple(obj):
97            types = " | ".join(type.name for type in obj.types)
98        else:
99            sep = "\n{}| ".format(" " * (len(name) + 1))
100            types = sep.join(
101                asdl_of(type.name, type) for type in obj.types
102            )
103        return "{} = {}".format(name, types)
104
105class EmitVisitor(asdl.VisitorBase):
106    """Visit that emits lines"""
107
108    def __init__(self, file, metadata = None):
109        self.file = file
110        self._metadata = metadata
111        super(EmitVisitor, self).__init__()
112
113    def emit(self, s, depth, reflow=True):
114        # XXX reflow long lines?
115        if reflow:
116            lines = reflow_lines(s, depth)
117        else:
118            lines = [s]
119        for line in lines:
120            if line:
121                line = (" " * TABSIZE * depth) + line
122            self.file.write(line + "\n")
123
124    @property
125    def metadata(self):
126        if self._metadata is None:
127            raise ValueError(
128                "%s was expecting to be annnotated with metadata"
129                % type(self).__name__
130            )
131        return self._metadata
132
133    @metadata.setter
134    def metadata(self, value):
135        self._metadata = value
136
137class MetadataVisitor(asdl.VisitorBase):
138    ROOT_TYPE = "AST"
139
140    def __init__(self, *args, **kwargs):
141        super().__init__(*args, **kwargs)
142
143        # Metadata:
144        #    - simple_sums: Tracks the list of compound type
145        #                   names where all the constructors
146        #                   belonging to that type lack of any
147        #                   fields.
148        #    - identifiers: All identifiers used in the AST declarations
149        #    - singletons:  List of all constructors that originates from
150        #                   simple sums.
151        #    - types:       List of all top level type names
152        #
153        self.metadata = types.SimpleNamespace(
154            simple_sums=set(),
155            identifiers=set(),
156            singletons=set(),
157            types={self.ROOT_TYPE},
158        )
159
160    def visitModule(self, mod):
161        for dfn in mod.dfns:
162            self.visit(dfn)
163
164    def visitType(self, type):
165        self.visit(type.value, type.name)
166
167    def visitSum(self, sum, name):
168        self.metadata.types.add(name)
169
170        simple_sum = is_simple(sum)
171        if simple_sum:
172            self.metadata.simple_sums.add(name)
173
174        for constructor in sum.types:
175            if simple_sum:
176                self.metadata.singletons.add(constructor.name)
177            self.visitConstructor(constructor)
178        self.visitFields(sum.attributes)
179
180    def visitConstructor(self, constructor):
181        self.metadata.types.add(constructor.name)
182        self.visitFields(constructor.fields)
183
184    def visitProduct(self, product, name):
185        self.metadata.types.add(name)
186        self.visitFields(product.attributes)
187        self.visitFields(product.fields)
188
189    def visitFields(self, fields):
190        for field in fields:
191            self.visitField(field)
192
193    def visitField(self, field):
194        self.metadata.identifiers.add(field.name)
195
196
197class TypeDefVisitor(EmitVisitor):
198    def visitModule(self, mod):
199        for dfn in mod.dfns:
200            self.visit(dfn)
201
202    def visitType(self, type, depth=0):
203        self.visit(type.value, type.name, depth)
204
205    def visitSum(self, sum, name, depth):
206        if is_simple(sum):
207            self.simple_sum(sum, name, depth)
208        else:
209            self.sum_with_constructors(sum, name, depth)
210
211    def simple_sum(self, sum, name, depth):
212        enum = []
213        for i in range(len(sum.types)):
214            type = sum.types[i]
215            enum.append("%s=%d" % (type.name, i + 1))
216        enums = ", ".join(enum)
217        ctype = get_c_type(name)
218        s = "typedef enum _%s { %s } %s;" % (name, enums, ctype)
219        self.emit(s, depth)
220        self.emit("", depth)
221
222    def sum_with_constructors(self, sum, name, depth):
223        ctype = get_c_type(name)
224        s = "typedef struct _%(name)s *%(ctype)s;" % locals()
225        self.emit(s, depth)
226        self.emit("", depth)
227
228    def visitProduct(self, product, name, depth):
229        ctype = get_c_type(name)
230        s = "typedef struct _%(name)s *%(ctype)s;" % locals()
231        self.emit(s, depth)
232        self.emit("", depth)
233
234class SequenceDefVisitor(EmitVisitor):
235    def visitModule(self, mod):
236        for dfn in mod.dfns:
237            self.visit(dfn)
238
239    def visitType(self, type, depth=0):
240        self.visit(type.value, type.name, depth)
241
242    def visitSum(self, sum, name, depth):
243        if is_simple(sum):
244            return
245        self.emit_sequence_constructor(name, depth)
246
247    def emit_sequence_constructor(self, name,depth):
248        ctype = get_c_type(name)
249        self.emit("""\
250typedef struct {
251    _ASDL_SEQ_HEAD
252    %(ctype)s typed_elements[1];
253} asdl_%(name)s_seq;""" % locals(), reflow=False, depth=depth)
254        self.emit("", depth)
255        self.emit("asdl_%(name)s_seq *_Py_asdl_%(name)s_seq_new(Py_ssize_t size, PyArena *arena);" % locals(), depth)
256        self.emit("", depth)
257
258    def visitProduct(self, product, name, depth):
259        self.emit_sequence_constructor(name, depth)
260
261class StructVisitor(EmitVisitor):
262    """Visitor to generate typedefs for AST."""
263
264    def visitModule(self, mod):
265        for dfn in mod.dfns:
266            self.visit(dfn)
267
268    def visitType(self, type, depth=0):
269        self.visit(type.value, type.name, depth)
270
271    def visitSum(self, sum, name, depth):
272        if not is_simple(sum):
273            self.sum_with_constructors(sum, name, depth)
274
275    def sum_with_constructors(self, sum, name, depth):
276        def emit(s, depth=depth):
277            self.emit(s % sys._getframe(1).f_locals, depth)
278        enum = []
279        for i in range(len(sum.types)):
280            type = sum.types[i]
281            enum.append("%s_kind=%d" % (type.name, i + 1))
282
283        emit("enum _%(name)s_kind {" + ", ".join(enum) + "};")
284
285        emit("struct _%(name)s {")
286        emit("enum _%(name)s_kind kind;", depth + 1)
287        emit("union {", depth + 1)
288        for t in sum.types:
289            self.visit(t, depth + 2)
290        emit("} v;", depth + 1)
291        for field in sum.attributes:
292            # rudimentary attribute handling
293            type = str(field.type)
294            assert type in asdl.builtin_types, type
295            emit("%s %s;" % (type, field.name), depth + 1);
296        emit("};")
297        emit("")
298
299    def visitConstructor(self, cons, depth):
300        if cons.fields:
301            self.emit("struct {", depth)
302            for f in cons.fields:
303                self.visit(f, depth + 1)
304            self.emit("} %s;" % cons.name, depth)
305            self.emit("", depth)
306
307    def visitField(self, field, depth):
308        # XXX need to lookup field.type, because it might be something
309        # like a builtin...
310        ctype = get_c_type(field.type)
311        name = field.name
312        if field.seq:
313            if field.type in self.metadata.simple_sums:
314                self.emit("asdl_int_seq *%(name)s;" % locals(), depth)
315            else:
316                _type = field.type
317                self.emit("asdl_%(_type)s_seq *%(name)s;" % locals(), depth)
318        else:
319            self.emit("%(ctype)s %(name)s;" % locals(), depth)
320
321    def visitProduct(self, product, name, depth):
322        self.emit("struct _%(name)s {" % locals(), depth)
323        for f in product.fields:
324            self.visit(f, depth + 1)
325        for field in product.attributes:
326            # rudimentary attribute handling
327            type = str(field.type)
328            assert type in asdl.builtin_types, type
329            self.emit("%s %s;" % (type, field.name), depth + 1);
330        self.emit("};", depth)
331        self.emit("", depth)
332
333
334def ast_func_name(name):
335    return f"_PyAST_{name}"
336
337
338class PrototypeVisitor(EmitVisitor):
339    """Generate function prototypes for the .h file"""
340
341    def visitModule(self, mod):
342        for dfn in mod.dfns:
343            self.visit(dfn)
344
345    def visitType(self, type):
346        self.visit(type.value, type.name)
347
348    def visitSum(self, sum, name):
349        if is_simple(sum):
350            pass # XXX
351        else:
352            for t in sum.types:
353                self.visit(t, name, sum.attributes)
354
355    def get_args(self, fields):
356        """Return list of C argument into, one for each field.
357
358        Argument info is 3-tuple of a C type, variable name, and flag
359        that is true if type can be NULL.
360        """
361        args = []
362        unnamed = {}
363        for f in fields:
364            if f.name is None:
365                name = f.type
366                c = unnamed[name] = unnamed.get(name, 0) + 1
367                if c > 1:
368                    name = "name%d" % (c - 1)
369            else:
370                name = f.name
371            # XXX should extend get_c_type() to handle this
372            if f.seq:
373                if f.type in self.metadata.simple_sums:
374                    ctype = "asdl_int_seq *"
375                else:
376                    ctype = f"asdl_{f.type}_seq *"
377            else:
378                ctype = get_c_type(f.type)
379            args.append((ctype, name, f.opt or f.seq))
380        return args
381
382    def visitConstructor(self, cons, type, attrs):
383        args = self.get_args(cons.fields)
384        attrs = self.get_args(attrs)
385        ctype = get_c_type(type)
386        self.emit_function(cons.name, ctype, args, attrs)
387
388    def emit_function(self, name, ctype, args, attrs, union=True):
389        args = args + attrs
390        if args:
391            argstr = ", ".join(["%s %s" % (atype, aname)
392                                for atype, aname, opt in args])
393            argstr += ", PyArena *arena"
394        else:
395            argstr = "PyArena *arena"
396        self.emit("%s %s(%s);" % (ctype, ast_func_name(name), argstr), False)
397
398    def visitProduct(self, prod, name):
399        self.emit_function(name, get_c_type(name),
400                           self.get_args(prod.fields),
401                           self.get_args(prod.attributes),
402                           union=False)
403
404
405class FunctionVisitor(PrototypeVisitor):
406    """Visitor to generate constructor functions for AST."""
407
408    def emit_function(self, name, ctype, args, attrs, union=True):
409        def emit(s, depth=0, reflow=True):
410            self.emit(s, depth, reflow)
411        argstr = ", ".join(["%s %s" % (atype, aname)
412                            for atype, aname, opt in args + attrs])
413        if argstr:
414            argstr += ", PyArena *arena"
415        else:
416            argstr = "PyArena *arena"
417        self.emit("%s" % ctype, 0)
418        emit("%s(%s)" % (ast_func_name(name), argstr))
419        emit("{")
420        emit("%s p;" % ctype, 1)
421        for argtype, argname, opt in args:
422            if not opt and argtype != "int":
423                emit("if (!%s) {" % argname, 1)
424                emit("PyErr_SetString(PyExc_ValueError,", 2)
425                msg = "field '%s' is required for %s" % (argname, name)
426                emit('                "%s");' % msg,
427                     2, reflow=False)
428                emit('return NULL;', 2)
429                emit('}', 1)
430
431        emit("p = (%s)_PyArena_Malloc(arena, sizeof(*p));" % ctype, 1);
432        emit("if (!p)", 1)
433        emit("return NULL;", 2)
434        if union:
435            self.emit_body_union(name, args, attrs)
436        else:
437            self.emit_body_struct(name, args, attrs)
438        emit("return p;", 1)
439        emit("}")
440        emit("")
441
442    def emit_body_union(self, name, args, attrs):
443        def emit(s, depth=0, reflow=True):
444            self.emit(s, depth, reflow)
445        emit("p->kind = %s_kind;" % name, 1)
446        for argtype, argname, opt in args:
447            emit("p->v.%s.%s = %s;" % (name, argname, argname), 1)
448        for argtype, argname, opt in attrs:
449            emit("p->%s = %s;" % (argname, argname), 1)
450
451    def emit_body_struct(self, name, args, attrs):
452        def emit(s, depth=0, reflow=True):
453            self.emit(s, depth, reflow)
454        for argtype, argname, opt in args:
455            emit("p->%s = %s;" % (argname, argname), 1)
456        for argtype, argname, opt in attrs:
457            emit("p->%s = %s;" % (argname, argname), 1)
458
459
460class PickleVisitor(EmitVisitor):
461
462    def visitModule(self, mod):
463        for dfn in mod.dfns:
464            self.visit(dfn)
465
466    def visitType(self, type):
467        self.visit(type.value, type.name)
468
469    def visitSum(self, sum, name):
470        pass
471
472    def visitProduct(self, sum, name):
473        pass
474
475    def visitConstructor(self, cons, name):
476        pass
477
478    def visitField(self, sum):
479        pass
480
481
482class Obj2ModPrototypeVisitor(PickleVisitor):
483    def visitProduct(self, prod, name):
484        code = "static int obj2ast_%s(struct ast_state *state, PyObject* obj, %s* out, PyArena* arena);"
485        self.emit(code % (name, get_c_type(name)), 0)
486
487    visitSum = visitProduct
488
489
490class Obj2ModVisitor(PickleVisitor):
491
492    attribute_special_defaults = {
493        "end_lineno": "lineno",
494        "end_col_offset": "col_offset",
495    }
496
497    @contextmanager
498    def recursive_call(self, node, level):
499        self.emit('if (_Py_EnterRecursiveCall(" while traversing \'%s\' node")) {' % node, level, reflow=False)
500        self.emit('goto failed;', level + 1)
501        self.emit('}', level)
502        yield
503        self.emit('_Py_LeaveRecursiveCall();', level)
504
505    def funcHeader(self, name):
506        ctype = get_c_type(name)
507        self.emit("int", 0)
508        self.emit("obj2ast_%s(struct ast_state *state, PyObject* obj, %s* out, PyArena* arena)" % (name, ctype), 0)
509        self.emit("{", 0)
510        self.emit("int isinstance;", 1)
511        self.emit("", 0)
512
513    def sumTrailer(self, name, add_label=False):
514        self.emit("", 0)
515        # there's really nothing more we can do if this fails ...
516        error = "expected some sort of %s, but got %%R" % name
517        format = "PyErr_Format(PyExc_TypeError, \"%s\", obj);"
518        self.emit(format % error, 1, reflow=False)
519        if add_label:
520            self.emit("failed:", 1)
521            self.emit("Py_XDECREF(tmp);", 1)
522        self.emit("return 1;", 1)
523        self.emit("}", 0)
524        self.emit("", 0)
525
526    def simpleSum(self, sum, name):
527        self.funcHeader(name)
528        for t in sum.types:
529            line = ("isinstance = PyObject_IsInstance(obj, "
530                    "state->%s_type);")
531            self.emit(line % (t.name,), 1)
532            self.emit("if (isinstance == -1) {", 1)
533            self.emit("return 1;", 2)
534            self.emit("}", 1)
535            self.emit("if (isinstance) {", 1)
536            self.emit("*out = %s;" % t.name, 2)
537            self.emit("return 0;", 2)
538            self.emit("}", 1)
539        self.sumTrailer(name)
540
541    def buildArgs(self, fields):
542        return ", ".join(fields + ["arena"])
543
544    def complexSum(self, sum, name):
545        self.funcHeader(name)
546        self.emit("PyObject *tmp = NULL;", 1)
547        self.emit("PyObject *tp;", 1)
548        for a in sum.attributes:
549            self.visitAttributeDeclaration(a, name, sum=sum)
550        self.emit("", 0)
551        # XXX: should we only do this for 'expr'?
552        self.emit("if (obj == Py_None) {", 1)
553        self.emit("*out = NULL;", 2)
554        self.emit("return 0;", 2)
555        self.emit("}", 1)
556        for a in sum.attributes:
557            self.visitField(a, name, sum=sum, depth=1)
558        for t in sum.types:
559            self.emit("tp = state->%s_type;" % (t.name,), 1)
560            self.emit("isinstance = PyObject_IsInstance(obj, tp);", 1)
561            self.emit("if (isinstance == -1) {", 1)
562            self.emit("return 1;", 2)
563            self.emit("}", 1)
564            self.emit("if (isinstance) {", 1)
565            for f in t.fields:
566                self.visitFieldDeclaration(f, t.name, sum=sum, depth=2)
567            self.emit("", 0)
568            for f in t.fields:
569                self.visitField(f, t.name, sum=sum, depth=2)
570            args = [f.name for f in t.fields] + [a.name for a in sum.attributes]
571            self.emit("*out = %s(%s);" % (ast_func_name(t.name), self.buildArgs(args)), 2)
572            self.emit("if (*out == NULL) goto failed;", 2)
573            self.emit("return 0;", 2)
574            self.emit("}", 1)
575        self.sumTrailer(name, True)
576
577    def visitAttributeDeclaration(self, a, name, sum=sum):
578        ctype = get_c_type(a.type)
579        self.emit("%s %s;" % (ctype, a.name), 1)
580
581    def visitSum(self, sum, name):
582        if is_simple(sum):
583            self.simpleSum(sum, name)
584        else:
585            self.complexSum(sum, name)
586
587    def visitProduct(self, prod, name):
588        ctype = get_c_type(name)
589        self.emit("int", 0)
590        self.emit("obj2ast_%s(struct ast_state *state, PyObject* obj, %s* out, PyArena* arena)" % (name, ctype), 0)
591        self.emit("{", 0)
592        self.emit("PyObject* tmp = NULL;", 1)
593        for f in prod.fields:
594            self.visitFieldDeclaration(f, name, prod=prod, depth=1)
595        for a in prod.attributes:
596            self.visitFieldDeclaration(a, name, prod=prod, depth=1)
597        self.emit("", 0)
598        for f in prod.fields:
599            self.visitField(f, name, prod=prod, depth=1)
600        for a in prod.attributes:
601            self.visitField(a, name, prod=prod, depth=1)
602        args = [f.name for f in prod.fields]
603        args.extend([a.name for a in prod.attributes])
604        self.emit("*out = %s(%s);" % (ast_func_name(name), self.buildArgs(args)), 1)
605        self.emit("return 0;", 1)
606        self.emit("failed:", 0)
607        self.emit("Py_XDECREF(tmp);", 1)
608        self.emit("return 1;", 1)
609        self.emit("}", 0)
610        self.emit("", 0)
611
612    def visitFieldDeclaration(self, field, name, sum=None, prod=None, depth=0):
613        ctype = get_c_type(field.type)
614        if field.seq:
615            if self.isSimpleType(field):
616                self.emit("asdl_int_seq* %s;" % field.name, depth)
617            else:
618                _type = field.type
619                self.emit(f"asdl_{field.type}_seq* {field.name};", depth)
620        else:
621            ctype = get_c_type(field.type)
622            self.emit("%s %s;" % (ctype, field.name), depth)
623
624    def isNumeric(self, field):
625        return get_c_type(field.type) in ("int", "bool")
626
627    def isSimpleType(self, field):
628        return field.type in self.metadata.simple_sums or self.isNumeric(field)
629
630    def visitField(self, field, name, sum=None, prod=None, depth=0):
631        ctype = get_c_type(field.type)
632        line = "if (_PyObject_LookupAttr(obj, state->%s, &tmp) < 0) {"
633        self.emit(line % field.name, depth)
634        self.emit("return 1;", depth+1)
635        self.emit("}", depth)
636        if not field.opt:
637            self.emit("if (tmp == NULL) {", depth)
638            message = "required field \\\"%s\\\" missing from %s" % (field.name, name)
639            format = "PyErr_SetString(PyExc_TypeError, \"%s\");"
640            self.emit(format % message, depth+1, reflow=False)
641            self.emit("return 1;", depth+1)
642        else:
643            self.emit("if (tmp == NULL || tmp == Py_None) {", depth)
644            self.emit("Py_CLEAR(tmp);", depth+1)
645            if self.isNumeric(field):
646                if field.name in self.attribute_special_defaults:
647                    self.emit(
648                        "%s = %s;" % (field.name, self.attribute_special_defaults[field.name]),
649                        depth+1,
650                    )
651                else:
652                    self.emit("%s = 0;" % field.name, depth+1)
653            elif not self.isSimpleType(field):
654                self.emit("%s = NULL;" % field.name, depth+1)
655            else:
656                raise TypeError("could not determine the default value for %s" % field.name)
657        self.emit("}", depth)
658        self.emit("else {", depth)
659
660        self.emit("int res;", depth+1)
661        if field.seq:
662            self.emit("Py_ssize_t len;", depth+1)
663            self.emit("Py_ssize_t i;", depth+1)
664            self.emit("if (!PyList_Check(tmp)) {", depth+1)
665            self.emit("PyErr_Format(PyExc_TypeError, \"%s field \\\"%s\\\" must "
666                      "be a list, not a %%.200s\", _PyType_Name(Py_TYPE(tmp)));" %
667                      (name, field.name),
668                      depth+2, reflow=False)
669            self.emit("goto failed;", depth+2)
670            self.emit("}", depth+1)
671            self.emit("len = PyList_GET_SIZE(tmp);", depth+1)
672            if self.isSimpleType(field):
673                self.emit("%s = _Py_asdl_int_seq_new(len, arena);" % field.name, depth+1)
674            else:
675                self.emit("%s = _Py_asdl_%s_seq_new(len, arena);" % (field.name, field.type), depth+1)
676            self.emit("if (%s == NULL) goto failed;" % field.name, depth+1)
677            self.emit("for (i = 0; i < len; i++) {", depth+1)
678            self.emit("%s val;" % ctype, depth+2)
679            self.emit("PyObject *tmp2 = PyList_GET_ITEM(tmp, i);", depth+2)
680            self.emit("Py_INCREF(tmp2);", depth+2)
681            with self.recursive_call(name, depth+2):
682                self.emit("res = obj2ast_%s(state, tmp2, &val, arena);" %
683                          field.type, depth+2, reflow=False)
684            self.emit("Py_DECREF(tmp2);", depth+2)
685            self.emit("if (res != 0) goto failed;", depth+2)
686            self.emit("if (len != PyList_GET_SIZE(tmp)) {", depth+2)
687            self.emit("PyErr_SetString(PyExc_RuntimeError, \"%s field \\\"%s\\\" "
688                      "changed size during iteration\");" %
689                      (name, field.name),
690                      depth+3, reflow=False)
691            self.emit("goto failed;", depth+3)
692            self.emit("}", depth+2)
693            self.emit("asdl_seq_SET(%s, i, val);" % field.name, depth+2)
694            self.emit("}", depth+1)
695        else:
696            with self.recursive_call(name, depth+1):
697                self.emit("res = obj2ast_%s(state, tmp, &%s, arena);" %
698                          (field.type, field.name), depth+1)
699            self.emit("if (res != 0) goto failed;", depth+1)
700
701        self.emit("Py_CLEAR(tmp);", depth+1)
702        self.emit("}", depth)
703
704
705class SequenceConstructorVisitor(EmitVisitor):
706    def visitModule(self, mod):
707        for dfn in mod.dfns:
708            self.visit(dfn)
709
710    def visitType(self, type):
711        self.visit(type.value, type.name)
712
713    def visitProduct(self, prod, name):
714        self.emit_sequence_constructor(name, get_c_type(name))
715
716    def visitSum(self, sum, name):
717        if not is_simple(sum):
718            self.emit_sequence_constructor(name, get_c_type(name))
719
720    def emit_sequence_constructor(self, name, type):
721        self.emit(f"GENERATE_ASDL_SEQ_CONSTRUCTOR({name}, {type})", depth=0)
722
723class PyTypesDeclareVisitor(PickleVisitor):
724
725    def visitProduct(self, prod, name):
726        self.emit("static PyObject* ast2obj_%s(struct ast_state *state, void*);" % name, 0)
727        if prod.attributes:
728            self.emit("static const char * const %s_attributes[] = {" % name, 0)
729            for a in prod.attributes:
730                self.emit('"%s",' % a.name, 1)
731            self.emit("};", 0)
732        if prod.fields:
733            self.emit("static const char * const %s_fields[]={" % name,0)
734            for f in prod.fields:
735                self.emit('"%s",' % f.name, 1)
736            self.emit("};", 0)
737
738    def visitSum(self, sum, name):
739        if sum.attributes:
740            self.emit("static const char * const %s_attributes[] = {" % name, 0)
741            for a in sum.attributes:
742                self.emit('"%s",' % a.name, 1)
743            self.emit("};", 0)
744        ptype = "void*"
745        if is_simple(sum):
746            ptype = get_c_type(name)
747        self.emit("static PyObject* ast2obj_%s(struct ast_state *state, %s);" % (name, ptype), 0)
748        for t in sum.types:
749            self.visitConstructor(t, name)
750
751    def visitConstructor(self, cons, name):
752        if cons.fields:
753            self.emit("static const char * const %s_fields[]={" % cons.name, 0)
754            for t in cons.fields:
755                self.emit('"%s",' % t.name, 1)
756            self.emit("};",0)
757
758
759class PyTypesVisitor(PickleVisitor):
760
761    def visitModule(self, mod):
762        self.emit("""
763
764typedef struct {
765    PyObject_HEAD
766    PyObject *dict;
767} AST_object;
768
769static void
770ast_dealloc(AST_object *self)
771{
772    /* bpo-31095: UnTrack is needed before calling any callbacks */
773    PyTypeObject *tp = Py_TYPE(self);
774    PyObject_GC_UnTrack(self);
775    Py_CLEAR(self->dict);
776    freefunc free_func = PyType_GetSlot(tp, Py_tp_free);
777    assert(free_func != NULL);
778    free_func(self);
779    Py_DECREF(tp);
780}
781
782static int
783ast_traverse(AST_object *self, visitproc visit, void *arg)
784{
785    Py_VISIT(Py_TYPE(self));
786    Py_VISIT(self->dict);
787    return 0;
788}
789
790static int
791ast_clear(AST_object *self)
792{
793    Py_CLEAR(self->dict);
794    return 0;
795}
796
797static int
798ast_type_init(PyObject *self, PyObject *args, PyObject *kw)
799{
800    struct ast_state *state = get_ast_state();
801    if (state == NULL) {
802        return -1;
803    }
804
805    Py_ssize_t i, numfields = 0;
806    int res = -1;
807    PyObject *key, *value, *fields;
808    if (_PyObject_LookupAttr((PyObject*)Py_TYPE(self), state->_fields, &fields) < 0) {
809        goto cleanup;
810    }
811    if (fields) {
812        numfields = PySequence_Size(fields);
813        if (numfields == -1) {
814            goto cleanup;
815        }
816    }
817
818    res = 0; /* if no error occurs, this stays 0 to the end */
819    if (numfields < PyTuple_GET_SIZE(args)) {
820        PyErr_Format(PyExc_TypeError, "%.400s constructor takes at most "
821                     "%zd positional argument%s",
822                     _PyType_Name(Py_TYPE(self)),
823                     numfields, numfields == 1 ? "" : "s");
824        res = -1;
825        goto cleanup;
826    }
827    for (i = 0; i < PyTuple_GET_SIZE(args); i++) {
828        /* cannot be reached when fields is NULL */
829        PyObject *name = PySequence_GetItem(fields, i);
830        if (!name) {
831            res = -1;
832            goto cleanup;
833        }
834        res = PyObject_SetAttr(self, name, PyTuple_GET_ITEM(args, i));
835        Py_DECREF(name);
836        if (res < 0) {
837            goto cleanup;
838        }
839    }
840    if (kw) {
841        i = 0;  /* needed by PyDict_Next */
842        while (PyDict_Next(kw, &i, &key, &value)) {
843            int contains = PySequence_Contains(fields, key);
844            if (contains == -1) {
845                res = -1;
846                goto cleanup;
847            } else if (contains == 1) {
848                Py_ssize_t p = PySequence_Index(fields, key);
849                if (p == -1) {
850                    res = -1;
851                    goto cleanup;
852                }
853                if (p < PyTuple_GET_SIZE(args)) {
854                    PyErr_Format(PyExc_TypeError,
855                        "%.400s got multiple values for argument '%U'",
856                        Py_TYPE(self)->tp_name, key);
857                    res = -1;
858                    goto cleanup;
859                }
860            }
861            res = PyObject_SetAttr(self, key, value);
862            if (res < 0) {
863                goto cleanup;
864            }
865        }
866    }
867  cleanup:
868    Py_XDECREF(fields);
869    return res;
870}
871
872/* Pickling support */
873static PyObject *
874ast_type_reduce(PyObject *self, PyObject *unused)
875{
876    struct ast_state *state = get_ast_state();
877    if (state == NULL) {
878        return NULL;
879    }
880
881    PyObject *dict;
882    if (_PyObject_LookupAttr(self, state->__dict__, &dict) < 0) {
883        return NULL;
884    }
885    if (dict) {
886        return Py_BuildValue("O()N", Py_TYPE(self), dict);
887    }
888    return Py_BuildValue("O()", Py_TYPE(self));
889}
890
891static PyMemberDef ast_type_members[] = {
892    {"__dictoffset__", T_PYSSIZET, offsetof(AST_object, dict), READONLY},
893    {NULL}  /* Sentinel */
894};
895
896static PyMethodDef ast_type_methods[] = {
897    {"__reduce__", ast_type_reduce, METH_NOARGS, NULL},
898    {NULL}
899};
900
901static PyGetSetDef ast_type_getsets[] = {
902    {"__dict__", PyObject_GenericGetDict, PyObject_GenericSetDict},
903    {NULL}
904};
905
906static PyType_Slot AST_type_slots[] = {
907    {Py_tp_dealloc, ast_dealloc},
908    {Py_tp_getattro, PyObject_GenericGetAttr},
909    {Py_tp_setattro, PyObject_GenericSetAttr},
910    {Py_tp_traverse, ast_traverse},
911    {Py_tp_clear, ast_clear},
912    {Py_tp_members, ast_type_members},
913    {Py_tp_methods, ast_type_methods},
914    {Py_tp_getset, ast_type_getsets},
915    {Py_tp_init, ast_type_init},
916    {Py_tp_alloc, PyType_GenericAlloc},
917    {Py_tp_new, PyType_GenericNew},
918    {Py_tp_free, PyObject_GC_Del},
919    {0, 0},
920};
921
922static PyType_Spec AST_type_spec = {
923    "ast.AST",
924    sizeof(AST_object),
925    0,
926    Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC,
927    AST_type_slots
928};
929
930static PyObject *
931make_type(struct ast_state *state, const char *type, PyObject* base,
932          const char* const* fields, int num_fields, const char *doc)
933{
934    PyObject *fnames, *result;
935    int i;
936    fnames = PyTuple_New(num_fields);
937    if (!fnames) return NULL;
938    for (i = 0; i < num_fields; i++) {
939        PyObject *field = PyUnicode_InternFromString(fields[i]);
940        if (!field) {
941            Py_DECREF(fnames);
942            return NULL;
943        }
944        PyTuple_SET_ITEM(fnames, i, field);
945    }
946    result = PyObject_CallFunction((PyObject*)&PyType_Type, "s(O){OOOOOOOs}",
947                    type, base,
948                    state->_fields, fnames,
949                    state->__match_args__, fnames,
950                    state->__module__,
951                    state->ast,
952                    state->__doc__, doc);
953    Py_DECREF(fnames);
954    return result;
955}
956
957static int
958add_attributes(struct ast_state *state, PyObject *type, const char * const *attrs, int num_fields)
959{
960    int i, result;
961    PyObject *s, *l = PyTuple_New(num_fields);
962    if (!l)
963        return 0;
964    for (i = 0; i < num_fields; i++) {
965        s = PyUnicode_InternFromString(attrs[i]);
966        if (!s) {
967            Py_DECREF(l);
968            return 0;
969        }
970        PyTuple_SET_ITEM(l, i, s);
971    }
972    result = PyObject_SetAttr(type, state->_attributes, l) >= 0;
973    Py_DECREF(l);
974    return result;
975}
976
977/* Conversion AST -> Python */
978
979static PyObject* ast2obj_list(struct ast_state *state, asdl_seq *seq, PyObject* (*func)(struct ast_state *state, void*))
980{
981    Py_ssize_t i, n = asdl_seq_LEN(seq);
982    PyObject *result = PyList_New(n);
983    PyObject *value;
984    if (!result)
985        return NULL;
986    for (i = 0; i < n; i++) {
987        value = func(state, asdl_seq_GET_UNTYPED(seq, i));
988        if (!value) {
989            Py_DECREF(result);
990            return NULL;
991        }
992        PyList_SET_ITEM(result, i, value);
993    }
994    return result;
995}
996
997static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), void *o)
998{
999    if (!o)
1000        o = Py_None;
1001    Py_INCREF((PyObject*)o);
1002    return (PyObject*)o;
1003}
1004#define ast2obj_constant ast2obj_object
1005#define ast2obj_identifier ast2obj_object
1006#define ast2obj_string ast2obj_object
1007
1008static PyObject* ast2obj_int(struct ast_state *Py_UNUSED(state), long b)
1009{
1010    return PyLong_FromLong(b);
1011}
1012
1013/* Conversion Python -> AST */
1014
1015static int obj2ast_object(struct ast_state *Py_UNUSED(state), PyObject* obj, PyObject** out, PyArena* arena)
1016{
1017    if (obj == Py_None)
1018        obj = NULL;
1019    if (obj) {
1020        if (_PyArena_AddPyObject(arena, obj) < 0) {
1021            *out = NULL;
1022            return -1;
1023        }
1024        Py_INCREF(obj);
1025    }
1026    *out = obj;
1027    return 0;
1028}
1029
1030static int obj2ast_constant(struct ast_state *Py_UNUSED(state), PyObject* obj, PyObject** out, PyArena* arena)
1031{
1032    if (_PyArena_AddPyObject(arena, obj) < 0) {
1033        *out = NULL;
1034        return -1;
1035    }
1036    Py_INCREF(obj);
1037    *out = obj;
1038    return 0;
1039}
1040
1041static int obj2ast_identifier(struct ast_state *state, PyObject* obj, PyObject** out, PyArena* arena)
1042{
1043    if (!PyUnicode_CheckExact(obj) && obj != Py_None) {
1044        PyErr_SetString(PyExc_TypeError, "AST identifier must be of type str");
1045        return 1;
1046    }
1047    return obj2ast_object(state, obj, out, arena);
1048}
1049
1050static int obj2ast_string(struct ast_state *state, PyObject* obj, PyObject** out, PyArena* arena)
1051{
1052    if (!PyUnicode_CheckExact(obj) && !PyBytes_CheckExact(obj)) {
1053        PyErr_SetString(PyExc_TypeError, "AST string must be of type str");
1054        return 1;
1055    }
1056    return obj2ast_object(state, obj, out, arena);
1057}
1058
1059static int obj2ast_int(struct ast_state* Py_UNUSED(state), PyObject* obj, int* out, PyArena* arena)
1060{
1061    int i;
1062    if (!PyLong_Check(obj)) {
1063        PyErr_Format(PyExc_ValueError, "invalid integer value: %R", obj);
1064        return 1;
1065    }
1066
1067    i = _PyLong_AsInt(obj);
1068    if (i == -1 && PyErr_Occurred())
1069        return 1;
1070    *out = i;
1071    return 0;
1072}
1073
1074static int add_ast_fields(struct ast_state *state)
1075{
1076    PyObject *empty_tuple;
1077    empty_tuple = PyTuple_New(0);
1078    if (!empty_tuple ||
1079        PyObject_SetAttrString(state->AST_type, "_fields", empty_tuple) < 0 ||
1080        PyObject_SetAttrString(state->AST_type, "__match_args__", empty_tuple) < 0 ||
1081        PyObject_SetAttrString(state->AST_type, "_attributes", empty_tuple) < 0) {
1082        Py_XDECREF(empty_tuple);
1083        return -1;
1084    }
1085    Py_DECREF(empty_tuple);
1086    return 0;
1087}
1088
1089""", 0, reflow=False)
1090
1091        self.file.write(textwrap.dedent('''
1092            static int
1093            init_types(struct ast_state *state)
1094            {
1095                // init_types() must not be called after _PyAST_Fini()
1096                // has been called
1097                assert(state->initialized >= 0);
1098
1099                if (state->initialized) {
1100                    return 1;
1101                }
1102                if (init_identifiers(state) < 0) {
1103                    return 0;
1104                }
1105                state->AST_type = PyType_FromSpec(&AST_type_spec);
1106                if (!state->AST_type) {
1107                    return 0;
1108                }
1109                if (add_ast_fields(state) < 0) {
1110                    return 0;
1111                }
1112        '''))
1113        for dfn in mod.dfns:
1114            self.visit(dfn)
1115        self.file.write(textwrap.dedent('''
1116                state->recursion_depth = 0;
1117                state->recursion_limit = 0;
1118                state->initialized = 1;
1119                return 1;
1120            }
1121        '''))
1122
1123    def visitProduct(self, prod, name):
1124        if prod.fields:
1125            fields = name+"_fields"
1126        else:
1127            fields = "NULL"
1128        self.emit('state->%s_type = make_type(state, "%s", state->AST_type, %s, %d,' %
1129                        (name, name, fields, len(prod.fields)), 1)
1130        self.emit('%s);' % reflow_c_string(asdl_of(name, prod), 2), 2, reflow=False)
1131        self.emit("if (!state->%s_type) return 0;" % name, 1)
1132        if prod.attributes:
1133            self.emit("if (!add_attributes(state, state->%s_type, %s_attributes, %d)) return 0;" %
1134                            (name, name, len(prod.attributes)), 1)
1135        else:
1136            self.emit("if (!add_attributes(state, state->%s_type, NULL, 0)) return 0;" % name, 1)
1137        self.emit_defaults(name, prod.fields, 1)
1138        self.emit_defaults(name, prod.attributes, 1)
1139
1140    def visitSum(self, sum, name):
1141        self.emit('state->%s_type = make_type(state, "%s", state->AST_type, NULL, 0,' %
1142                  (name, name), 1)
1143        self.emit('%s);' % reflow_c_string(asdl_of(name, sum), 2), 2, reflow=False)
1144        self.emit("if (!state->%s_type) return 0;" % name, 1)
1145        if sum.attributes:
1146            self.emit("if (!add_attributes(state, state->%s_type, %s_attributes, %d)) return 0;" %
1147                            (name, name, len(sum.attributes)), 1)
1148        else:
1149            self.emit("if (!add_attributes(state, state->%s_type, NULL, 0)) return 0;" % name, 1)
1150        self.emit_defaults(name, sum.attributes, 1)
1151        simple = is_simple(sum)
1152        for t in sum.types:
1153            self.visitConstructor(t, name, simple)
1154
1155    def visitConstructor(self, cons, name, simple):
1156        if cons.fields:
1157            fields = cons.name+"_fields"
1158        else:
1159            fields = "NULL"
1160        self.emit('state->%s_type = make_type(state, "%s", state->%s_type, %s, %d,' %
1161                            (cons.name, cons.name, name, fields, len(cons.fields)), 1)
1162        self.emit('%s);' % reflow_c_string(asdl_of(cons.name, cons), 2), 2, reflow=False)
1163        self.emit("if (!state->%s_type) return 0;" % cons.name, 1)
1164        self.emit_defaults(cons.name, cons.fields, 1)
1165        if simple:
1166            self.emit("state->%s_singleton = PyType_GenericNew((PyTypeObject *)"
1167                      "state->%s_type, NULL, NULL);" %
1168                             (cons.name, cons.name), 1)
1169            self.emit("if (!state->%s_singleton) return 0;" % cons.name, 1)
1170
1171    def emit_defaults(self, name, fields, depth):
1172        for field in fields:
1173            if field.opt:
1174                self.emit('if (PyObject_SetAttr(state->%s_type, state->%s, Py_None) == -1)' %
1175                            (name, field.name), depth)
1176                self.emit("return 0;", depth+1)
1177
1178
1179class ASTModuleVisitor(PickleVisitor):
1180
1181    def visitModule(self, mod):
1182        self.emit("static int", 0)
1183        self.emit("astmodule_exec(PyObject *m)", 0)
1184        self.emit("{", 0)
1185        self.emit('struct ast_state *state = get_ast_state();', 1)
1186        self.emit('if (state == NULL) {', 1)
1187        self.emit('return -1;', 2)
1188        self.emit('}', 1)
1189        self.emit('if (PyModule_AddObjectRef(m, "AST", state->AST_type) < 0) {', 1)
1190        self.emit('return -1;', 2)
1191        self.emit('}', 1)
1192        self.emit('if (PyModule_AddIntMacro(m, PyCF_ALLOW_TOP_LEVEL_AWAIT) < 0) {', 1)
1193        self.emit("return -1;", 2)
1194        self.emit('}', 1)
1195        self.emit('if (PyModule_AddIntMacro(m, PyCF_ONLY_AST) < 0) {', 1)
1196        self.emit("return -1;", 2)
1197        self.emit('}', 1)
1198        self.emit('if (PyModule_AddIntMacro(m, PyCF_TYPE_COMMENTS) < 0) {', 1)
1199        self.emit("return -1;", 2)
1200        self.emit('}', 1)
1201        for dfn in mod.dfns:
1202            self.visit(dfn)
1203        self.emit("return 0;", 1)
1204        self.emit("}", 0)
1205        self.emit("", 0)
1206        self.emit("""
1207static PyModuleDef_Slot astmodule_slots[] = {
1208    {Py_mod_exec, astmodule_exec},
1209    {0, NULL}
1210};
1211
1212static struct PyModuleDef _astmodule = {
1213    PyModuleDef_HEAD_INIT,
1214    .m_name = "_ast",
1215    // The _ast module uses a per-interpreter state (PyInterpreterState.ast)
1216    .m_size = 0,
1217    .m_slots = astmodule_slots,
1218};
1219
1220PyMODINIT_FUNC
1221PyInit__ast(void)
1222{
1223    return PyModuleDef_Init(&_astmodule);
1224}
1225""".strip(), 0, reflow=False)
1226
1227    def visitProduct(self, prod, name):
1228        self.addObj(name)
1229
1230    def visitSum(self, sum, name):
1231        self.addObj(name)
1232        for t in sum.types:
1233            self.visitConstructor(t, name)
1234
1235    def visitConstructor(self, cons, name):
1236        self.addObj(cons.name)
1237
1238    def addObj(self, name):
1239        self.emit("if (PyModule_AddObjectRef(m, \"%s\", "
1240                  "state->%s_type) < 0) {" % (name, name), 1)
1241        self.emit("return -1;", 2)
1242        self.emit('}', 1)
1243
1244
1245class StaticVisitor(PickleVisitor):
1246    CODE = '''Very simple, always emit this static code.  Override CODE'''
1247
1248    def visit(self, object):
1249        self.emit(self.CODE, 0, reflow=False)
1250
1251
1252class ObjVisitor(PickleVisitor):
1253
1254    def func_begin(self, name):
1255        ctype = get_c_type(name)
1256        self.emit("PyObject*", 0)
1257        self.emit("ast2obj_%s(struct ast_state *state, void* _o)" % (name), 0)
1258        self.emit("{", 0)
1259        self.emit("%s o = (%s)_o;" % (ctype, ctype), 1)
1260        self.emit("PyObject *result = NULL, *value = NULL;", 1)
1261        self.emit("PyTypeObject *tp;", 1)
1262        self.emit('if (!o) {', 1)
1263        self.emit("Py_RETURN_NONE;", 2)
1264        self.emit("}", 1)
1265        self.emit("if (++state->recursion_depth > state->recursion_limit) {", 1)
1266        self.emit("PyErr_SetString(PyExc_RecursionError,", 2)
1267        self.emit('"maximum recursion depth exceeded during ast construction");', 3)
1268        self.emit("return 0;", 2)
1269        self.emit("}", 1)
1270
1271    def func_end(self):
1272        self.emit("state->recursion_depth--;", 1)
1273        self.emit("return result;", 1)
1274        self.emit("failed:", 0)
1275        self.emit("Py_XDECREF(value);", 1)
1276        self.emit("Py_XDECREF(result);", 1)
1277        self.emit("return NULL;", 1)
1278        self.emit("}", 0)
1279        self.emit("", 0)
1280
1281    def visitSum(self, sum, name):
1282        if is_simple(sum):
1283            self.simpleSum(sum, name)
1284            return
1285        self.func_begin(name)
1286        self.emit("switch (o->kind) {", 1)
1287        for i in range(len(sum.types)):
1288            t = sum.types[i]
1289            self.visitConstructor(t, i + 1, name)
1290        self.emit("}", 1)
1291        for a in sum.attributes:
1292            self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1)
1293            self.emit("if (!value) goto failed;", 1)
1294            self.emit('if (PyObject_SetAttr(result, state->%s, value) < 0)' % a.name, 1)
1295            self.emit('goto failed;', 2)
1296            self.emit('Py_DECREF(value);', 1)
1297        self.func_end()
1298
1299    def simpleSum(self, sum, name):
1300        self.emit("PyObject* ast2obj_%s(struct ast_state *state, %s_ty o)" % (name, name), 0)
1301        self.emit("{", 0)
1302        self.emit("switch(o) {", 1)
1303        for t in sum.types:
1304            self.emit("case %s:" % t.name, 2)
1305            self.emit("Py_INCREF(state->%s_singleton);" % t.name, 3)
1306            self.emit("return state->%s_singleton;" % t.name, 3)
1307        self.emit("}", 1)
1308        self.emit("Py_UNREACHABLE();", 1);
1309        self.emit("}", 0)
1310
1311    def visitProduct(self, prod, name):
1312        self.func_begin(name)
1313        self.emit("tp = (PyTypeObject *)state->%s_type;" % name, 1)
1314        self.emit("result = PyType_GenericNew(tp, NULL, NULL);", 1);
1315        self.emit("if (!result) return NULL;", 1)
1316        for field in prod.fields:
1317            self.visitField(field, name, 1, True)
1318        for a in prod.attributes:
1319            self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1)
1320            self.emit("if (!value) goto failed;", 1)
1321            self.emit("if (PyObject_SetAttr(result, state->%s, value) < 0)" % a.name, 1)
1322            self.emit('goto failed;', 2)
1323            self.emit('Py_DECREF(value);', 1)
1324        self.func_end()
1325
1326    def visitConstructor(self, cons, enum, name):
1327        self.emit("case %s_kind:" % cons.name, 1)
1328        self.emit("tp = (PyTypeObject *)state->%s_type;" % cons.name, 2)
1329        self.emit("result = PyType_GenericNew(tp, NULL, NULL);", 2);
1330        self.emit("if (!result) goto failed;", 2)
1331        for f in cons.fields:
1332            self.visitField(f, cons.name, 2, False)
1333        self.emit("break;", 2)
1334
1335    def visitField(self, field, name, depth, product):
1336        def emit(s, d):
1337            self.emit(s, depth + d)
1338        if product:
1339            value = "o->%s" % field.name
1340        else:
1341            value = "o->v.%s.%s" % (name, field.name)
1342        self.set(field, value, depth)
1343        emit("if (!value) goto failed;", 0)
1344        emit("if (PyObject_SetAttr(result, state->%s, value) == -1)" % field.name, 0)
1345        emit("goto failed;", 1)
1346        emit("Py_DECREF(value);", 0)
1347
1348    def set(self, field, value, depth):
1349        if field.seq:
1350            if field.type in self.metadata.simple_sums:
1351                # While the sequence elements are stored as void*,
1352                # simple sums expects an enum
1353                self.emit("{", depth)
1354                self.emit("Py_ssize_t i, n = asdl_seq_LEN(%s);" % value, depth+1)
1355                self.emit("value = PyList_New(n);", depth+1)
1356                self.emit("if (!value) goto failed;", depth+1)
1357                self.emit("for(i = 0; i < n; i++)", depth+1)
1358                # This cannot fail, so no need for error handling
1359                self.emit(
1360                    "PyList_SET_ITEM(value, i, ast2obj_{0}(state, ({0}_ty)asdl_seq_GET({1}, i)));".format(
1361                        field.type,
1362                        value
1363                    ),
1364                    depth + 2,
1365                    reflow=False,
1366                )
1367                self.emit("}", depth)
1368            else:
1369                self.emit("value = ast2obj_list(state, (asdl_seq*)%s, ast2obj_%s);" % (value, field.type), depth)
1370        else:
1371            self.emit("value = ast2obj_%s(state, %s);" % (field.type, value), depth, reflow=False)
1372
1373
1374class PartingShots(StaticVisitor):
1375
1376    CODE = """
1377PyObject* PyAST_mod2obj(mod_ty t)
1378{
1379    struct ast_state *state = get_ast_state();
1380    if (state == NULL) {
1381        return NULL;
1382    }
1383
1384    int recursion_limit = Py_GetRecursionLimit();
1385    int starting_recursion_depth;
1386    /* Be careful here to prevent overflow. */
1387    int COMPILER_STACK_FRAME_SCALE = 3;
1388    PyThreadState *tstate = _PyThreadState_GET();
1389    if (!tstate) {
1390        return 0;
1391    }
1392    state->recursion_limit = (recursion_limit < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
1393        recursion_limit * COMPILER_STACK_FRAME_SCALE : recursion_limit;
1394    int recursion_depth = tstate->recursion_limit - tstate->recursion_remaining;
1395    starting_recursion_depth = (recursion_depth < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
1396        recursion_depth * COMPILER_STACK_FRAME_SCALE : recursion_depth;
1397    state->recursion_depth = starting_recursion_depth;
1398
1399    PyObject *result = ast2obj_mod(state, t);
1400
1401    /* Check that the recursion depth counting balanced correctly */
1402    if (result && state->recursion_depth != starting_recursion_depth) {
1403        PyErr_Format(PyExc_SystemError,
1404            "AST constructor recursion depth mismatch (before=%d, after=%d)",
1405            starting_recursion_depth, state->recursion_depth);
1406        return 0;
1407    }
1408    return result;
1409}
1410
1411/* mode is 0 for "exec", 1 for "eval" and 2 for "single" input */
1412mod_ty PyAST_obj2mod(PyObject* ast, PyArena* arena, int mode)
1413{
1414    const char * const req_name[] = {"Module", "Expression", "Interactive"};
1415    int isinstance;
1416
1417    if (PySys_Audit("compile", "OO", ast, Py_None) < 0) {
1418        return NULL;
1419    }
1420
1421    struct ast_state *state = get_ast_state();
1422    if (state == NULL) {
1423        return NULL;
1424    }
1425
1426    PyObject *req_type[3];
1427    req_type[0] = state->Module_type;
1428    req_type[1] = state->Expression_type;
1429    req_type[2] = state->Interactive_type;
1430
1431    assert(0 <= mode && mode <= 2);
1432
1433    isinstance = PyObject_IsInstance(ast, req_type[mode]);
1434    if (isinstance == -1)
1435        return NULL;
1436    if (!isinstance) {
1437        PyErr_Format(PyExc_TypeError, "expected %s node, got %.400s",
1438                     req_name[mode], _PyType_Name(Py_TYPE(ast)));
1439        return NULL;
1440    }
1441
1442    mod_ty res = NULL;
1443    if (obj2ast_mod(state, ast, &res, arena) != 0)
1444        return NULL;
1445    else
1446        return res;
1447}
1448
1449int PyAST_Check(PyObject* obj)
1450{
1451    struct ast_state *state = get_ast_state();
1452    if (state == NULL) {
1453        return -1;
1454    }
1455    return PyObject_IsInstance(obj, state->AST_type);
1456}
1457"""
1458
1459class ChainOfVisitors:
1460    def __init__(self, *visitors, metadata = None):
1461        self.visitors = visitors
1462        self.metadata = metadata
1463
1464    def visit(self, object):
1465        for v in self.visitors:
1466            v.metadata = self.metadata
1467            v.visit(object)
1468            v.emit("", 0)
1469
1470
1471def generate_ast_state(module_state, f):
1472    f.write('struct ast_state {\n')
1473    f.write('    int initialized;\n')
1474    f.write('    int recursion_depth;\n')
1475    f.write('    int recursion_limit;\n')
1476    for s in module_state:
1477        f.write('    PyObject *' + s + ';\n')
1478    f.write('};')
1479
1480
1481def generate_ast_fini(module_state, f):
1482    f.write(textwrap.dedent("""
1483            void _PyAST_Fini(PyInterpreterState *interp)
1484            {
1485                struct ast_state *state = &interp->ast;
1486
1487    """))
1488    for s in module_state:
1489        f.write("    Py_CLEAR(state->" + s + ');\n')
1490    f.write(textwrap.dedent("""
1491            #if !defined(NDEBUG)
1492                state->initialized = -1;
1493            #else
1494                state->initialized = 0;
1495            #endif
1496            }
1497
1498    """))
1499
1500
1501def generate_module_def(mod, metadata, f, internal_h):
1502    # Gather all the data needed for ModuleSpec
1503    state_strings = {
1504        "ast",
1505        "_fields",
1506        "__match_args__",
1507        "__doc__",
1508        "__dict__",
1509        "__module__",
1510        "_attributes",
1511        *metadata.identifiers
1512    }
1513
1514    module_state = state_strings.copy()
1515    module_state.update(
1516        "%s_singleton" % singleton
1517        for singleton in metadata.singletons
1518    )
1519    module_state.update(
1520        "%s_type" % type
1521        for type in metadata.types
1522    )
1523
1524    state_strings = sorted(state_strings)
1525    module_state = sorted(module_state)
1526
1527    generate_ast_state(module_state, internal_h)
1528
1529    print(textwrap.dedent("""
1530        #include "Python.h"
1531        #include "pycore_ast.h"
1532        #include "pycore_ast_state.h"     // struct ast_state
1533        #include "pycore_ceval.h"         // _Py_EnterRecursiveCall
1534        #include "pycore_interp.h"        // _PyInterpreterState.ast
1535        #include "pycore_pystate.h"       // _PyInterpreterState_GET()
1536        #include "structmember.h"
1537        #include <stddef.h>
1538
1539        // Forward declaration
1540        static int init_types(struct ast_state *state);
1541
1542        static struct ast_state*
1543        get_ast_state(void)
1544        {
1545            PyInterpreterState *interp = _PyInterpreterState_GET();
1546            struct ast_state *state = &interp->ast;
1547            if (!init_types(state)) {
1548                return NULL;
1549            }
1550            return state;
1551        }
1552    """).strip(), file=f)
1553
1554    generate_ast_fini(module_state, f)
1555
1556    f.write('static int init_identifiers(struct ast_state *state)\n')
1557    f.write('{\n')
1558    for identifier in state_strings:
1559        f.write('    if ((state->' + identifier)
1560        f.write(' = PyUnicode_InternFromString("')
1561        f.write(identifier + '")) == NULL) return 0;\n')
1562    f.write('    return 1;\n')
1563    f.write('};\n\n')
1564
1565def write_header(mod, metadata, f):
1566    f.write(textwrap.dedent("""
1567        #ifndef Py_INTERNAL_AST_H
1568        #define Py_INTERNAL_AST_H
1569        #ifdef __cplusplus
1570        extern "C" {
1571        #endif
1572
1573        #ifndef Py_BUILD_CORE
1574        #  error "this header requires Py_BUILD_CORE define"
1575        #endif
1576
1577        #include "pycore_asdl.h"
1578
1579    """).lstrip())
1580
1581    c = ChainOfVisitors(
1582        TypeDefVisitor(f),
1583        SequenceDefVisitor(f),
1584        StructVisitor(f),
1585        metadata=metadata
1586    )
1587    c.visit(mod)
1588
1589    f.write("// Note: these macros affect function definitions, not only call sites.\n")
1590    prototype_visitor = PrototypeVisitor(f, metadata=metadata)
1591    prototype_visitor.visit(mod)
1592
1593    f.write(textwrap.dedent("""
1594
1595        PyObject* PyAST_mod2obj(mod_ty t);
1596        mod_ty PyAST_obj2mod(PyObject* ast, PyArena* arena, int mode);
1597        int PyAST_Check(PyObject* obj);
1598
1599        extern int _PyAST_Validate(mod_ty);
1600
1601        /* _PyAST_ExprAsUnicode is defined in ast_unparse.c */
1602        extern PyObject* _PyAST_ExprAsUnicode(expr_ty);
1603
1604        /* Return the borrowed reference to the first literal string in the
1605           sequence of statements or NULL if it doesn't start from a literal string.
1606           Doesn't set exception. */
1607        extern PyObject* _PyAST_GetDocString(asdl_stmt_seq *);
1608
1609        #ifdef __cplusplus
1610        }
1611        #endif
1612        #endif /* !Py_INTERNAL_AST_H */
1613    """))
1614
1615
1616def write_internal_h_header(mod, f):
1617    print(textwrap.dedent("""
1618        #ifndef Py_INTERNAL_AST_STATE_H
1619        #define Py_INTERNAL_AST_STATE_H
1620        #ifdef __cplusplus
1621        extern "C" {
1622        #endif
1623
1624        #ifndef Py_BUILD_CORE
1625        #  error "this header requires Py_BUILD_CORE define"
1626        #endif
1627    """).lstrip(), file=f)
1628
1629
1630def write_internal_h_footer(mod, f):
1631    print(textwrap.dedent("""
1632
1633        #ifdef __cplusplus
1634        }
1635        #endif
1636        #endif /* !Py_INTERNAL_AST_STATE_H */
1637    """), file=f)
1638
1639def write_source(mod, metadata, f, internal_h_file):
1640    generate_module_def(mod, metadata, f, internal_h_file)
1641
1642    v = ChainOfVisitors(
1643        SequenceConstructorVisitor(f),
1644        PyTypesDeclareVisitor(f),
1645        PyTypesVisitor(f),
1646        Obj2ModPrototypeVisitor(f),
1647        FunctionVisitor(f),
1648        ObjVisitor(f),
1649        Obj2ModVisitor(f),
1650        ASTModuleVisitor(f),
1651        PartingShots(f),
1652        metadata=metadata
1653    )
1654    v.visit(mod)
1655
1656def main(input_filename, c_filename, h_filename, internal_h_filename, dump_module=False):
1657    auto_gen_msg = AUTOGEN_MESSAGE.format("/".join(Path(__file__).parts[-2:]))
1658    mod = asdl.parse(input_filename)
1659    if dump_module:
1660        print('Parsed Module:')
1661        print(mod)
1662    if not asdl.check(mod):
1663        sys.exit(1)
1664
1665    metadata_visitor = MetadataVisitor()
1666    metadata_visitor.visit(mod)
1667    metadata = metadata_visitor.metadata
1668
1669    with c_filename.open("w") as c_file, \
1670         h_filename.open("w") as h_file, \
1671         internal_h_filename.open("w") as internal_h_file:
1672        c_file.write(auto_gen_msg)
1673        h_file.write(auto_gen_msg)
1674        internal_h_file.write(auto_gen_msg)
1675
1676        write_internal_h_header(mod, internal_h_file)
1677        write_source(mod, metadata, c_file, internal_h_file)
1678        write_header(mod, metadata, h_file)
1679        write_internal_h_footer(mod, internal_h_file)
1680
1681    print(f"{c_filename}, {h_filename}, {internal_h_filename} regenerated.")
1682
1683if __name__ == "__main__":
1684    parser = ArgumentParser()
1685    parser.add_argument("input_file", type=Path)
1686    parser.add_argument("-C", "--c-file", type=Path, required=True)
1687    parser.add_argument("-H", "--h-file", type=Path, required=True)
1688    parser.add_argument("-I", "--internal-h-file", type=Path, required=True)
1689    parser.add_argument("-d", "--dump-module", action="store_true")
1690
1691    args = parser.parse_args()
1692    main(args.input_file, args.c_file, args.h_file,
1693         args.internal_h_file, args.dump_module)
1694