1 module concurrency.operations.whenall;
2 
3 import concurrency;
4 import concurrency.receiver;
5 import concurrency.sender;
6 import concurrency.stoptoken;
7 import concepts;
8 import std.traits;
9 import concurrency.utils : spin_yield, casWeak;
10 
11 WhenAllSender!(Senders) whenAll(Senders...)(Senders senders) if (Senders.length > 1) {
12   return WhenAllSender!(Senders)(senders);
13 }
14 
15 private enum Flags : size_t {
16   locked = 0x1,
17   value_produced = 0x2,
18   doneOrError_produced = 0x4
19 }
20 
21 private enum Counter : size_t {
22   tick = 0x8
23 }
24 
25 private template WhenAllResult(Senders...) {
26   import std.meta;
27   import std.typecons;
28   import mir.algebraic : Algebraic, Nullable;
29   import concurrency.utils : NoVoid;
30   template Cummulative(size_t count, Ts...) {
31     static if (Ts.length > 0) {
32       enum head = count + Ts[0];
33       static if (Ts.length == 1)
34         alias Cummulative = AliasSeq!(head);
35       else static if (Ts.length > 1)
36         alias Cummulative = AliasSeq!(head, Cummulative!(head, Ts[1..$]));
37     } else {
38       alias Cummulative = AliasSeq!();
39     }
40   }
41   alias SenderValue(T) = T.Value;
42   alias SenderValues = staticMap!(SenderValue, Senders);
43   alias ValueTypes = Filter!(NoVoid, SenderValues);
44   static if (ValueTypes.length > 1)
45     alias Values = Tuple!(Filter!(NoVoid, SenderValues));
46   else static if (ValueTypes.length == 1)
47     alias Values = ValueTypes[0];
48   alias Indexes = Cummulative!(0, staticMap!(NoVoid, SenderValues));
49 
50   static if (ValueTypes.length > 0) {
51     struct WhenAllResult {
52       Values values;
53       void setValue(T)(T t, size_t index) {
54         switch (index) {
55           foreach(idx, I; Indexes) {
56           case idx:
57             static if (ValueTypes.length == 1)
58               values = t;
59             else static if (is(typeof(values[I-1]) == T))
60               values[I-1] = t;
61             return;
62           }
63         default: assert(false, "out of bounds");
64         }
65       }
66     }
67   } else {
68     struct WhenAllResult {
69     }
70   }
71 }
72 
73 private struct WhenAllOp(Receiver, Senders...) {
74   import std.meta : staticMap;
75   alias R = WhenAllResult!(Senders);
76   alias ElementReceiver(Sender) = WhenAllReceiver!(Receiver, Sender.Value, R);
77   alias ConnectResult(Sender) = OpType!(Sender, ElementReceiver!Sender);
78   alias Ops = staticMap!(ConnectResult, Senders);
79   Receiver receiver;
80   WhenAllState!R state;
81   Ops ops;
82   @disable this(this);
83   @disable this(ref return scope typeof(this) rhs);
84   this(Receiver receiver, Senders senders) {
85     this.receiver = receiver;
86     state = new WhenAllState!R();
87     foreach(i, Sender; Senders) {
88       ops[i] = senders[i].connect(WhenAllReceiver!(Receiver, Sender.Value, WhenAllResult!(Senders))(receiver, state, i, Senders.length));
89     }
90   }
91   void start() @trusted nothrow scope {
92     import concurrency.stoptoken : StopSource;
93     if (receiver.getStopToken().isStopRequested) {
94       receiver.setDone();
95       return;
96     }
97     state.cb = receiver.getStopToken().onStop(cast(void delegate() nothrow @safe shared)&state.stop); // butt ugly cast, but it won't take the second overload
98     foreach(i, _; Senders) {
99       ops[i].start();
100     }
101   }
102 }
103 
104 import std.meta : allSatisfy, ApplyRight;
105 
106 struct WhenAllSender(Senders...) if (allSatisfy!(ApplyRight!(models, isSender), Senders)) {
107   alias Result = WhenAllResult!(Senders);
108   static if (hasMember!(Result, "values"))
109     alias Value = typeof(Result.values);
110   else
111     alias Value = void;
112   Senders senders;
113   auto connect(Receiver)(return Receiver receiver) @safe scope return {
114     // ensure NRVO
115     auto op = WhenAllOp!(Receiver, Senders)(receiver, senders);
116     return op;
117   }
118 }
119 
120 private class WhenAllState(Value) : StopSource {
121   import concurrency.bitfield;
122   StopCallback cb;
123   static if (is(typeof(Value.values)))
124     Value value;
125   Exception exception;
126   shared SharedBitField!Flags bitfield;
127 }
128 
129 private struct WhenAllReceiver(Receiver, InnerValue, Value) {
130   import core.atomic : atomicOp, atomicLoad, MemoryOrder;
131   Receiver receiver;
132   WhenAllState!(Value) state;
133   size_t senderIndex;
134   size_t senderCount;
135   auto getStopToken() {
136     return StopToken(state);
137   }
138   private bool isValueProduced(size_t state) {
139     return (state & Flags.value_produced) > 0;
140   }
141   private bool isDoneOrErrorProduced(size_t state) {
142     return (state & Flags.doneOrError_produced) > 0;
143   }
144   private bool isLast(size_t state) {
145     return (state >> 3) == senderCount;
146   }
147   static if (!is(InnerValue == void))
148     void setValue(InnerValue value) @safe {
149       with (state.bitfield.lock(Flags.value_produced, Counter.tick)) {
150         state.value.setValue(value, senderIndex);
151         release();
152         process(newState);
153       }
154     }
155   else
156     void setValue() @safe {
157       with (state.bitfield.update(Flags.value_produced, Counter.tick)) {
158         process(newState);
159       }
160     }
161   void setDone() @safe nothrow {
162     with (state.bitfield.update(Flags.doneOrError_produced, Counter.tick)) {
163       if (!isDoneOrErrorProduced(oldState))
164         state.stop();
165       process(newState);
166     }
167   }
168   void setError(Exception exception) @safe nothrow {
169     with (state.bitfield.lock(Flags.doneOrError_produced, Counter.tick)) {
170       if (!isDoneOrErrorProduced(oldState)) {
171         state.exception = exception;
172         release(); // must release before calling .stop
173         state.stop();
174       } else
175         release();
176       process(newState);
177     }
178   }
179   private void process(size_t newState) {
180     if (!isLast(newState))
181       return;
182 
183     state.cb.dispose();
184 
185     if (receiver.getStopToken().isStopRequested)
186       receiver.setDone();
187     else if (isDoneOrErrorProduced(newState)) {
188       if (state.exception)
189         receiver.setError(state.exception);
190       else
191         receiver.setDone();
192     } else {
193       import concurrency.receiver : setValueOrError;
194       static if (is(typeof(Value.values)))
195         receiver.setValueOrError(state.value.values);
196       else
197         receiver.setValueOrError();
198     }
199   }
200   mixin ForwardExtensionPoints!receiver;
201 }