1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import pytest
 
import attr
 
from ..abc import SendStream, ReceiveStream
from .._highlevel_generic import StapledStream
 
 
@attr.s
class RecordSendStream(SendStream):
    record = attr.ib(factory=list)
 
    async def send_all(self, data):
        self.record.append(("send_all", data))
 
    async def wait_send_all_might_not_block(self):
        self.record.append("wait_send_all_might_not_block")
 
    async def aclose(self):
        self.record.append("aclose")
 
 
@attr.s
class RecordReceiveStream(ReceiveStream):
    record = attr.ib(factory=list)
 
    async def receive_some(self, max_bytes=None):
        self.record.append(("receive_some", max_bytes))
 
    async def aclose(self):
        self.record.append("aclose")
 
 
async def test_StapledStream():
    send_stream = RecordSendStream()
    receive_stream = RecordReceiveStream()
    stapled = StapledStream(send_stream, receive_stream)
 
    assert stapled.send_stream is send_stream
    assert stapled.receive_stream is receive_stream
 
    await stapled.send_all(b"foo")
    await stapled.wait_send_all_might_not_block()
    assert send_stream.record == [
        ("send_all", b"foo"),
        "wait_send_all_might_not_block",
    ]
    send_stream.record.clear()
 
    await stapled.send_eof()
    assert send_stream.record == ["aclose"]
    send_stream.record.clear()
 
    async def fake_send_eof():
        send_stream.record.append("send_eof")
 
    send_stream.send_eof = fake_send_eof
    await stapled.send_eof()
    assert send_stream.record == ["send_eof"]
 
    send_stream.record.clear()
    assert receive_stream.record == []
 
    await stapled.receive_some(1234)
    assert receive_stream.record == [("receive_some", 1234)]
    assert send_stream.record == []
    receive_stream.record.clear()
 
    await stapled.aclose()
    assert receive_stream.record == ["aclose"]
    assert send_stream.record == ["aclose"]
 
 
async def test_StapledStream_with_erroring_close():
    # Make sure that if one of the aclose methods errors out, then the other
    # one still gets called.
    class BrokenSendStream(RecordSendStream):
        async def aclose(self):
            await super().aclose()
            raise ValueError
 
    class BrokenReceiveStream(RecordReceiveStream):
        async def aclose(self):
            await super().aclose()
            raise ValueError
 
    stapled = StapledStream(BrokenSendStream(), BrokenReceiveStream())
 
    with pytest.raises(ValueError) as excinfo:
        await stapled.aclose()
    assert isinstance(excinfo.value.__context__, ValueError)
 
    assert stapled.send_stream.record == ["aclose"]
    assert stapled.receive_stream.record == ["aclose"]