zmc
2023-08-08 e792e9a60d958b93aef96050644f369feb25d61b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import pytest
 
from functools import partial
import errno
 
import attr
 
import trio
from trio.testing import memory_stream_pair, wait_all_tasks_blocked
 
 
@attr.s(hash=False, eq=False)
class MemoryListener(trio.abc.Listener):
    closed = attr.ib(default=False)
    accepted_streams = attr.ib(factory=list)
    queued_streams = attr.ib(factory=(lambda: trio.open_memory_channel(1)))
    accept_hook = attr.ib(default=None)
 
    async def connect(self):
        assert not self.closed
        client, server = memory_stream_pair()
        await self.queued_streams[0].send(server)
        return client
 
    async def accept(self):
        await trio.lowlevel.checkpoint()
        assert not self.closed
        if self.accept_hook is not None:
            await self.accept_hook()
        stream = await self.queued_streams[1].receive()
        self.accepted_streams.append(stream)
        return stream
 
    async def aclose(self):
        self.closed = True
        await trio.lowlevel.checkpoint()
 
 
async def test_serve_listeners_basic():
    listeners = [MemoryListener(), MemoryListener()]
 
    record = []
 
    def close_hook():
        # Make sure this is a forceful close
        assert trio.current_effective_deadline() == float("-inf")
        record.append("closed")
 
    async def handler(stream):
        await stream.send_all(b"123")
        assert await stream.receive_some(10) == b"456"
        stream.send_stream.close_hook = close_hook
        stream.receive_stream.close_hook = close_hook
 
    async def client(listener):
        s = await listener.connect()
        assert await s.receive_some(10) == b"123"
        await s.send_all(b"456")
 
    async def do_tests(parent_nursery):
        async with trio.open_nursery() as nursery:
            for listener in listeners:
                for _ in range(3):
                    nursery.start_soon(client, listener)
 
        await wait_all_tasks_blocked()
 
        # verifies that all 6 streams x 2 directions each were closed ok
        assert len(record) == 12
 
        parent_nursery.cancel_scope.cancel()
 
    async with trio.open_nursery() as nursery:
        l2 = await nursery.start(trio.serve_listeners, handler, listeners)
        assert l2 == listeners
        # This is just split into another function because gh-136 isn't
        # implemented yet
        nursery.start_soon(do_tests, nursery)
 
    for listener in listeners:
        assert listener.closed
 
 
async def test_serve_listeners_accept_unrecognized_error():
    for error in [KeyError(), OSError(errno.ECONNABORTED, "ECONNABORTED")]:
        listener = MemoryListener()
 
        async def raise_error():
            raise error
 
        listener.accept_hook = raise_error
 
        with pytest.raises(type(error)) as excinfo:
            await trio.serve_listeners(None, [listener])
        assert excinfo.value is error
 
 
async def test_serve_listeners_accept_capacity_error(autojump_clock, caplog):
    listener = MemoryListener()
 
    async def raise_EMFILE():
        raise OSError(errno.EMFILE, "out of file descriptors")
 
    listener.accept_hook = raise_EMFILE
 
    # It retries every 100 ms, so in 950 ms it will retry at 0, 100, ..., 900
    # = 10 times total
    with trio.move_on_after(0.950):
        await trio.serve_listeners(None, [listener])
 
    assert len(caplog.records) == 10
    for record in caplog.records:
        assert "retrying" in record.msg
        assert record.exc_info[1].errno == errno.EMFILE
 
 
async def test_serve_listeners_connection_nursery(autojump_clock):
    listener = MemoryListener()
 
    async def handler(stream):
        await trio.sleep(1)
 
    class Done(Exception):
        pass
 
    async def connection_watcher(*, task_status=trio.TASK_STATUS_IGNORED):
        async with trio.open_nursery() as nursery:
            task_status.started(nursery)
            await wait_all_tasks_blocked()
            assert len(nursery.child_tasks) == 10
            raise Done
 
    with pytest.raises(Done):
        async with trio.open_nursery() as nursery:
            handler_nursery = await nursery.start(connection_watcher)
            await nursery.start(
                partial(
                    trio.serve_listeners,
                    handler,
                    [listener],
                    handler_nursery=handler_nursery,
                )
            )
            for _ in range(10):
                nursery.start_soon(listener.connect)