1 module concurrency.operations.whenall; 2 3 import concurrency; 4 import concurrency.receiver; 5 import concurrency.sender; 6 import concurrency.stoptoken; 7 import concepts; 8 import std.traits; 9 import concurrency.utils : spin_yield, casWeak; 10 11 WhenAllSender!(Senders) whenAll(Senders...)(Senders senders) if (Senders.length > 1) { 12 return WhenAllSender!(Senders)(senders); 13 } 14 15 private enum Flags : size_t { 16 locked = 0x1, 17 value_produced = 0x2, 18 doneOrError_produced = 0x4 19 } 20 21 private enum Counter : size_t { 22 tick = 0x8 23 } 24 25 private template WhenAllResult(Senders...) { 26 import std.meta; 27 import std.typecons; 28 import mir.algebraic : Algebraic, Nullable; 29 import concurrency.utils : NoVoid; 30 template Cummulative(size_t count, Ts...) { 31 static if (Ts.length > 0) { 32 enum head = count + Ts[0]; 33 static if (Ts.length == 1) 34 alias Cummulative = AliasSeq!(head); 35 else static if (Ts.length > 1) 36 alias Cummulative = AliasSeq!(head, Cummulative!(head, Ts[1..$])); 37 } else { 38 alias Cummulative = AliasSeq!(); 39 } 40 } 41 alias SenderValue(T) = T.Value; 42 alias SenderValues = staticMap!(SenderValue, Senders); 43 alias ValueTypes = Filter!(NoVoid, SenderValues); 44 static if (ValueTypes.length > 1) 45 alias Values = Tuple!(Filter!(NoVoid, SenderValues)); 46 else static if (ValueTypes.length == 1) 47 alias Values = ValueTypes[0]; 48 alias Indexes = Cummulative!(0, staticMap!(NoVoid, SenderValues)); 49 50 static if (ValueTypes.length > 0) { 51 struct WhenAllResult { 52 Values values; 53 void setValue(T)(T t, size_t index) { 54 switch (index) { 55 foreach(idx, I; Indexes) { 56 case idx: 57 static if (ValueTypes.length == 1) 58 values = t; 59 else static if (is(typeof(values[I-1]) == T)) 60 values[I-1] = t; 61 return; 62 } 63 default: assert(false, "out of bounds"); 64 } 65 } 66 } 67 } else { 68 struct WhenAllResult { 69 } 70 } 71 } 72 73 private struct WhenAllOp(Receiver, Senders...) { 74 import std.meta : staticMap; 75 alias R = WhenAllResult!(Senders); 76 alias ElementReceiver(Sender) = WhenAllReceiver!(Receiver, Sender.Value, R); 77 alias ConnectResult(Sender) = OpType!(Sender, ElementReceiver!Sender); 78 alias Ops = staticMap!(ConnectResult, Senders); 79 Receiver receiver; 80 WhenAllState!R state; 81 Ops ops; 82 @disable this(this); 83 @disable this(ref return scope typeof(this) rhs); 84 this(Receiver receiver, Senders senders) { 85 this.receiver = receiver; 86 state = new WhenAllState!R(); 87 foreach(i, Sender; Senders) { 88 ops[i] = senders[i].connect(WhenAllReceiver!(Receiver, Sender.Value, WhenAllResult!(Senders))(receiver, state, i, Senders.length)); 89 } 90 } 91 void start() @trusted nothrow scope { 92 import concurrency.stoptoken : StopSource; 93 if (receiver.getStopToken().isStopRequested) { 94 receiver.setDone(); 95 return; 96 } 97 state.cb = receiver.getStopToken().onStop(cast(void delegate() nothrow @safe shared)&state.stop); // butt ugly cast, but it won't take the second overload 98 foreach(i, _; Senders) { 99 ops[i].start(); 100 } 101 } 102 } 103 104 import std.meta : allSatisfy, ApplyRight; 105 106 struct WhenAllSender(Senders...) if (allSatisfy!(ApplyRight!(models, isSender), Senders)) { 107 alias Result = WhenAllResult!(Senders); 108 static if (hasMember!(Result, "values")) 109 alias Value = typeof(Result.values); 110 else 111 alias Value = void; 112 Senders senders; 113 auto connect(Receiver)(return Receiver receiver) @safe scope return { 114 // ensure NRVO 115 auto op = WhenAllOp!(Receiver, Senders)(receiver, senders); 116 return op; 117 } 118 } 119 120 private class WhenAllState(Value) : StopSource { 121 import concurrency.bitfield; 122 StopCallback cb; 123 static if (is(typeof(Value.values))) 124 Value value; 125 Exception exception; 126 shared SharedBitField!Flags bitfield; 127 } 128 129 private struct WhenAllReceiver(Receiver, InnerValue, Value) { 130 import core.atomic : atomicOp, atomicLoad, MemoryOrder; 131 Receiver receiver; 132 WhenAllState!(Value) state; 133 size_t senderIndex; 134 size_t senderCount; 135 auto getStopToken() { 136 return StopToken(state); 137 } 138 private bool isValueProduced(size_t state) { 139 return (state & Flags.value_produced) > 0; 140 } 141 private bool isDoneOrErrorProduced(size_t state) { 142 return (state & Flags.doneOrError_produced) > 0; 143 } 144 private bool isLast(size_t state) { 145 return (state >> 3) == senderCount; 146 } 147 static if (!is(InnerValue == void)) 148 void setValue(InnerValue value) @safe { 149 with (state.bitfield.lock(Flags.value_produced, Counter.tick)) { 150 state.value.setValue(value, senderIndex); 151 release(); 152 process(newState); 153 } 154 } 155 else 156 void setValue() @safe { 157 with (state.bitfield.update(Flags.value_produced, Counter.tick)) { 158 process(newState); 159 } 160 } 161 void setDone() @safe nothrow { 162 with (state.bitfield.update(Flags.doneOrError_produced, Counter.tick)) { 163 if (!isDoneOrErrorProduced(oldState)) 164 state.stop(); 165 process(newState); 166 } 167 } 168 void setError(Exception exception) @safe nothrow { 169 with (state.bitfield.lock(Flags.doneOrError_produced, Counter.tick)) { 170 if (!isDoneOrErrorProduced(oldState)) { 171 state.exception = exception; 172 release(); // must release before calling .stop 173 state.stop(); 174 } else 175 release(); 176 process(newState); 177 } 178 } 179 private void process(size_t newState) { 180 if (!isLast(newState)) 181 return; 182 183 state.cb.dispose(); 184 185 if (receiver.getStopToken().isStopRequested) 186 receiver.setDone(); 187 else if (isDoneOrErrorProduced(newState)) { 188 if (state.exception) 189 receiver.setError(state.exception); 190 else 191 receiver.setDone(); 192 } else { 193 import concurrency.receiver : setValueOrError; 194 static if (is(typeof(Value.values))) 195 receiver.setValueOrError(state.value.values); 196 else 197 receiver.setValueOrError(); 198 } 199 } 200 mixin ForwardExtensionPoints!receiver; 201 }