1 module concurrency.operations.whenall; 2 3 import concurrency; 4 import concurrency.receiver; 5 import concurrency.sender; 6 import concurrency.stoptoken; 7 import concepts; 8 import std.traits; 9 import concurrency.utils : spin_yield, casWeak; 10 11 WhenAllSender!(Senders) whenAll(Senders...)(Senders senders) if (Senders.length > 1) { 12 return WhenAllSender!(Senders)(senders); 13 } 14 15 private enum Flags : size_t { 16 locked = 0x1, 17 value_produced = 0x2, 18 doneOrError_produced = 0x4 19 } 20 21 private enum Counter : size_t { 22 tick = 0x8 23 } 24 25 template SenderValues(Senders...) { 26 import std.meta; 27 alias SenderValue(T) = T.Value; 28 alias SenderValues = staticMap!(SenderValue, Senders); 29 } 30 31 private template WhenAllResult(SenderValues...) { 32 import std.meta; 33 import std.typecons; 34 import mir.algebraic : Algebraic, Nullable; 35 import concurrency.utils : NoVoid; 36 template Cummulative(size_t count, Ts...) { 37 static if (Ts.length > 0) { 38 enum head = count + Ts[0]; 39 static if (Ts.length == 1) 40 alias Cummulative = AliasSeq!(head); 41 else static if (Ts.length > 1) 42 alias Cummulative = AliasSeq!(head, Cummulative!(head, Ts[1..$])); 43 } else { 44 alias Cummulative = AliasSeq!(); 45 } 46 } 47 alias ValueTypes = Filter!(NoVoid, SenderValues); 48 static if (ValueTypes.length > 1) 49 alias Values = Tuple!(Filter!(NoVoid, SenderValues)); 50 else static if (ValueTypes.length == 1) 51 alias Values = ValueTypes[0]; 52 alias Indexes = Cummulative!(0, staticMap!(NoVoid, SenderValues)); 53 54 static if (ValueTypes.length > 0) { 55 struct WhenAllResult { 56 Values values; 57 void setValue(T)(T t, size_t index) { 58 switch (index) { 59 foreach(idx, I; Indexes) { 60 case idx: 61 static if (ValueTypes.length == 1) 62 values = t; 63 else static if (is(typeof(values[I-1]) == T)) 64 values[I-1] = t; 65 return; 66 } 67 default: assert(false, "out of bounds"); 68 } 69 } 70 } 71 } else { 72 struct WhenAllResult { 73 } 74 } 75 } 76 77 private struct WhenAllOp(Receiver, Senders...) { 78 import std.meta : staticMap; 79 alias R = WhenAllResult!(SenderValues!Senders); 80 alias ElementReceiver(Sender) = WhenAllReceiver!(Receiver, Sender.Value, R); 81 alias ConnectResult(Sender) = OpType!(Sender, ElementReceiver!Sender); 82 alias Ops = staticMap!(ConnectResult, Senders); 83 Receiver receiver; 84 WhenAllState!R state; 85 Ops ops; 86 @disable this(this); 87 @disable this(ref return scope typeof(this) rhs); 88 this(Receiver receiver, Senders senders) { 89 this.receiver = receiver; 90 state = new WhenAllState!R(); 91 foreach(i, Sender; Senders) { 92 ops[i] = senders[i].connect(WhenAllReceiver!(Receiver, Sender.Value, R)(receiver, state, i, Senders.length)); 93 } 94 } 95 void start() @trusted nothrow scope { 96 import concurrency.stoptoken : StopSource; 97 if (receiver.getStopToken().isStopRequested) { 98 receiver.setDone(); 99 return; 100 } 101 state.cb = receiver.getStopToken().onStop(cast(void delegate() nothrow @safe shared)&state.stop); // butt ugly cast, but it won't take the second overload 102 foreach(i, _; Senders) { 103 ops[i].start(); 104 } 105 } 106 } 107 108 import std.meta : allSatisfy, ApplyRight; 109 110 struct WhenAllSender(Senders...) if (allSatisfy!(ApplyRight!(models, isSender), Senders)) { 111 alias Result = WhenAllResult!(SenderValues!Senders); 112 static if (hasMember!(Result, "values")) 113 alias Value = typeof(Result.values); 114 else 115 alias Value = void; 116 Senders senders; 117 auto connect(Receiver)(return Receiver receiver) @safe scope return { 118 // ensure NRVO 119 auto op = WhenAllOp!(Receiver, Senders)(receiver, senders); 120 return op; 121 } 122 } 123 124 private class WhenAllState(Value) : StopSource { 125 import concurrency.bitfield; 126 StopCallback cb; 127 static if (is(typeof(Value.values))) 128 Value value; 129 Exception exception; 130 shared SharedBitField!Flags bitfield; 131 } 132 133 private struct WhenAllReceiver(Receiver, InnerValue, Value) { 134 import core.atomic : atomicOp, atomicLoad, MemoryOrder; 135 Receiver receiver; 136 WhenAllState!(Value) state; 137 size_t senderIndex; 138 size_t senderCount; 139 auto getStopToken() { 140 return StopToken(state); 141 } 142 private bool isValueProduced(size_t state) { 143 return (state & Flags.value_produced) > 0; 144 } 145 private bool isDoneOrErrorProduced(size_t state) { 146 return (state & Flags.doneOrError_produced) > 0; 147 } 148 private bool isLast(size_t state) { 149 return (state >> 3) == atomicLoad(senderCount); 150 } 151 static if (!is(InnerValue == void)) 152 void setValue(InnerValue value) @safe { 153 with (state.bitfield.lock(Flags.value_produced, Counter.tick)) { 154 bool last = isLast(newState); 155 state.value.setValue(value, senderIndex); 156 release(); 157 if (last) 158 process(newState); 159 } 160 } 161 else 162 void setValue() @safe { 163 with (state.bitfield.update(Flags.value_produced, Counter.tick)) { 164 bool last = isLast(newState); 165 if (last) 166 process(newState); 167 } 168 } 169 void setDone() @safe nothrow { 170 with (state.bitfield.update(Flags.doneOrError_produced, Counter.tick)) { 171 bool last = isLast(newState); 172 if (!isDoneOrErrorProduced(oldState)) 173 state.stop(); 174 if (last) 175 process(newState); 176 } 177 } 178 void setError(Exception exception) @safe nothrow { 179 with (state.bitfield.lock(Flags.doneOrError_produced, Counter.tick)) { 180 bool last = isLast(newState); 181 if (!isDoneOrErrorProduced(oldState)) { 182 state.exception = exception; 183 release(); // must release before calling .stop 184 state.stop(); 185 } else 186 release(); 187 if (last) 188 process(newState); 189 } 190 } 191 private void process(size_t newState) { 192 state.cb.dispose(); 193 194 if (receiver.getStopToken().isStopRequested) 195 receiver.setDone(); 196 else if (isDoneOrErrorProduced(newState)) { 197 if (state.exception) 198 receiver.setError(state.exception); 199 else 200 receiver.setDone(); 201 } else { 202 import concurrency.receiver : setValueOrError; 203 static if (is(typeof(Value.values))) 204 receiver.setValueOrError(state.value.values); 205 else 206 receiver.setValueOrError(); 207 } 208 } 209 mixin ForwardExtensionPoints!receiver; 210 }