1 module concurrency.operations.toshared;
2 
3 import concurrency;
4 import concurrency.receiver;
5 import concurrency.sender;
6 import concurrency.stoptoken;
7 import concepts;
8 import std.traits;
9 import mir.algebraic : Algebraic, Nullable, match;
10 
11 /// Wraps a Sender in a SharedSender. A SharedSender allows many receivers to connect to the same underlying Sender, forwarding the same termination call to each receiver.
12 /// The underlying Sender is connected and started only once. It can be explicitely `reset` so that it connects and starts the underlying Sender the next time it started. Calling `reset` when the underlying Sender hasn't completed is a no-op.
13 /// When the last receiver triggers its stoptoken while the underlying Sender is still running, the latter will be cancelled and one termination function of the former will be called after the latter is completed. (This is to ensure structured concurrency, otherwise tasks could be left running without anyone awaiting them).
14 /// If an receiver is connected after the underlying Sender has already been completed, that receiver will have one of its termination functions called immediately.
15 /// This operation is useful when you have multiple tasks that all depend on one shared task. It allows you to write the shared task as a regular Sender and simply apply a `.toShared`.
16 auto toShared(Sender, Scheduler)(Sender sender, Scheduler scheduler) {
17   return new SharedSender!(Sender, Scheduler)(sender, scheduler);
18 }
19 
20 auto toShared(Sender)(Sender sender) {
21   static struct NullScheduler {}
22   return new SharedSender!(Sender, NullScheduler)(sender, NullScheduler());
23 }
24 
25 class SharedSender(Sender, Scheduler) if (models!(Sender, isSender)) {
26   import std.traits : ReturnType;
27   import concurrency.slist;
28   import concurrency.bitfield;
29   static assert(models!(typeof(this), isSender));
30   alias Value = Sender.Value;
31   static if (is(Value == void)) {
32     static struct ValueRep{}
33   } else
34     alias ValueRep = Value;
35   static struct Done{}
36   alias InternalValue = Algebraic!(Exception, ValueRep, Done);
37   alias DG = void delegate(InternalValue) nothrow @safe shared;
38   static struct SharedSenderOp(Receiver) {
39     SharedSender parent;
40     Receiver receiver;
41     StopCallback cb;
42     void start() nothrow @trusted scope {
43       parent.add(&(cast(shared)this).onValue);
44       cb = receiver.getStopToken.onStop(&(cast(shared)this).onStop);
45     }
46     void onStop() nothrow @trusted shared {
47       with(unshared) {
48         /// If this is the last one connected, remove will return false,
49         /// stop the underlying sender and we will receive the setDone via
50         /// the onValue.
51         /// This is to ensure we always await the underlying sender for
52         /// completion.
53         if (parent.remove(&(cast(shared)this).onValue))
54           receiver.setDone();
55       }
56     }
57     void onValue(InternalValue value) nothrow @safe shared {
58       with(unshared) {
59         value.match!((ValueRep v){
60             try {
61               static if (is(Value == void))
62                 receiver.setValue();
63               else
64                 receiver.setValue(v);
65             } catch (Exception e) {
66               /// TODO: dispose needs to be called in all cases, except
67               /// this onValue can sometimes be called immediately,
68               /// leaving no room to set cb.dispose...
69               cb.dispose();
70               receiver.setError(e);
71             }
72           }, (Exception e){
73             receiver.setError(e);
74           }, (Done d){
75             receiver.setDone();
76           });
77       }
78     }
79     private auto ref unshared() @trusted nothrow shared {
80       return cast()this;
81     }
82   }
83   static class SharedSenderState : StopSource {
84     import std.traits : ReturnType;
85     alias Op = OpType!(Sender, SharedSenderReceiver);
86     SharedSender parent;
87     shared SList!DG dgs;
88     Nullable!InternalValue value;
89     Op op;
90     this(SharedSender parent) {
91       this.dgs = new shared SList!DG;
92       this.parent = parent;
93     }
94   }
95   static struct SharedSenderReceiver {
96     SharedSenderState state;
97     Scheduler scheduler;
98     static if (is(Sender.Value == void))
99       void setValue() @safe {
100         state.value = InternalValue(ValueRep());
101         process();
102       }
103     else
104       void setValue(ValueRep v) @safe {
105         state.value = InternalValue(v);
106         process();
107       }
108     void setDone() @safe nothrow {
109       state.value = InternalValue(Done());
110       process();
111     }
112     void setError(Exception e) @safe nothrow {
113       state.value = InternalValue(e);
114       process();
115     }
116     private void process() @trusted {
117       with(state.parent.counter.lock(Flags.completed)) {
118         release(oldState & (~0x3)); // release early and remove all ticks
119         InternalValue v = state.value.get;
120         if (state.isStopRequested)
121           v = Done();
122         foreach(dg; state.dgs[])
123           dg(v);
124       }
125     }
126     StopToken getStopToken() @safe nothrow {
127       return StopToken(state);
128     }
129     Scheduler getScheduler() @safe nothrow scope {
130       return scheduler;
131     }
132   }
133   private {
134     Sender sender;
135     Scheduler scheduler;
136     SharedSenderState state;
137     enum Flags {
138       locked = 0x1,
139       completed = 0x2,
140       tick = 0x4
141     }
142     SharedBitField!Flags counter;
143     void add(DG dg) @safe nothrow {
144       with(counter.lock(0, Flags.tick)) {
145         if (was(Flags.completed)) {
146           InternalValue value = state.value.get;
147           release(Flags.tick); // release early
148           dg(value);
149         } else {
150           if ((oldState >> 2) == 0) {
151             auto localState = new SharedSenderState(this);
152             this.state = localState;
153             release(); // release early
154             localState.dgs.pushBack(dg);
155             localState.op = sender.connect(SharedSenderReceiver(localState, scheduler));
156             localState.op.start();
157           } else {
158             auto localState = state;
159             release(); // release early
160             localState.dgs.pushBack(dg);
161           }
162         }
163       }
164     }
165     /// returns false if it is the last
166     bool remove(DG dg) @safe nothrow {
167       with (counter.lock(0, 0, Flags.tick)) {
168         if (was(Flags.completed)) {
169           release(0-Flags.tick); // release early
170           return true;
171         }
172         if ((newState >> 2) == 0) {
173           auto localStopSource = state;
174           release(); // release early
175           localStopSource.stop();
176           return false;
177         } else {
178           auto localReceiver = state;
179           release(); // release early
180           localReceiver.dgs.remove(dg);
181           return true;
182         }
183       }
184     }
185   }
186   bool isCompleted() @trusted {
187     import core.atomic : MemoryOrder;
188     return (counter.load!(MemoryOrder.acq) & Flags.completed) > 0;
189   }
190   void reset() @trusted {
191     with (counter.lock()) {
192       if (was(Flags.completed))
193         release(Flags.completed);
194     }
195   }
196   this(Sender sender, Scheduler scheduler) {
197     this.sender = sender;
198     this.scheduler = scheduler;
199   }
200   auto connect(Receiver)(return Receiver receiver) @safe scope return {
201     // ensure NRVO
202     auto op = SharedSenderOp!Receiver(this, receiver);
203     return op;
204   }
205 }