1 module concurrency.stream.flatmapbase; 2 3 import concurrency.stream.stream; 4 import concurrency.sender : OpType, isSender; 5 import concurrency.receiver : ForwardExtensionPoints; 6 import concurrency.stoptoken : StopSource, StopToken; 7 import std.traits : ReturnType; 8 import concurrency.utils : isThreadSafeFunction; 9 import concepts; 10 import core.sync.semaphore : Semaphore; 11 12 enum OnOverlap { 13 wait, 14 latest 15 } 16 17 template FlatMapBaseStreamOp(Stream, Fun, OnOverlap overlap) { 18 static assert(isThreadSafeFunction!Fun); 19 alias Properties = StreamProperties!Stream; 20 alias InnerSender = ReturnType!Fun; 21 static assert(models!(InnerSender, isSender), "Fun must produce a Sender"); 22 alias DG = CollectDelegate!(InnerSender.Value); 23 struct FlatMapBaseStreamOp(Receiver) { 24 alias State = .State!(Properties.Sender.Value, InnerSender.Value, Receiver, overlap); 25 alias Op = OpType!(Properties.Sender, StreamReceiver!State); 26 alias InnerOp = OpType!(InnerSender, InnerSenderReceiver!State); 27 Fun fun; 28 Op op; 29 InnerOp innerOp; 30 State state; 31 @disable this(ref return scope typeof(this) rhs); 32 @disable this(this); 33 this(Stream stream, Fun fun, return DG dg, return Receiver receiver) @trusted scope return { 34 this.fun = fun; 35 state = new State(dg, receiver); 36 // TODO: would it be good to do the fun in a transform operation? 37 op = stream.collect(cast(Properties.DG)&item).connect(StreamReceiver!State(state)); 38 } 39 static if (is(Properties.ElementType == void)) 40 void item() { 41 if (state.isStopRequested) 42 return; 43 state.onItem(); 44 with (state.bitfield.lock()) { 45 if (isDoneOrErrorProduced(oldState)) { 46 return; 47 } 48 release(Counter.tick); // release early 49 auto sender = fun(); 50 runInnerSender(sender); 51 } 52 } 53 else 54 void item(Properties.ElementType t) { 55 if (state.isStopRequested) 56 return; 57 state.onItem(); 58 with (state.bitfield.lock()) { 59 if (isDoneOrErrorProduced(oldState)) { 60 return; 61 } 62 release(Counter.tick); // release early 63 auto sender = fun(t); 64 runInnerSender(sender); 65 } 66 } 67 private void runInnerSender(ref InnerSender sender) { 68 innerOp = sender.connect(InnerSenderReceiver!(State)(state)); 69 innerOp.start(); 70 } 71 void start() nothrow @safe { 72 op.start(); 73 } 74 } 75 } 76 77 private enum Flags : size_t { 78 locked = 0x1, 79 value_produced = 0x2, 80 doneOrError_produced = 0x4 81 } 82 83 private enum Counter : size_t { 84 tick = 0x8 85 } 86 87 private bool isLast(size_t state) @safe @nogc nothrow pure { 88 return (state >> 3) == 2; 89 } 90 91 private bool isDoneOrErrorProduced(size_t state) @safe @nogc nothrow pure { 92 return (state & Flags.doneOrError_produced) > 0; 93 } 94 95 final class State(TStreamSenderValue, TSenderValue, Receiver, OnOverlap overlap) : StopSource { 96 import concurrency.bitfield; 97 import concurrency.stoptoken; 98 import std.exception : assumeWontThrow; 99 alias DG = CollectDelegate!(SenderValue); 100 alias StreamSenderValue = TStreamSenderValue; 101 alias SenderValue = TSenderValue; 102 alias onOverlap = overlap; 103 DG dg; 104 Receiver receiver; 105 static if (!is(StreamSenderValue == void)) 106 StreamSenderValue value; 107 Throwable throwable; 108 Semaphore semaphore; 109 StopCallback cb; 110 static if (overlap == OnOverlap.latest) 111 StopSource innerStopSource; 112 shared SharedBitField!Flags bitfield; 113 this(DG dg, Receiver receiver) { 114 this.dg = dg; 115 this.receiver = receiver; 116 semaphore = new Semaphore(1); 117 static if (overlap == OnOverlap.latest) 118 innerStopSource = new StopSource(); 119 bitfield = SharedBitField!Flags(Counter.tick); 120 cb = receiver.getStopToken.onStop(cast(void delegate() nothrow @safe shared)&stop); 121 } 122 override bool stop() nothrow @trusted { 123 return (cast(shared)this).stop(); 124 } 125 override bool stop() nothrow @trusted shared { 126 static if (overlap == OnOverlap.latest) { 127 auto r = super.stop(); 128 innerStopSource.stop(); 129 return r; 130 } else { 131 return super.stop(); 132 } 133 } 134 private void onItem() @trusted { 135 static if (overlap == OnOverlap.latest) { 136 innerStopSource.stop(); 137 semaphore.wait(); 138 innerStopSource.reset(); 139 } else { 140 semaphore.wait(); 141 } 142 } 143 private void process(size_t newState) { 144 cb.dispose(); 145 146 if (receiver.getStopToken().isStopRequested) 147 receiver.setDone(); 148 else if (isDoneOrErrorProduced(newState)) { 149 if (throwable) 150 receiver.setError(throwable); 151 else 152 receiver.setDone(); 153 } else { 154 import concurrency.receiver : setValueOrError; 155 static if (is(typeof(Value.values))) 156 receiver.setValueOrError(state.value.values); 157 else 158 receiver.setValueOrError(); 159 } 160 } 161 private StopToken getSenderStopToken() @safe nothrow { 162 static if (overlap == OnOverlap.latest) { 163 return StopToken(innerStopSource); 164 } else { 165 return StopToken(this); 166 } 167 } 168 } 169 170 struct StreamReceiver(State) { 171 State state; 172 static if (is(State.StreamSenderValue == void)) { 173 void setValue() @safe { 174 with (state.bitfield.update(Flags.value_produced, Counter.tick)) { 175 if (isLast(newState)) 176 state.process(newState); 177 } 178 } 179 } else { 180 void setValue(State.StreamSenderValue value) @safe { 181 with (state.bitfield.lock(Flags.value_produced, Counter.tick)) { 182 bool last = isLast(newState); 183 state.value = value; 184 release(); 185 if (last) 186 state.process(newState); 187 } 188 } 189 } 190 void setError(Throwable t) @safe nothrow { 191 with (state.bitfield.lock(Flags.doneOrError_produced, Counter.tick)) { 192 bool last = isLast(newState); 193 if (!isDoneOrErrorProduced(oldState)) { 194 state.throwable = t; 195 release(); // must release before calling .stop 196 state.stop(); 197 } else 198 release(); 199 if (last) 200 state.process(newState); 201 } 202 } 203 void setDone() @safe nothrow { 204 with (state.bitfield.update(Flags.doneOrError_produced, Counter.tick)) { 205 bool last = isLast(newState); 206 if (!isDoneOrErrorProduced(oldState)) 207 state.stop(); 208 if (last) 209 state.process(newState); 210 } 211 } 212 StopToken getStopToken() @safe nothrow { 213 return StopToken(state); 214 } 215 private auto receiver() { 216 return state.receiver; 217 } 218 mixin ForwardExtensionPoints!(receiver); 219 } 220 221 struct InnerSenderReceiver(State) { 222 State state; 223 static if (is(State.SenderValue == void)) { 224 void setValue() @safe { 225 state.dg(); 226 onSenderValue(); 227 } 228 } else { 229 void setValue(State.SenderValue value) @safe { 230 state.dg(value); 231 onSenderValue(); 232 } 233 } 234 void setError(Throwable t) @safe nothrow { 235 with (state.bitfield.lock(Flags.doneOrError_produced, Counter.tick)) { 236 bool last = isLast(newState); 237 if (!isDoneOrErrorProduced(oldState)) { 238 state.throwable = t; 239 release(); // must release before calling .stop 240 state.stop(); 241 } else 242 release(); 243 if (last) 244 state.process(newState); 245 else 246 notify(); 247 } 248 } 249 void setDone() @safe nothrow { 250 static if (State.onOverlap == OnOverlap.latest) { 251 if (!state.isStopRequested) { 252 state.bitfield.add(Counter.tick); 253 notify(); 254 return; 255 } 256 } 257 with (state.bitfield.update(Flags.doneOrError_produced, Counter.tick)) { 258 bool last = isLast(newState); 259 if (!isDoneOrErrorProduced(oldState)) 260 state.stop(); 261 if (last) 262 state.process(newState); 263 else 264 notify(); 265 } 266 } 267 auto getStopToken() @safe nothrow { 268 return state.getSenderStopToken; 269 } 270 private auto receiver() { 271 return state.receiver; 272 } 273 private void onSenderValue() @trusted { 274 with (state.bitfield.update(0, Counter.tick)) { 275 if (isLast(newState)) 276 state.process(newState); 277 else 278 notify(); 279 } 280 } 281 private void notify() @trusted { 282 import std.exception : assumeWontThrow; 283 state.semaphore.notify().assumeWontThrow; 284 } 285 mixin ForwardExtensionPoints!(receiver); 286 }