1 module concurrency.operations.whenall;
2 
3 import concurrency;
4 import concurrency.receiver;
5 import concurrency.sender;
6 import concurrency.stoptoken;
7 import concepts;
8 import std.traits;
9 import concurrency.utils : spin_yield, casWeak;
10 
11 WhenAllSender!(Senders) whenAll(Senders...)(Senders senders) if (Senders.length > 1) {
12   return WhenAllSender!(Senders)(senders);
13 }
14 
15 private enum Flags : size_t {
16   locked = 0x1,
17   value_produced = 0x2,
18   doneOrError_produced = 0x4
19 }
20 
21 private enum Counter : size_t {
22   tick = 0x8
23 }
24 
25 template SenderValues(Senders...) {
26   import std.meta;
27   alias SenderValue(T) = T.Value;
28   alias SenderValues = staticMap!(SenderValue, Senders);
29 }
30 
31 private template WhenAllResult(SenderValues...) {
32   import std.meta;
33   import std.typecons;
34   import mir.algebraic : Algebraic, Nullable;
35   import concurrency.utils : NoVoid;
36   template Cummulative(size_t count, Ts...) {
37     static if (Ts.length > 0) {
38       enum head = count + Ts[0];
39       static if (Ts.length == 1)
40         alias Cummulative = AliasSeq!(head);
41       else static if (Ts.length > 1)
42         alias Cummulative = AliasSeq!(head, Cummulative!(head, Ts[1..$]));
43     } else {
44       alias Cummulative = AliasSeq!();
45     }
46   }
47   alias ValueTypes = Filter!(NoVoid, SenderValues);
48   static if (ValueTypes.length > 1)
49     alias Values = Tuple!(Filter!(NoVoid, SenderValues));
50   else static if (ValueTypes.length == 1)
51     alias Values = ValueTypes[0];
52   alias Indexes = Cummulative!(0, staticMap!(NoVoid, SenderValues));
53 
54   static if (ValueTypes.length > 0) {
55     struct WhenAllResult {
56       Values values;
57       void setValue(T)(T t, size_t index) {
58         switch (index) {
59           foreach(idx, I; Indexes) {
60           case idx:
61             static if (ValueTypes.length == 1)
62               values = t;
63             else static if (is(typeof(values[I-1]) == T))
64               values[I-1] = t;
65             return;
66           }
67         default: assert(false, "out of bounds");
68         }
69       }
70     }
71   } else {
72     struct WhenAllResult {
73     }
74   }
75 }
76 
77 private struct WhenAllOp(Receiver, Senders...) {
78   import std.meta : staticMap;
79   alias R = WhenAllResult!(SenderValues!Senders);
80   alias ElementReceiver(Sender) = WhenAllReceiver!(Receiver, Sender.Value, R);
81   alias ConnectResult(Sender) = OpType!(Sender, ElementReceiver!Sender);
82   alias Ops = staticMap!(ConnectResult, Senders);
83   Receiver receiver;
84   WhenAllState!R state;
85   Ops ops;
86   @disable this(this);
87   @disable this(ref return scope typeof(this) rhs);
88   this(Receiver receiver, Senders senders) {
89     this.receiver = receiver;
90     state = new WhenAllState!R();
91     foreach(i, Sender; Senders) {
92       ops[i] = senders[i].connect(WhenAllReceiver!(Receiver, Sender.Value, R)(receiver, state, i, Senders.length));
93     }
94   }
95   void start() @trusted nothrow scope {
96     import concurrency.stoptoken : StopSource;
97     if (receiver.getStopToken().isStopRequested) {
98       receiver.setDone();
99       return;
100     }
101     state.cb = receiver.getStopToken().onStop(cast(void delegate() nothrow @safe shared)&state.stop); // butt ugly cast, but it won't take the second overload
102     foreach(i, _; Senders) {
103       ops[i].start();
104     }
105   }
106 }
107 
108 import std.meta : allSatisfy, ApplyRight;
109 
110 struct WhenAllSender(Senders...) if (allSatisfy!(ApplyRight!(models, isSender), Senders)) {
111   alias Result = WhenAllResult!(SenderValues!Senders);
112   static if (hasMember!(Result, "values"))
113     alias Value = typeof(Result.values);
114   else
115     alias Value = void;
116   Senders senders;
117   auto connect(Receiver)(return Receiver receiver) @safe scope return {
118     // ensure NRVO
119     auto op = WhenAllOp!(Receiver, Senders)(receiver, senders);
120     return op;
121   }
122 }
123 
124 private class WhenAllState(Value) : StopSource {
125   import concurrency.bitfield;
126   StopCallback cb;
127   static if (is(typeof(Value.values)))
128     Value value;
129   Exception exception;
130   shared SharedBitField!Flags bitfield;
131 }
132 
133 private struct WhenAllReceiver(Receiver, InnerValue, Value) {
134   import core.atomic : atomicOp, atomicLoad, MemoryOrder;
135   Receiver receiver;
136   WhenAllState!(Value) state;
137   size_t senderIndex;
138   size_t senderCount;
139   auto getStopToken() {
140     return StopToken(state);
141   }
142   private bool isValueProduced(size_t state) {
143     return (state & Flags.value_produced) > 0;
144   }
145   private bool isDoneOrErrorProduced(size_t state) {
146     return (state & Flags.doneOrError_produced) > 0;
147   }
148   private bool isLast(size_t state) {
149     return (state >> 3) == atomicLoad(senderCount);
150   }
151   static if (!is(InnerValue == void))
152     void setValue(InnerValue value) @safe {
153       with (state.bitfield.lock(Flags.value_produced, Counter.tick)) {
154         bool last = isLast(newState);
155         state.value.setValue(value, senderIndex);
156         release();
157         if (last)
158           process(newState);
159       }
160     }
161   else
162     void setValue() @safe {
163       with (state.bitfield.update(Flags.value_produced, Counter.tick)) {
164         bool last = isLast(newState);
165         if (last)
166           process(newState);
167       }
168     }
169   void setDone() @safe nothrow {
170     with (state.bitfield.update(Flags.doneOrError_produced, Counter.tick)) {
171       bool last = isLast(newState);
172       if (!isDoneOrErrorProduced(oldState))
173         state.stop();
174       if (last)
175         process(newState);
176     }
177   }
178   void setError(Exception exception) @safe nothrow {
179     with (state.bitfield.lock(Flags.doneOrError_produced, Counter.tick)) {
180       bool last = isLast(newState);
181       if (!isDoneOrErrorProduced(oldState)) {
182         state.exception = exception;
183         release(); // must release before calling .stop
184         state.stop();
185       } else
186         release();
187       if (last)
188         process(newState);
189     }
190   }
191   private void process(size_t newState) {
192     state.cb.dispose();
193 
194     if (receiver.getStopToken().isStopRequested)
195       receiver.setDone();
196     else if (isDoneOrErrorProduced(newState)) {
197       if (state.exception)
198         receiver.setError(state.exception);
199       else
200         receiver.setDone();
201     } else {
202       import concurrency.receiver : setValueOrError;
203       static if (is(typeof(Value.values)))
204         receiver.setValueOrError(state.value.values);
205       else
206         receiver.setValueOrError();
207     }
208   }
209   mixin ForwardExtensionPoints!receiver;
210 }