1 module concurrency.sender;
2 
3 import concepts;
4 import std.traits : ReturnType, isCallable;
5 import core.time : Duration;
6 
7 // A Sender represents something that completes with either:
8 // 1. a value (which can be void)
9 // 2. completion, in response to cancellation
10 // 3. an Throwable
11 //
12 // Many things can be represented as a Sender.
13 // Threads, Fibers, coroutines, etc. In general, any async operation.
14 //
15 // A Sender is lazy. Work it represents is only started when
16 // the sender is connected to a receiver and explicitly started.
17 //
18 // Senders and Receivers go hand in hand. Senders send a value,
19 // Receivers receive one.
20 //
21 // Senders are useful because many Tasks can be represented as them,
22 // and any operation on top of senders then works on any one of those
23 // Tasks.
24 //
25 // The most common operation is `sync_wait`. It blocks the current
26 // execution context to await the Sender.
27 //
28 // There are many others as well. Like `when_all`, `retry`, `when_any`,
29 // etc. These algorithms can be used on any sender.
30 //
31 // Cancellation happens through StopTokens. A Sender can ask a Receiver
32 // for a StopToken. Default is a NeverStopToken but Receiver's can
33 // customize this.
34 //
35 // The StopToken can be polled or a callback can be registered with one.
36 //
37 // Senders enforce Structured Concurrency because work cannot be
38 // started unless it is awaited.
39 //
40 // These concepts are heavily inspired by several C++ proposals
41 // starting with http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p0443r14.html
42 
43 /// checks that T is a Sender
44 void checkSender(T)() @safe {
45   import concurrency.scheduler : SchedulerObjectBase;
46   import concurrency.stoptoken : StopToken;
47   T t = T.init;
48   struct Scheduler {
49     import core.time : Duration;
50     auto schedule() @safe { return VoidSender(); }
51     auto scheduleAfter(Duration) @safe { return VoidSender(); }
52   }
53   struct Receiver {
54     int* i; // force it scope
55     static if (is(T.Value == void))
56       void setValue() @safe {}
57     else
58       void setValue(T.Value) @safe {}
59     void setDone() @safe nothrow {}
60     void setError(Throwable e) @safe nothrow {}
61     StopToken getStopToken() @safe nothrow { return StopToken.init; }
62     Scheduler getScheduler() @safe nothrow { return Scheduler.init; }
63   }
64   scope receiver = Receiver.init;
65   OpType!(T, Receiver) op = t.connect(receiver);
66   static if (!isValidOp!(T, Receiver))
67     pragma(msg, "Warning: ", T, "'s operation state is not returned via the stack");
68 }
69 enum isSender(T) = is(typeof(checkSender!T));
70 
71 /// It is ok for the operation state to be on the heap, but if it is on the stack we need to ensure any copies are elided. We can't be 100% sure (the compiler may still blit), but this is the best we can do.
72 template isValidOp(Sender, Receiver) {
73   import std.traits : isPointer;
74   import std.meta : allSatisfy;
75   alias overloads = __traits(getOverloads, Sender, "connect", true);
76   template isRVO(alias connect) {
77     static if (__traits(isTemplate, connect))
78       enum isRVO = __traits(isReturnOnStack, connect!Receiver);
79     else
80       enum isRVO = __traits(isReturnOnStack, connect);
81   }
82   alias Op = OpType!(Sender, Receiver);
83   enum isValidOp = isPointer!Op || is(Op == OperationObject) || is(Op == class) || (allSatisfy!(isRVO, overloads) && !__traits(isPOD, Op));
84 }
85 
86 /// A Sender that sends a single value of type T
87 struct ValueSender(T) {
88   static assert (models!(typeof(this), isSender));
89   alias Value = T;
90   static struct Op(Receiver) {
91     Receiver receiver;
92     static if (!is(T == void))
93       T value;
94     void start() nothrow @trusted scope {
95       import concurrency.receiver : setValueOrError;
96       static if (!is(T == void))
97         receiver.setValueOrError(value);
98       else
99         receiver.setValueOrError();
100     }
101   }
102   static if (!is(T == void))
103     T value;
104   Op!Receiver connect(Receiver)(return Receiver receiver) @safe scope return {
105     // ensure NRVO
106     static if (!is(T == void))
107       auto op = Op!(Receiver)(receiver, value);
108     else
109       auto op = Op!(Receiver)(receiver);
110     return op;
111   }
112 }
113 
114 auto just(T...)(T t) {
115   import std.typecons : tuple, Tuple;
116   static if (T.length == 1)
117     return ValueSender!(T[0])(t);
118   else
119     return ValueSender!(Tuple!T)(tuple(t));
120 }
121 
122 struct JustFromSender(Fun) {
123   static assert (models!(typeof(this), isSender));
124   alias Value = ReturnType!fun;
125   static struct Op(Receiver) {
126     Receiver receiver;
127     Fun fun;
128     void start() @trusted nothrow {
129       import std.traits : hasFunctionAttributes;
130       static if (hasFunctionAttributes!(Fun, "nothrow")) {
131         set();
132       } else {
133         try {
134           set();
135         } catch (Exception e) {
136           receiver.setError(e);
137         }
138       }
139     }
140     private void set() @safe {
141       import concurrency.receiver : setValueOrError;
142       static if (is(Value == void)) {
143         fun();
144         if (receiver.getStopToken.isStopRequested)
145           receiver.setDone();
146         else
147           receiver.setValue();
148       } else {
149         auto r = fun();
150         if (receiver.getStopToken.isStopRequested)
151           receiver.setDone();
152         else
153           receiver.setValue(r);
154       }
155     }
156   }
157   Fun fun;
158   Op!Receiver connect(Receiver)(return Receiver receiver) @safe scope return {
159     // ensure NRVO
160     auto op = Op!(Receiver)(receiver, fun);
161     return op;
162   }
163 }
164 
165 JustFromSender!(Fun) justFrom(Fun)(Fun fun) if (isCallable!Fun) {
166   import std.traits : hasFunctionAttributes, isFunction, isFunctionPointer;
167   import concurrency.utils : isThreadSafeFunction;
168   static assert(isThreadSafeFunction!Fun);
169   return JustFromSender!Fun(fun);
170 }
171 
172 /// A polymorphic sender of type T
173 interface SenderObjectBase(T) {
174   import concurrency.receiver;
175   import concurrency.scheduler : SchedulerObjectBase;
176   import concurrency.stoptoken : StopToken, stopTokenObject;
177   static assert (models!(typeof(this), isSender));
178   alias Value = T;
179   alias Op = OperationObject;
180   OperationObject connect(return ReceiverObjectBase!(T) receiver) @safe scope;
181   OperationObject connect(Receiver)(return Receiver receiver) @trusted scope {
182     return connect(new class(receiver) ReceiverObjectBase!T {
183       Receiver receiver;
184       this(Receiver receiver) {
185         this.receiver = receiver;
186       }
187       static if (is(T == void)) {
188         void setValue() {
189           receiver.setValueOrError();
190         }
191       } else {
192         void setValue(T value) {
193           receiver.setValueOrError(value);
194         }
195       }
196       void setDone() nothrow {
197         receiver.setDone();
198       }
199       void setError(Throwable e) nothrow {
200         receiver.setError(e);
201       }
202       StopToken getStopToken() nothrow {
203         return stopTokenObject(receiver.getStopToken());
204       }
205       SchedulerObjectBase getScheduler() nothrow @safe scope {
206         import concurrency.scheduler : toSchedulerObject;
207         return receiver.getScheduler().toSchedulerObject;
208       }
209     });
210   }
211 }
212 
213 /// Type-erased operational state object
214 /// used in polymorphic senders
215 struct OperationObject {
216   private void delegate() nothrow shared _start;
217   void start() scope nothrow @trusted { _start(); }
218 }
219 
220 interface OperationalStateBase {
221   void start() @safe nothrow;
222 }
223 
224 /// calls connect on the Sender but stores the OperationState on the heap
225 OperationalStateBase connectHeap(Sender, Receiver)(Sender sender, Receiver receiver) @safe {
226   alias State = typeof(sender.connect(receiver));
227   return new class(sender, receiver) OperationalStateBase {
228     State state;
229     this(return Sender sender, return Receiver receiver) @trusted {
230       state = sender.connect(receiver);
231     }
232     void start() @safe nothrow {
233       state.start();
234     }
235   };
236 }
237 
238 /// A class extending from SenderObjectBase that wraps any Sender
239 class SenderObjectImpl(Sender) : SenderObjectBase!(Sender.Value) {
240   import concurrency.receiver : ReceiverObjectBase;
241   static assert (models!(typeof(this), isSender));
242   private Sender sender;
243   this(Sender sender) {
244     this.sender = sender;
245   }
246   OperationObject connect(return ReceiverObjectBase!(Sender.Value) receiver) @trusted scope {
247     auto state = sender.connectHeap(receiver);
248     return OperationObject(cast(typeof(OperationObject._start))&state.start);
249   }
250   OperationObject connect(Receiver)(return Receiver receiver) @safe scope {
251     auto base = cast(SenderObjectBase!(Sender.Value))this;
252     return base.connect(receiver);
253   }
254 }
255 
256 /// Converts any Sender to a polymorphic SenderObject
257 auto toSenderObject(Sender)(Sender sender) {
258   static assert(models!(Sender, isSender));
259   static if (is(Sender : SenderObjectBase!(Sender.Value))) {
260     return sender;
261   } else
262     return cast(SenderObjectBase!(Sender.Value))new SenderObjectImpl!(Sender)(sender);
263 }
264 
265 /// A sender that always sets an error
266 struct ThrowingSender {
267   alias Value = void;
268   static struct Op(Receiver) {
269     Receiver receiver;
270     void start() {
271       receiver.setError(new Exception("ThrowingSender"));
272     }
273   }
274   auto connect(Receiver)(return Receiver receiver) @safe scope return {
275     // ensure NRVO
276     auto op = Op!Receiver(receiver);
277     return op;
278   }
279 }
280 
281 /// A sender that always calls setDone
282 struct DoneSender {
283   static assert (models!(typeof(this), isSender));
284   alias Value = void;
285   static struct DoneOp(Receiver) {
286     Receiver receiver;
287     void start() nothrow @trusted scope {
288       receiver.setDone();
289     }
290   }
291   auto connect(Receiver)(return Receiver receiver) @safe scope return {
292     // ensure NRVO
293     auto op = DoneOp!(Receiver)(receiver);
294     return op;
295   }
296 }
297 
298 /// A sender that always calls setValue with no args
299 struct VoidSender {
300   static assert (models!(typeof(this), isSender));
301   alias Value = void;
302   struct VoidOp(Receiver) {
303     Receiver receiver;
304     void start() nothrow @safe {
305       import concurrency.receiver : setValueOrError;
306       receiver.setValueOrError();
307     }
308   }
309   auto connect(Receiver)(return Receiver receiver) @safe scope return {
310     // ensure NRVO
311     auto op = VoidOp!Receiver(receiver);
312     return op;
313   }
314 }
315 
316 /// A sender that always calls setError
317 struct ErrorSender {
318   static assert (models!(typeof(this), isSender));
319   alias Value = void;
320   Throwable exception;
321   static struct ErrorOp(Receiver) {
322     Receiver receiver;
323     Throwable exception;
324     void start() nothrow @trusted scope {
325       receiver.setError(exception);
326     }
327   }
328   auto connect(Receiver)(return Receiver receiver) @safe scope return {
329     // ensure NRVO
330     auto op = ErrorOp!(Receiver)(receiver, exception);
331     return op;
332   }
333 }
334 
335 template OpType(Sender, Receiver) {
336   static if (is(Sender.Op)) {
337     alias OpType = Sender.Op;
338   } else {
339     import std.traits : ReturnType;
340     import std.meta : staticMap;
341     template GetOpType(alias connect) {
342       static if (__traits(isTemplate, connect)) {
343         alias GetOpType = ReturnType!(connect!Receiver);//(Receiver.init));
344       } else {
345         alias GetOpType = ReturnType!(connect);//(Receiver.init));
346       }
347     }
348     alias overloads = __traits(getOverloads, Sender, "connect", true);
349     alias opTypes = staticMap!(GetOpType, overloads);
350     alias OpType = opTypes[0];
351   }
352 }
353 
354 /// A sender that delays before calling setValue
355 struct DelaySender {
356   alias Value = void;
357   Duration dur;
358   auto connect(Receiver)(return Receiver receiver) @trusted scope return {
359     // ensure NRVO
360     auto op = receiver.getScheduler().scheduleAfter(dur).connect(receiver);
361     return op;
362   }
363 }
364 
365 auto delay(Duration dur) {
366   return DelaySender(dur);
367 }
368 
369 struct PromiseSenderOp(T, Receiver) {
370   import concurrency.stoptoken;
371   import concurrency.bitfield;
372   private enum Flags : size_t {
373     locked = 0x0,
374     setup = 0x1,
375     value = 0x2,
376     stop = 0x4
377   }
378   alias Sender = Promise!T;
379   alias InternalValue = Sender.InternalValue;
380   shared Sender parent;
381   Receiver receiver;
382   StopCallback cb;
383   shared SharedBitField!Flags bitfield;
384   void start() nothrow @trusted scope {
385     // if already completed we can optimize
386     if (parent.isCompleted) {
387       bitfield.add(Flags.setup);
388       parent.add(&(cast(shared)this).onValue);
389       return;
390     }
391     // Otherwise we have to be a bit careful here,
392     // both the onStop and the onValue we register
393     // can be called from possibly different contexts.
394     // We can't atomically connect both, so we have to
395     // devise a scheme to handle one or both being called
396     // before we are done here.
397 
398     // we use a simple atomic bitfield that we set after setup
399     // is done. If `onValue` or `onStop` trigger before setup
400     // is complete, they update the bitfield and return early.
401     // After we setup both, we flip the setup bit and check
402     // if any of the callbacks triggered in the meantime,
403     // if they did we know we have to perform some cleanup
404     // if they didn't the callbacks themselves will handle it
405 
406     bool triggeredInline = parent.add(&(cast(shared)this).onValue);
407     // if triggeredInline there is no point in setting up the stop callback
408     if (!triggeredInline)
409       cb = receiver.getStopToken.onStop(&(cast(shared)this).onStop);
410 
411     with (bitfield.add(Flags.setup)) {
412       if (has(Flags.stop)) {
413         // it stopped before we finished setup
414         parent.remove(&(cast(shared)this).onValue);
415         receiver.setDone();
416       }
417       if (has(Flags.value)) {
418         // it fired before we finished setup
419         // just add it again, it will fire again
420         parent.add(&(cast(shared)this).onValue);
421       }
422     }
423   }
424   void onStop() nothrow @trusted shared {
425     // we toggle the stop bit and return early if setup bit isn't set
426     with (bitfield.add(Flags.stop))
427       if (!has(Flags.setup))
428         return;
429     with(unshared) {
430       // If `parent.remove` returns true, onValue will never be called,
431       // so we can call setDone ourselves.
432       // If it returns false onStop and onValue are in a race, and we
433       // let onValue pass.
434       if (parent.remove(&(cast(shared)this).onValue))
435         receiver.setDone();
436     }
437   }
438   void onValue(InternalValue value) nothrow @safe shared {
439     import mir.algebraic : match;
440     // we toggle the stop bit and return early if setup bit isn't set
441     with (bitfield.add(Flags.value))
442       if (!has(Flags.setup))
443         return;
444     with(unshared) {
445       // `cb.dispose` will ensure onStop will never be called
446       // after it returns. It will also block if it is currently
447       // being executed.
448       // This means that when it completes we are the only one
449       // calling the receiver's termination functions.
450       if (cb)
451         cb.dispose();
452       value.match!((Sender.ValueRep v){
453           try {
454             static if (is(Sender.Value == void))
455               receiver.setValue();
456             else
457               receiver.setValue(v);
458           } catch (Exception e) {
459             receiver.setError(e);
460           }
461         }, (Throwable e){
462           receiver.setError(e);
463         }, (Sender.Done d){
464           receiver.setDone();
465         });
466     }
467   }
468   private auto ref unshared() @trusted nothrow shared {
469     return cast()this;
470   }
471 }
472 
473 class Promise(T) {
474   import std.traits : ReturnType;
475   import concurrency.slist;
476   import concurrency.bitfield;
477   import mir.algebraic : Algebraic, match, Nullable;
478   alias Value = T;
479   static if (is(Value == void)) {
480     static struct ValueRep{}
481   } else
482     alias ValueRep = Value;
483   static struct Done{}
484   alias InternalValue = Algebraic!(Throwable, ValueRep, Done);
485   alias DG = void delegate(InternalValue) nothrow @safe shared;
486   private {
487     shared SList!DG dgs;
488     Nullable!InternalValue value;
489     enum Flags {
490       locked = 0x1,
491       completed = 0x2
492     }
493     SharedBitField!Flags counter;
494     bool add(DG dg) @safe nothrow shared {
495       with(unshared) {
496         with(counter.lock()) {
497           if (was(Flags.completed)) {
498             auto val = value.get;
499             release(); // release early
500             dg(val);
501             return true;
502           } else {
503             dgs.pushBack(dg);
504             return false;
505           }
506         }
507       }
508     }
509     bool remove(DG dg) @safe nothrow shared {
510       with (counter.lock()) {
511         if (was(Flags.completed)) {
512           release(); // release early
513           return false;
514         } else {
515           dgs.remove(dg);
516           return true;
517         }
518       }
519     }
520     private auto ref unshared() @trusted nothrow shared {
521       return cast()this;
522     }
523   }
524   private bool pushImpl(P)(P t) @safe shared nothrow {
525     import std.exception : enforce;
526     with (counter.lock(Flags.completed)) {
527       if (was(Flags.completed))
528         return false;
529       InternalValue val = InternalValue(t);
530       (cast()value) = val;
531       auto localDgs = dgs.release();
532       release();
533       foreach(dg; localDgs)
534         dg(val);
535       return true;
536     }
537   }
538   bool cancel() @safe shared nothrow {
539     return pushImpl(Done());
540   }
541   bool error(Throwable e) @safe shared nothrow {
542     return pushImpl(e);
543   }
544   static if (is(Value == void)) {
545     bool fulfill() @safe shared nothrow {
546       return pushImpl(ValueRep());
547     }
548   } else {
549     bool fulfill(T t) @safe shared nothrow {
550       return pushImpl(t);
551     }
552   }
553   bool isCompleted() @trusted shared nothrow {
554     import core.atomic : MemoryOrder;
555     return (counter.load!(MemoryOrder.acq) & Flags.completed) > 0;
556   }
557   this() {
558     this.dgs = new shared SList!DG;
559   }
560   auto sender() shared @safe nothrow {
561     return shared PromiseSender!T(this);
562   }
563 }
564 
565 shared(Promise!T) promise(T)() {
566   return new shared Promise!T();
567 }
568 
569 struct PromiseSender(T) {
570   alias Value = T;
571   static assert(models!(typeof(this), isSender));
572   private shared Promise!T promise;
573 
574   auto connect(Receiver)(return Receiver receiver) @trusted scope {
575     // ensure NRVO
576     auto op = (cast(shared)this).connect(receiver);
577     return op;
578   }
579   auto connect(Receiver)(return Receiver receiver) @safe shared scope return {
580     // ensure NRVO
581     auto op = PromiseSenderOp!(T, Receiver)(promise, receiver);
582     return op;
583   }
584 }