zmc
2023-08-08 e792e9a60d958b93aef96050644f369feb25d61b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
import attr
import pytest
from ... import _core, _abc
from .tutil import check_sequence_matches
 
 
@attr.s(eq=False, hash=False)
class TaskRecorder:
    record = attr.ib(factory=list)
 
    def before_run(self):
        self.record.append(("before_run",))
 
    def task_scheduled(self, task):
        self.record.append(("schedule", task))
 
    def before_task_step(self, task):
        assert task is _core.current_task()
        self.record.append(("before", task))
 
    def after_task_step(self, task):
        assert task is _core.current_task()
        self.record.append(("after", task))
 
    def after_run(self):
        self.record.append(("after_run",))
 
    def filter_tasks(self, tasks):
        for item in self.record:
            if item[0] in ("schedule", "before", "after") and item[1] in tasks:
                yield item
            if item[0] in ("before_run", "after_run"):
                yield item
 
 
def test_instruments(recwarn):
    r1 = TaskRecorder()
    r2 = TaskRecorder()
    r3 = TaskRecorder()
 
    task = None
 
    # We use a child task for this, because the main task does some extra
    # bookkeeping stuff that can leak into the instrument results, and we
    # don't want to deal with it.
    async def task_fn():
        nonlocal task
        task = _core.current_task()
 
        for _ in range(4):
            await _core.checkpoint()
        # replace r2 with r3, to test that we can manipulate them as we go
        _core.remove_instrument(r2)
        with pytest.raises(KeyError):
            _core.remove_instrument(r2)
        # add is idempotent
        _core.add_instrument(r3)
        _core.add_instrument(r3)
        for _ in range(1):
            await _core.checkpoint()
 
    async def main():
        async with _core.open_nursery() as nursery:
            nursery.start_soon(task_fn)
 
    _core.run(main, instruments=[r1, r2])
 
    # It sleeps 5 times, so it runs 6 times.  Note that checkpoint()
    # reschedules the task immediately upon yielding, before the
    # after_task_step event fires.
    expected = (
        [("before_run",), ("schedule", task)]
        + [("before", task), ("schedule", task), ("after", task)] * 5
        + [("before", task), ("after", task), ("after_run",)]
    )
    assert r1.record == r2.record + r3.record
    assert list(r1.filter_tasks([task])) == expected
 
 
def test_instruments_interleave():
    tasks = {}
 
    async def two_step1():
        tasks["t1"] = _core.current_task()
        await _core.checkpoint()
 
    async def two_step2():
        tasks["t2"] = _core.current_task()
        await _core.checkpoint()
 
    async def main():
        async with _core.open_nursery() as nursery:
            nursery.start_soon(two_step1)
            nursery.start_soon(two_step2)
 
    r = TaskRecorder()
    _core.run(main, instruments=[r])
 
    expected = [
        ("before_run",),
        ("schedule", tasks["t1"]),
        ("schedule", tasks["t2"]),
        {
            ("before", tasks["t1"]),
            ("schedule", tasks["t1"]),
            ("after", tasks["t1"]),
            ("before", tasks["t2"]),
            ("schedule", tasks["t2"]),
            ("after", tasks["t2"]),
        },
        {
            ("before", tasks["t1"]),
            ("after", tasks["t1"]),
            ("before", tasks["t2"]),
            ("after", tasks["t2"]),
        },
        ("after_run",),
    ]
    print(list(r.filter_tasks(tasks.values())))
    check_sequence_matches(list(r.filter_tasks(tasks.values())), expected)
 
 
def test_null_instrument():
    # undefined instrument methods are skipped
    class NullInstrument:
        def something_unrelated(self):
            pass  # pragma: no cover
 
    async def main():
        await _core.checkpoint()
 
    _core.run(main, instruments=[NullInstrument()])
 
 
def test_instrument_before_after_run():
    record = []
 
    class BeforeAfterRun:
        def before_run(self):
            record.append("before_run")
 
        def after_run(self):
            record.append("after_run")
 
    async def main():
        pass
 
    _core.run(main, instruments=[BeforeAfterRun()])
    assert record == ["before_run", "after_run"]
 
 
def test_instrument_task_spawn_exit():
    record = []
 
    class SpawnExitRecorder:
        def task_spawned(self, task):
            record.append(("spawned", task))
 
        def task_exited(self, task):
            record.append(("exited", task))
 
    async def main():
        return _core.current_task()
 
    main_task = _core.run(main, instruments=[SpawnExitRecorder()])
    assert ("spawned", main_task) in record
    assert ("exited", main_task) in record
 
 
# This test also tests having a crash before the initial task is even spawned,
# which is very difficult to handle.
def test_instruments_crash(caplog):
    record = []
 
    class BrokenInstrument:
        def task_scheduled(self, task):
            record.append("scheduled")
            raise ValueError("oops")
 
        def close(self):
            # Shouldn't be called -- tests that the instrument disabling logic
            # works right.
            record.append("closed")  # pragma: no cover
 
    async def main():
        record.append("main ran")
        return _core.current_task()
 
    r = TaskRecorder()
    main_task = _core.run(main, instruments=[r, BrokenInstrument()])
    assert record == ["scheduled", "main ran"]
    # the TaskRecorder kept going throughout, even though the BrokenInstrument
    # was disabled
    assert ("after", main_task) in r.record
    assert ("after_run",) in r.record
    # And we got a log message
    exc_type, exc_value, exc_traceback = caplog.records[0].exc_info
    assert exc_type is ValueError
    assert str(exc_value) == "oops"
    assert "Instrument has been disabled" in caplog.records[0].message
 
 
def test_instruments_monkeypatch():
    class NullInstrument(_abc.Instrument):
        pass
 
    instrument = NullInstrument()
 
    async def main():
        record = []
 
        # Changing the set of hooks implemented by an instrument after
        # it's installed doesn't make them start being called right away
        instrument.before_task_step = record.append
        await _core.checkpoint()
        await _core.checkpoint()
        assert len(record) == 0
 
        # But if we remove and re-add the instrument, the new hooks are
        # picked up
        _core.remove_instrument(instrument)
        _core.add_instrument(instrument)
        await _core.checkpoint()
        await _core.checkpoint()
        assert record.count(_core.current_task()) == 2
 
        _core.remove_instrument(instrument)
        await _core.checkpoint()
        await _core.checkpoint()
        assert record.count(_core.current_task()) == 2
 
    _core.run(main, instruments=[instrument])
 
 
def test_instrument_that_raises_on_getattr():
    class EvilInstrument:
        def task_exited(self, task):
            assert False  # pragma: no cover
 
        @property
        def after_run(self):
            raise ValueError("oops")
 
    async def main():
        with pytest.raises(ValueError):
            _core.add_instrument(EvilInstrument())
 
        # Make sure the instrument is fully removed from the per-method lists
        runner = _core.current_task()._runner
        assert "after_run" not in runner.instruments
        assert "task_exited" not in runner.instruments
 
    _core.run(main)