1 module concurrency.sender;
2 
3 import concepts;
4 import std.traits : ReturnType, isCallable;
5 import core.time : Duration;
6 
7 // A Sender represents something that completes with either:
8 // 1. a value (which can be void)
9 // 2. completion, in response to cancellation
10 // 3. an Throwable
11 //
12 // Many things can be represented as a Sender.
13 // Threads, Fibers, coroutines, etc. In general, any async operation.
14 //
15 // A Sender is lazy. Work it represents is only started when
16 // the sender is connected to a receiver and explicitly started.
17 //
18 // Senders and Receivers go hand in hand. Senders send a value,
19 // Receivers receive one.
20 //
21 // Senders are useful because many Tasks can be represented as them,
22 // and any operation on top of senders then works on any one of those
23 // Tasks.
24 //
25 // The most common operation is `sync_wait`. It blocks the current
26 // execution context to await the Sender.
27 //
28 // There are many others as well. Like `when_all`, `retry`, `when_any`,
29 // etc. These algorithms can be used on any sender.
30 //
31 // Cancellation happens through StopTokens. A Sender can ask a Receiver
32 // for a StopToken. Default is a NeverStopToken but Receiver's can
33 // customize this.
34 //
35 // The StopToken can be polled or a callback can be registered with one.
36 //
37 // Senders enforce Structured Concurrency because work cannot be
38 // started unless it is awaited.
39 //
40 // These concepts are heavily inspired by several C++ proposals
41 // starting with http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p0443r14.html
42 
43 /// checks that T is a Sender
44 void checkSender(T)() @safe {
45   import concurrency.scheduler : SchedulerObjectBase;
46   import concurrency.stoptoken : StopToken;
47   T t = T.init;
48   struct Receiver {
49     static if (is(T.Value == void))
50       void setValue() {}
51     else
52       void setValue(T.Value) {}
53     void setDone() nothrow {}
54     void setError(Throwable e) nothrow {}
55     StopToken getStopToken() nothrow { return StopToken.init; }
56     SchedulerObjectBase getScheduler() nothrow { return null; }
57   }
58   OpType!(T, Receiver) op = t.connect(Receiver.init);
59   static if (!isValidOp!(T, Receiver))
60     pragma(msg, "Warning: ", T, "'s operation state is not returned via the stack");
61 }
62 enum isSender(T) = is(typeof(checkSender!T));
63 
64 /// It is ok for the operation state to be on the heap, but if it is on the stack we need to ensure any copies are elided. We can't be 100% sure (the compiler may still blit), but this is the best we can do.
65 template isValidOp(Sender, Receiver) {
66   import std.traits : isPointer;
67   import std.meta : allSatisfy;
68   alias overloads = __traits(getOverloads, Sender, "connect", true);
69   template isRVO(alias connect) {
70     static if (__traits(isTemplate, connect))
71       enum isRVO = __traits(isReturnOnStack, connect!Receiver);
72     else
73       enum isRVO = __traits(isReturnOnStack, connect);
74   }
75   alias Op = OpType!(Sender, Receiver);
76   enum isValidOp = isPointer!Op || is(Op == OperationObject) || is(Op == class) || (allSatisfy!(isRVO, overloads) && !__traits(isPOD, Op));
77 }
78 
79 /// A Sender that sends a single value of type T
80 struct ValueSender(T) {
81   static assert (models!(typeof(this), isSender));
82   alias Value = T;
83   static struct Op(Receiver) {
84     Receiver receiver;
85     static if (!is(T == void))
86       T value;
87     void start() nothrow @trusted scope {
88       import concurrency.receiver : setValueOrError;
89       static if (!is(T == void))
90         receiver.setValueOrError(value);
91       else
92         receiver.setValueOrError();
93     }
94   }
95   static if (!is(T == void))
96     T value;
97   Op!Receiver connect(Receiver)(return Receiver receiver) @safe scope return {
98     // ensure NRVO
99     static if (!is(T == void))
100       auto op = Op!(Receiver)(receiver, value);
101     else
102       auto op = Op!(Receiver)(receiver);
103     return op;
104   }
105 }
106 
107 auto just(T...)(T t) {
108   import std.typecons : tuple, Tuple;
109   static if (T.length == 1)
110     return ValueSender!(T[0])(t);
111   else
112     return ValueSender!(Tuple!T)(tuple(t));
113 }
114 
115 struct JustFromSender(Fun) {
116   static assert (models!(typeof(this), isSender));
117   alias Value = ReturnType!fun;
118   static struct Op(Receiver) {
119     Receiver receiver;
120     Fun fun;
121     void start() @trusted nothrow {
122       import std.traits : hasFunctionAttributes;
123       static if (hasFunctionAttributes!(Fun, "nothrow")) {
124         set();
125       } else {
126         try {
127           set();
128         } catch (Exception e) {
129           receiver.setError(e);
130         }
131       }
132     }
133     private void set() @safe {
134       import concurrency.receiver : setValueOrError;
135       static if (is(Value == void)) {
136         fun();
137         if (receiver.getStopToken.isStopRequested)
138           receiver.setDone();
139         else
140           receiver.setValue();
141       } else {
142         auto r = fun();
143         if (receiver.getStopToken.isStopRequested)
144           receiver.setDone();
145         else
146           receiver.setValue(r);
147       }
148     }
149   }
150   Fun fun;
151   Op!Receiver connect(Receiver)(return Receiver receiver) @safe scope return {
152     // ensure NRVO
153     auto op = Op!(Receiver)(receiver, fun);
154     return op;
155   }
156 }
157 
158 JustFromSender!(Fun) justFrom(Fun)(Fun fun) if (isCallable!Fun) {
159   import std.traits : hasFunctionAttributes, isFunction, isFunctionPointer;
160   import concurrency.utils : isThreadSafeFunction;
161   static assert(isThreadSafeFunction!Fun);
162   return JustFromSender!Fun(fun);
163 }
164 
165 /// A polymorphic sender of type T
166 interface SenderObjectBase(T) {
167   import concurrency.receiver;
168   import concurrency.scheduler : SchedulerObjectBase;
169   import concurrency.stoptoken : StopToken, stopTokenObject;
170   static assert (models!(typeof(this), isSender));
171   alias Value = T;
172   alias Op = OperationObject;
173   OperationObject connect(ReceiverObjectBase!(T) receiver) @safe;
174   OperationObject connect(Receiver)(return Receiver receiver) @trusted scope {
175     return connect(new class(receiver) ReceiverObjectBase!T {
176       Receiver receiver;
177       this(Receiver receiver) {
178         this.receiver = receiver;
179       }
180       static if (is(T == void)) {
181         void setValue() {
182           receiver.setValueOrError();
183         }
184       } else {
185         void setValue(T value) {
186           receiver.setValueOrError(value);
187         }
188       }
189       void setDone() nothrow {
190         receiver.setDone();
191       }
192       void setError(Throwable e) nothrow {
193         receiver.setError(e);
194       }
195       StopToken getStopToken() nothrow {
196         return stopTokenObject(receiver.getStopToken());
197       }
198       SchedulerObjectBase getScheduler() nothrow @safe {
199         import concurrency.scheduler : toSchedulerObject;
200         return receiver.getScheduler().toSchedulerObject;
201       }
202     });
203   }
204 }
205 
206 /// Type-erased operational state object
207 /// used in polymorphic senders
208 struct OperationObject {
209   private void delegate() nothrow shared _start;
210   void start() nothrow @trusted { _start(); }
211 }
212 
213 interface OperationalStateBase {
214   void start() @safe nothrow;
215 }
216 
217 /// calls connect on the Sender but stores the OperationState on the heap
218 OperationalStateBase connectHeap(Sender, Receiver)(Sender sender, Receiver receiver) {
219   alias State = typeof(sender.connect(receiver));
220   return new class(sender, receiver) OperationalStateBase {
221     State state;
222     this(Sender sender, Receiver receiver) {
223       state = sender.connect(receiver);
224     }
225     void start() @safe nothrow {
226       state.start();
227     }
228   };
229 }
230 
231 /// A class extending from SenderObjectBase that wraps any Sender
232 class SenderObjectImpl(Sender) : SenderObjectBase!(Sender.Value) {
233   import concurrency.receiver : ReceiverObjectBase;
234   static assert (models!(typeof(this), isSender));
235   private Sender sender;
236   this(Sender sender) {
237     this.sender = sender;
238   }
239   OperationObject connect(ReceiverObjectBase!(Sender.Value) receiver) @trusted {
240     auto state = sender.connectHeap(receiver);
241     return OperationObject(cast(typeof(OperationObject._start))&state.start);
242   }
243   OperationObject connect(Receiver)(Receiver receiver) {
244     auto base = cast(SenderObjectBase!(Sender.Value))this;
245     return base.connect(receiver);
246   }
247 }
248 
249 /// Converts any Sender to a polymorphic SenderObject
250 auto toSenderObject(Sender)(Sender sender) {
251   static assert(models!(Sender, isSender));
252   static if (is(Sender : SenderObjectBase!(Sender.Value))) {
253     return sender;
254   } else
255     return cast(SenderObjectBase!(Sender.Value))new SenderObjectImpl!(Sender)(sender);
256 }
257 
258 /// A sender that always sets an error
259 struct ThrowingSender {
260   alias Value = void;
261   static struct Op(Receiver) {
262     Receiver receiver;
263     void start() {
264       receiver.setError(new Exception("ThrowingSender"));
265     }
266   }
267   auto connect(Receiver)(return Receiver receiver) @safe scope return {
268     // ensure NRVO
269     auto op = Op!Receiver(receiver);
270     return op;
271   }
272 }
273 
274 /// A sender that always calls setDone
275 struct DoneSender {
276   static assert (models!(typeof(this), isSender));
277   alias Value = void;
278   static struct DoneOp(Receiver) {
279     Receiver receiver;
280     void start() nothrow @trusted scope {
281       receiver.setDone();
282     }
283   }
284   auto connect(Receiver)(return Receiver receiver) @safe scope return {
285     // ensure NRVO
286     auto op = DoneOp!(Receiver)(receiver);
287     return op;
288   }
289 }
290 
291 /// A sender that always calls setValue with no args
292 struct VoidSender {
293   static assert (models!(typeof(this), isSender));
294   alias Value = void;
295   struct VoidOp(Receiver) {
296     Receiver receiver;
297     void start() nothrow @trusted scope {
298       import concurrency.receiver : setValueOrError;
299       receiver.setValueOrError();
300     }
301   }
302   auto connect(Receiver)(return Receiver receiver) @safe scope return{
303     // ensure NRVO
304     auto op = VoidOp!Receiver(receiver);
305     return op;
306   }
307 }
308 /// A sender that always calls setError
309 struct ErrorSender {
310   static assert (models!(typeof(this), isSender));
311   alias Value = void;
312   Throwable exception;
313   static struct ErrorOp(Receiver) {
314     Receiver receiver;
315     Throwable exception;
316     void start() nothrow @trusted scope {
317       receiver.setError(exception);
318     }
319   }
320   auto connect(Receiver)(return Receiver receiver) @safe scope return {
321     // ensure NRVO
322     auto op = ErrorOp!(Receiver)(receiver, exception);
323     return op;
324   }
325 }
326 
327 template OpType(Sender, Receiver) {
328   static if (is(Sender.Op)) {
329     alias OpType = Sender.Op;
330   } else {
331     import std.traits : ReturnType;
332     import std.meta : staticMap;
333     template GetOpType(alias connect) {
334       static if (__traits(isTemplate, connect)) {
335         alias GetOpType = ReturnType!(connect!Receiver);//(Receiver.init));
336       } else {
337         alias GetOpType = ReturnType!(connect);//(Receiver.init));
338       }
339     }
340     alias overloads = __traits(getOverloads, Sender, "connect", true);
341     alias opTypes = staticMap!(GetOpType, overloads);
342     alias OpType = opTypes[0];
343   }
344 }
345 
346 /// A sender that delays before calling setValue
347 struct DelaySender {
348   alias Value = void;
349   Duration dur;
350   auto connect(Receiver)(return Receiver receiver) @safe return scope {
351     // ensure NRVO
352     auto op = receiver.getScheduler().scheduleAfter(dur).connect(receiver);
353     return op;
354   }
355 }
356 
357 auto delay(Duration dur) {
358   return DelaySender(dur);
359 }
360 
361 struct PromiseSenderOp(T, Receiver) {
362   import concurrency.stoptoken;
363   alias Sender = PromiseSender!T;
364   alias InternalValue = Sender.InternalValue;
365   shared Sender parent;
366   Receiver receiver;
367   StopCallback cb;
368   void start() nothrow @trusted scope {
369     parent.add(&(cast(shared)this).onValue);
370     cb = receiver.getStopToken.onStop(&(cast(shared)this).onStop);
371   }
372   void onStop() nothrow @trusted shared {
373     with(unshared) {
374       parent.remove(&(cast(shared)this).onValue);
375       receiver.setDone();
376     }
377   }
378   void onValue(InternalValue value) nothrow @safe shared {
379     import mir.algebraic : match;
380     with(unshared) {
381       value.match!((Sender.ValueRep v){
382           try {
383             static if (is(Value == void))
384               receiver.setValue();
385             else
386               receiver.setValue(v);
387           } catch (Exception e) {
388             /// TODO: dispose needs to be called in all cases, except
389             /// this onValue can sometimes be called immediately,
390             /// leaving no room to set cb.dispose...
391             cb.dispose();
392             receiver.setError(e);
393           }
394         }, (Throwable e){
395           receiver.setError(e);
396         }, (Sender.Done d){
397           receiver.setDone();
398         });
399     }
400   }
401   private auto ref unshared() @trusted nothrow shared {
402     return cast()this;
403   }
404 }
405 
406 class PromiseSender(T) {
407   import std.traits : ReturnType;
408   import concurrency.slist;
409   import concurrency.bitfield;
410   import mir.algebraic : Algebraic, match, Nullable;
411   static assert(models!(typeof(this), isSender));
412   alias Value = T;
413   static if (is(Value == void)) {
414     static struct ValueRep{}
415   } else
416     alias ValueRep = Value;
417   static struct Done{}
418   alias InternalValue = Algebraic!(Throwable, ValueRep, Done);
419   alias DG = void delegate(InternalValue) nothrow @safe shared;
420   private {
421     shared SList!DG dgs;
422     Nullable!InternalValue value;
423     enum Flags {
424       locked = 0x1,
425       completed = 0x2
426     }
427     SharedBitField!Flags counter;
428     void add(DG dg) @safe nothrow shared {
429       with(unshared) {
430         with(counter.lock()) {
431           if (was(Flags.completed)) {
432             auto val = value.get;
433             release(); // release early
434             dg(val);
435           } else {
436             dgs.pushBack(dg);
437           }
438         }
439       }
440     }
441     void remove(DG dg) @safe nothrow shared {
442       with (counter.lock()) {
443         if (was(Flags.completed)) {
444           release(); // release early
445         } else {
446           dgs.remove(dg);
447         }
448       }
449     }
450     private auto ref unshared() @trusted nothrow shared {
451       return cast()this;
452     }
453   }
454   private void pushImpl(P)(P t) @safe shared {
455     import std.exception : enforce;
456     with (counter.lock(Flags.completed)) {
457       enforce(!was(Flags.completed), "Can only complete once");
458       InternalValue val = InternalValue(t);
459       (cast()value) = val;
460       auto localDgs = dgs.release();
461       release();
462       foreach(dg; localDgs)
463         dg(val);
464     }
465   }
466   void cancel() @safe shared {
467     pushImpl(Done());
468   }
469   void error(Throwable e) @safe shared {
470     pushImpl(e);
471   }
472   void fulfill(T t) @safe shared {
473     pushImpl(t);
474   }
475   bool isCompleted() @trusted shared {
476     import core.atomic : MemoryOrder;
477     return (counter.load!(MemoryOrder.acq) & Flags.completed) > 0;
478   }
479   this() {
480     this.dgs = new shared SList!DG;
481   }
482   auto connect(Receiver)(return Receiver receiver) @trusted scope {
483     // ensure NRVO
484     auto op = (cast(shared)this).connect(receiver);
485     return op;
486   }
487   auto connect(Receiver)(return Receiver receiver) @safe shared scope return {
488     // ensure NRVO
489     auto op = PromiseSenderOp!(T, Receiver)(this, receiver);
490     return op;
491   }
492 }
493 
494 shared(PromiseSender!T) promise(T)() {
495   return new shared PromiseSender!T();
496 }