1 module concurrency.operations.race;
2 
3 import concurrency;
4 import concurrency.receiver;
5 import concurrency.sender;
6 import concurrency.stoptoken;
7 import concurrency.utils : spin_yield, casWeak;
8 import concepts;
9 import std.traits;
10 
11 /// Runs both Senders and propagates the value of whoever completes first
12 /// if both error out the first exception is propagated,
13 /// uses mir.algebraic if the Sender value types differ
14 RaceSender!(Senders) race(Senders...)(Senders senders) {
15   return RaceSender!(Senders)(senders, false);
16 }
17 
18 private template Result(Senders...) {
19   import concurrency.utils : NoVoid;
20   import mir.algebraic : Algebraic, Nullable;
21   import std.meta : staticMap, Filter, NoDuplicates;
22   alias getValue(Sender) = Sender.Value;
23   alias SenderValues = staticMap!(getValue, Senders);
24   alias NoVoidValueTypes = Filter!(NoVoid, SenderValues);
25   enum HasVoid = SenderValues.length != NoVoidValueTypes.length;
26   alias ValueTypes = NoDuplicates!(NoVoidValueTypes);
27 
28   static if (ValueTypes.length == 0)
29     alias Result = void;
30   else static if (ValueTypes.length == 1)
31     static if (HasVoid)
32       alias Result = Nullable!(ValueTypes[0]);
33     else
34       alias Result = ValueTypes[0];
35   else {
36     alias Result = Algebraic!(ValueTypes);
37   }
38 }
39 
40 private struct RaceOp(Receiver, Senders...) {
41   import std.meta : staticMap;
42   alias R = Result!(Senders);
43   alias ElementReceiver(Sender) = RaceReceiver!(Receiver, Sender.Value, R);
44   alias ConnectResult(Sender) = OpType!(Sender, ElementReceiver!Sender);
45   alias Ops = staticMap!(ConnectResult, Senders);
46   Receiver receiver;
47   State!R state;
48   Ops ops;
49   @disable this(this);
50   @disable this(ref return scope typeof(this) rhs);
51   this(Receiver receiver, return Senders senders, bool noDropouts) @trusted scope {
52     this.receiver = receiver;
53     state = new State!(R)(noDropouts);
54     foreach(i, Sender; Senders) {
55       ops[i] = senders[i].connect(ElementReceiver!(Sender)(receiver, state, Senders.length));
56     }
57   }
58   void start() @trusted nothrow scope {
59     import concurrency.stoptoken : StopSource;
60     if (receiver.getStopToken().isStopRequested) {
61       receiver.setDone();
62       return;
63     }
64     state.cb = receiver.getStopToken().onStop(cast(void delegate() nothrow @safe shared)&state.stop); // butt ugly cast, but it won't take the second overload
65     foreach(i, _; Senders) {
66       ops[i].start();
67     }
68   }
69 }
70 
71 import std.meta : allSatisfy, ApplyRight;
72 
73 struct RaceSender(Senders...) if (allSatisfy!(ApplyRight!(models, isSender), Senders)) {
74   static assert(models!(typeof(this), isSender));
75   alias Value = Result!(Senders);
76   Senders senders;
77   bool noDropouts; // if true then we fail the moment one contender does, otherwise we keep running until one finishes
78   auto connect(Receiver)(return Receiver receiver) @safe scope return {
79     // ensure NRVO
80     auto op = RaceOp!(Receiver, Senders)(receiver, senders, noDropouts);
81     return op;
82   }
83 }
84 
85 private class State(Value) : StopSource {
86   import concurrency.bitfield;
87   StopCallback cb;
88   static if (!is(Value == void))
89     Value value;
90   Exception exception;
91   shared SharedBitField!Flags bitfield;
92   bool noDropouts;
93   this(bool noDropouts) {
94     this.noDropouts = noDropouts;
95   }
96 }
97 
98 private enum Flags : size_t {
99   locked = 0x1,
100   value_produced = 0x2,
101   doneOrError_produced = 0x4
102 }
103 
104 private enum Counter : size_t {
105   tick = 0x8,
106   mask = ~0x7
107 }
108 
109 private struct RaceReceiver(Receiver, InnerValue, Value) {
110   import core.atomic : atomicOp, atomicLoad, MemoryOrder;
111   Receiver receiver;
112   State!(Value) state;
113   size_t senderCount;
114   auto getStopToken() {
115     return StopToken(state);
116   }
117   private bool isValueProduced(size_t state) {
118     return (state & Flags.value_produced) > 0;
119   }
120   private bool isDoneOrErrorProduced(size_t state) {
121     return (state & Flags.doneOrError_produced) > 0;
122   }
123   private bool isLast(size_t state) {
124     return (state >> 3) == senderCount;
125   }
126   static if (!is(InnerValue == void))
127     void setValue(InnerValue value) @safe nothrow {
128       with (state.bitfield.lock(Flags.value_produced, Counter.tick)) {
129         if (!isValueProduced(oldState)) {
130           static if (is(InnerValue == Value))
131             state.value = value;
132           else
133             state.value = Value(value);
134           release(); // must release before calling .stop
135           state.stop();
136         } else
137           release();
138 
139         process(newState);
140       }
141     }
142   else
143     void setValue() @safe nothrow {
144       with (state.bitfield.update(Flags.value_produced, Counter.tick)) {
145         if (!isValueProduced(oldState)) {
146           state.stop();
147         }
148         process(newState);
149       }
150     }
151   void setDone() @safe nothrow {
152     with (state.bitfield.update(Flags.doneOrError_produced, Counter.tick)) {
153       if (state.noDropouts && !isDoneOrErrorProduced(oldState)) {
154         state.stop();
155       }
156       process(newState);
157     }
158   }
159   void setError(Exception exception) @safe nothrow {
160     with (state.bitfield.lock(Flags.doneOrError_produced, Counter.tick)) {
161       if (!isDoneOrErrorProduced(oldState)) {
162         state.exception = exception;
163         if (state.noDropouts) {
164           release(); // release before stop
165           state.stop();
166         }
167       }
168       release();
169       process(newState);
170     }
171   }
172   private void process(size_t newState) {
173     import concurrency.receiver : setValueOrError;
174 
175     if (!isLast(newState))
176       return;
177 
178     state.cb.dispose();
179     if (receiver.getStopToken().isStopRequested)
180       receiver.setDone();
181     else if (isValueProduced(newState)) {
182       static if (is(Value == void))
183         receiver.setValueOrError();
184       else
185         receiver.setValueOrError(state.value);
186     } else if (state.exception)
187       receiver.setError(state.exception);
188     else
189       receiver.setDone();
190   }
191   mixin ForwardExtensionPoints!receiver;
192 }