1 module concurrency.stream.flatmapbase;
2 
3 import concurrency.stream.stream;
4 import concurrency.sender : OpType, isSender;
5 import concurrency.receiver : ForwardExtensionPoints;
6 import concurrency.stoptoken : StopSource, StopToken;
7 import std.traits : ReturnType;
8 import concurrency.utils : isThreadSafeFunction;
9 import concepts;
10 import core.sync.semaphore : Semaphore;
11 
12 enum OnOverlap {
13   wait,
14   latest
15 }
16 
17 template FlatMapBaseStreamOp(Stream, Fun, OnOverlap overlap) {
18   static assert(isThreadSafeFunction!Fun);
19   alias Properties = StreamProperties!Stream;
20   alias InnerSender = ReturnType!Fun;
21   static assert(models!(InnerSender, isSender), "Fun must produce a Sender");
22   alias DG = CollectDelegate!(InnerSender.Value);
23   struct FlatMapBaseStreamOp(Receiver) {
24     alias State = .State!(Properties.Sender.Value, InnerSender.Value, Receiver, overlap);
25     alias Op = OpType!(Properties.Sender, StreamReceiver!State);
26     alias InnerOp = OpType!(InnerSender, InnerSenderReceiver!State);
27     Fun fun;
28     Op op;
29     InnerOp innerOp;
30     State state;
31     @disable this(ref return scope typeof(this) rhs);
32     @disable this(this);
33     this(Stream stream, Fun fun, return DG dg, return Receiver receiver) @trusted scope return {
34       this.fun = fun;
35       state = new State(dg, receiver);
36       // TODO: would it be good to do the fun in a transform operation?
37       op = stream.collect(cast(Properties.DG)&item).connect(StreamReceiver!State(state));
38     }
39     static if (is(Properties.ElementType == void))
40       void item() {
41         if (state.isStopRequested)
42           return;
43         state.onItem();
44         with (state.bitfield.lock()) {
45           if (isDoneOrErrorProduced(oldState)) {
46             return;
47           }
48           release(Counter.tick); // release early
49           auto sender = fun();
50           runInnerSender(sender);
51         }
52       }
53     else
54       void item(Properties.ElementType t) {
55         if (state.isStopRequested)
56           return;
57         state.onItem();
58         with (state.bitfield.lock()) {
59           if (isDoneOrErrorProduced(oldState)) {
60             return;
61           }
62           release(Counter.tick); // release early
63           auto sender = fun(t);
64           runInnerSender(sender);
65         }
66       }
67     private void runInnerSender(ref InnerSender sender) {
68       innerOp = sender.connect(InnerSenderReceiver!(State)(state));
69       innerOp.start();
70     }
71     void start() nothrow @safe {
72       op.start();
73     }
74   }
75 }
76 
77 private enum Flags : size_t {
78   locked = 0x1,
79   value_produced = 0x2,
80   doneOrError_produced = 0x4
81 }
82 
83 private enum Counter : size_t {
84   tick = 0x8
85 }
86 
87 private bool isLast(size_t state) @safe @nogc nothrow pure {
88   return (state >> 3) == 2;
89 }
90 
91 private bool isDoneOrErrorProduced(size_t state) @safe @nogc nothrow pure {
92   return (state & Flags.doneOrError_produced) > 0;
93 }
94 
95 final class State(TStreamSenderValue, TSenderValue, Receiver, OnOverlap overlap) : StopSource {
96   import concurrency.bitfield;
97   import concurrency.stoptoken;
98   import std.exception : assumeWontThrow;
99   alias DG = CollectDelegate!(SenderValue);
100   alias StreamSenderValue = TStreamSenderValue;
101   alias SenderValue = TSenderValue;
102   alias onOverlap = overlap;
103   DG dg;
104   Receiver receiver;
105   static if (!is(StreamSenderValue == void))
106     StreamSenderValue value;
107   Throwable throwable;
108   Semaphore semaphore;
109   StopCallback cb;
110   static if (overlap == OnOverlap.latest)
111     StopSource innerStopSource;
112   shared SharedBitField!Flags bitfield;
113   this(DG dg, Receiver receiver) {
114     this.dg = dg;
115     this.receiver = receiver;
116     semaphore = new Semaphore(1);
117     static if (overlap == OnOverlap.latest)
118       innerStopSource = new StopSource();
119     bitfield = SharedBitField!Flags(Counter.tick);
120     cb = receiver.getStopToken.onStop(cast(void delegate() nothrow @safe shared)&stop);
121   }
122   override bool stop() nothrow @trusted {
123     return (cast(shared)this).stop();
124   }
125   override bool stop() nothrow @trusted shared {
126     static if (overlap == OnOverlap.latest) {
127       auto r = super.stop();
128       innerStopSource.stop();
129       return r;
130     } else {
131       return super.stop();
132     }
133   }
134   private void onItem() @trusted {
135     static if (overlap == OnOverlap.latest) {
136       innerStopSource.stop();
137       semaphore.wait();
138       innerStopSource.reset();
139     } else {
140       semaphore.wait();
141     }
142   }
143   private void process(size_t newState) {
144     cb.dispose();
145 
146     if (receiver.getStopToken().isStopRequested)
147       receiver.setDone();
148     else if (isDoneOrErrorProduced(newState)) {
149       if (throwable)
150         receiver.setError(throwable);
151       else
152         receiver.setDone();
153     } else {
154       import concurrency.receiver : setValueOrError;
155       static if (is(typeof(Value.values)))
156         receiver.setValueOrError(state.value.values);
157       else
158         receiver.setValueOrError();
159     }
160   }
161   private StopToken getSenderStopToken() @safe nothrow {
162     static if (overlap == OnOverlap.latest) {
163       return StopToken(innerStopSource);
164     } else {
165       return StopToken(this);
166     }
167   }
168 }
169 
170 struct StreamReceiver(State) {
171   State state;
172   static if (is(State.StreamSenderValue == void)) {
173     void setValue() @safe {
174       with (state.bitfield.update(Flags.value_produced, Counter.tick)) {
175         if (isLast(newState))
176           state.process(newState);
177       }
178     }
179   } else {
180     void setValue(State.StreamSenderValue value) @safe {
181       with (state.bitfield.lock(Flags.value_produced, Counter.tick)) {
182         bool last = isLast(newState);
183         state.value = value;
184         release();
185         if (last)
186           state.process(newState);
187       }
188     }
189   }
190   void setError(Throwable t) @safe nothrow {
191     with (state.bitfield.lock(Flags.doneOrError_produced, Counter.tick)) {
192       bool last = isLast(newState);
193       if (!isDoneOrErrorProduced(oldState)) {
194         state.throwable = t;
195         release(); // must release before calling .stop
196         state.stop();
197       } else
198         release();
199       if (last)
200         state.process(newState);
201     }
202   }
203   void setDone() @safe nothrow {
204     with (state.bitfield.update(Flags.doneOrError_produced, Counter.tick)) {
205       bool last = isLast(newState);
206       if (!isDoneOrErrorProduced(oldState))
207         state.stop();
208       if (last)
209         state.process(newState);
210     }
211   }
212   StopToken getStopToken() @safe nothrow {
213     return StopToken(state);
214   }
215   private auto receiver() {
216     return state.receiver;
217   }
218   mixin ForwardExtensionPoints!(receiver);
219 }
220 
221 struct InnerSenderReceiver(State) {
222   State state;
223   static if (is(State.SenderValue == void)) {
224     void setValue() @safe {
225       state.dg();
226       onSenderValue();
227     }
228   } else {
229     void setValue(State.SenderValue value) @safe {
230       state.dg(value);
231       onSenderValue();
232     }
233   }
234   void setError(Throwable t) @safe nothrow {
235     with (state.bitfield.lock(Flags.doneOrError_produced, Counter.tick)) {
236       bool last = isLast(newState);
237       if (!isDoneOrErrorProduced(oldState)) {
238         state.throwable = t;
239         release(); // must release before calling .stop
240         state.stop();
241       } else
242         release();
243       if (last)
244         state.process(newState);
245       else
246         notify();
247     }
248   }
249   void setDone() @safe nothrow {
250     static if (State.onOverlap == OnOverlap.latest) {
251       if (!state.isStopRequested) {
252         state.bitfield.add(Counter.tick);
253         notify();
254         return;
255       }
256     }
257     with (state.bitfield.update(Flags.doneOrError_produced, Counter.tick)) {
258       bool last = isLast(newState);
259       if (!isDoneOrErrorProduced(oldState))
260         state.stop();
261       if (last)
262         state.process(newState);
263       else
264         notify();
265     }
266   }
267   auto getStopToken() @safe nothrow {
268     return state.getSenderStopToken;
269   }
270   private auto receiver() {
271     return state.receiver;
272   }
273   private void onSenderValue() @trusted {
274     with (state.bitfield.update(0, Counter.tick)) {
275       if (isLast(newState))
276         state.process(newState);
277       else
278         notify();
279     }
280   }
281   private void notify() @trusted {
282     import std.exception : assumeWontThrow;
283     state.semaphore.notify().assumeWontThrow;
284   }
285   mixin ForwardExtensionPoints!(receiver);
286 }