1 module concurrency.operations.whenall;
2 
3 import concurrency;
4 import concurrency.receiver;
5 import concurrency.sender;
6 import concurrency.stoptoken;
7 import concepts;
8 import std.traits;
9 import concurrency.utils : spin_yield, casWeak;
10 
11 WhenAllSender!(Senders) whenAll(Senders...)(Senders senders) {
12   return WhenAllSender!(Senders)(senders);
13 }
14 
15 private enum Flags : size_t {
16   locked = 0x1,
17   value_produced = 0x2,
18   doneOrError_produced = 0x4
19 }
20 
21 private enum Counter : size_t {
22   tick = 0x8
23 }
24 
25 template GetSenderValues(Senders...) {
26   import std.meta;
27   alias SenderValue(T) = T.Value;
28   alias GetSenderValues = staticMap!(SenderValue, Senders);
29 }
30 
31 private template WhenAllResult(Senders...) if (Senders.length > 1) {
32   import std.meta;
33   import std.typecons;
34   import mir.algebraic : Algebraic, Nullable;
35   import concurrency.utils : NoVoid;
36   template Cummulative(size_t count, Ts...) {
37     static if (Ts.length > 0) {
38       enum head = count + Ts[0];
39       static if (Ts.length == 1)
40         alias Cummulative = AliasSeq!(head);
41       else static if (Ts.length > 1)
42         alias Cummulative = AliasSeq!(head, Cummulative!(head, Ts[1..$]));
43     } else {
44       alias Cummulative = AliasSeq!();
45     }
46   }
47   alias SenderValues = GetSenderValues!(Senders);
48   alias ValueTypes = Filter!(NoVoid, SenderValues);
49   static if (ValueTypes.length > 1)
50     alias Values = Tuple!(Filter!(NoVoid, SenderValues));
51   else static if (ValueTypes.length == 1)
52     alias Values = ValueTypes[0];
53   alias Indexes = Cummulative!(0, staticMap!(NoVoid, SenderValues));
54 
55   static if (ValueTypes.length > 0) {
56     struct WhenAllResult {
57       Values values;
58       void setValue(T)(T t, size_t index) {
59         switch (index) {
60           foreach(idx, I; Indexes) {
61           case idx:
62             static if (ValueTypes.length == 1)
63               values = t;
64             else static if (is(typeof(values[I-1]) == T))
65               values[I-1] = t;
66             return;
67           }
68         default: assert(false, "out of bounds");
69         }
70       }
71     }
72   } else {
73     struct WhenAllResult {
74     }
75   }
76 }
77 
78 alias ArrayElement(T : P[], P) = P;
79 
80 private template WhenAllResult(Senders...) if (Senders.length == 1) {
81   alias Element = ArrayElement!(Senders).Value;
82   static if (is(Element : void)) {
83     struct WhenAllResult {}
84   } else {
85     struct WhenAllResult {
86       Element[] values;
87       void setValue(Element)(Element elem, size_t index) {
88         values[index] = elem;
89       }
90     }
91   }
92 }
93 
94 private struct WhenAllOp(Receiver, Senders...) {
95   import std.meta : staticMap;
96   alias R = WhenAllResult!(Senders);
97   static if (Senders.length > 1) {
98     alias ElementReceiver(Sender) = WhenAllReceiver!(Receiver, Sender.Value, R);
99     alias ConnectResult(Sender) = OpType!(Sender, ElementReceiver!Sender);
100     alias Ops = staticMap!(ConnectResult, Senders);
101   } else {
102     alias ElementReceiver = WhenAllReceiver!(Receiver, ArrayElement!(Senders).Value, R);
103     alias Ops = OpType!(ArrayElement!(Senders), ElementReceiver)[];
104   }
105   Receiver receiver;
106   WhenAllState!R state;
107   Ops ops;
108   @disable this(this);
109   @disable this(ref return scope typeof(this) rhs);
110   this(Receiver receiver, Senders senders) {
111     this.receiver = receiver;
112     state = new WhenAllState!R();
113     static if (Senders.length > 1) {
114       foreach(i, Sender; Senders) {
115         ops[i] = senders[i].connect(WhenAllReceiver!(Receiver, Sender.Value, R)(receiver, state, i, Senders.length));
116       }
117     } else {
118       static if (!is(ArrayElement!(Senders).Value : void))
119         state.value.values.length = senders[0].length;
120       ops.length = senders[0].length;
121       foreach(i; 0..senders[0].length) {
122         ops[i] = senders[0][i].connect(WhenAllReceiver!(Receiver, ArrayElement!(Senders).Value, R)(receiver, state, i, senders[0].length));
123       }
124     }
125   }
126   void start() @trusted nothrow scope {
127     import concurrency.stoptoken : StopSource;
128     if (receiver.getStopToken().isStopRequested) {
129       receiver.setDone();
130       return;
131     }
132     state.cb = receiver.getStopToken().onStop(cast(void delegate() nothrow @safe shared)&state.stop); // butt ugly cast, but it won't take the second overload
133     static if (Senders.length > 1) {
134       foreach(i, _; Senders) {
135         ops[i].start();
136       }
137     } else {
138       foreach(i; 0..ops.length) {
139         ops[i].start();
140       }
141     }
142   }
143 }
144 
145 import std.meta : allSatisfy, ApplyRight;
146 
147 struct WhenAllSender(Senders...)
148      if ((Senders.length > 1 && allSatisfy!(ApplyRight!(models, isSender), Senders)) ||
149          (models!(ArrayElement!(Senders[0]), isSender))) {
150   alias Result = WhenAllResult!(Senders);
151   static if (hasMember!(Result, "values"))
152     alias Value = typeof(Result.values);
153   else
154     alias Value = void;
155   Senders senders;
156   auto connect(Receiver)(return Receiver receiver) @safe scope return {
157     // ensure NRVO
158     auto op = WhenAllOp!(Receiver, Senders)(receiver, senders);
159     return op;
160   }
161 }
162 
163 private class WhenAllState(Value) : StopSource {
164   import concurrency.bitfield;
165   StopCallback cb;
166   static if (is(typeof(Value.values)))
167     Value value;
168   Throwable exception;
169   shared SharedBitField!Flags bitfield;
170 }
171 
172 private struct WhenAllReceiver(Receiver, InnerValue, Value) {
173   import core.atomic : atomicOp, atomicLoad, MemoryOrder;
174   Receiver receiver;
175   WhenAllState!(Value) state;
176   size_t senderIndex;
177   size_t senderCount;
178   auto getStopToken() {
179     return StopToken(state);
180   }
181   private bool isValueProduced(size_t state) {
182     return (state & Flags.value_produced) > 0;
183   }
184   private bool isDoneOrErrorProduced(size_t state) {
185     return (state & Flags.doneOrError_produced) > 0;
186   }
187   private bool isLast(size_t state) {
188     return (state >> 3) == atomicLoad(senderCount);
189   }
190   static if (!is(InnerValue == void))
191     void setValue(InnerValue value) @safe {
192       with (state.bitfield.lock(Flags.value_produced, Counter.tick)) {
193         bool last = isLast(newState);
194         state.value.setValue(value, senderIndex);
195         release();
196         if (last)
197           process(newState);
198       }
199     }
200   else
201     void setValue() @safe {
202       with (state.bitfield.update(Flags.value_produced, Counter.tick)) {
203         bool last = isLast(newState);
204         if (last)
205           process(newState);
206       }
207     }
208   void setDone() @safe nothrow {
209     with (state.bitfield.update(Flags.doneOrError_produced, Counter.tick)) {
210       bool last = isLast(newState);
211       if (!isDoneOrErrorProduced(oldState))
212         state.stop();
213       if (last)
214         process(newState);
215     }
216   }
217   void setError(Throwable exception) @safe nothrow {
218     with (state.bitfield.lock(Flags.doneOrError_produced, Counter.tick)) {
219       bool last = isLast(newState);
220       if (!isDoneOrErrorProduced(oldState)) {
221         state.exception = exception;
222         release(); // must release before calling .stop
223         state.stop();
224       } else
225         release();
226       if (last)
227         process(newState);
228     }
229   }
230   private void process(size_t newState) {
231     state.cb.dispose();
232 
233     if (receiver.getStopToken().isStopRequested)
234       receiver.setDone();
235     else if (isDoneOrErrorProduced(newState)) {
236       if (state.exception)
237         receiver.setError(state.exception);
238       else
239         receiver.setDone();
240     } else {
241       import concurrency.receiver : setValueOrError;
242       static if (is(typeof(Value.values)))
243         receiver.setValueOrError(state.value.values);
244       else
245         receiver.setValueOrError();
246     }
247   }
248   mixin ForwardExtensionPoints!receiver;
249 }