1 module concurrency.operations.toshared;
2 
3 import concurrency;
4 import concurrency.receiver;
5 import concurrency.sender;
6 import concurrency.stoptoken;
7 import concepts;
8 import std.traits;
9 import mir.algebraic : Algebraic, Nullable, match;
10 
11 /// Wraps a Sender in a SharedSender. A SharedSender allows many receivers to connect to the same underlying Sender, forwarding the same termination call to each receiver.
12 /// The underlying Sender is connected and started only once. It can be explicitely `reset` so that it connects and starts the underlying Sender the next time it started. Calling `reset` when the underlying Sender hasn't completed is a no-op.
13 /// When the last receiver triggers its stoptoken while the underlying Sender is still running, the latter will be cancelled and one termination function of the former will be called after the latter is completed. (This is to ensure structured concurrency, otherwise tasks could be left running without anyone awaiting them).
14 /// If an receiver is connected after the underlying Sender has already been completed, that receiver will have one of its termination functions called immediately.
15 /// This operation is useful when you have multiple tasks that all depend on one shared task. It allows you to write the shared task as a regular Sender and simply apply a `.toShared`.
16 auto toShared(Sender, Scheduler)(Sender sender, Scheduler scheduler) {
17   return new SharedSender!(Sender, Scheduler)(sender, scheduler);
18 }
19 
20 auto toShared(Sender)(Sender sender) {
21   return new SharedSender!(Sender, NullScheduler)(sender, NullScheduler());
22 }
23 
24 private struct NullScheduler {}
25 private struct Done{}
26 private struct ValueRep{}
27 
28 class SharedSender(Sender, Scheduler) if (models!(Sender, isSender)) {
29   import std.traits : ReturnType;
30   import concurrency.slist;
31   import concurrency.bitfield;
32   static assert(models!(typeof(this), isSender));
33   alias Value = Sender.Value;
34   static if (!is(Value == void))
35     alias ValueRep = Value;
36   alias InternalValue = Algebraic!(Exception, ValueRep, Done);
37   alias DG = void delegate(InternalValue) nothrow @safe shared;
38   static struct SharedSenderOp(Receiver) {
39     SharedSender parent;
40     Receiver receiver;
41     StopCallback cb;
42     void start() nothrow @trusted scope {
43       parent.add(&(cast(shared)this).onValue);
44       cb = receiver.getStopToken.onStop(&(cast(shared)this).onStop);
45     }
46     void onStop() nothrow @trusted shared {
47       with(unshared) {
48         /// If this is the last one connected, remove will return false,
49         /// stop the underlying sender and we will receive the setDone via
50         /// the onValue.
51         /// This is to ensure we always await the underlying sender for
52         /// completion.
53         if (parent.remove(&(cast(shared)this).onValue))
54           receiver.setDone();
55       }
56     }
57     void onValue(InternalValue value) nothrow @safe shared {
58       with(unshared) {
59         value.match!((ValueRep v){
60             try {
61               static if (is(Value == void))
62                 receiver.setValue();
63               else
64                 receiver.setValue(v);
65             } catch (Exception e) {
66               /// TODO: dispose needs to be called in all cases, except
67               /// this onValue can sometimes be called immediately,
68               /// leaving no room to set cb.dispose...
69               cb.dispose();
70               receiver.setError(e);
71             }
72           }, (Exception e){
73             receiver.setError(e);
74           }, (Done d){
75             receiver.setDone();
76           });
77       }
78     }
79     private auto ref unshared() @trusted nothrow shared {
80       return cast()this;
81     }
82   }
83   static class SharedSenderState : StopSource {
84     import std.traits : ReturnType;
85     alias Op = OpType!(Sender, SharedSenderReceiver);
86     SharedSender parent;
87     shared SList!DG dgs;
88     Nullable!InternalValue value;
89     Op op;
90     this(SharedSender parent) {
91       this.dgs = new shared SList!DG;
92       this.parent = parent;
93     }
94   }
95   static struct SharedSenderReceiver {
96     SharedSenderState state;
97     Scheduler scheduler;
98     static if (is(Sender.Value == void))
99       void setValue() @safe {
100         state.value = InternalValue(ValueRep());
101         process();
102       }
103     else
104       void setValue(ValueRep v) @safe {
105         state.value = InternalValue(v);
106         process();
107       }
108     void setDone() @safe nothrow {
109       state.value = InternalValue(Done());
110       process();
111     }
112     void setError(Exception e) @safe nothrow {
113       state.value = InternalValue(e);
114       process();
115     }
116     private void process() @trusted {
117       state.parent.process();
118     }
119     StopToken getStopToken() @safe nothrow {
120       return StopToken(state);
121     }
122     Scheduler getScheduler() @safe nothrow scope {
123       return scheduler;
124     }
125   }
126   private {
127     Sender sender;
128     Scheduler scheduler;
129     SharedSenderState state;
130     enum Flags {
131       locked = 0x1,
132       completed = 0x2,
133       tick = 0x4
134     }
135     SharedBitField!Flags counter;
136     void add(DG dg) @safe nothrow {
137       with(counter.lock(0, Flags.tick)) {
138         if (was(Flags.completed)) {
139           InternalValue value = state.value.get;
140           release(Flags.tick); // release early
141           dg(value);
142         } else {
143           if ((oldState >> 2) == 0) {
144             auto localState = new SharedSenderState(this);
145             this.state = localState;
146             release(); // release early
147             localState.dgs.pushBack(dg);
148             try {
149               localState.op = sender.connect(SharedSenderReceiver(localState, scheduler));
150             } catch (Exception e) {
151               state.value = InternalValue(e);
152               process();
153             }
154             localState.op.start();
155           } else {
156             auto localState = state;
157             release(); // release early
158             localState.dgs.pushBack(dg);
159           }
160         }
161       }
162     }
163     /// returns false if it is the last
164     bool remove(DG dg) @safe nothrow {
165       with (counter.lock(0, 0, Flags.tick)) {
166         if (was(Flags.completed)) {
167           release(0-Flags.tick); // release early
168           return true;
169         }
170         if ((newState >> 2) == 0) {
171           auto localStopSource = state;
172           release(); // release early
173           localStopSource.stop();
174           return false;
175         } else {
176           auto localReceiver = state;
177           release(); // release early
178           localReceiver.dgs.remove(dg);
179           return true;
180         }
181       }
182     }
183   }
184   private void process() {
185     with(counter.lock(Flags.completed)) {
186       release(oldState & (~0x3)); // release early and remove all ticks
187       InternalValue v = state.value.get;
188       if (state.isStopRequested)
189         v = Done();
190       foreach(dg; state.dgs[])
191         dg(v);
192     }
193   }
194   bool isCompleted() @trusted {
195     import core.atomic : MemoryOrder;
196     return (counter.load!(MemoryOrder.acq) & Flags.completed) > 0;
197   }
198   void reset() @trusted {
199     with (counter.lock()) {
200       if (was(Flags.completed))
201         release(Flags.completed);
202     }
203   }
204   this(Sender sender, Scheduler scheduler) {
205     this.sender = sender;
206     this.scheduler = scheduler;
207   }
208   auto connect(Receiver)(return Receiver receiver) @safe scope return {
209     // ensure NRVO
210     auto op = SharedSenderOp!Receiver(this, receiver);
211     return op;
212   }
213 }