1 module concurrency.stream.flatmapbase; 2 3 import concurrency.stream.stream; 4 import concurrency.sender : OpType, isSender; 5 import concurrency.receiver : ForwardExtensionPoints; 6 import concurrency.stoptoken : StopSource, StopToken; 7 import std.traits : ReturnType; 8 import concurrency.utils : isThreadSafeFunction; 9 import concepts; 10 import core.sync.semaphore : Semaphore; 11 12 enum OnOverlap { 13 wait, 14 latest 15 } 16 17 template FlatMapBaseStreamOp(Stream, Fun, OnOverlap overlap) { 18 static assert(isThreadSafeFunction!Fun); 19 alias Properties = StreamProperties!Stream; 20 alias InnerSender = ReturnType!Fun; 21 static assert(models!(InnerSender, isSender), "Fun must produce a Sender"); 22 alias DG = CollectDelegate!(InnerSender.Value); 23 struct FlatMapBaseStreamOp(Receiver) { 24 alias State = .State!(Properties.Sender.Value, InnerSender.Value, Receiver, overlap); 25 alias Op = OpType!(Properties.Sender, StreamReceiver!State); 26 Fun fun; 27 Op op; 28 State state; 29 @disable this(ref return scope typeof(this) rhs); 30 @disable this(this); 31 this(Stream stream, Fun fun, return DG dg, return Receiver receiver) @trusted return scope { 32 this.fun = fun; 33 state = new State(dg, receiver); 34 // TODO: would it be good to do the fun in a transform operation? 35 op = stream.collect(cast(Properties.DG)&item).connect(StreamReceiver!State(state)); 36 } 37 static if (is(Properties.ElementType == void)) 38 void item() { 39 if (state.isStopRequested) 40 return; 41 state.onItem(); 42 with (state.bitfield.lock()) { 43 if (isDoneOrErrorProduced(oldState)) { 44 return; 45 } 46 auto sender = fun(); 47 release(Counter.tick); // release early 48 runInnerSender(sender); 49 } 50 } 51 else 52 void item(Properties.ElementType t) { 53 if (state.isStopRequested) 54 return; 55 state.onItem(); 56 with (state.bitfield.lock()) { 57 if (isDoneOrErrorProduced(oldState)) { 58 return; 59 } 60 auto sender = fun(t); 61 release(Counter.tick); // release early 62 runInnerSender(sender); 63 } 64 } 65 private void runInnerSender(ref InnerSender sender) { 66 import concurrency.sender : connectHeap; 67 auto innerOp = sender.connectHeap(InnerSenderReceiver!(State)(state)); 68 innerOp.start(); 69 } 70 void start() nothrow @safe scope { 71 op.start(); 72 } 73 } 74 } 75 76 private enum Flags : size_t { 77 locked = 0x1, 78 value_produced = 0x2, 79 doneOrError_produced = 0x4 80 } 81 82 private enum Counter : size_t { 83 tick = 0x8 84 } 85 86 private bool isLast(size_t state) @safe @nogc nothrow pure { 87 return (state >> 3) == 2; 88 } 89 90 private bool isDoneOrErrorProduced(size_t state) @safe @nogc nothrow pure { 91 return (state & Flags.doneOrError_produced) > 0; 92 } 93 94 final class State(TStreamSenderValue, TSenderValue, Receiver, OnOverlap overlap) : StopSource { 95 import concurrency.bitfield; 96 import concurrency.stoptoken; 97 import std.exception : assumeWontThrow; 98 alias DG = CollectDelegate!(SenderValue); 99 alias StreamSenderValue = TStreamSenderValue; 100 alias SenderValue = TSenderValue; 101 alias onOverlap = overlap; 102 DG dg; 103 Receiver receiver; 104 static if (!is(StreamSenderValue == void)) 105 StreamSenderValue value; 106 Throwable throwable; 107 Semaphore semaphore; 108 StopCallback cb; 109 static if (overlap == OnOverlap.latest) 110 StopSource innerStopSource; 111 shared SharedBitField!Flags bitfield; 112 this(DG dg, Receiver receiver) { 113 this.dg = dg; 114 this.receiver = receiver; 115 semaphore = new Semaphore(1); 116 static if (overlap == OnOverlap.latest) 117 innerStopSource = new StopSource(); 118 bitfield = SharedBitField!Flags(Counter.tick); 119 cb = receiver.getStopToken.onStop(cast(void delegate() nothrow @safe shared)&stop); 120 } 121 override bool stop() nothrow @trusted { 122 return (cast(shared)this).stop(); 123 } 124 override bool stop() nothrow @trusted shared { 125 static if (overlap == OnOverlap.latest) { 126 auto r = super.stop(); 127 innerStopSource.stop(); 128 return r; 129 } else { 130 return super.stop(); 131 } 132 } 133 private void onItem() @trusted { 134 static if (overlap == OnOverlap.latest) { 135 innerStopSource.stop(); 136 semaphore.wait(); 137 innerStopSource.reset(); 138 } else { 139 semaphore.wait(); 140 } 141 } 142 private void process(size_t newState) { 143 cb.dispose(); 144 145 if (receiver.getStopToken().isStopRequested) 146 receiver.setDone(); 147 else if (isDoneOrErrorProduced(newState)) { 148 if (throwable) 149 receiver.setError(throwable); 150 else 151 receiver.setDone(); 152 } else { 153 import concurrency.receiver : setValueOrError; 154 static if (is(typeof(Value.values))) 155 receiver.setValueOrError(state.value.values); 156 else 157 receiver.setValueOrError(); 158 } 159 } 160 private StopToken getSenderStopToken() @safe nothrow { 161 static if (overlap == OnOverlap.latest) { 162 return StopToken(innerStopSource); 163 } else { 164 return StopToken(this); 165 } 166 } 167 } 168 169 struct StreamReceiver(State) { 170 State state; 171 static if (is(State.StreamSenderValue == void)) { 172 void setValue() @safe { 173 with (state.bitfield.update(Flags.value_produced, Counter.tick)) { 174 if (isLast(newState)) 175 state.process(newState); 176 } 177 } 178 } else { 179 void setValue(State.StreamSenderValue value) @safe { 180 with (state.bitfield.lock(Flags.value_produced, Counter.tick)) { 181 bool last = isLast(newState); 182 state.value = value; 183 release(); 184 if (last) 185 state.process(newState); 186 } 187 } 188 } 189 void setError(Throwable t) @safe nothrow { 190 with (state.bitfield.lock(Flags.doneOrError_produced, Counter.tick)) { 191 bool last = isLast(newState); 192 if (!isDoneOrErrorProduced(oldState)) { 193 state.throwable = t; 194 release(); // must release before calling .stop 195 state.stop(); 196 } else 197 release(); 198 if (last) 199 state.process(newState); 200 } 201 } 202 void setDone() @safe nothrow { 203 with (state.bitfield.update(Flags.doneOrError_produced, Counter.tick)) { 204 bool last = isLast(newState); 205 if (!isDoneOrErrorProduced(oldState)) 206 state.stop(); 207 if (last) 208 state.process(newState); 209 } 210 } 211 StopToken getStopToken() @safe nothrow { 212 return StopToken(state); 213 } 214 private auto receiver() { 215 return state.receiver; 216 } 217 mixin ForwardExtensionPoints!(receiver); 218 } 219 220 struct InnerSenderReceiver(State) { 221 State state; 222 static if (is(State.SenderValue == void)) { 223 void setValue() @safe { 224 state.dg(); 225 onSenderValue(); 226 } 227 } else { 228 void setValue(State.SenderValue value) @safe { 229 state.dg(value); 230 onSenderValue(); 231 } 232 } 233 void setError(Throwable t) @safe nothrow { 234 with (state.bitfield.lock(Flags.doneOrError_produced, Counter.tick)) { 235 bool last = isLast(newState); 236 if (!isDoneOrErrorProduced(oldState)) { 237 state.throwable = t; 238 release(); // must release before calling .stop 239 state.stop(); 240 } else 241 release(); 242 if (last) 243 state.process(newState); 244 else 245 notify(); 246 } 247 } 248 void setDone() @safe nothrow { 249 static if (State.onOverlap == OnOverlap.latest) { 250 if (!state.isStopRequested) { 251 state.bitfield.add(Counter.tick); 252 notify(); 253 return; 254 } 255 } 256 with (state.bitfield.update(Flags.doneOrError_produced, Counter.tick)) { 257 bool last = isLast(newState); 258 if (!isDoneOrErrorProduced(oldState)) 259 state.stop(); 260 if (last) 261 state.process(newState); 262 else 263 notify(); 264 } 265 } 266 auto getStopToken() @safe nothrow { 267 return state.getSenderStopToken; 268 } 269 private auto receiver() { 270 return state.receiver; 271 } 272 private void onSenderValue() @trusted { 273 with (state.bitfield.update(0, Counter.tick)) { 274 if (isLast(newState)) 275 state.process(newState); 276 else 277 notify(); 278 } 279 } 280 private void notify() @trusted { 281 import std.exception : assumeWontThrow; 282 state.semaphore.notify().assumeWontThrow; 283 } 284 mixin ForwardExtensionPoints!(receiver); 285 }