1 module concurrency.operations.whenall; 2 3 import concurrency; 4 import concurrency.receiver; 5 import concurrency.sender; 6 import concurrency.stoptoken; 7 import concepts; 8 import std.traits; 9 import concurrency.utils : spin_yield, casWeak; 10 11 WhenAllSender!(Senders) whenAll(Senders...)(Senders senders) { 12 return WhenAllSender!(Senders)(senders); 13 } 14 15 private enum Flags : size_t { 16 locked = 0x1, 17 value_produced = 0x2, 18 doneOrError_produced = 0x4 19 } 20 21 private enum Counter : size_t { 22 tick = 0x8 23 } 24 25 template GetSenderValues(Senders...) { 26 import std.meta; 27 alias SenderValue(T) = T.Value; 28 alias GetSenderValues = staticMap!(SenderValue, Senders); 29 } 30 31 private template WhenAllResult(Senders...) if (Senders.length > 1) { 32 import std.meta; 33 import std.typecons; 34 import mir.algebraic : Algebraic, Nullable; 35 import concurrency.utils : NoVoid; 36 template Cummulative(size_t count, Ts...) { 37 static if (Ts.length > 0) { 38 enum head = count + Ts[0]; 39 static if (Ts.length == 1) 40 alias Cummulative = AliasSeq!(head); 41 else static if (Ts.length > 1) 42 alias Cummulative = AliasSeq!(head, Cummulative!(head, Ts[1..$])); 43 } else { 44 alias Cummulative = AliasSeq!(); 45 } 46 } 47 alias SenderValues = GetSenderValues!(Senders); 48 alias ValueTypes = Filter!(NoVoid, SenderValues); 49 static if (ValueTypes.length > 1) 50 alias Values = Tuple!(Filter!(NoVoid, SenderValues)); 51 else static if (ValueTypes.length == 1) 52 alias Values = ValueTypes[0]; 53 alias Indexes = Cummulative!(0, staticMap!(NoVoid, SenderValues)); 54 55 static if (ValueTypes.length > 0) { 56 struct WhenAllResult { 57 Values values; 58 void setValue(T)(T t, size_t index) { 59 switch (index) { 60 foreach(idx, I; Indexes) { 61 case idx: 62 static if (ValueTypes.length == 1) 63 values = t; 64 else static if (is(typeof(values[I-1]) == T)) 65 values[I-1] = t; 66 return; 67 } 68 default: assert(false, "out of bounds"); 69 } 70 } 71 } 72 } else { 73 struct WhenAllResult { 74 } 75 } 76 } 77 78 alias ArrayElement(T : P[], P) = P; 79 80 private template WhenAllResult(Senders...) if (Senders.length == 1) { 81 alias Element = ArrayElement!(Senders).Value; 82 static if (is(Element : void)) { 83 struct WhenAllResult {} 84 } else { 85 struct WhenAllResult { 86 Element[] values; 87 void setValue(Element)(Element elem, size_t index) { 88 values[index] = elem; 89 } 90 } 91 } 92 } 93 94 private struct WhenAllOp(Receiver, Senders...) { 95 import std.meta : staticMap; 96 alias R = WhenAllResult!(Senders); 97 static if (Senders.length > 1) { 98 alias ElementReceiver(Sender) = WhenAllReceiver!(Receiver, Sender.Value, R); 99 alias ConnectResult(Sender) = OpType!(Sender, ElementReceiver!Sender); 100 alias Ops = staticMap!(ConnectResult, Senders); 101 } else { 102 alias ElementReceiver = WhenAllReceiver!(Receiver, ArrayElement!(Senders).Value, R); 103 alias Ops = OpType!(ArrayElement!(Senders), ElementReceiver)[]; 104 } 105 Receiver receiver; 106 WhenAllState!R state; 107 Ops ops; 108 @disable this(this); 109 @disable this(ref return scope typeof(this) rhs); 110 this(Receiver receiver, Senders senders) { 111 this.receiver = receiver; 112 state = new WhenAllState!R(); 113 static if (Senders.length > 1) { 114 foreach(i, Sender; Senders) { 115 ops[i] = senders[i].connect(WhenAllReceiver!(Receiver, Sender.Value, R)(receiver, state, i, Senders.length)); 116 } 117 } else { 118 static if (!is(ArrayElement!(Senders).Value : void)) 119 state.value.values.length = senders[0].length; 120 ops.length = senders[0].length; 121 foreach(i; 0..senders[0].length) { 122 ops[i] = senders[0][i].connect(WhenAllReceiver!(Receiver, ArrayElement!(Senders).Value, R)(receiver, state, i, senders[0].length)); 123 } 124 } 125 } 126 void start() @trusted nothrow scope { 127 import concurrency.stoptoken : StopSource; 128 if (receiver.getStopToken().isStopRequested) { 129 receiver.setDone(); 130 return; 131 } 132 state.cb = receiver.getStopToken().onStop(cast(void delegate() nothrow @safe shared)&state.stop); // butt ugly cast, but it won't take the second overload 133 static if (Senders.length > 1) { 134 foreach(i, _; Senders) { 135 ops[i].start(); 136 } 137 } else { 138 foreach(i; 0..ops.length) { 139 ops[i].start(); 140 } 141 } 142 } 143 } 144 145 import std.meta : allSatisfy, ApplyRight; 146 147 struct WhenAllSender(Senders...) 148 if ((Senders.length > 1 && allSatisfy!(ApplyRight!(models, isSender), Senders)) || 149 (models!(ArrayElement!(Senders[0]), isSender))) { 150 alias Result = WhenAllResult!(Senders); 151 static if (hasMember!(Result, "values")) 152 alias Value = typeof(Result.values); 153 else 154 alias Value = void; 155 Senders senders; 156 auto connect(Receiver)(return Receiver receiver) @safe scope return { 157 // ensure NRVO 158 auto op = WhenAllOp!(Receiver, Senders)(receiver, senders); 159 return op; 160 } 161 } 162 163 private class WhenAllState(Value) : StopSource { 164 import concurrency.bitfield; 165 StopCallback cb; 166 static if (is(typeof(Value.values))) 167 Value value; 168 Throwable exception; 169 shared SharedBitField!Flags bitfield; 170 } 171 172 private struct WhenAllReceiver(Receiver, InnerValue, Value) { 173 import core.atomic : atomicOp, atomicLoad, MemoryOrder; 174 Receiver receiver; 175 WhenAllState!(Value) state; 176 size_t senderIndex; 177 size_t senderCount; 178 auto getStopToken() { 179 return StopToken(state); 180 } 181 private bool isValueProduced(size_t state) { 182 return (state & Flags.value_produced) > 0; 183 } 184 private bool isDoneOrErrorProduced(size_t state) { 185 return (state & Flags.doneOrError_produced) > 0; 186 } 187 private bool isLast(size_t state) { 188 return (state >> 3) == atomicLoad(senderCount); 189 } 190 static if (!is(InnerValue == void)) 191 void setValue(InnerValue value) @safe { 192 with (state.bitfield.lock(Flags.value_produced, Counter.tick)) { 193 bool last = isLast(newState); 194 state.value.setValue(value, senderIndex); 195 release(); 196 if (last) 197 process(newState); 198 } 199 } 200 else 201 void setValue() @safe { 202 with (state.bitfield.update(Flags.value_produced, Counter.tick)) { 203 bool last = isLast(newState); 204 if (last) 205 process(newState); 206 } 207 } 208 void setDone() @safe nothrow { 209 with (state.bitfield.update(Flags.doneOrError_produced, Counter.tick)) { 210 bool last = isLast(newState); 211 if (!isDoneOrErrorProduced(oldState)) 212 state.stop(); 213 if (last) 214 process(newState); 215 } 216 } 217 void setError(Throwable exception) @safe nothrow { 218 with (state.bitfield.lock(Flags.doneOrError_produced, Counter.tick)) { 219 bool last = isLast(newState); 220 if (!isDoneOrErrorProduced(oldState)) { 221 state.exception = exception; 222 release(); // must release before calling .stop 223 state.stop(); 224 } else 225 release(); 226 if (last) 227 process(newState); 228 } 229 } 230 private void process(size_t newState) { 231 state.cb.dispose(); 232 233 if (receiver.getStopToken().isStopRequested) 234 receiver.setDone(); 235 else if (isDoneOrErrorProduced(newState)) { 236 if (state.exception) 237 receiver.setError(state.exception); 238 else 239 receiver.setDone(); 240 } else { 241 import concurrency.receiver : setValueOrError; 242 static if (is(typeof(Value.values))) 243 receiver.setValueOrError(state.value.values); 244 else 245 receiver.setValueOrError(); 246 } 247 } 248 mixin ForwardExtensionPoints!receiver; 249 }