1 module concurrency.operations.race;
2 
3 import concurrency;
4 import concurrency.receiver;
5 import concurrency.sender;
6 import concurrency.stoptoken;
7 import concurrency.utils : spin_yield, casWeak;
8 import concepts;
9 import std.traits;
10 
11 /// Runs both Senders and propagates the value of whoever completes first
12 /// if both error out the first exception is propagated,
13 /// uses mir.algebraic if the Sender value types differ
14 RaceSender!(Senders) race(Senders...)(Senders senders) {
15   return RaceSender!(Senders)(senders, false);
16 }
17 
18 private template Result(Senders...) {
19   import concurrency.utils : NoVoid;
20   import mir.algebraic : Algebraic, Nullable;
21   import std.meta : staticMap, Filter, NoDuplicates;
22   alias getValue(Sender) = Sender.Value;
23   alias SenderValues = staticMap!(getValue, Senders);
24   alias NoVoidValueTypes = Filter!(NoVoid, SenderValues);
25   enum HasVoid = SenderValues.length != NoVoidValueTypes.length;
26   alias ValueTypes = NoDuplicates!(NoVoidValueTypes);
27 
28   static if (ValueTypes.length == 0)
29     alias Result = void;
30   else static if (ValueTypes.length == 1)
31     static if (HasVoid)
32       alias Result = Nullable!(ValueTypes[0]);
33     else
34       alias Result = ValueTypes[0];
35   else {
36     alias Result = Algebraic!(ValueTypes);
37   }
38 }
39 
40 private struct RaceOp(Receiver, Senders...) {
41   import std.meta : staticMap;
42   alias R = Result!(Senders);
43   alias ElementReceiver(Sender) = RaceReceiver!(Receiver, Sender.Value, R);
44   alias ConnectResult(Sender) = OpType!(Sender, ElementReceiver!Sender);
45   alias Ops = staticMap!(ConnectResult, Senders);
46   Receiver receiver;
47   State!R state;
48   Ops ops;
49   @disable this(this);
50   @disable this(ref return scope typeof(this) rhs);
51   this(Receiver receiver, return Senders senders, bool noDropouts) @trusted scope {
52     this.receiver = receiver;
53     state = new State!(R)(noDropouts);
54     foreach(i, Sender; Senders) {
55       ops[i] = senders[i].connect(ElementReceiver!(Sender)(receiver, state, Senders.length));
56     }
57   }
58   void start() @trusted nothrow scope {
59     import concurrency.stoptoken : StopSource;
60     if (receiver.getStopToken().isStopRequested) {
61       receiver.setDone();
62       return;
63     }
64     state.cb = receiver.getStopToken().onStop(cast(void delegate() nothrow @safe shared)&state.stop); // butt ugly cast, but it won't take the second overload
65     foreach(i, _; Senders) {
66       ops[i].start();
67     }
68   }
69 }
70 
71 import std.meta : allSatisfy, ApplyRight;
72 
73 struct RaceSender(Senders...) if (allSatisfy!(ApplyRight!(models, isSender), Senders)) {
74   static assert(models!(typeof(this), isSender));
75   alias Value = Result!(Senders);
76   Senders senders;
77   bool noDropouts; // if true then we fail the moment one contender does, otherwise we keep running until one finishes
78   auto connect(Receiver)(return Receiver receiver) @safe scope return {
79     // ensure NRVO
80     auto op = RaceOp!(Receiver, Senders)(receiver, senders, noDropouts);
81     return op;
82   }
83 }
84 
85 private class State(Value) : StopSource {
86   import concurrency.bitfield;
87   StopCallback cb;
88   shared SharedBitField!Flags bitfield;
89   static if (!is(Value == void))
90     Value value;
91   Exception exception;
92   bool noDropouts;
93   this(bool noDropouts) {
94     this.noDropouts = noDropouts;
95   }
96 }
97 
98 private enum Flags : size_t {
99   locked = 0x1,
100   value_produced = 0x2,
101   doneOrError_produced = 0x4
102 }
103 
104 private enum Counter : size_t {
105   tick = 0x8,
106   mask = ~0x7
107 }
108 
109 private struct RaceReceiver(Receiver, InnerValue, Value) {
110   import core.atomic : atomicOp, atomicLoad, MemoryOrder;
111   Receiver receiver;
112   State!(Value) state;
113   size_t senderCount;
114   auto getStopToken() {
115     return StopToken(state);
116   }
117   private bool isValueProduced(size_t state) {
118     return (state & Flags.value_produced) > 0;
119   }
120   private bool isDoneOrErrorProduced(size_t state) {
121     return (state & Flags.doneOrError_produced) > 0;
122   }
123   private bool isLast(size_t state) {
124     return (state >> 3) == atomicLoad(senderCount);
125   }
126   static if (!is(InnerValue == void))
127     void setValue(InnerValue value) @safe nothrow {
128       with (state.bitfield.lock(Flags.value_produced, Counter.tick)) {
129         bool last = isLast(newState);
130         if (!isValueProduced(oldState)) {
131           static if (is(InnerValue == Value))
132             state.value = value;
133           else
134             state.value = Value(value);
135           release(); // must release before calling .stop
136           state.stop();
137         } else
138           release();
139 
140         if (last)
141           process(newState);
142       }
143     }
144   else
145     void setValue() @safe nothrow {
146       with (state.bitfield.update(Flags.value_produced, Counter.tick)) {
147         bool last = isLast(newState);
148         if (!isValueProduced(oldState)) {
149           state.stop();
150         }
151         if (last)
152           process(newState);
153       }
154     }
155   void setDone() @safe nothrow {
156     with (state.bitfield.update(Flags.doneOrError_produced, Counter.tick)) {
157       bool last = isLast(newState);
158       if (state.noDropouts && !isDoneOrErrorProduced(oldState)) {
159         state.stop();
160       }
161       if (last)
162         process(newState);
163     }
164   }
165   void setError(Exception exception) @safe nothrow {
166     with (state.bitfield.lock(Flags.doneOrError_produced, Counter.tick)) {
167       bool last = isLast(newState);
168       if (!isDoneOrErrorProduced(oldState)) {
169         state.exception = exception;
170         if (state.noDropouts) {
171           release(); // release before stop
172           state.stop();
173         }
174       }
175       release();
176       if (last)
177         process(newState);
178     }
179   }
180   private void process(size_t newState) {
181     import concurrency.receiver : setValueOrError;
182 
183     state.cb.dispose();
184     if (receiver.getStopToken().isStopRequested)
185       receiver.setDone();
186     else if (isValueProduced(newState)) {
187       static if (is(Value == void))
188         receiver.setValueOrError();
189       else
190         receiver.setValueOrError(state.value);
191     } else if (state.exception)
192       receiver.setError(state.exception);
193     else
194       receiver.setDone();
195   }
196   mixin ForwardExtensionPoints!receiver;
197 }