1 module concurrency.operations.stopwhen;
2 
3 import concurrency;
4 import concurrency.receiver;
5 import concurrency.sender;
6 import concurrency.stoptoken;
7 import concurrency.utils : spin_yield, casWeak;
8 import concepts;
9 import std.traits;
10 
11 /// stopWhen cancels the source when the trigger completes normally. If the either source or trigger completes with cancellation or with an error, the first one is propagates after both are completed.
12 StopWhenSender!(Sender, Trigger) stopWhen(Sender, Trigger)(Sender source, Trigger trigger) {
13   return StopWhenSender!(Sender, Trigger)(source, trigger);
14 }
15 
16 private struct StopWhenOp(Receiver, Sender, Trigger) {
17   alias SenderOp = OpType!(Sender, SourceReceiver!(Receiver, Sender.Value));
18   alias TriggerOp = OpType!(Trigger, TriggerReceiver!(Receiver, Sender.Value));
19   Receiver receiver;
20   State!(Sender.Value) state;
21   SenderOp sourceOp;
22   TriggerOp triggerOp;
23   @disable this(this);
24   @disable this(ref return scope typeof(this) rhs);
25   this(Receiver receiver, return Sender source, return Trigger trigger) @trusted scope {
26     this.receiver = receiver;
27     state = new State!(Sender.Value)();
28     sourceOp = source.connect(SourceReceiver!(Receiver, Sender.Value)(receiver, state));
29     triggerOp = trigger.connect(TriggerReceiver!(Receiver, Sender.Value)(receiver, state));
30   }
31   void start() @trusted nothrow scope {
32     if (receiver.getStopToken().isStopRequested) {
33       receiver.setDone();
34       return;
35     }
36     state.cb = receiver.getStopToken().onStop(cast(void delegate() nothrow @safe shared)&state.stop); // butt ugly cast, but it won't take the second overload
37     sourceOp.start;
38     triggerOp.start;
39   }
40 }
41 
42 struct StopWhenSender(Sender, Trigger) if (models!(Sender, isSender) && models!(Trigger, isSender)) {
43   static assert(models!(typeof(this), isSender));
44   alias Value = Sender.Value;
45   Sender sender;
46   Trigger trigger;
47   auto connect(Receiver)(return Receiver receiver) @safe scope return {
48     // ensure NRVO
49     auto op = StopWhenOp!(Receiver, Sender, Trigger)(receiver, sender, trigger);
50     return op;
51   }
52 }
53 
54 private class State(Value) : StopSource {
55   import concurrency.bitfield;
56   StopCallback cb;
57   shared SharedBitField!Flags bitfield;
58   static if (!is(Value == void))
59     Value value;
60   Throwable exception;
61 }
62 
63 private enum Flags : size_t {
64   locked = 0x1,
65   value_produced = 0x2,
66   doneOrError_produced = 0x4,
67   tick = 0x8
68 }
69 
70 private enum Counter : size_t {
71   tick = 0x8
72 }
73 
74 private void process(State, Receiver)(State state, Receiver receiver, size_t newState) {
75   import concurrency.receiver : setValueOrError;
76 
77   state.cb.dispose();
78   if (receiver.getStopToken().isStopRequested)
79     receiver.setDone();
80   else if (isValueProduced(newState)) {
81     static if (__traits(compiles, state.value))
82       receiver.setValueOrError(state.value);
83     else
84       receiver.setValueOrError();
85   } else if (state.exception)
86     receiver.setError(state.exception);
87   else
88     receiver.setDone();
89 }
90 
91 private bool isValueProduced(size_t state) @safe nothrow pure {
92   return (state & Flags.value_produced) > 0;
93 }
94 private bool isDoneOrErrorProduced(size_t state) @safe nothrow pure {
95   return (state & Flags.doneOrError_produced) > 0;
96 }
97 private bool isLast(size_t state) @safe nothrow pure {
98   return (state & Flags.tick) > 0;
99 }
100 
101 private struct TriggerReceiver(Receiver, Value) {
102   Receiver receiver;
103   State!(Value) state;
104   auto getStopToken() {
105     return StopToken(state);
106   }
107   void setValue() @safe nothrow {
108     with (state.bitfield.update(Flags.tick)) {
109       if (!isLast(oldState))
110         state.stop();
111       else
112         state.process(receiver, newState);
113     }
114   }
115   void setDone() @safe nothrow {
116     with (state.bitfield.update(Flags.doneOrError_produced, Flags.tick)) {
117       if (!isLast(oldState))
118         state.stop();
119       else
120         state.process(receiver, newState);
121     }
122   }
123   void setError(Throwable exception) @safe nothrow {
124     with (state.bitfield.lock(Flags.doneOrError_produced, Counter.tick)) {
125       bool last = isLast(newState);
126       if (!isDoneOrErrorProduced(oldState)) {
127         state.exception = exception;
128         release(); // release before stop
129         state.stop();
130       } else {
131         release();
132         if (last)
133           state.process(receiver, newState);
134       }
135     }
136   }
137   mixin ForwardExtensionPoints!receiver;
138 }
139 
140 private struct SourceReceiver(Receiver, Value) {
141   import core.atomic : atomicOp, atomicLoad, MemoryOrder;
142   Receiver receiver;
143   State!(Value) state;
144   auto getStopToken() {
145     return StopToken(state);
146   }
147   static if (!is(Value == void))
148     void setValue(Value value) @safe nothrow {
149       with (state.bitfield.update(Flags.value_produced | Flags.tick)) {
150         bool last = isLast(newState);
151         state.value = value;
152 
153         if (!last)
154           state.stop();
155         else
156           if (isDoneOrErrorProduced(oldState))
157             state.process(receiver, oldState);
158           else
159             state.process(receiver, newState);
160       }
161     }
162   else
163     void setValue() @safe nothrow {
164       with (state.bitfield.update(Flags.value_produced | Flags.tick)) {
165         bool last = isLast(newState);
166         if (!last)
167           state.stop();
168         else
169           if (isDoneOrErrorProduced(oldState))
170             state.process(receiver, oldState);
171           else
172             state.process(receiver, newState);
173       }
174     }
175   void setDone() @safe nothrow {
176     with (state.bitfield.update(Flags.doneOrError_produced | Flags.tick)) {
177       bool last = isLast(newState);
178       if (!last)
179         state.stop();
180       else
181         state.process(receiver, newState);
182     }
183   }
184   void setError(Throwable exception) @safe nothrow {
185     with (state.bitfield.lock(Flags.doneOrError_produced | Flags.tick)) {
186       bool last = isLast(newState);
187       if (!isDoneOrErrorProduced(oldState)) {
188         state.exception = exception;
189       }
190       release();
191       if (!last)
192         state.stop();
193       else
194         state.process(receiver, newState);
195     }
196   }
197   mixin ForwardExtensionPoints!receiver;
198 }