1 module concurrency;
2 
3 import concurrency.stoptoken;
4 import concurrency.sender;
5 import concurrency.thread;
6 import concepts;
7 import mir.algebraic: reflectErr, Variant, Algebraic, assumeOk;
8 
9 bool isMainThread() @trusted {
10   import core.thread : Thread;
11   return Thread.getThis().isMainThread();
12 }
13 
14 package struct SyncWaitReceiver2(Value) {
15   static struct State {
16     LocalThreadWorker worker;
17     bool canceled;
18     static if (!is(Value == void))
19       Value result;
20     Throwable throwable;
21     StopSource stopSource;
22 
23     this(StopSource stopSource) {
24       this.stopSource = stopSource;
25       worker = LocalThreadWorker(getLocalThreadExecutor());
26     }
27   }
28   State* state;
29   void setDone() nothrow @safe {
30     state.canceled = true;
31     state.worker.stop();
32   }
33 
34   void setError(Throwable e) nothrow @safe {
35     state.throwable = e;
36     state.worker.stop();
37   }
38   static if (is(Value == void))
39     void setValue() nothrow @safe {
40       state.worker.stop();
41     }
42   else
43     void setValue(Value value) nothrow @safe {
44       state.result = value;
45       state.worker.stop();
46     }
47   auto getStopToken() nothrow @safe @nogc {
48     return StopToken(state.stopSource);
49   }
50   auto getScheduler() nothrow @safe {
51     import concurrency.scheduler : SchedulerAdapter;
52     return SchedulerAdapter!(LocalThreadWorker*)(&state.worker);
53   }
54 }
55 
56 deprecated("Use syncWait instead")
57 auto sync_wait(Sender, StopSource)(auto ref Sender sender, StopSource stopSource) {
58   alias Value = Sender.Value;
59   auto result = syncWait(sender, (()@trusted=>cast()stopSource)());
60   static if (is(Value == void)) {
61     return result.match!((Cancelled c) => false,
62                          (Exception e) { throw e; },
63                          () => true); // void
64   } else {
65     return result.match!((Cancelled c) { throw new Exception("Cancelled"); },
66                          (Exception e) { throw e; },
67                          "a");
68   }
69 }
70 
71 deprecated("Use syncWait instead")
72 auto sync_wait(Sender)(auto scope ref Sender sender) {
73   alias Value = Sender.Value;
74   auto result = syncWait(sender);
75   static if (is(Value == void)) {
76     return result.match!((Cancelled c) => false,
77                          (Exception e) { throw e; },
78                          () => true); // void
79   } else {
80     return result.match!((Cancelled c) { throw new Exception("Cancelled"); },
81                          (Exception e) { throw e; },
82                          "a");
83   }
84 }
85 
86 @reflectErr enum Cancelled { cancelled }
87 
88 struct Result(T) {
89   alias V = Variant!(Cancelled, Exception, T);
90   V result;
91   this(P)(P p) {
92     result = p;
93   }
94   bool isCancelled() {
95     return result._is!Cancelled;
96   }
97   bool isError() {
98     return result._is!Exception;
99   }
100   bool isOk() {
101     return result.isOk;
102   }
103   auto value() {
104     static if (!is(T == void))
105       alias valueHandler = (T v) => v;
106     else
107       alias valueHandler = (){};
108 
109     import mir.algebraic : match;
110     return result.match!(valueHandler,
111                          function T (Cancelled c) {
112                            throw new Exception("Cancelled");
113                          },
114                          function T (Exception e) {
115                            throw e;
116                          });
117   }
118   auto get(T)() {
119     return result.get!T;
120   }
121   auto assumeOk() {
122     return value();
123   }
124 }
125 
126 /// matches over the result of syncWait
127 template match(Handlers...) {
128   // has to be separate because of dual-context limitation
129   auto match(T)(Result!T r) {
130     import mir.algebraic : match, optionalMatch;
131     return r.result.optionalMatch!(r => r).match!(Handlers);
132   }
133 }
134 
135 void setTopLevelStopSource(shared StopSource stopSource) @trusted {
136   import std.exception : enforce;
137   enforce(parentStopSource is null);
138   parentStopSource = cast()stopSource;
139 }
140 
141 package(concurrency) static StopSource parentStopSource;
142 
143 /// Start the Sender and waits until it completes, cancels, or has an error.
144 auto syncWait(Sender, StopSource)(auto ref Sender sender, StopSource stopSource) {
145   return syncWaitImpl(sender, (()@trusted=>cast()stopSource)());
146 }
147 
148 auto syncWait(Sender)(auto scope ref Sender sender) {
149   import concurrency.signal : globalStopSource;
150   auto childStopSource = new shared StopSource();
151   StopToken parentStopToken = parentStopSource ? StopToken(parentStopSource) : StopToken(globalStopSource);
152 
153   StopCallback cb = parentStopToken.onStop(() shared { childStopSource.stop(); });
154   auto result = syncWaitImpl(sender, (()@trusted=>cast()childStopSource)());
155   // detach stopSource
156   cb.dispose();
157   return result;
158 }
159 
160 private Result!(Sender.Value) syncWaitImpl(Sender)(auto scope ref Sender sender, StopSource stopSource) @safe {
161   import mir.algebraic : Algebraic, Nullable;
162   static assert(models!(Sender, isSender));
163   import concurrency.signal;
164   import core.sys.posix.signal : SIGTERM, SIGINT;
165 
166   alias Value = Sender.Value;
167   alias Receiver = SyncWaitReceiver2!(Value);
168 
169   /// TODO: not fiber safe
170   auto old = parentStopSource;
171   parentStopSource = stopSource;
172 
173   auto state = Receiver.State(stopSource);
174   scope receiver = (()@trusted => Receiver(&state))();
175   auto op = sender.connect(receiver);
176   op.start();
177 
178   state.worker.start();
179 
180   parentStopSource = old;
181 
182   if (state.canceled)
183     return Result!Value(Cancelled());
184 
185   if (state.throwable !is null) {
186     if (auto e = cast(Exception)state.throwable)
187       return Result!Value(e);
188     throw state.throwable;
189   }
190   static if (is(Value == void))
191     return Result!Value();
192   else
193     return Result!Value(state.result);
194 }