1 module concurrency.stream.flatmapbase;
2 
3 import concurrency.stream.stream;
4 import concurrency.sender : OpType, isSender;
5 import concurrency.receiver : ForwardExtensionPoints;
6 import concurrency.stoptoken : StopSource, StopToken;
7 import std.traits : ReturnType;
8 import concurrency.utils : isThreadSafeFunction;
9 import concepts;
10 import core.sync.semaphore : Semaphore;
11 
12 enum OnOverlap {
13   wait,
14   latest
15 }
16 
17 template FlatMapBaseStreamOp(Stream, Fun, OnOverlap overlap) {
18   static assert(isThreadSafeFunction!Fun);
19   alias Properties = StreamProperties!Stream;
20   alias InnerSender = ReturnType!Fun;
21   static assert(models!(InnerSender, isSender), "Fun must produce a Sender");
22   alias DG = CollectDelegate!(InnerSender.Value);
23   struct FlatMapBaseStreamOp(Receiver) {
24     alias State = .State!(Properties.Sender.Value, InnerSender.Value, Receiver, overlap);
25     alias Op = OpType!(Properties.Sender, StreamReceiver!State);
26     Fun fun;
27     Op op;
28     State state;
29     @disable this(ref return scope typeof(this) rhs);
30     @disable this(this);
31     this(Stream stream, Fun fun, return DG dg, return Receiver receiver) @trusted return scope {
32       this.fun = fun;
33       state = new State(dg, receiver);
34       // TODO: would it be good to do the fun in a transform operation?
35       op = stream.collect(cast(Properties.DG)&item).connect(StreamReceiver!State(state));
36     }
37     static if (is(Properties.ElementType == void))
38       void item() {
39         if (state.isStopRequested)
40           return;
41         state.onItem();
42         with (state.bitfield.lock()) {
43           if (isDoneOrErrorProduced(oldState)) {
44             return;
45           }
46           auto sender = fun();
47           release(Counter.tick); // release early
48           runInnerSender(sender);
49         }
50       }
51     else
52       void item(Properties.ElementType t) {
53         if (state.isStopRequested)
54           return;
55         state.onItem();
56         with (state.bitfield.lock()) {
57           if (isDoneOrErrorProduced(oldState)) {
58             return;
59           }
60           auto sender = fun(t);
61           release(Counter.tick); // release early
62           runInnerSender(sender);
63         }
64       }
65     private void runInnerSender(ref InnerSender sender) {
66       import concurrency.sender : connectHeap;
67       auto innerOp = sender.connectHeap(InnerSenderReceiver!(State)(state));
68       innerOp.start();
69     }
70     void start() nothrow @safe scope {
71       op.start();
72     }
73   }
74 }
75 
76 private enum Flags : size_t {
77   locked = 0x1,
78   value_produced = 0x2,
79   doneOrError_produced = 0x4
80 }
81 
82 private enum Counter : size_t {
83   tick = 0x8
84 }
85 
86 private bool isLast(size_t state) @safe @nogc nothrow pure {
87   return (state >> 3) == 2;
88 }
89 
90 private bool isDoneOrErrorProduced(size_t state) @safe @nogc nothrow pure {
91   return (state & Flags.doneOrError_produced) > 0;
92 }
93 
94 final class State(TStreamSenderValue, TSenderValue, Receiver, OnOverlap overlap) : StopSource {
95   import concurrency.bitfield;
96   import concurrency.stoptoken;
97   import std.exception : assumeWontThrow;
98   alias DG = CollectDelegate!(SenderValue);
99   alias StreamSenderValue = TStreamSenderValue;
100   alias SenderValue = TSenderValue;
101   alias onOverlap = overlap;
102   DG dg;
103   Receiver receiver;
104   static if (!is(StreamSenderValue == void))
105     StreamSenderValue value;
106   Throwable throwable;
107   Semaphore semaphore;
108   StopCallback cb;
109   static if (overlap == OnOverlap.latest)
110     StopSource innerStopSource;
111   shared SharedBitField!Flags bitfield;
112   this(DG dg, Receiver receiver) {
113     this.dg = dg;
114     this.receiver = receiver;
115     semaphore = new Semaphore(1);
116     static if (overlap == OnOverlap.latest)
117       innerStopSource = new StopSource();
118     bitfield = SharedBitField!Flags(Counter.tick);
119     cb = receiver.getStopToken.onStop(cast(void delegate() nothrow @safe shared)&stop);
120   }
121   override bool stop() nothrow @trusted {
122     return (cast(shared)this).stop();
123   }
124   override bool stop() nothrow @trusted shared {
125     static if (overlap == OnOverlap.latest) {
126       auto r = super.stop();
127       innerStopSource.stop();
128       return r;
129     } else {
130       return super.stop();
131     }
132   }
133   private void onItem() @trusted {
134     static if (overlap == OnOverlap.latest) {
135       innerStopSource.stop();
136       semaphore.wait();
137       innerStopSource.reset();
138     } else {
139       semaphore.wait();
140     }
141   }
142   private void process(size_t newState) {
143     cb.dispose();
144 
145     if (receiver.getStopToken().isStopRequested)
146       receiver.setDone();
147     else if (isDoneOrErrorProduced(newState)) {
148       if (throwable)
149         receiver.setError(throwable);
150       else
151         receiver.setDone();
152     } else {
153       import concurrency.receiver : setValueOrError;
154       static if (is(typeof(Value.values)))
155         receiver.setValueOrError(state.value.values);
156       else
157         receiver.setValueOrError();
158     }
159   }
160   private StopToken getSenderStopToken() @safe nothrow {
161     static if (overlap == OnOverlap.latest) {
162       return StopToken(innerStopSource);
163     } else {
164       return StopToken(this);
165     }
166   }
167 }
168 
169 struct StreamReceiver(State) {
170   State state;
171   static if (is(State.StreamSenderValue == void)) {
172     void setValue() @safe {
173       with (state.bitfield.update(Flags.value_produced, Counter.tick)) {
174         if (isLast(newState))
175           state.process(newState);
176       }
177     }
178   } else {
179     void setValue(State.StreamSenderValue value) @safe {
180       with (state.bitfield.lock(Flags.value_produced, Counter.tick)) {
181         bool last = isLast(newState);
182         state.value = value;
183         release();
184         if (last)
185           state.process(newState);
186       }
187     }
188   }
189   void setError(Throwable t) @safe nothrow {
190     with (state.bitfield.lock(Flags.doneOrError_produced, Counter.tick)) {
191       bool last = isLast(newState);
192       if (!isDoneOrErrorProduced(oldState)) {
193         state.throwable = t;
194         release(); // must release before calling .stop
195         state.stop();
196       } else
197         release();
198       if (last)
199         state.process(newState);
200     }
201   }
202   void setDone() @safe nothrow {
203     with (state.bitfield.update(Flags.doneOrError_produced, Counter.tick)) {
204       bool last = isLast(newState);
205       if (!isDoneOrErrorProduced(oldState))
206         state.stop();
207       if (last)
208         state.process(newState);
209     }
210   }
211   StopToken getStopToken() @safe nothrow {
212     return StopToken(state);
213   }
214   private auto receiver() {
215     return state.receiver;
216   }
217   mixin ForwardExtensionPoints!(receiver);
218 }
219 
220 struct InnerSenderReceiver(State) {
221   State state;
222   static if (is(State.SenderValue == void)) {
223     void setValue() @safe {
224       state.dg();
225       onSenderValue();
226     }
227   } else {
228     void setValue(State.SenderValue value) @safe {
229       state.dg(value);
230       onSenderValue();
231     }
232   }
233   void setError(Throwable t) @safe nothrow {
234     with (state.bitfield.lock(Flags.doneOrError_produced, Counter.tick)) {
235       bool last = isLast(newState);
236       if (!isDoneOrErrorProduced(oldState)) {
237         state.throwable = t;
238         release(); // must release before calling .stop
239         state.stop();
240       } else
241         release();
242       if (last)
243         state.process(newState);
244       else
245         notify();
246     }
247   }
248   void setDone() @safe nothrow {
249     static if (State.onOverlap == OnOverlap.latest) {
250       if (!state.isStopRequested) {
251         state.bitfield.add(Counter.tick);
252         notify();
253         return;
254       }
255     }
256     with (state.bitfield.update(Flags.doneOrError_produced, Counter.tick)) {
257       bool last = isLast(newState);
258       if (!isDoneOrErrorProduced(oldState))
259         state.stop();
260       if (last)
261         state.process(newState);
262       else
263         notify();
264     }
265   }
266   auto getStopToken() @safe nothrow {
267     return state.getSenderStopToken;
268   }
269   private auto receiver() {
270     return state.receiver;
271   }
272   private void onSenderValue() @trusted {
273     with (state.bitfield.update(0, Counter.tick)) {
274       if (isLast(newState))
275         state.process(newState);
276       else
277         notify();
278     }
279   }
280   private void notify() @trusted {
281     import std.exception : assumeWontThrow;
282     state.semaphore.notify().assumeWontThrow;
283   }
284   mixin ForwardExtensionPoints!(receiver);
285 }