1 module concurrency.operations.race;
2 
3 import concurrency;
4 import concurrency.receiver;
5 import concurrency.sender;
6 import concurrency.stoptoken;
7 import concurrency.utils : spin_yield, casWeak;
8 import concepts;
9 import std.traits;
10 
11 /// Runs both Senders and propagates the value of whoever completes first
12 /// if both error out the first exception is propagated,
13 /// uses mir.algebraic if the Sender value types differ
14 RaceSender!(Senders) race(Senders...)(Senders senders) {
15   return RaceSender!(Senders)(senders, false);
16 }
17 
18 private template Result(Senders...) if (Senders.length > 1) {
19   import concurrency.utils : NoVoid;
20   import mir.algebraic : Algebraic, Nullable;
21   import std.meta : staticMap, Filter, NoDuplicates;
22   alias getValue(Sender) = Sender.Value;
23   alias SenderValues = staticMap!(getValue, Senders);
24   alias NoVoidValueTypes = Filter!(NoVoid, SenderValues);
25   enum HasVoid = SenderValues.length != NoVoidValueTypes.length;
26   alias ValueTypes = NoDuplicates!(NoVoidValueTypes);
27 
28   static if (ValueTypes.length == 0)
29     alias Result = void;
30   else static if (ValueTypes.length == 1)
31     static if (HasVoid)
32       alias Result = Nullable!(ValueTypes[0]);
33     else
34       alias Result = ValueTypes[0];
35   else {
36     alias Result = Algebraic!(ValueTypes);
37   }
38 }
39 
40 alias ArrayElement(T : P[], P) = P;
41 
42 import std.traits;
43 private template Result(Senders...) if (Senders.length == 1) {
44   alias Result = ArrayElement!(Senders).Value;
45 }
46 
47 private struct RaceOp(Receiver, Senders...) {
48   import std.meta : staticMap;
49   alias R = Result!(Senders);
50   static if (Senders.length > 1) {
51     alias ElementReceiver(Sender) = RaceReceiver!(Receiver, Sender.Value, R);
52     alias ConnectResult(Sender) = OpType!(Sender, ElementReceiver!Sender);
53     alias Ops = staticMap!(ConnectResult, Senders);
54   } else {
55     alias ElementReceiver = RaceReceiver!(Receiver, R, R);
56     alias Ops = OpType!(ArrayElement!(Senders[0]), ElementReceiver)[];
57   }
58   Receiver receiver;
59   State!R state;
60   Ops ops;
61   @disable this(this);
62   @disable this(ref return scope typeof(this) rhs);
63   this(Receiver receiver, return Senders senders, bool noDropouts) @trusted scope {
64     this.receiver = receiver;
65     state = new State!(R)(noDropouts);
66     static if (Senders.length > 1) {
67       foreach(i, Sender; Senders) {
68         ops[i] = senders[i].connect(ElementReceiver!(Sender)(receiver, state, Senders.length));
69       }
70     } else {
71       ops.length = senders[0].length;
72       foreach(i; 0..senders[0].length) {
73         ops[i] = senders[0][i].connect(ElementReceiver(receiver, state, senders[0].length));
74       }
75     }
76   }
77   void start() @trusted nothrow scope {
78     import concurrency.stoptoken : StopSource;
79     if (receiver.getStopToken().isStopRequested) {
80       receiver.setDone();
81       return;
82     }
83     state.cb = receiver.getStopToken().onStop(cast(void delegate() nothrow @safe shared)&state.stop); // butt ugly cast, but it won't take the second overload
84     static if (Senders.length > 1) {
85       foreach(i, _; Senders) {
86         ops[i].start();
87       }
88     } else {
89       foreach(i; 0..ops.length) {
90         ops[i].start();
91       }
92     }
93   }
94 }
95 
96 import std.meta : allSatisfy, ApplyRight;
97 
98 struct RaceSender(Senders...)
99      if ((Senders.length > 1 && allSatisfy!(ApplyRight!(models, isSender), Senders)) ||
100          (models!(ArrayElement!(Senders[0]), isSender))) {
101   static assert(models!(typeof(this), isSender));
102   alias Value = Result!(Senders);
103   Senders senders;
104   bool noDropouts; // if true then we fail the moment one contender does, otherwise we keep running until one finishes
105   auto connect(Receiver)(return Receiver receiver) @safe scope return {
106     // ensure NRVO
107     auto op = RaceOp!(Receiver, Senders)(receiver, senders, noDropouts);
108     return op;
109   }
110 }
111 
112 private class State(Value) : StopSource {
113   import concurrency.bitfield;
114   StopCallback cb;
115   shared SharedBitField!Flags bitfield;
116   static if (!is(Value == void))
117     Value value;
118   Throwable exception;
119   bool noDropouts;
120   this(bool noDropouts) {
121     this.noDropouts = noDropouts;
122   }
123 }
124 
125 private enum Flags : size_t {
126   locked = 0x1,
127   value_produced = 0x2,
128   doneOrError_produced = 0x4
129 }
130 
131 private enum Counter : size_t {
132   tick = 0x8,
133   mask = ~0x7
134 }
135 
136 private struct RaceReceiver(Receiver, InnerValue, Value) {
137   import core.atomic : atomicOp, atomicLoad, MemoryOrder;
138   Receiver receiver;
139   State!(Value) state;
140   size_t senderCount;
141   auto getStopToken() {
142     return StopToken(state);
143   }
144   private bool isValueProduced(size_t state) {
145     return (state & Flags.value_produced) > 0;
146   }
147   private bool isDoneOrErrorProduced(size_t state) {
148     return (state & Flags.doneOrError_produced) > 0;
149   }
150   private bool isLast(size_t state) {
151     return (state >> 3) == atomicLoad(senderCount);
152   }
153   static if (!is(InnerValue == void))
154     void setValue(InnerValue value) @safe nothrow {
155       with (state.bitfield.lock(Flags.value_produced, Counter.tick)) {
156         bool last = isLast(newState);
157         if (!isValueProduced(oldState)) {
158           static if (is(InnerValue == Value))
159             state.value = value;
160           else
161             state.value = Value(value);
162           release(); // must release before calling .stop
163           state.stop();
164         } else
165           release();
166 
167         if (last)
168           process(newState);
169       }
170     }
171   else
172     void setValue() @safe nothrow {
173       with (state.bitfield.update(Flags.value_produced, Counter.tick)) {
174         bool last = isLast(newState);
175         if (!isValueProduced(oldState)) {
176           state.stop();
177         }
178         if (last)
179           process(newState);
180       }
181     }
182   void setDone() @safe nothrow {
183     with (state.bitfield.update(Flags.doneOrError_produced, Counter.tick)) {
184       bool last = isLast(newState);
185       if (state.noDropouts && !isDoneOrErrorProduced(oldState)) {
186         state.stop();
187       }
188       if (last)
189         process(newState);
190     }
191   }
192   void setError(Throwable exception) @safe nothrow {
193     with (state.bitfield.lock(Flags.doneOrError_produced, Counter.tick)) {
194       bool last = isLast(newState);
195       if (!isDoneOrErrorProduced(oldState)) {
196         state.exception = exception;
197         if (state.noDropouts) {
198           release(); // release before stop
199           state.stop();
200         }
201       }
202       release();
203       if (last)
204         process(newState);
205     }
206   }
207   private void process(size_t newState) {
208     import concurrency.receiver : setValueOrError;
209 
210     state.cb.dispose();
211     if (receiver.getStopToken().isStopRequested)
212       receiver.setDone();
213     else if (isValueProduced(newState)) {
214       static if (is(Value == void))
215         receiver.setValueOrError();
216       else
217         receiver.setValueOrError(state.value);
218     } else if (state.exception)
219       receiver.setError(state.exception);
220     else
221       receiver.setDone();
222   }
223   mixin ForwardExtensionPoints!receiver;
224 }