1 module concurrency.nursery;
2 
3 import concurrency.stoptoken : StopSource, StopToken, StopCallback, onStop;
4 import concurrency.thread : LocalThreadExecutor;
5 import concurrency.receiver : getStopToken;
6 import concurrency.scheduler : SchedulerObjectBase;
7 import std.typecons : Nullable;
8 
9 /// A Nursery is a place for senders to be ran in, while being a Sender itself.
10 /// Stopping the Nursery cancels all senders.
11 /// When any Sender completes with an Error all Senders are canceled as well.
12 /// Cancellation is signaled with a StopToken.
13 /// Senders themselves bare the responsibility to respond to stop requests.
14 /// When cancellation happens all Senders are waited on for completion.
15 /// Senders can be added to the Nursery at any time.
16 /// Senders are only started when the Nursery itself is being awaited on.
17 class Nursery : StopSource {
18   import concurrency.sender : isSender, OperationalStateBase;
19   import core.sync.mutex : Mutex;
20 
21   alias Value = void;
22   private {
23     Node[] operations;
24     struct Node {
25       OperationalStateBase state;
26       void start() @safe nothrow {
27         state.start();
28       }
29       size_t id;
30     }
31     Mutex mutex;
32     shared size_t busy = 0;
33     shared size_t counter = 0;
34     Exception exception; // first exception from sender, if any
35     ReceiverObject receiver;
36     StopCallback stopCallback;
37     Nursery assumeThreadSafe() @trusted shared nothrow {
38       return cast(Nursery)this;
39     }
40   }
41 
42   this() @safe shared {
43     import concurrency.utils : resetScheduler;
44     resetScheduler();
45     with(assumeThreadSafe) mutex = new Mutex();
46   }
47 
48   StopToken getStopToken() nothrow @trusted shared {
49     return StopToken(cast(Nursery)this);
50   }
51 
52   private auto getScheduler() nothrow @trusted shared {
53     return (cast()receiver).getScheduler();
54   }
55 
56   private void setError(Exception e, size_t id) nothrow @safe shared {
57     import core.atomic : cas;
58     with(assumeThreadSafe) cas(&exception, cast(Exception)null, e); // store exception if not already
59     done(id);
60     stop();
61   }
62 
63   private void done(size_t id) nothrow @trusted shared {
64     import std.algorithm : countUntil, remove;
65     import core.atomic : atomicOp;
66 
67     with (assumeThreadSafe) {
68       mutex.lock_nothrow();
69       auto idx = operations.countUntil!(o => o.id == id);
70       if (idx != -1)
71         operations = operations.remove(idx);
72       bool isDone = atomicOp!"-="(busy,1) == 0;
73       auto localReceiver = receiver;
74       auto localException = exception;
75       if (isDone) {
76         exception = null;
77         receiver = null;
78         stopCallback.dispose();
79         stopCallback = null;
80       }
81       mutex.unlock_nothrow();
82 
83       if (isDone && localReceiver !is null) {
84         if (localException !is null) {
85           localReceiver.setError(localException);
86         } else if (isStopRequested()) {
87           localReceiver.setDone();
88         } else {
89           try {
90             localReceiver.setValue();
91           } catch (Exception e) {
92             localReceiver.setError(e);
93           }
94         }
95       }
96     }
97   }
98 
99   void run(Sender)(Nullable!Sender sender) shared if (isSender!Sender) {
100     if (!sender.isNull)
101       run(sender.get());
102   }
103 
104   void run(Sender)(Sender sender) shared @trusted if (isSender!Sender) {
105     import std.typecons : Nullable;
106     import core.atomic : atomicOp;
107     import concurrency.sender : connectHeap;
108 
109     static if (is(Sender == class) || is(Sender == interface))
110       if (sender is null)
111         return;
112 
113     size_t id = atomicOp!"+="(counter, 1);
114     auto op = sender.connectHeap(NurseryReceiver!(Sender.Value)(this, id));
115 
116     mutex.lock_nothrow();
117     operations ~= cast(shared) Node(op, id);
118     atomicOp!"+="(busy, 1);
119     bool hasStarted = this.receiver !is null;
120     mutex.unlock_nothrow();
121 
122     if (hasStarted)
123       op.start();
124   }
125 
126   auto connect(Receiver)(return Receiver receiver) @trusted scope {
127     return (cast(shared)this).connect(receiver);
128   }
129 
130   auto connect(Receiver)(Receiver receiver) shared scope @safe {
131     final class ReceiverImpl : ReceiverObject {
132       Receiver receiver;
133       this(Receiver receiver) { this.receiver = receiver; }
134       void setValue() @safe { receiver.setValue(); }
135       void setDone() nothrow @safe { receiver.setDone(); }
136       void setError(Exception e) nothrow @safe { receiver.setError(e); }
137       SchedulerObjectBase getScheduler() nothrow @safe {
138         import concurrency.scheduler : toSchedulerObject;
139         return receiver.getScheduler().toSchedulerObject();
140       }
141     }
142     static struct Op {
143       shared Nursery nursery;
144       StopCallback cb;
145       ReceiverObject receiver;
146       @disable this(ref return scope typeof(this) rhs);
147       @disable this(this);
148       this(shared Nursery n, StopCallback cb, ReceiverObject r) {
149         nursery = n;
150         this.cb = cb;
151         receiver = r;
152       }
153       void start() nothrow scope @trusted {
154         nursery.setReceiver(receiver, cb);
155       }
156     }
157     auto stopToken = receiver.getStopToken();
158     auto cb = (()@trusted => stopToken.onStop(() shared nothrow @trusted => cast(void)this.stop()))();
159     return Op(this, cb, new ReceiverImpl(receiver));
160   }
161 
162   private void setReceiver(ReceiverObject r, StopCallback cb) nothrow @safe shared {
163     with(assumeThreadSafe) {
164       mutex.lock_nothrow();
165       assert(this.receiver is null, "Cannot await a nursery twice.");
166       receiver = r;
167       stopCallback = cb;
168       auto ops = operations.dup();
169       mutex.unlock_nothrow();
170 
171       // start all work
172       foreach(op; ops)
173         op.start();
174     }
175   }
176 }
177 
178 private interface ReceiverObject {
179   void setValue() @safe;
180   void setDone() nothrow @safe;
181   void setError(Exception e) nothrow @safe;
182   SchedulerObjectBase getScheduler() nothrow @safe;
183 }
184 
185 private struct NurseryReceiver(Value) {
186   shared Nursery nursery;
187   size_t id;
188   this(shared Nursery nursery, size_t id) {
189     this.nursery = nursery;
190     this.id = id;
191   }
192 
193   static if (is(Value == void)) {
194     void setValue() shared @safe {
195       (cast() this).setDone();
196     }
197     void setValue() @safe {
198       (cast() this).setDone();
199     }
200   } else {
201     void setValue(Value val) shared @trusted {
202       (cast() this).setDone();
203     }
204     void setValue(Value val) @safe {
205       nursery.done(id);
206     }
207   }
208 
209   void setDone() nothrow @safe {
210     nursery.done(id);
211   }
212 
213   void setError(Exception e) nothrow @safe {
214     nursery.setError(e, id);
215   }
216 
217   auto getStopToken() @safe {
218     return nursery.getStopToken();
219   }
220 
221   auto getScheduler() @safe {
222     return nursery.getScheduler();
223   }
224 }