1 module concurrency.sender;
2 
3 import concepts;
4 import std.traits : ReturnType, isCallable;
5 import core.time : Duration;
6 
7 // A Sender represents something that completes with either:
8 // 1. a value (which can be void)
9 // 2. completion, in response to cancellation
10 // 3. an Throwable
11 //
12 // Many things can be represented as a Sender.
13 // Threads, Fibers, coroutines, etc. In general, any async operation.
14 //
15 // A Sender is lazy. Work it represents is only started when
16 // the sender is connected to a receiver and explicitly started.
17 //
18 // Senders and Receivers go hand in hand. Senders send a value,
19 // Receivers receive one.
20 //
21 // Senders are useful because many Tasks can be represented as them,
22 // and any operation on top of senders then works on any one of those
23 // Tasks.
24 //
25 // The most common operation is `sync_wait`. It blocks the current
26 // execution context to await the Sender.
27 //
28 // There are many others as well. Like `when_all`, `retry`, `when_any`,
29 // etc. These algorithms can be used on any sender.
30 //
31 // Cancellation happens through StopTokens. A Sender can ask a Receiver
32 // for a StopToken. Default is a NeverStopToken but Receiver's can
33 // customize this.
34 //
35 // The StopToken can be polled or a callback can be registered with one.
36 //
37 // Senders enforce Structured Concurrency because work cannot be
38 // started unless it is awaited.
39 //
40 // These concepts are heavily inspired by several C++ proposals
41 // starting with http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p0443r14.html
42 
43 /// checks that T is a Sender
44 void checkSender(T)() @safe {
45   import concurrency.scheduler : SchedulerObjectBase;
46   import concurrency.stoptoken : StopToken;
47   T t = T.init;
48   struct Scheduler {
49     import core.time : Duration;
50     auto schedule() @safe { return VoidSender(); }
51     auto scheduleAfter(Duration) @safe { return VoidSender(); }
52   }
53   static struct Receiver {
54     int* i; // force it scope
55     static if (is(T.Value == void))
56       void setValue() @safe {}
57     else
58       void setValue(T.Value) @safe {}
59     void setDone() @safe nothrow {}
60     void setError(Throwable e) @safe nothrow {}
61     StopToken getStopToken() @safe nothrow { return StopToken.init; }
62     Scheduler getScheduler() @safe nothrow { return Scheduler.init; }
63   }
64   scope receiver = Receiver.init;
65   OpType!(T, Receiver) op = t.connect(receiver);
66   static if (!isValidOp!(T, Receiver))
67     pragma(msg, "Warning: ", T, "'s operation state is not returned via the stack");
68 }
69 enum isSender(T) = is(typeof(checkSender!T));
70 
71 /// It is ok for the operation state to be on the heap, but if it is on the stack we need to ensure any copies are elided. We can't be 100% sure (the compiler may still blit), but this is the best we can do.
72 template isValidOp(Sender, Receiver) {
73   import std.traits : isPointer;
74   import std.meta : allSatisfy;
75   alias overloads = __traits(getOverloads, Sender, "connect", true);
76   template isRVO(alias connect) {
77     static if (__traits(isTemplate, connect))
78       enum isRVO = __traits(isReturnOnStack, connect!Receiver);
79     else
80       enum isRVO = __traits(isReturnOnStack, connect);
81   }
82   alias Op = OpType!(Sender, Receiver);
83   enum isValidOp = isPointer!Op || is(Op == OperationObject) || is(Op == class) || (allSatisfy!(isRVO, overloads) && !__traits(isPOD, Op));
84 }
85 
86 /// A Sender that sends a single value of type T
87 struct ValueSender(T) {
88   static assert (models!(typeof(this), isSender));
89   alias Value = T;
90   static struct Op(Receiver) {
91     Receiver receiver;
92     static if (!is(T == void))
93       T value;
94     @disable this(ref return scope typeof(this) rhs);
95     @disable this(this);
96     void start() nothrow @trusted scope {
97       import concurrency.receiver : setValueOrError;
98       static if (!is(T == void))
99         receiver.setValueOrError(value);
100       else
101         receiver.setValueOrError();
102     }
103   }
104   static if (!is(T == void))
105     T value;
106   Op!Receiver connect(Receiver)(return Receiver receiver) @safe {
107     // ensure NRVO
108     static if (!is(T == void))
109       auto op = Op!(Receiver)(receiver, value);
110     else
111       auto op = Op!(Receiver)(receiver);
112     return op;
113   }
114 }
115 
116 auto just(T...)(T t) {
117   import std.typecons : tuple, Tuple;
118   static if (T.length == 1)
119     return ValueSender!(T[0])(t);
120   else
121     return ValueSender!(Tuple!T)(tuple(t));
122 }
123 
124 struct JustFromSender(Fun) {
125   static assert (models!(typeof(this), isSender));
126   alias Value = ReturnType!fun;
127   static struct Op(Receiver) {
128     Receiver receiver;
129     Fun fun;
130     @disable this(ref return scope typeof(this) rhs);
131     @disable this(this);
132     void start() @trusted nothrow {
133       import std.traits : hasFunctionAttributes;
134       static if (hasFunctionAttributes!(Fun, "nothrow")) {
135         set();
136       } else {
137         try {
138           set();
139         } catch (Exception e) {
140           receiver.setError(e);
141         }
142       }
143     }
144     private void set() @safe {
145       import concurrency.receiver : setValueOrError;
146       static if (is(Value == void)) {
147         fun();
148         if (receiver.getStopToken.isStopRequested)
149           receiver.setDone();
150         else
151           receiver.setValue();
152       } else {
153         auto r = fun();
154         if (receiver.getStopToken.isStopRequested)
155           receiver.setDone();
156         else
157           receiver.setValue(r);
158       }
159     }
160   }
161   Fun fun;
162   Op!Receiver connect(Receiver)(return Receiver receiver) @safe return scope {
163     // ensure NRVO
164     auto op = Op!(Receiver)(receiver, fun);
165     return op;
166   }
167 }
168 
169 JustFromSender!(Fun) justFrom(Fun)(Fun fun) if (isCallable!Fun) {
170   import std.traits : hasFunctionAttributes, isFunction, isFunctionPointer;
171   import concurrency.utils : isThreadSafeFunction;
172   static assert(isThreadSafeFunction!Fun);
173   return JustFromSender!Fun(fun);
174 }
175 
176 /// A polymorphic sender of type T
177 interface SenderObjectBase(T) {
178   import concurrency.receiver;
179   import concurrency.scheduler : SchedulerObjectBase;
180   import concurrency.stoptoken : StopToken, stopTokenObject;
181   static assert (models!(typeof(this), isSender));
182   alias Value = T;
183   alias Op = OperationObject;
184   OperationObject connect(return ReceiverObjectBase!(T) receiver) @safe scope;
185   OperationObject connect(Receiver)(return Receiver receiver) @trusted scope {
186     return connect(new class(receiver) ReceiverObjectBase!T {
187       Receiver receiver;
188       this(Receiver receiver) {
189         this.receiver = receiver;
190       }
191       static if (is(T == void)) {
192         void setValue() {
193           receiver.setValueOrError();
194         }
195       } else {
196         void setValue(T value) {
197           receiver.setValueOrError(value);
198         }
199       }
200       void setDone() nothrow {
201         receiver.setDone();
202       }
203       void setError(Throwable e) nothrow {
204         receiver.setError(e);
205       }
206       StopToken getStopToken() nothrow {
207         return stopTokenObject(receiver.getStopToken());
208       }
209       SchedulerObjectBase getScheduler() nothrow @safe scope {
210         import concurrency.scheduler : toSchedulerObject;
211         return receiver.getScheduler().toSchedulerObject;
212       }
213     });
214   }
215 }
216 
217 /// Type-erased operational state object
218 /// used in polymorphic senders
219 struct OperationObject {
220   private void delegate() nothrow shared _start;
221   void start() scope nothrow @trusted { _start(); }
222 }
223 
224 interface OperationalStateBase {
225   void start() @safe nothrow;
226 }
227 
228 /// calls connect on the Sender but stores the OperationState on the heap
229 OperationalStateBase connectHeap(Sender, Receiver)(Sender sender, Receiver receiver) @safe {
230   alias State = typeof(sender.connect(receiver));
231   return new class(sender, receiver) OperationalStateBase {
232     State state;
233     this(return Sender sender, return Receiver receiver) @trusted {
234       state = sender.connect(receiver);
235     }
236     void start() @safe nothrow {
237       state.start();
238     }
239   };
240 }
241 
242 /// A class extending from SenderObjectBase that wraps any Sender
243 class SenderObjectImpl(Sender) : SenderObjectBase!(Sender.Value) {
244   import concurrency.receiver : ReceiverObjectBase;
245   static assert (models!(typeof(this), isSender));
246   private Sender sender;
247   this(Sender sender) {
248     this.sender = sender;
249   }
250   OperationObject connect(return ReceiverObjectBase!(Sender.Value) receiver) @trusted scope {
251     auto state = sender.connectHeap(receiver);
252     return OperationObject(cast(typeof(OperationObject._start))&state.start);
253   }
254   OperationObject connect(Receiver)(return Receiver receiver) @safe scope {
255     auto base = cast(SenderObjectBase!(Sender.Value))this;
256     return base.connect(receiver);
257   }
258 }
259 
260 /// Converts any Sender to a polymorphic SenderObject
261 auto toSenderObject(Sender)(Sender sender) {
262   static assert(models!(Sender, isSender));
263   static if (is(Sender : SenderObjectBase!(Sender.Value))) {
264     return sender;
265   } else
266     return cast(SenderObjectBase!(Sender.Value))new SenderObjectImpl!(Sender)(sender);
267 }
268 
269 /// A sender that always sets an error
270 struct ThrowingSender {
271   alias Value = void;
272   static struct Op(Receiver) {
273     Receiver receiver;
274     @disable this(ref return scope typeof(this) rhs);
275     @disable this(this);
276     void start() {
277       receiver.setError(new Exception("ThrowingSender"));
278     }
279   }
280   auto connect(Receiver)(return Receiver receiver) @safe return scope {
281     // ensure NRVO
282     auto op = Op!Receiver(receiver);
283     return op;
284   }
285 }
286 
287 /// A sender that always calls setDone
288 struct DoneSender {
289   static assert (models!(typeof(this), isSender));
290   alias Value = void;
291   static struct DoneOp(Receiver) {
292     Receiver receiver;
293     @disable this(ref return scope typeof(this) rhs);
294     @disable this(this);
295     void start() nothrow @trusted scope {
296       receiver.setDone();
297     }
298   }
299   auto connect(Receiver)(return Receiver receiver) @safe return scope {
300     // ensure NRVO
301     auto op = DoneOp!(Receiver)(receiver);
302     return op;
303   }
304 }
305 
306 /// A sender that always calls setValue with no args
307 struct VoidSender {
308   static assert (models!(typeof(this), isSender));
309   alias Value = void;
310   struct VoidOp(Receiver) {
311     Receiver receiver;
312     @disable this(ref return scope typeof(this) rhs);
313     @disable this(this);
314     void start() nothrow @safe {
315       import concurrency.receiver : setValueOrError;
316       receiver.setValueOrError();
317     }
318   }
319   auto connect(Receiver)(return Receiver receiver) @safe return scope {
320     // ensure NRVO
321     auto op = VoidOp!Receiver(receiver);
322     return op;
323   }
324 }
325 
326 /// A sender that always calls setError
327 struct ErrorSender {
328   static assert (models!(typeof(this), isSender));
329   alias Value = void;
330   Throwable exception;
331   static struct ErrorOp(Receiver) {
332     Receiver receiver;
333     Throwable exception;
334     @disable this(ref return scope typeof(this) rhs);
335     @disable this(this);
336     void start() nothrow @trusted scope {
337       receiver.setError(exception);
338     }
339   }
340   auto connect(Receiver)(return Receiver receiver) @safe return scope {
341     // ensure NRVO
342     auto op = ErrorOp!(Receiver)(receiver, exception);
343     return op;
344   }
345 }
346 
347 template OpType(Sender, Receiver) {
348   static if (is(Sender.Op)) {
349     alias OpType = Sender.Op;
350   } else {
351     import std.traits : ReturnType;
352     import std.meta : staticMap;
353     template GetOpType(alias connect) {
354       static if (__traits(isTemplate, connect)) {
355         alias GetOpType = ReturnType!(connect!Receiver);//(Receiver.init));
356       } else {
357         alias GetOpType = ReturnType!(connect);//(Receiver.init));
358       }
359     }
360     alias overloads = __traits(getOverloads, Sender, "connect", true);
361     alias opTypes = staticMap!(GetOpType, overloads);
362     alias OpType = opTypes[0];
363   }
364 }
365 
366 /// A sender that delays before calling setValue
367 struct DelaySender {
368   alias Value = void;
369   Duration dur;
370   auto connect(Receiver)(return Receiver receiver) @safe return scope {
371     // ensure NRVO
372     auto op = receiver.getScheduler().scheduleAfter(dur).connect(receiver);
373     return op;
374   }
375 }
376 
377 auto delay(Duration dur) {
378   return DelaySender(dur);
379 }
380 
381 struct PromiseSenderOp(T, Receiver) {
382   import concurrency.stoptoken;
383   import concurrency.bitfield;
384   private enum Flags : size_t {
385     locked = 0x0,
386     setup = 0x1,
387     value = 0x2,
388     stop = 0x4
389   }
390   alias Sender = Promise!T;
391   alias InternalValue = Sender.InternalValue;
392   shared Sender parent;
393   Receiver receiver;
394   StopCallback cb;
395   shared SharedBitField!Flags bitfield;
396   @disable this(ref return scope typeof(this) rhs);
397   @disable this(this);
398   void start() nothrow @trusted scope {
399     // if already completed we can optimize
400     if (parent.isCompleted) {
401       bitfield.add(Flags.setup);
402       parent.add(&(cast(shared)this).onValue);
403       return;
404     }
405     // Otherwise we have to be a bit careful here,
406     // both the onStop and the onValue we register
407     // can be called from possibly different contexts.
408     // We can't atomically connect both, so we have to
409     // devise a scheme to handle one or both being called
410     // before we are done here.
411 
412     // we use a simple atomic bitfield that we set after setup
413     // is done. If `onValue` or `onStop` trigger before setup
414     // is complete, they update the bitfield and return early.
415     // After we setup both, we flip the setup bit and check
416     // if any of the callbacks triggered in the meantime,
417     // if they did we know we have to perform some cleanup
418     // if they didn't the callbacks themselves will handle it
419 
420     bool triggeredInline = parent.add(&(cast(shared)this).onValue);
421     // if triggeredInline there is no point in setting up the stop callback
422     if (!triggeredInline)
423       cb = receiver.getStopToken.onStop(&(cast(shared)this).onStop);
424 
425     with (bitfield.add(Flags.setup)) {
426       if (has(Flags.stop)) {
427         // it stopped before we finished setup
428         parent.remove(&(cast(shared)this).onValue);
429         receiver.setDone();
430       }
431       if (has(Flags.value)) {
432         // it fired before we finished setup
433         // just add it again, it will fire again
434         parent.add(&(cast(shared)this).onValue);
435       }
436     }
437   }
438   void onStop() nothrow @trusted shared {
439     // we toggle the stop bit and return early if setup bit isn't set
440     with (bitfield.add(Flags.stop))
441       if (!has(Flags.setup))
442         return;
443     with(unshared) {
444       // If `parent.remove` returns true, onValue will never be called,
445       // so we can call setDone ourselves.
446       // If it returns false onStop and onValue are in a race, and we
447       // let onValue pass.
448       if (parent.remove(&(cast(shared)this).onValue))
449         receiver.setDone();
450     }
451   }
452   void onValue(InternalValue value) nothrow @safe shared {
453     import mir.algebraic : match;
454     // we toggle the stop bit and return early if setup bit isn't set
455     with (bitfield.add(Flags.value))
456       if (!has(Flags.setup))
457         return;
458     with(unshared) {
459       // `cb.dispose` will ensure onStop will never be called
460       // after it returns. It will also block if it is currently
461       // being executed.
462       // This means that when it completes we are the only one
463       // calling the receiver's termination functions.
464       if (cb)
465         cb.dispose();
466       value.match!((Sender.ValueRep v){
467           try {
468             static if (is(Sender.Value == void))
469               receiver.setValue();
470             else
471               receiver.setValue(v);
472           } catch (Exception e) {
473             receiver.setError(e);
474           }
475         }, (Throwable e){
476           receiver.setError(e);
477         }, (Sender.Done d){
478           receiver.setDone();
479         });
480     }
481   }
482   private auto ref unshared() @trusted nothrow shared {
483     return cast()this;
484   }
485 }
486 
487 class Promise(T) {
488   import std.traits : ReturnType;
489   import concurrency.slist;
490   import concurrency.bitfield;
491   import mir.algebraic : Algebraic, match, Nullable;
492   alias Value = T;
493   static if (is(Value == void)) {
494     static struct ValueRep{}
495   } else
496     alias ValueRep = Value;
497   static struct Done{}
498   alias InternalValue = Algebraic!(Throwable, ValueRep, Done);
499   alias DG = void delegate(InternalValue) nothrow @safe shared;
500   private {
501     shared SList!DG dgs;
502     Nullable!InternalValue value;
503     enum Flags {
504       locked = 0x1,
505       completed = 0x2
506     }
507     SharedBitField!Flags counter;
508     bool add(DG dg) @safe nothrow shared {
509       with(unshared) {
510         with(counter.lock()) {
511           if (was(Flags.completed)) {
512             auto val = value.get;
513             release(); // release early
514             dg(val);
515             return true;
516           } else {
517             dgs.pushBack(dg);
518             return false;
519           }
520         }
521       }
522     }
523     bool remove(DG dg) @safe nothrow shared {
524       with (counter.lock()) {
525         if (was(Flags.completed)) {
526           release(); // release early
527           return false;
528         } else {
529           dgs.remove(dg);
530           return true;
531         }
532       }
533     }
534     private auto ref unshared() @trusted nothrow shared {
535       return cast()this;
536     }
537   }
538   private bool pushImpl(P)(P t) @safe shared nothrow {
539     import std.exception : enforce;
540     with (counter.lock(Flags.completed)) {
541       if (was(Flags.completed))
542         return false;
543       InternalValue val = InternalValue(t);
544       (cast()value) = val;
545       auto localDgs = dgs.release();
546       release();
547       foreach(dg; localDgs)
548         dg(val);
549       return true;
550     }
551   }
552   bool cancel() @safe shared nothrow {
553     return pushImpl(Done());
554   }
555   bool error(Throwable e) @safe shared nothrow {
556     return pushImpl(e);
557   }
558   static if (is(Value == void)) {
559     bool fulfill() @safe shared nothrow {
560       return pushImpl(ValueRep());
561     }
562   } else {
563     bool fulfill(T t) @safe shared nothrow {
564       return pushImpl(t);
565     }
566   }
567   bool isCompleted() @trusted shared nothrow {
568     import core.atomic : MemoryOrder;
569     return (counter.load!(MemoryOrder.acq) & Flags.completed) > 0;
570   }
571   this() {
572     this.dgs = new shared SList!DG;
573   }
574   auto sender() shared @safe nothrow {
575     return shared PromiseSender!T(this);
576   }
577 }
578 
579 shared(Promise!T) promise(T)() {
580   return new shared Promise!T();
581 }
582 
583 struct PromiseSender(T) {
584   alias Value = T;
585   static assert(models!(typeof(this), isSender));
586   private shared Promise!T promise;
587 
588   auto connect(Receiver)(return Receiver receiver) @safe return scope {
589     // ensure NRVO
590     auto op = asShared.connect(receiver);
591     return op;
592   }
593   auto connect(Receiver)(return Receiver receiver) @safe shared return scope {
594     // ensure NRVO
595     auto op = PromiseSenderOp!(T, Receiver)(promise, receiver);
596     return op;
597   }
598   private auto asShared() @trusted return scope {
599     return cast(shared)this;
600   }
601 }
602 
603 struct Defer(Fun) {
604   import concurrency.utils;
605   static assert (isThreadSafeCallable!Fun);
606   alias Sender = typeof(fun());
607   static assert(models!(Sender, isSender));
608   alias Value = Sender.Value;
609   Fun fun;
610   auto connect(Receiver)(return Receiver receiver) @safe {
611     // ensure NRVO
612     auto op = fun().connect(receiver);
613     return op;
614   }
615 }
616 
617 auto defer(Fun)(Fun fun) {
618   return Defer!(Fun)(fun);
619 }