1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from collections import defaultdict
 
import attr
from async_generator import asynccontextmanager
 
from .. import _core
from .. import _util
from .. import Event
 
if False:
    from typing import DefaultDict, Set
 
 
@attr.s(eq=False, hash=False)
class Sequencer(metaclass=_util.Final):
    """A convenience class for forcing code in different tasks to run in an
    explicit linear order.
 
    Instances of this class implement a ``__call__`` method which returns an
    async context manager. The idea is that you pass a sequence number to
    ``__call__`` to say where this block of code should go in the linear
    sequence. Block 0 starts immediately, and then block N doesn't start until
    block N-1 has finished.
 
    Example:
      An extremely elaborate way to print the numbers 0-5, in order::
 
         async def worker1(seq):
             async with seq(0):
                 print(0)
             async with seq(4):
                 print(4)
 
         async def worker2(seq):
             async with seq(2):
                 print(2)
             async with seq(5):
                 print(5)
 
         async def worker3(seq):
             async with seq(1):
                 print(1)
             async with seq(3):
                 print(3)
 
         async def main():
            seq = trio.testing.Sequencer()
            async with trio.open_nursery() as nursery:
                nursery.start_soon(worker1, seq)
                nursery.start_soon(worker2, seq)
                nursery.start_soon(worker3, seq)
 
    """
 
    _sequence_points = attr.ib(
        factory=lambda: defaultdict(Event), init=False
    )  # type: DefaultDict[int, Event]
    _claimed = attr.ib(factory=set, init=False)  # type: Set[int]
    _broken = attr.ib(default=False, init=False)
 
    @asynccontextmanager
    async def __call__(self, position: int):
        if position in self._claimed:
            raise RuntimeError("Attempted to re-use sequence point {}".format(position))
        if self._broken:
            raise RuntimeError("sequence broken!")
        self._claimed.add(position)
        if position != 0:
            try:
                await self._sequence_points[position].wait()
            except _core.Cancelled:
                self._broken = True
                for event in self._sequence_points.values():
                    event.set()
                raise RuntimeError("Sequencer wait cancelled -- sequence broken")
            else:
                if self._broken:
                    raise RuntimeError("sequence broken!")
        try:
            yield
        finally:
            self._sequence_points[position + 1].set()