1 module concurrency;
2 
3 import concurrency.stoptoken;
4 import concurrency.sender;
5 import concurrency.thread;
6 import concepts;
7 
8 bool isMainThread() @trusted {
9   import core.thread : Thread;
10   return Thread.getThis().isMainThread();
11 }
12 
13 package struct SyncWaitReceiver2(Value) {
14   static struct State {
15     LocalThreadWorker worker;
16     bool canceled;
17     static if (!is(Value == void))
18       Value result;
19     Throwable throwable;
20     StopSource stopSource;
21 
22     this(StopSource stopSource) {
23       this.stopSource = stopSource;
24       worker = LocalThreadWorker(getLocalThreadExecutor());
25     }
26   }
27   State* state;
28   void setDone() nothrow @safe {
29     state.canceled = true;
30     state.worker.stop();
31   }
32 
33   void setError(Throwable e) nothrow @safe {
34     state.throwable = e;
35     state.worker.stop();
36   }
37   static if (is(Value == void))
38     void setValue() nothrow @safe {
39       state.worker.stop();
40     }
41   else
42     void setValue(Value value) nothrow @safe {
43       state.result = value;
44       state.worker.stop();
45     }
46   auto getStopToken() nothrow @safe @nogc {
47     return StopToken(state.stopSource);
48   }
49   auto getScheduler() nothrow @safe {
50     import concurrency.scheduler : SchedulerAdapter;
51     return SchedulerAdapter!(LocalThreadWorker*)(&state.worker);
52   }
53 }
54 
55 deprecated("Use syncWait instead")
56 auto sync_wait(Sender, StopSource)(auto ref Sender sender, StopSource stopSource) {
57   alias Value = Sender.Value;
58   auto result = syncWait(sender, (()@trusted=>cast()stopSource)());
59   static if (is(Value == void)) {
60     return result.match!((Cancelled c) => false,
61                          (Exception e) { throw e; },
62                          (typeof(null)) => true);
63   } else {
64     return result.match!((Cancelled c) { throw new Exception("Cancelled"); },
65                          (Exception e) { throw e; },
66                          (ref t) => t);
67   }
68 }
69 
70 deprecated("Use syncWait instead")
71 auto sync_wait(Sender)(auto scope ref Sender sender) {
72   alias Value = Sender.Value;
73   auto result = syncWait(sender);
74   static if (is(Value == void)) {
75     return result.match!((Cancelled c) => false,
76                          (Exception e) { throw e; },
77                          (typeof(null)) => true);
78   } else {
79     return result.match!((Cancelled c) { throw new Exception("Cancelled"); },
80                          (Exception e) { throw e; },
81                          (ref t) => t);
82   }
83 }
84 
85 struct Cancelled {}
86 static immutable cancelledException = new Exception("Cancelled");
87 
88 struct Result(T) {
89   import mir.algebraic : Algebraic, Nullable;
90   static struct Value(T) {
91     static if (!is(T == void))
92       T value;
93   }
94   static if (is(T == void))
95     Nullable!(Cancelled, Exception) result;
96   else
97     Algebraic!(Value!T, Cancelled, Exception) result;
98 
99   static if (!is(T == void))
100     this(T v) {
101       result = Value!T(v);
102     }
103   this(Cancelled c) {
104     result = c;
105   }
106   this(Exception e) {
107     result = e;
108   }
109 
110   bool isCancelled() {
111     return result._is!Cancelled;
112   }
113   bool isError() {
114     return result._is!Exception;
115   }
116   bool isOk() {
117     static if (is(T == void))
118       return result.isNull;
119     else
120       return result._is!(Value!T);
121   }
122   static if (!is(T == void))
123     T value() {
124       if (isCancelled)
125         throw cancelledException;
126       if (isError)
127         throw error;
128       return result.get!(Value!T).value;
129     }
130   Exception error() {
131     return result.get!(Exception);
132   }
133   void assumeOk() {
134     import mir.algebraic : match;
135     static if (is(T == void))
136       result.match!((typeof(null)){},(Exception e){throw e;},(Cancelled c){throw cancelledException;});
137     else
138       result.match!((Exception e){throw e;},(Cancelled c){throw cancelledException;},(ref t){});
139   }
140 }
141 
142 /// matches over the result of syncWait
143 template match(Handlers...) {
144   // has to be separate because of dual-context limitation
145   auto match(T)(Result!T r) {
146     import mir.algebraic : match;
147     static if (is(T == void))
148       return r.result.match!(Handlers);
149     else
150       return r.result.match!((Result!(T).Value!(T) v) => v.value, (ref t) => t).match!(Handlers);
151   }
152 }
153 
154 void setTopLevelStopSource(shared StopSource stopSource) @trusted {
155   import std.exception : enforce;
156   enforce(parentStopSource is null);
157   parentStopSource = cast()stopSource;
158 }
159 
160 package(concurrency) static StopSource parentStopSource;
161 
162 /// Start the Sender and waits until it completes, cancels, or has an error.
163 auto syncWait(Sender, StopSource)(auto ref Sender sender, StopSource stopSource) {
164   return syncWaitImpl(sender, (()@trusted=>cast()stopSource)());
165 }
166 
167 auto syncWait(Sender)(auto scope ref Sender sender) {
168   import concurrency.signal : globalStopSource;
169   auto childStopSource = new shared StopSource();
170   StopToken parentStopToken = parentStopSource ? StopToken(parentStopSource) : StopToken(globalStopSource);
171 
172   StopCallback cb = parentStopToken.onStop(() shared { childStopSource.stop(); });
173   auto result = syncWaitImpl(sender, (()@trusted=>cast()childStopSource)());
174   // detach stopSource
175   cb.dispose();
176   return result;
177 }
178 
179 private Result!(Sender.Value) syncWaitImpl(Sender)(auto scope ref Sender sender, StopSource stopSource) @safe {
180   import mir.algebraic : Algebraic, Nullable;
181   static assert(models!(Sender, isSender));
182   import concurrency.signal;
183   import core.sys.posix.signal : SIGTERM, SIGINT;
184 
185   alias Value = Sender.Value;
186   alias Receiver = SyncWaitReceiver2!(Value);
187 
188   /// TODO: not fiber safe
189   auto old = parentStopSource;
190   parentStopSource = stopSource;
191 
192   auto state = Receiver.State(stopSource);
193   Receiver receiver = (()@trusted => Receiver(&state))();
194   auto op = sender.connect(receiver);
195   op.start();
196 
197   state.worker.start();
198 
199   parentStopSource = old;
200 
201   if (state.canceled)
202     return Result!Value(Cancelled());
203 
204   if (auto t = cast(Error)state.throwable)
205     throw t;
206 
207   if (state.throwable !is null) {
208     if (auto e = cast(Exception)state.throwable)
209       return Result!Value(e);
210     assert(false, "Can only rethrow Exceptions");
211   }
212   static if (is(Value == void))
213     return Result!Value();
214   else
215     return Result!Value(state.result);
216 }