1 module concurrency.syncwait;
2 
3 import concurrency.stoptoken;
4 import concurrency.sender;
5 import concurrency.thread;
6 import concepts;
7 import mir.algebraic: reflectErr, Variant, Algebraic, assumeOk;
8 
9 bool isMainThread() @trusted {
10   import core.thread : Thread;
11   return Thread.getThis().isMainThread();
12 }
13 
14 package struct SyncWaitReceiver2(Value) {
15   static struct State {
16     LocalThreadWorker worker;
17     bool canceled;
18     static if (!is(Value == void))
19       Value result;
20     Throwable throwable;
21     StopSource stopSource;
22 
23     this(StopSource stopSource) {
24       this.stopSource = stopSource;
25       worker = LocalThreadWorker(getLocalThreadExecutor());
26     }
27   }
28   State* state;
29   void setDone() nothrow @safe {
30     state.canceled = true;
31     state.worker.stop();
32   }
33 
34   void setError(Throwable e) nothrow @safe {
35     state.throwable = e;
36     state.worker.stop();
37   }
38   static if (is(Value == void))
39     void setValue() nothrow @safe {
40       state.worker.stop();
41     }
42   else
43     void setValue(Value value) nothrow @safe {
44       state.result = value;
45       state.worker.stop();
46     }
47   auto getStopToken() nothrow @safe @nogc {
48     return StopToken(state.stopSource);
49   }
50   auto getScheduler() nothrow @safe {
51     import concurrency.scheduler : SchedulerAdapter;
52     return SchedulerAdapter!(LocalThreadWorker*)(&state.worker);
53   }
54 }
55 
56 @reflectErr enum Cancelled { cancelled }
57 
58 struct Result(T) {
59   alias V = Variant!(Cancelled, Exception, T);
60   V result;
61   this(P)(P p) {
62     result = p;
63   }
64   bool isCancelled() {
65     return result._is!Cancelled;
66   }
67   bool isError() {
68     return result._is!Exception;
69   }
70   bool isOk() {
71     return result.isOk;
72   }
73   auto value() {
74     static if (!is(T == void))
75       alias valueHandler = (T v) => v;
76     else
77       alias valueHandler = (){};
78 
79     import mir.algebraic : match;
80     return result.match!(valueHandler,
81                          function T (Cancelled c) {
82                            throw new Exception("Cancelled");
83                          },
84                          function T (Exception e) {
85                            throw e;
86                          });
87   }
88   auto get(T)() {
89     return result.get!T;
90   }
91   auto assumeOk() {
92     return value();
93   }
94 }
95 
96 /// matches over the result of syncWait
97 template match(Handlers...) {
98   // has to be separate because of dual-context limitation
99   auto match(T)(Result!T r) {
100     import mir.algebraic : match, optionalMatch;
101     return r.result.optionalMatch!(r => r).match!(Handlers);
102   }
103 }
104 
105 void setTopLevelStopSource(shared StopSource stopSource) @trusted {
106   import std.exception : enforce;
107   enforce(parentStopSource is null);
108   parentStopSource = cast()stopSource;
109 }
110 
111 package(concurrency) static StopSource parentStopSource;
112 
113 /// Start the Sender and waits until it completes, cancels, or has an error.
114 auto syncWait(Sender, StopSource)(auto ref Sender sender, StopSource stopSource) {
115   return syncWaitImpl(sender, (()@trusted=>cast()stopSource)());
116 }
117 
118 auto syncWait(Sender)(auto scope ref Sender sender) {
119   import concurrency.signal : globalStopSource;
120   auto childStopSource = new shared StopSource();
121   StopToken parentStopToken = parentStopSource ? StopToken(parentStopSource) : StopToken(globalStopSource);
122 
123   StopCallback cb = parentStopToken.onStop(() shared { childStopSource.stop(); });
124   auto result = syncWaitImpl(sender, (()@trusted=>cast()childStopSource)());
125   // detach stopSource
126   cb.dispose();
127   return result;
128 }
129 
130 private Result!(Sender.Value) syncWaitImpl(Sender)(auto scope ref Sender sender, StopSource stopSource) @safe {
131   import mir.algebraic : Algebraic, Nullable;
132   static assert(models!(Sender, isSender));
133   import concurrency.signal;
134   import core.sys.posix.signal : SIGTERM, SIGINT;
135 
136   alias Value = Sender.Value;
137   alias Receiver = SyncWaitReceiver2!(Value);
138 
139   /// TODO: not fiber safe
140   auto old = parentStopSource;
141   parentStopSource = stopSource;
142 
143   auto state = Receiver.State(stopSource);
144   scope receiver = (()@trusted => Receiver(&state))();
145   auto op = sender.connect(receiver);
146   op.start();
147 
148   state.worker.start();
149 
150   parentStopSource = old;
151 
152   if (state.canceled)
153     return Result!Value(Cancelled());
154 
155   if (state.throwable !is null) {
156     if (auto e = cast(Exception)state.throwable)
157       return Result!Value(e);
158     throw state.throwable;
159   }
160   static if (is(Value == void))
161     return Result!Value();
162   else
163     return Result!Value(state.result);
164 }