1 module concurrency.operations.toshared;
2 
3 import concurrency.receiver;
4 import concurrency.sender;
5 import concurrency.stoptoken;
6 import concurrency.scheduler : NullScheduler;
7 import concepts;
8 import std.traits;
9 import mir.algebraic : Algebraic, Nullable, match;
10 
11 /// Wraps a Sender in a SharedSender. A SharedSender allows many receivers to connect to the same underlying Sender, forwarding the same termination call to each receiver.
12 /// The underlying Sender is connected and started only once. It can be explicitely `reset` so that it connects and starts the underlying Sender the next time it started. Calling `reset` when the underlying Sender hasn't completed is a no-op.
13 /// When the last receiver triggers its stoptoken while the underlying Sender is still running, the latter will be cancelled and one termination function of the former will be called after the latter is completed. (This is to ensure structured concurrency, otherwise tasks could be left running without anyone awaiting them).
14 /// If an receiver is connected after the underlying Sender has already been completed, that receiver will have one of its termination functions called immediately.
15 /// This operation is useful when you have multiple tasks that all depend on one shared task. It allows you to write the shared task as a regular Sender and simply apply a `.toShared`.
16 auto toShared(Sender, Scheduler)(Sender sender, Scheduler scheduler) {
17   return new SharedSender!(Sender, Scheduler, ResetLogic.keepLatest)(sender, scheduler);
18 }
19 
20 auto toShared(Sender)(Sender sender) {
21   return new SharedSender!(Sender, NullScheduler, ResetLogic.keepLatest)(sender, NullScheduler());
22 }
23 
24 enum ResetLogic {
25   keepLatest,
26   alwaysReset
27 }
28 
29 class SharedSender(Sender, Scheduler, ResetLogic resetLogic) if (models!(Sender, isSender)) {
30   import std.traits : ReturnType;
31   static assert(models!(typeof(this), isSender));
32   alias Props = Properties!(Sender);
33   alias Value = Props.Value;
34   alias InternalValue = Props.InternalValue;
35   private {
36     Sender sender;
37     Scheduler scheduler;
38     SharedSenderState!(Sender) state;
39     void add(Props.DG dg) @safe nothrow {
40       with(state.counter.lock(0, Flags.tick)) {
41         if (was(Flags.completed)) {
42           InternalValue value = state.inst.value.get;
43           release(Flags.tick); // release early
44           dg(value);
45         } else {
46           if ((oldState >> 2) == 0) {
47             auto localState = new SharedSenderInstStateImpl!(Sender, Scheduler, resetLogic)();
48             this.state.inst = localState;
49             release(); // release early
50             localState.dgs.pushBack(dg);
51             try {
52               localState.op = sender.connect(SharedSenderReceiver!(Sender, Scheduler, resetLogic)(&state, scheduler));
53             } catch (Exception e) {
54               state.process!(resetLogic)(InternalValue(e));
55             }
56             localState.op.start();
57           } else {
58             auto localState = state.inst;
59             localState.dgs.pushBack(dg);
60           }
61         }
62       }
63     }
64     /// returns false if it is the last
65     bool remove(Props.DG dg) @safe nothrow {
66       with (state.counter.lock(0, 0, Flags.tick)) {
67         if (was(Flags.completed)) {
68           release(0-Flags.tick); // release early
69           return true;
70         }
71         if ((newState >> 2) == 0) {
72           auto localStopSource = state.inst;
73           release(); // release early
74           localStopSource.stop();
75           return false;
76         } else {
77           auto localReceiver = state.inst;
78           release(); // release early
79           localReceiver.dgs.remove(dg);
80           return true;
81         }
82       }
83     }
84   }
85   bool isCompleted() @trusted {
86     import core.atomic : MemoryOrder;
87     return (state.counter.load!(MemoryOrder.acq) & Flags.completed) > 0;
88   }
89   void reset() @trusted {
90     with (state.counter.lock()) {
91       if (was(Flags.completed))
92         release(Flags.completed);
93     }
94   }
95   this(Sender sender, Scheduler scheduler) {
96     this.sender = sender;
97     this.scheduler = scheduler;
98   }
99   auto connect(Receiver)(return Receiver receiver) @safe return scope {
100     // ensure NRVO
101     auto op = SharedSenderOp!(Sender, Scheduler, resetLogic, Receiver)(this, receiver);
102     return op;
103   }
104 }
105 
106 private enum Flags {
107   locked = 0x1,
108   completed = 0x2,
109   tick = 0x4
110 }
111 
112 private struct Done{}
113 
114 private struct ValueRep{}
115 
116 private template Properties(Sender) {
117   alias Value = Sender.Value;
118   static if (!is(Value == void))
119     alias ValueRep = Value;
120   else
121     alias ValueRep = .ValueRep;
122   alias InternalValue = Algebraic!(Throwable, ValueRep, Done);
123   alias DG = void delegate(InternalValue) nothrow @safe shared;
124 }
125 
126 private struct SharedSenderOp(Sender, Scheduler, ResetLogic resetLogic, Receiver) {
127   alias Props = Properties!(Sender);
128   SharedSender!(Sender, Scheduler, resetLogic) parent;
129   Receiver receiver;
130   StopCallback cb;
131   @disable this(ref return scope typeof(this) rhs);
132   @disable this(this);
133   void start() nothrow @trusted scope {
134     parent.add(&(cast(shared)this).onValue);
135     cb = receiver.getStopToken.onStop(&(cast(shared)this).onStop);
136   }
137   void onStop() nothrow @trusted shared {
138     with(unshared) {
139       /// If this is the last one connected, remove will return false,
140       /// stop the underlying sender and we will receive the setDone via
141       /// the onValue.
142       /// This is to ensure we always await the underlying sender for
143       /// completion.
144       if (parent.remove(&(cast(shared)this).onValue))
145         receiver.setDone();
146     }
147   }
148   void onValue(Props.InternalValue value) nothrow @safe shared {
149     with(unshared) {
150       value.match!((Props.ValueRep v){
151           try {
152             static if (is(Props.Value == void))
153               receiver.setValue();
154             else
155               receiver.setValue(v);
156           } catch (Exception e) {
157             /// TODO: dispose needs to be called in all cases, except
158             /// this onValue can sometimes be called immediately,
159             /// leaving no room to set cb.dispose...
160             cb.dispose();
161             receiver.setError(e);
162           }
163         }, (Throwable e){
164           receiver.setError(e);
165         }, (Done d){
166           receiver.setDone();
167         });
168     }
169   }
170   private auto ref unshared() @trusted nothrow shared {
171     return cast()this;
172   }
173 }
174 
175 private struct SharedSenderReceiver(Sender, Scheduler, ResetLogic resetLogic) {
176   alias InternalValue = Properties!(Sender).InternalValue;
177   alias ValueRep = Properties!(Sender).ValueRep;
178   SharedSenderState!(Sender)* state;
179   Scheduler scheduler;
180   static if (is(Sender.Value == void))
181     void setValue() @safe {
182       process(InternalValue(ValueRep()));
183     }
184   else
185     void setValue(ValueRep v) @safe {
186       process(InternalValue(v));
187     }
188   void setDone() @safe nothrow {
189     process(InternalValue(Done()));
190   }
191   void setError(Throwable e) @safe nothrow {
192     process(InternalValue(e));
193   }
194   private void process(InternalValue v) @safe {
195     state.process!(resetLogic)(v);
196   }
197   StopToken getStopToken() @trusted nothrow {
198     return StopToken(state.inst);
199   }
200   Scheduler getScheduler() @safe nothrow scope {
201     return scheduler;
202   }
203 }
204 
205 private struct SharedSenderState(Sender) {
206   import concurrency.bitfield;
207 
208   alias Props = Properties!(Sender);
209 
210   SharedSenderInstState!(Sender) inst;
211   SharedBitField!Flags counter;
212 }
213 
214 private template process(ResetLogic resetLogic) {
215   void process(State, InternalValue)(State state, InternalValue value) @safe {
216     state.inst.value = value;
217     static if (resetLogic == ResetLogic.alwaysReset) {
218       size_t updateFlag = 0;
219     } else {
220       size_t updateFlag = Flags.completed;
221     }
222     with(state.counter.lock(updateFlag)) {
223       auto localState = state.inst;
224       InternalValue v = localState.value.get;
225       release(oldState & (~0x3)); // release early and remove all ticks
226       if (localState.isStopRequested)
227         v = Done();
228       foreach(dg; localState.dgs[])
229         dg(v);
230     }
231   }
232 }
233 
234 private class SharedSenderInstState(Sender) : StopSource {
235   import concurrency.slist;
236   import std.traits : ReturnType;
237   alias Props = Properties!(Sender);
238   shared SList!(Props.DG) dgs;
239   Nullable!(Props.InternalValue) value;
240   this() {
241     this.dgs = new shared SList!(Props.DG);
242   }
243 }
244 
245 /// NOTE: this is a super class to break a dependency cycle of SharedSenderReceiver on itself (which it technically doesn't have but is probably too complex for the compiler)
246 private class SharedSenderInstStateImpl(Sender, Scheduler, ResetLogic resetLogic) : SharedSenderInstState!(Sender) {
247   alias Op = OpType!(Sender, SharedSenderReceiver!(Sender, Scheduler, resetLogic));
248   Op op;
249 }