1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
#! /usr/bin/env python3
# -*- coding: utf-8 -`-
"""
Code generation script for class methods
to be exported as public API
"""
import argparse
import ast
import astor
import os
from pathlib import Path
import sys
 
from textwrap import indent
 
PREFIX = "_generated"
 
HEADER = """# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._instrumentation import Instrument
 
# fmt: off
"""
 
FOOTER = """# fmt: on
"""
 
TEMPLATE = """locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
    return{}GLOBAL_RUN_CONTEXT.{}.{}
except AttributeError:
    raise RuntimeError("must be called from async context")
"""
 
 
def is_function(node):
    """Check if the AST node is either a function
    or an async function
    """
    if isinstance(node, ast.FunctionDef) or isinstance(node, ast.AsyncFunctionDef):
        return True
    return False
 
 
def is_public(node):
    """Check if the AST node has a _public decorator"""
    if not is_function(node):
        return False
    for decorator in node.decorator_list:
        if isinstance(decorator, ast.Name) and decorator.id == "_public":
            return True
    return False
 
 
def get_public_methods(tree):
    """Return a list of methods marked as public.
    The function walks the given tree and extracts
    all objects that are functions which are marked
    public.
    """
    for node in ast.walk(tree):
        if is_public(node):
            yield node
 
 
def create_passthrough_args(funcdef):
    """Given a function definition, create a string that represents taking all
    the arguments from the function, and passing them through to another
    invocation of the same function.
 
    Example input: ast.parse("def f(a, *, b): ...")
    Example output: "(a, b=b)"
    """
    call_args = []
    for arg in funcdef.args.args:
        call_args.append(arg.arg)
    if funcdef.args.vararg:
        call_args.append("*" + funcdef.args.vararg.arg)
    for arg in funcdef.args.kwonlyargs:
        call_args.append(arg.arg + "=" + arg.arg)
    if funcdef.args.kwarg:
        call_args.append("**" + funcdef.args.kwarg.arg)
    return "({})".format(", ".join(call_args))
 
 
def gen_public_wrappers_source(source_path: Path, lookup_path: str) -> str:
    """Scan the given .py file for @_public decorators, and generate wrapper
    functions.
 
    """
    generated = [HEADER]
    source = astor.code_to_ast.parse_file(source_path)
    for method in get_public_methods(source):
        # Remove self from arguments
        assert method.args.args[0].arg == "self"
        del method.args.args[0]
 
        # Remove decorators
        method.decorator_list = []
 
        # Create pass through arguments
        new_args = create_passthrough_args(method)
 
        # Remove method body without the docstring
        if ast.get_docstring(method) is None:
            del method.body[:]
        else:
            # The first entry is always the docstring
            del method.body[1:]
 
        # Create the function definition including the body
        func = astor.to_source(method, indent_with=" " * 4)
 
        # Create export function body
        template = TEMPLATE.format(
            " await " if isinstance(method, ast.AsyncFunctionDef) else " ",
            lookup_path,
            method.name + new_args,
        )
 
        # Assemble function definition arguments and body
        snippet = func + indent(template, " " * 4)
 
        # Append the snippet to the corresponding module
        generated.append(snippet)
    generated.append(FOOTER)
    return "\n\n".join(generated)
 
 
def matches_disk_files(new_files):
    for new_path, new_source in new_files.items():
        if not os.path.exists(new_path):
            return False
        with open(new_path, "r", encoding="utf-8") as old_file:
            old_source = old_file.read()
        if old_source != new_source:
            return False
    return True
 
 
def process(sources_and_lookups, *, do_test):
    new_files = {}
    for source_path, lookup_path in sources_and_lookups:
        print("Scanning:", source_path)
        new_source = gen_public_wrappers_source(source_path, lookup_path)
        dirname, basename = os.path.split(source_path)
        new_path = os.path.join(dirname, PREFIX + basename)
        new_files[new_path] = new_source
    if do_test:
        if not matches_disk_files(new_files):
            print("Generated sources are outdated. Please regenerate.")
            sys.exit(1)
        else:
            print("Generated sources are up to date.")
    else:
        for new_path, new_source in new_files.items():
            with open(new_path, "w", encoding="utf-8") as f:
                f.write(new_source)
        print("Regenerated sources successfully.")
 
 
# This is in fact run in CI, but only in the formatting check job, which
# doesn't collect coverage.
def main():  # pragma: no cover
    parser = argparse.ArgumentParser(
        description="Generate python code for public api wrappers"
    )
    parser.add_argument(
        "--test", "-t", action="store_true", help="test if code is still up to date"
    )
    parsed_args = parser.parse_args()
 
    source_root = Path.cwd()
    # Double-check we found the right directory
    assert (source_root / "LICENSE").exists()
    core = source_root / "trio/_core"
    to_wrap = [
        (core / "_run.py", "runner"),
        (core / "_instrumentation.py", "runner.instruments"),
        (core / "_io_windows.py", "runner.io_manager"),
        (core / "_io_epoll.py", "runner.io_manager"),
        (core / "_io_kqueue.py", "runner.io_manager"),
    ]
 
    process(to_wrap, do_test=parsed_args.test)
 
 
if __name__ == "__main__":  # pragma: no cover
    main()