1 module concurrency.operations.race; 2 3 import concurrency; 4 import concurrency.receiver; 5 import concurrency.sender; 6 import concurrency.stoptoken; 7 import concurrency.utils : spin_yield, casWeak; 8 import concepts; 9 import std.traits; 10 11 /// Runs both Senders and propagates the value of whoever completes first 12 /// if both error out the first exception is propagated, 13 /// uses mir.algebraic if the Sender value types differ 14 RaceSender!(Senders) race(Senders...)(Senders senders) { 15 return RaceSender!(Senders)(senders, false); 16 } 17 18 private template Result(Senders...) { 19 import concurrency.utils : NoVoid; 20 import mir.algebraic : Algebraic, Nullable; 21 import std.meta : staticMap, Filter, NoDuplicates; 22 alias getValue(Sender) = Sender.Value; 23 alias SenderValues = staticMap!(getValue, Senders); 24 alias NoVoidValueTypes = Filter!(NoVoid, SenderValues); 25 enum HasVoid = SenderValues.length != NoVoidValueTypes.length; 26 alias ValueTypes = NoDuplicates!(NoVoidValueTypes); 27 28 static if (ValueTypes.length == 0) 29 alias Result = void; 30 else static if (ValueTypes.length == 1) 31 static if (HasVoid) 32 alias Result = Nullable!(ValueTypes[0]); 33 else 34 alias Result = ValueTypes[0]; 35 else { 36 alias Result = Algebraic!(ValueTypes); 37 } 38 } 39 40 private struct RaceOp(Receiver, Senders...) { 41 import std.meta : staticMap; 42 alias R = Result!(Senders); 43 alias ElementReceiver(Sender) = RaceReceiver!(Receiver, Sender.Value, R); 44 alias ConnectResult(Sender) = OpType!(Sender, ElementReceiver!Sender); 45 alias Ops = staticMap!(ConnectResult, Senders); 46 Receiver receiver; 47 State!R state; 48 Ops ops; 49 @disable this(this); 50 @disable this(ref return scope typeof(this) rhs); 51 this(Receiver receiver, return Senders senders, bool noDropouts) @trusted scope { 52 this.receiver = receiver; 53 state = new State!(R)(noDropouts); 54 foreach(i, Sender; Senders) { 55 ops[i] = senders[i].connect(ElementReceiver!(Sender)(receiver, state, Senders.length)); 56 } 57 } 58 void start() @trusted nothrow scope { 59 import concurrency.stoptoken : StopSource; 60 if (receiver.getStopToken().isStopRequested) { 61 receiver.setDone(); 62 return; 63 } 64 state.cb = receiver.getStopToken().onStop(cast(void delegate() nothrow @safe shared)&state.stop); // butt ugly cast, but it won't take the second overload 65 foreach(i, _; Senders) { 66 ops[i].start(); 67 } 68 } 69 } 70 71 import std.meta : allSatisfy, ApplyRight; 72 73 struct RaceSender(Senders...) if (allSatisfy!(ApplyRight!(models, isSender), Senders)) { 74 static assert(models!(typeof(this), isSender)); 75 alias Value = Result!(Senders); 76 Senders senders; 77 bool noDropouts; // if true then we fail the moment one contender does, otherwise we keep running until one finishes 78 auto connect(Receiver)(return Receiver receiver) @safe scope return { 79 // ensure NRVO 80 auto op = RaceOp!(Receiver, Senders)(receiver, senders, noDropouts); 81 return op; 82 } 83 } 84 85 private class State(Value) : StopSource { 86 import concurrency.bitfield; 87 StopCallback cb; 88 static if (!is(Value == void)) 89 Value value; 90 Exception exception; 91 shared SharedBitField!Flags bitfield; 92 bool noDropouts; 93 this(bool noDropouts) { 94 this.noDropouts = noDropouts; 95 } 96 } 97 98 private enum Flags : size_t { 99 locked = 0x1, 100 value_produced = 0x2, 101 doneOrError_produced = 0x4 102 } 103 104 private enum Counter : size_t { 105 tick = 0x8, 106 mask = ~0x7 107 } 108 109 private struct RaceReceiver(Receiver, InnerValue, Value) { 110 import core.atomic : atomicOp, atomicLoad, MemoryOrder; 111 Receiver receiver; 112 State!(Value) state; 113 size_t senderCount; 114 auto getStopToken() { 115 return StopToken(state); 116 } 117 private bool isValueProduced(size_t state) { 118 return (state & Flags.value_produced) > 0; 119 } 120 private bool isDoneOrErrorProduced(size_t state) { 121 return (state & Flags.doneOrError_produced) > 0; 122 } 123 private bool isLast(size_t state) { 124 return (state >> 3) == senderCount; 125 } 126 static if (!is(InnerValue == void)) 127 void setValue(InnerValue value) @safe nothrow { 128 with (state.bitfield.lock(Flags.value_produced, Counter.tick)) { 129 if (!isValueProduced(oldState)) { 130 static if (is(InnerValue == Value)) 131 state.value = value; 132 else 133 state.value = Value(value); 134 release(); // must release before calling .stop 135 state.stop(); 136 } else 137 release(); 138 139 process(newState); 140 } 141 } 142 else 143 void setValue() @safe nothrow { 144 with (state.bitfield.update(Flags.value_produced, Counter.tick)) { 145 if (!isValueProduced(oldState)) { 146 state.stop(); 147 } 148 process(newState); 149 } 150 } 151 void setDone() @safe nothrow { 152 with (state.bitfield.update(Flags.doneOrError_produced, Counter.tick)) { 153 if (state.noDropouts && !isDoneOrErrorProduced(oldState)) { 154 state.stop(); 155 } 156 process(newState); 157 } 158 } 159 void setError(Exception exception) @safe nothrow { 160 with (state.bitfield.lock(Flags.doneOrError_produced, Counter.tick)) { 161 if (!isDoneOrErrorProduced(oldState)) { 162 state.exception = exception; 163 if (state.noDropouts) { 164 release(); // release before stop 165 state.stop(); 166 } 167 } 168 release(); 169 process(newState); 170 } 171 } 172 private void process(size_t newState) { 173 import concurrency.receiver : setValueOrError; 174 175 if (!isLast(newState)) 176 return; 177 178 state.cb.dispose(); 179 if (receiver.getStopToken().isStopRequested) 180 receiver.setDone(); 181 else if (isValueProduced(newState)) { 182 static if (is(Value == void)) 183 receiver.setValueOrError(); 184 else 185 receiver.setValueOrError(state.value); 186 } else if (state.exception) 187 receiver.setError(state.exception); 188 else 189 receiver.setDone(); 190 } 191 mixin ForwardExtensionPoints!receiver; 192 }