1 module concurrency.operations.race; 2 3 import concurrency; 4 import concurrency.receiver; 5 import concurrency.sender; 6 import concurrency.stoptoken; 7 import concurrency.utils : spin_yield, casWeak; 8 import concepts; 9 import std.traits; 10 11 /// Runs both Senders and propagates the value of whoever completes first 12 /// if both error out the first exception is propagated, 13 /// uses mir.algebraic if the Sender value types differ 14 RaceSender!(Senders) race(Senders...)(Senders senders) { 15 return RaceSender!(Senders)(senders, false); 16 } 17 18 private template Result(Senders...) if (Senders.length > 1) { 19 import concurrency.utils : NoVoid; 20 import mir.algebraic : Algebraic, Nullable; 21 import std.meta : staticMap, Filter, NoDuplicates; 22 alias getValue(Sender) = Sender.Value; 23 alias SenderValues = staticMap!(getValue, Senders); 24 alias NoVoidValueTypes = Filter!(NoVoid, SenderValues); 25 enum HasVoid = SenderValues.length != NoVoidValueTypes.length; 26 alias ValueTypes = NoDuplicates!(NoVoidValueTypes); 27 28 static if (ValueTypes.length == 0) 29 alias Result = void; 30 else static if (ValueTypes.length == 1) 31 static if (HasVoid) 32 alias Result = Nullable!(ValueTypes[0]); 33 else 34 alias Result = ValueTypes[0]; 35 else { 36 alias Result = Algebraic!(ValueTypes); 37 } 38 } 39 40 alias ArrayElement(T : P[], P) = P; 41 42 import std.traits; 43 private template Result(Senders...) if (Senders.length == 1) { 44 alias Result = ArrayElement!(Senders).Value; 45 } 46 47 private struct RaceOp(Receiver, Senders...) { 48 import std.meta : staticMap; 49 alias R = Result!(Senders); 50 static if (Senders.length > 1) { 51 alias ElementReceiver(Sender) = RaceReceiver!(Receiver, Sender.Value, R); 52 alias ConnectResult(Sender) = OpType!(Sender, ElementReceiver!Sender); 53 alias Ops = staticMap!(ConnectResult, Senders); 54 } else { 55 alias ElementReceiver = RaceReceiver!(Receiver, R, R); 56 alias Ops = OpType!(ArrayElement!(Senders[0]), ElementReceiver)[]; 57 } 58 Receiver receiver; 59 State!R state; 60 Ops ops; 61 @disable this(this); 62 @disable this(ref return scope typeof(this) rhs); 63 this(Receiver receiver, return Senders senders, bool noDropouts) @trusted scope { 64 this.receiver = receiver; 65 state = new State!(R)(noDropouts); 66 static if (Senders.length > 1) { 67 foreach(i, Sender; Senders) { 68 ops[i] = senders[i].connect(ElementReceiver!(Sender)(receiver, state, Senders.length)); 69 } 70 } else { 71 ops.length = senders[0].length; 72 foreach(i; 0..senders[0].length) { 73 ops[i] = senders[0][i].connect(ElementReceiver(receiver, state, senders[0].length)); 74 } 75 } 76 } 77 void start() @trusted nothrow scope { 78 import concurrency.stoptoken : StopSource; 79 if (receiver.getStopToken().isStopRequested) { 80 receiver.setDone(); 81 return; 82 } 83 state.cb = receiver.getStopToken().onStop(cast(void delegate() nothrow @safe shared)&state.stop); // butt ugly cast, but it won't take the second overload 84 static if (Senders.length > 1) { 85 foreach(i, _; Senders) { 86 ops[i].start(); 87 } 88 } else { 89 foreach(i; 0..ops.length) { 90 ops[i].start(); 91 } 92 } 93 } 94 } 95 96 import std.meta : allSatisfy, ApplyRight; 97 98 struct RaceSender(Senders...) 99 if ((Senders.length > 1 && allSatisfy!(ApplyRight!(models, isSender), Senders)) || 100 (models!(ArrayElement!(Senders[0]), isSender))) { 101 static assert(models!(typeof(this), isSender)); 102 alias Value = Result!(Senders); 103 Senders senders; 104 bool noDropouts; // if true then we fail the moment one contender does, otherwise we keep running until one finishes 105 auto connect(Receiver)(return Receiver receiver) @safe scope return { 106 // ensure NRVO 107 auto op = RaceOp!(Receiver, Senders)(receiver, senders, noDropouts); 108 return op; 109 } 110 } 111 112 private class State(Value) : StopSource { 113 import concurrency.bitfield; 114 StopCallback cb; 115 shared SharedBitField!Flags bitfield; 116 static if (!is(Value == void)) 117 Value value; 118 Throwable exception; 119 bool noDropouts; 120 this(bool noDropouts) { 121 this.noDropouts = noDropouts; 122 } 123 } 124 125 private enum Flags : size_t { 126 locked = 0x1, 127 value_produced = 0x2, 128 doneOrError_produced = 0x4 129 } 130 131 private enum Counter : size_t { 132 tick = 0x8, 133 mask = ~0x7 134 } 135 136 private struct RaceReceiver(Receiver, InnerValue, Value) { 137 import core.atomic : atomicOp, atomicLoad, MemoryOrder; 138 Receiver receiver; 139 State!(Value) state; 140 size_t senderCount; 141 auto getStopToken() { 142 return StopToken(state); 143 } 144 private bool isValueProduced(size_t state) { 145 return (state & Flags.value_produced) > 0; 146 } 147 private bool isDoneOrErrorProduced(size_t state) { 148 return (state & Flags.doneOrError_produced) > 0; 149 } 150 private bool isLast(size_t state) { 151 return (state >> 3) == atomicLoad(senderCount); 152 } 153 static if (!is(InnerValue == void)) 154 void setValue(InnerValue value) @safe nothrow { 155 with (state.bitfield.lock(Flags.value_produced, Counter.tick)) { 156 bool last = isLast(newState); 157 if (!isValueProduced(oldState)) { 158 static if (is(InnerValue == Value)) 159 state.value = value; 160 else 161 state.value = Value(value); 162 release(); // must release before calling .stop 163 state.stop(); 164 } else 165 release(); 166 167 if (last) 168 process(newState); 169 } 170 } 171 else 172 void setValue() @safe nothrow { 173 with (state.bitfield.update(Flags.value_produced, Counter.tick)) { 174 bool last = isLast(newState); 175 if (!isValueProduced(oldState)) { 176 state.stop(); 177 } 178 if (last) 179 process(newState); 180 } 181 } 182 void setDone() @safe nothrow { 183 with (state.bitfield.update(Flags.doneOrError_produced, Counter.tick)) { 184 bool last = isLast(newState); 185 if (state.noDropouts && !isDoneOrErrorProduced(oldState)) { 186 state.stop(); 187 } 188 if (last) 189 process(newState); 190 } 191 } 192 void setError(Throwable exception) @safe nothrow { 193 with (state.bitfield.lock(Flags.doneOrError_produced, Counter.tick)) { 194 bool last = isLast(newState); 195 if (!isDoneOrErrorProduced(oldState)) { 196 state.exception = exception; 197 if (state.noDropouts) { 198 release(); // release before stop 199 state.stop(); 200 } 201 } 202 release(); 203 if (last) 204 process(newState); 205 } 206 } 207 private void process(size_t newState) { 208 import concurrency.receiver : setValueOrError; 209 210 state.cb.dispose(); 211 if (receiver.getStopToken().isStopRequested) 212 receiver.setDone(); 213 else if (isValueProduced(newState)) { 214 static if (is(Value == void)) 215 receiver.setValueOrError(); 216 else 217 receiver.setValueOrError(state.value); 218 } else if (state.exception) 219 receiver.setError(state.exception); 220 else 221 receiver.setDone(); 222 } 223 mixin ForwardExtensionPoints!receiver; 224 }