1 module concurrency.operations.race; 2 3 import concurrency; 4 import concurrency.receiver; 5 import concurrency.sender; 6 import concurrency.stoptoken; 7 import concurrency.utils : spin_yield, casWeak; 8 import concepts; 9 import std.traits; 10 11 /// Runs both Senders and propagates the value of whoever completes first 12 /// if both error out the first exception is propagated, 13 /// uses mir.algebraic if the Sender value types differ 14 RaceSender!(Senders) race(Senders...)(Senders senders) { 15 return RaceSender!(Senders)(senders, false); 16 } 17 18 private template Result(Senders...) { 19 import concurrency.utils : NoVoid; 20 import mir.algebraic : Algebraic, Nullable; 21 import std.meta : staticMap, Filter, NoDuplicates; 22 alias getValue(Sender) = Sender.Value; 23 alias SenderValues = staticMap!(getValue, Senders); 24 alias NoVoidValueTypes = Filter!(NoVoid, SenderValues); 25 enum HasVoid = SenderValues.length != NoVoidValueTypes.length; 26 alias ValueTypes = NoDuplicates!(NoVoidValueTypes); 27 28 static if (ValueTypes.length == 0) 29 alias Result = void; 30 else static if (ValueTypes.length == 1) 31 static if (HasVoid) 32 alias Result = Nullable!(ValueTypes[0]); 33 else 34 alias Result = ValueTypes[0]; 35 else { 36 alias Result = Algebraic!(ValueTypes); 37 } 38 } 39 40 private struct RaceOp(Receiver, Senders...) { 41 import std.meta : staticMap; 42 alias R = Result!(Senders); 43 alias ElementReceiver(Sender) = RaceReceiver!(Receiver, Sender.Value, R); 44 alias ConnectResult(Sender) = OpType!(Sender, ElementReceiver!Sender); 45 alias Ops = staticMap!(ConnectResult, Senders); 46 Receiver receiver; 47 State!R state; 48 Ops ops; 49 @disable this(this); 50 @disable this(ref return scope typeof(this) rhs); 51 this(Receiver receiver, return Senders senders, bool noDropouts) @trusted scope { 52 this.receiver = receiver; 53 state = new State!(R)(noDropouts); 54 foreach(i, Sender; Senders) { 55 ops[i] = senders[i].connect(ElementReceiver!(Sender)(receiver, state, Senders.length)); 56 } 57 } 58 void start() @trusted nothrow scope { 59 import concurrency.stoptoken : StopSource; 60 if (receiver.getStopToken().isStopRequested) { 61 receiver.setDone(); 62 return; 63 } 64 state.cb = receiver.getStopToken().onStop(cast(void delegate() nothrow @safe shared)&state.stop); // butt ugly cast, but it won't take the second overload 65 foreach(i, _; Senders) { 66 ops[i].start(); 67 } 68 } 69 } 70 71 import std.meta : allSatisfy, ApplyRight; 72 73 struct RaceSender(Senders...) if (allSatisfy!(ApplyRight!(models, isSender), Senders)) { 74 static assert(models!(typeof(this), isSender)); 75 alias Value = Result!(Senders); 76 Senders senders; 77 bool noDropouts; // if true then we fail the moment one contender does, otherwise we keep running until one finishes 78 auto connect(Receiver)(return Receiver receiver) @safe scope return { 79 // ensure NRVO 80 auto op = RaceOp!(Receiver, Senders)(receiver, senders, noDropouts); 81 return op; 82 } 83 } 84 85 private class State(Value) : StopSource { 86 import concurrency.bitfield; 87 StopCallback cb; 88 shared SharedBitField!Flags bitfield; 89 static if (!is(Value == void)) 90 Value value; 91 Exception exception; 92 bool noDropouts; 93 this(bool noDropouts) { 94 this.noDropouts = noDropouts; 95 } 96 } 97 98 private enum Flags : size_t { 99 locked = 0x1, 100 value_produced = 0x2, 101 doneOrError_produced = 0x4 102 } 103 104 private enum Counter : size_t { 105 tick = 0x8, 106 mask = ~0x7 107 } 108 109 private struct RaceReceiver(Receiver, InnerValue, Value) { 110 import core.atomic : atomicOp, atomicLoad, MemoryOrder; 111 Receiver receiver; 112 State!(Value) state; 113 size_t senderCount; 114 auto getStopToken() { 115 return StopToken(state); 116 } 117 private bool isValueProduced(size_t state) { 118 return (state & Flags.value_produced) > 0; 119 } 120 private bool isDoneOrErrorProduced(size_t state) { 121 return (state & Flags.doneOrError_produced) > 0; 122 } 123 private bool isLast(size_t state) { 124 return (state >> 3) == atomicLoad(senderCount); 125 } 126 static if (!is(InnerValue == void)) 127 void setValue(InnerValue value) @safe nothrow { 128 with (state.bitfield.lock(Flags.value_produced, Counter.tick)) { 129 bool last = isLast(newState); 130 if (!isValueProduced(oldState)) { 131 static if (is(InnerValue == Value)) 132 state.value = value; 133 else 134 state.value = Value(value); 135 release(); // must release before calling .stop 136 state.stop(); 137 } else 138 release(); 139 140 if (last) 141 process(newState); 142 } 143 } 144 else 145 void setValue() @safe nothrow { 146 with (state.bitfield.update(Flags.value_produced, Counter.tick)) { 147 bool last = isLast(newState); 148 if (!isValueProduced(oldState)) { 149 state.stop(); 150 } 151 if (last) 152 process(newState); 153 } 154 } 155 void setDone() @safe nothrow { 156 with (state.bitfield.update(Flags.doneOrError_produced, Counter.tick)) { 157 bool last = isLast(newState); 158 if (state.noDropouts && !isDoneOrErrorProduced(oldState)) { 159 state.stop(); 160 } 161 if (last) 162 process(newState); 163 } 164 } 165 void setError(Exception exception) @safe nothrow { 166 with (state.bitfield.lock(Flags.doneOrError_produced, Counter.tick)) { 167 bool last = isLast(newState); 168 if (!isDoneOrErrorProduced(oldState)) { 169 state.exception = exception; 170 if (state.noDropouts) { 171 release(); // release before stop 172 state.stop(); 173 } 174 } 175 release(); 176 if (last) 177 process(newState); 178 } 179 } 180 private void process(size_t newState) { 181 import concurrency.receiver : setValueOrError; 182 183 state.cb.dispose(); 184 if (receiver.getStopToken().isStopRequested) 185 receiver.setDone(); 186 else if (isValueProduced(newState)) { 187 static if (is(Value == void)) 188 receiver.setValueOrError(); 189 else 190 receiver.setValueOrError(state.value); 191 } else if (state.exception) 192 receiver.setError(state.exception); 193 else 194 receiver.setDone(); 195 } 196 mixin ForwardExtensionPoints!receiver; 197 }