1 module concurrency.operations.toshared;
2 
3 import concurrency;
4 import concurrency.receiver;
5 import concurrency.sender;
6 import concurrency.stoptoken;
7 import concepts;
8 import std.traits;
9 import mir.algebraic : Algebraic, Nullable, match;
10 
11 /// Wraps a Sender in a SharedSender. A SharedSender allows many receivers to connect to the same underlying Sender, forwarding the same termination call to each receiver.
12 /// The underlying Sender is connected and started only once. It can be explicitely `reset` so that it connects and starts the underlying Sender the next time it started. Calling `reset` when the underlying Sender hasn't completed is a no-op.
13 /// When the last receiver triggers its stoptoken while the underlying Sender is still running, the latter will be cancelled and one termination function of the former will be called after the latter is completed. (This is to ensure structured concurrency, otherwise tasks could be left running without anyone awaiting them).
14 /// If an receiver is connected after the underlying Sender has already been completed, that receiver will have one of its termination functions called immediately.
15 /// This operation is useful when you have multiple tasks that all depend on one shared task. It allows you to write the shared task as a regular Sender and simply apply a `.toShared`.
16 auto toShared(Sender, Scheduler)(Sender sender, Scheduler scheduler) {
17   return new SharedSender!(Sender, Scheduler, ResetLogic.keepLatest)(sender, scheduler);
18 }
19 
20 auto toShared(Sender)(Sender sender) {
21   return new SharedSender!(Sender, NullScheduler, ResetLogic.keepLatest)(sender, NullScheduler());
22 }
23 
24 struct NullScheduler {}
25 enum ResetLogic {
26   keepLatest,
27   alwaysReset
28 }
29 
30 class SharedSender(Sender, Scheduler, ResetLogic resetLogic) if (models!(Sender, isSender)) {
31   import std.traits : ReturnType;
32   static assert(models!(typeof(this), isSender));
33   alias Props = Properties!(Sender);
34   alias Value = Props.Value;
35   alias InternalValue = Props.InternalValue;
36   private {
37     Sender sender;
38     Scheduler scheduler;
39     SharedSenderState!(Sender) state;
40     void add(Props.DG dg) @safe nothrow {
41       with(state.counter.lock(0, Flags.tick)) {
42         if (was(Flags.completed)) {
43           InternalValue value = state.inst.value.get;
44           release(Flags.tick); // release early
45           dg(value);
46         } else {
47           if ((oldState >> 2) == 0) {
48             auto localState = new SharedSenderInstStateImpl!(Sender, Scheduler, resetLogic)();
49             this.state.inst = localState;
50             release(); // release early
51             localState.dgs.pushBack(dg);
52             try {
53               localState.op = sender.connect(SharedSenderReceiver!(Sender, Scheduler, resetLogic)(&state, scheduler));
54             } catch (Exception e) {
55               state.process!(resetLogic)(InternalValue(e));
56             }
57             localState.op.start();
58           } else {
59             auto localState = state.inst;
60             localState.dgs.pushBack(dg);
61           }
62         }
63       }
64     }
65     /// returns false if it is the last
66     bool remove(Props.DG dg) @safe nothrow {
67       with (state.counter.lock(0, 0, Flags.tick)) {
68         if (was(Flags.completed)) {
69           release(0-Flags.tick); // release early
70           return true;
71         }
72         if ((newState >> 2) == 0) {
73           auto localStopSource = state.inst;
74           release(); // release early
75           localStopSource.stop();
76           return false;
77         } else {
78           auto localReceiver = state.inst;
79           release(); // release early
80           localReceiver.dgs.remove(dg);
81           return true;
82         }
83       }
84     }
85   }
86   bool isCompleted() @trusted {
87     import core.atomic : MemoryOrder;
88     return (state.counter.load!(MemoryOrder.acq) & Flags.completed) > 0;
89   }
90   void reset() @trusted {
91     with (state.counter.lock()) {
92       if (was(Flags.completed))
93         release(Flags.completed);
94     }
95   }
96   this(Sender sender, Scheduler scheduler) {
97     this.sender = sender;
98     this.scheduler = scheduler;
99   }
100   auto connect(Receiver)(return Receiver receiver) @safe scope return {
101     // ensure NRVO
102     auto op = SharedSenderOp!(Sender, Scheduler, resetLogic, Receiver)(this, receiver);
103     return op;
104   }
105 }
106 
107 private enum Flags {
108   locked = 0x1,
109   completed = 0x2,
110   tick = 0x4
111 }
112 
113 private struct Done{}
114 
115 private struct ValueRep{}
116 
117 private template Properties(Sender) {
118   alias Value = Sender.Value;
119   static if (!is(Value == void))
120     alias ValueRep = Value;
121   else
122     alias ValueRep = .ValueRep;
123   alias InternalValue = Algebraic!(Throwable, ValueRep, Done);
124   alias DG = void delegate(InternalValue) nothrow @safe shared;
125 }
126 
127 private struct SharedSenderOp(Sender, Scheduler, ResetLogic resetLogic, Receiver) {
128   alias Props = Properties!(Sender);
129   SharedSender!(Sender, Scheduler, resetLogic) parent;
130   Receiver receiver;
131   StopCallback cb;
132   void start() nothrow @trusted scope {
133     parent.add(&(cast(shared)this).onValue);
134     cb = receiver.getStopToken.onStop(&(cast(shared)this).onStop);
135   }
136   void onStop() nothrow @trusted shared {
137     with(unshared) {
138       /// If this is the last one connected, remove will return false,
139       /// stop the underlying sender and we will receive the setDone via
140       /// the onValue.
141       /// This is to ensure we always await the underlying sender for
142       /// completion.
143       if (parent.remove(&(cast(shared)this).onValue))
144         receiver.setDone();
145     }
146   }
147   void onValue(Props.InternalValue value) nothrow @safe shared {
148     with(unshared) {
149       value.match!((Props.ValueRep v){
150           try {
151             static if (is(Props.Value == void))
152               receiver.setValue();
153             else
154               receiver.setValue(v);
155           } catch (Exception e) {
156             /// TODO: dispose needs to be called in all cases, except
157             /// this onValue can sometimes be called immediately,
158             /// leaving no room to set cb.dispose...
159             cb.dispose();
160             receiver.setError(e);
161           }
162         }, (Throwable e){
163           receiver.setError(e);
164         }, (Done d){
165           receiver.setDone();
166         });
167     }
168   }
169   private auto ref unshared() @trusted nothrow shared {
170     return cast()this;
171   }
172 }
173 
174 private struct SharedSenderReceiver(Sender, Scheduler, ResetLogic resetLogic) {
175   alias InternalValue = Properties!(Sender).InternalValue;
176   alias ValueRep = Properties!(Sender).ValueRep;
177   SharedSenderState!(Sender)* state;
178   Scheduler scheduler;
179   static if (is(Sender.Value == void))
180     void setValue() @safe {
181       process(InternalValue(ValueRep()));
182     }
183   else
184     void setValue(ValueRep v) @safe {
185       process(InternalValue(v));
186     }
187   void setDone() @safe nothrow {
188     process(InternalValue(Done()));
189   }
190   void setError(Throwable e) @safe nothrow {
191     process(InternalValue(e));
192   }
193   private void process(InternalValue v) @safe {
194     state.process!(resetLogic)(v);
195   }
196   StopToken getStopToken() @trusted nothrow {
197     return StopToken(state.inst);
198   }
199   Scheduler getScheduler() @safe nothrow scope {
200     return scheduler;
201   }
202 }
203 
204 private struct SharedSenderState(Sender) {
205   import concurrency.bitfield;
206 
207   alias Props = Properties!(Sender);
208 
209   SharedSenderInstState!(Sender) inst;
210   SharedBitField!Flags counter;
211 }
212 
213 private template process(ResetLogic resetLogic) {
214   void process(State, InternalValue)(State state, InternalValue value) @safe {
215     state.inst.value = value;
216     static if (resetLogic == ResetLogic.alwaysReset) {
217       size_t updateFlag = 0;
218     } else {
219       size_t updateFlag = Flags.completed;
220     }
221     with(state.counter.lock(updateFlag)) {
222       auto localState = state.inst;
223       InternalValue v = localState.value.get;
224       release(oldState & (~0x3)); // release early and remove all ticks
225       if (localState.isStopRequested)
226         v = Done();
227       foreach(dg; localState.dgs[])
228         dg(v);
229     }
230   }
231 }
232 
233 private class SharedSenderInstState(Sender) : StopSource {
234   import concurrency.slist;
235   import std.traits : ReturnType;
236   alias Props = Properties!(Sender);
237   shared SList!(Props.DG) dgs;
238   Nullable!(Props.InternalValue) value;
239   this() {
240     this.dgs = new shared SList!(Props.DG);
241   }
242 }
243 
244 /// NOTE: this is a super class to break a dependency cycle of SharedSenderReceiver on itself (which it technically doesn't have but is probably too complex for the compiler)
245 private class SharedSenderInstStateImpl(Sender, Scheduler, ResetLogic resetLogic) : SharedSenderInstState!(Sender) {
246   alias Op = OpType!(Sender, SharedSenderReceiver!(Sender, Scheduler, resetLogic));
247   Op op;
248 }