1 module concurrency.operations.stopwhen; 2 3 import concurrency; 4 import concurrency.receiver; 5 import concurrency.sender; 6 import concurrency.stoptoken; 7 import concurrency.utils : spin_yield, casWeak; 8 import concepts; 9 import std.traits; 10 11 /// stopWhen cancels the source when the trigger completes normally. If the either source or trigger completes with cancellation or with an error, the first one is propagates after both are completed. 12 StopWhenSender!(Sender, Trigger) stopWhen(Sender, Trigger)(Sender source, Trigger trigger) { 13 return StopWhenSender!(Sender, Trigger)(source, trigger); 14 } 15 16 private struct StopWhenOp(Receiver, Sender, Trigger) { 17 alias SenderOp = OpType!(Sender, SourceReceiver!(Receiver, Sender.Value)); 18 alias TriggerOp = OpType!(Trigger, TriggerReceiver!(Receiver, Sender.Value)); 19 Receiver receiver; 20 State!(Sender.Value) state; 21 SenderOp sourceOp; 22 TriggerOp triggerOp; 23 @disable this(this); 24 @disable this(ref return scope typeof(this) rhs); 25 this(Receiver receiver, return Sender source, return Trigger trigger) @trusted scope { 26 this.receiver = receiver; 27 state = new State!(Sender.Value)(); 28 sourceOp = source.connect(SourceReceiver!(Receiver, Sender.Value)(receiver, state)); 29 triggerOp = trigger.connect(TriggerReceiver!(Receiver, Sender.Value)(receiver, state)); 30 } 31 void start() @trusted nothrow scope { 32 if (receiver.getStopToken().isStopRequested) { 33 receiver.setDone(); 34 return; 35 } 36 state.cb = receiver.getStopToken().onStop(cast(void delegate() nothrow @safe shared)&state.stop); // butt ugly cast, but it won't take the second overload 37 sourceOp.start; 38 triggerOp.start; 39 } 40 } 41 42 struct StopWhenSender(Sender, Trigger) if (models!(Sender, isSender) && models!(Trigger, isSender)) { 43 static assert(models!(typeof(this), isSender)); 44 alias Value = Sender.Value; 45 Sender sender; 46 Trigger trigger; 47 auto connect(Receiver)(return Receiver receiver) @safe scope return { 48 // ensure NRVO 49 auto op = StopWhenOp!(Receiver, Sender, Trigger)(receiver, sender, trigger); 50 return op; 51 } 52 } 53 54 private class State(Value) : StopSource { 55 import concurrency.bitfield; 56 StopCallback cb; 57 shared SharedBitField!Flags bitfield; 58 static if (!is(Value == void)) 59 Value value; 60 Throwable exception; 61 } 62 63 private enum Flags : size_t { 64 locked = 0x1, 65 value_produced = 0x2, 66 doneOrError_produced = 0x4, 67 tick = 0x8 68 } 69 70 private enum Counter : size_t { 71 tick = 0x8 72 } 73 74 private void process(State, Receiver)(State state, Receiver receiver, size_t newState) { 75 import concurrency.receiver : setValueOrError; 76 77 state.cb.dispose(); 78 if (receiver.getStopToken().isStopRequested) 79 receiver.setDone(); 80 else if (isValueProduced(newState)) { 81 static if (__traits(compiles, state.value)) 82 receiver.setValueOrError(state.value); 83 else 84 receiver.setValueOrError(); 85 } else if (state.exception) 86 receiver.setError(state.exception); 87 else 88 receiver.setDone(); 89 } 90 91 private bool isValueProduced(size_t state) @safe nothrow pure { 92 return (state & Flags.value_produced) > 0; 93 } 94 private bool isDoneOrErrorProduced(size_t state) @safe nothrow pure { 95 return (state & Flags.doneOrError_produced) > 0; 96 } 97 private bool isLast(size_t state) @safe nothrow pure { 98 return (state & Flags.tick) > 0; 99 } 100 101 private struct TriggerReceiver(Receiver, Value) { 102 Receiver receiver; 103 State!(Value) state; 104 auto getStopToken() { 105 return StopToken(state); 106 } 107 void setValue() @safe nothrow { 108 with (state.bitfield.update(Flags.tick)) { 109 if (!isLast(oldState)) 110 state.stop(); 111 else 112 state.process(receiver, newState); 113 } 114 } 115 void setDone() @safe nothrow { 116 with (state.bitfield.update(Flags.doneOrError_produced, Flags.tick)) { 117 if (!isLast(oldState)) 118 state.stop(); 119 else 120 state.process(receiver, newState); 121 } 122 } 123 void setError(Throwable exception) @safe nothrow { 124 with (state.bitfield.lock(Flags.doneOrError_produced, Counter.tick)) { 125 bool last = isLast(newState); 126 if (!isDoneOrErrorProduced(oldState)) { 127 state.exception = exception; 128 release(); // release before stop 129 state.stop(); 130 } else { 131 release(); 132 if (last) 133 state.process(receiver, newState); 134 } 135 } 136 } 137 mixin ForwardExtensionPoints!receiver; 138 } 139 140 private struct SourceReceiver(Receiver, Value) { 141 import core.atomic : atomicOp, atomicLoad, MemoryOrder; 142 Receiver receiver; 143 State!(Value) state; 144 auto getStopToken() { 145 return StopToken(state); 146 } 147 static if (!is(Value == void)) 148 void setValue(Value value) @safe nothrow { 149 with (state.bitfield.update(Flags.value_produced | Flags.tick)) { 150 bool last = isLast(newState); 151 state.value = value; 152 153 if (!last) 154 state.stop(); 155 else 156 if (isDoneOrErrorProduced(oldState)) 157 state.process(receiver, oldState); 158 else 159 state.process(receiver, newState); 160 } 161 } 162 else 163 void setValue() @safe nothrow { 164 with (state.bitfield.update(Flags.value_produced | Flags.tick)) { 165 bool last = isLast(newState); 166 if (!last) 167 state.stop(); 168 else 169 if (isDoneOrErrorProduced(oldState)) 170 state.process(receiver, oldState); 171 else 172 state.process(receiver, newState); 173 } 174 } 175 void setDone() @safe nothrow { 176 with (state.bitfield.update(Flags.doneOrError_produced | Flags.tick)) { 177 bool last = isLast(newState); 178 if (!last) 179 state.stop(); 180 else 181 state.process(receiver, newState); 182 } 183 } 184 void setError(Throwable exception) @safe nothrow { 185 with (state.bitfield.lock(Flags.doneOrError_produced | Flags.tick)) { 186 bool last = isLast(newState); 187 if (!isDoneOrErrorProduced(oldState)) { 188 state.exception = exception; 189 } 190 release(); 191 if (!last) 192 state.stop(); 193 else 194 state.process(receiver, newState); 195 } 196 } 197 mixin ForwardExtensionPoints!receiver; 198 }