1 module concurrency.scheduler;
2 
3 import concurrency.sender : SenderObjectBase, isSender;
4 import core.time : Duration;
5 import concepts;
6 import mir.algebraic : Nullable, nullable;
7 
8 void checkScheduler(T)() {
9   import concurrency.sender : checkSender;
10   import core.time : msecs;
11   T t = T.init;
12   alias Sender = typeof(t.schedule());
13   checkSender!Sender();
14   alias AfterSender = typeof(t.scheduleAfter(10.msecs));
15   checkSender!AfterSender();
16 }
17 enum isScheduler(T) = is(typeof(checkScheduler!T));
18 
19 /// polymorphic Scheduler
20 interface SchedulerObjectBase {
21   SenderObjectBase!void schedule() @safe;
22   SenderObjectBase!void scheduleAfter(Duration d) @safe;
23 }
24 
25 class SchedulerObject(S) : SchedulerObjectBase {
26   import concurrency.sender : toSenderObject;
27   S scheduler;
28   this(S scheduler) {
29     this.scheduler = scheduler;
30   }
31   SenderObjectBase!void schedule() @safe {
32     return scheduler.schedule().toSenderObject();
33   }
34   SenderObjectBase!void scheduleAfter(Duration d) @safe {
35     return scheduler.scheduleAfter(d).toSenderObject();
36   }
37 }
38 
39 SchedulerObjectBase toSchedulerObject(S)(S scheduler) {
40   return new SchedulerObject!(S)(scheduler);
41 }
42 
43 enum TimerTrigger {
44   trigger,
45   cancel
46 }
47 
48 alias TimerDelegate = void delegate(TimerTrigger) shared @safe;
49 
50 struct Timer {
51   TimerDelegate dg;
52   ulong id_;
53   ulong id() @safe nothrow @nogc { return id_; }
54 }
55 
56 auto localThreadScheduler() {
57   import concurrency.thread : LocalThreadWorker, getLocalThreadExecutor;
58   return SchedulerAdapter!LocalThreadWorker(LocalThreadWorker(getLocalThreadExecutor));
59 }
60 
61 alias LocalThreadScheduler = typeof(localThreadScheduler());
62 
63 struct SchedulerAdapter(Worker) {
64   import concurrency.receiver : setValueOrError;
65   import concurrency.executor : VoidDelegate;
66   import core.time : Duration;
67   Worker worker;
68   auto schedule() {
69     static struct ScheduleOp(Receiver) {
70       Worker worker;
71       Receiver receiver;
72       @disable this(ref return scope typeof(this) rhs);
73       @disable this(this);
74       void start() @trusted nothrow {
75         try {
76           worker.schedule(cast(VoidDelegate)()=>receiver.setValueOrError());
77         } catch (Exception e) {
78           receiver.setError(e);
79         }
80       }
81     }
82     static struct ScheduleSender {
83       alias Value = void;
84       Worker worker;
85       auto connect(Receiver)(return Receiver receiver) @safe scope return {
86         // ensure NRVO
87         auto op = ScheduleOp!(Receiver)(worker, receiver);
88         return op;
89       }
90     }
91     return ScheduleSender(worker);
92   }
93   auto schedule() shared @trusted {
94     return (cast()this).schedule();
95   }
96   auto scheduleAfter(Duration dur) @safe {
97     return ScheduleAfterSender!(Worker)(worker, dur);
98   }
99   auto scheduleAfter(Duration dur) shared @trusted {
100     return (cast()this).scheduleAfter(dur);
101   }
102 }
103 
104 struct ScheduleAfterOp(Worker, Receiver) {
105   import std.traits : ReturnType;
106   import concurrency.bitfield : SharedBitField;
107   import concurrency.stoptoken : StopCallback, onStop;
108   import concurrency.receiver : setValueOrError;
109 
110   enum Flags {
111     locked = 0x0,
112     stop = 0x1,
113     triggered = 0x2,
114     setup = 0x4,
115   }
116   alias Timer = ReturnType!(Worker.addTimer);
117   Worker worker;
118   Duration dur;
119   Receiver receiver;
120   Timer timer;
121   StopCallback stopCb;
122   shared SharedBitField!Flags flags;
123   @disable this(ref return scope typeof(this) rhs);
124   @disable this(this);
125   void start() @trusted scope nothrow {
126     if (receiver.getStopToken().isStopRequested) {
127       receiver.setDone();
128       return;
129     }
130 
131     stopCb = receiver.getStopToken().onStop(cast(void delegate() nothrow @safe shared)&stop);
132 
133     try {
134       timer = worker.addTimer(cast(void delegate(TimerTrigger) @safe shared)&trigger, dur);
135     } catch (Exception e) {
136       receiver.setError(e);
137       return;
138     }
139 
140     with (flags.add(Flags.setup)) {
141       if (has(Flags.stop)) {
142         try { worker.cancelTimer(timer); } catch (Exception e) {} // TODO: what to do here?
143       }
144       if (has(Flags.triggered)) {
145         receiver.setValueOrError();
146       }
147     }
148   }
149   private void trigger(TimerTrigger cause) @trusted nothrow {
150     with (flags.add(Flags.triggered)) {
151       if (!has(Flags.setup))
152         return;
153       stopCb.dispose();
154       final switch (cause) {
155       case TimerTrigger.cancel:
156         receiver.setDone();
157         break;
158       case TimerTrigger.trigger:
159         receiver.setValueOrError();
160         break;
161       }
162     }
163   }
164   private void stop() @trusted nothrow {
165     with (flags.add(Flags.stop)) {
166       if (!has(Flags.setup)) {
167         return;
168       }
169       if (!has(Flags.triggered)) {
170         try { worker.cancelTimer(timer); } catch (Exception e) {} // TODO: what to do here?
171       }
172     }
173   }
174 }
175 
176 struct ScheduleAfterSender(Worker) {
177   alias Value = void;
178   Worker worker;
179   Duration dur;
180   auto connect(Receiver)(return Receiver receiver) @safe return scope {
181     // ensure NRVO
182     auto op = ScheduleAfterOp!(Worker, Receiver)(worker, dur, receiver);
183     return op;
184   }
185 }
186 
187 struct ManualTimeScheduler {
188   shared ManualTimeWorker worker;
189   auto schedule() {
190     import core.time : msecs;
191     return scheduleAfter(0.msecs);
192   }
193   auto scheduleAfter(Duration dur) {
194     return ScheduleAfterSender!(shared ManualTimeWorker)(worker, dur);
195   }
196 }
197 
198 class ManualTimeWorker {
199   import concurrency.timingwheels : TimingWheels;
200   import concurrency.executor : VoidDelegate;
201   import core.sync.mutex : Mutex;
202   import core.time : msecs, hnsecs;
203   import std.array : Appender;
204   private {
205     TimingWheels!Timer wheels;
206     Appender!(Timer[]) expiredTimers;
207     Mutex mutex;
208     size_t time = 1;
209     shared ulong nextTimerId;
210   }
211   auto lock() @trusted shared {
212     import concurrency.utils : SharedGuard;
213     return SharedGuard!(ManualTimeWorker).acquire(this, cast()mutex);
214   }
215   this() @trusted shared {
216     mutex = cast(shared)new Mutex();
217     (cast()wheels).init(time);
218   }
219   ManualTimeScheduler getScheduler() @safe shared {
220     return ManualTimeScheduler(this);
221   }
222   Timer addTimer(TimerDelegate dg, Duration dur) @trusted shared {
223     import core.atomic : atomicOp;
224     with(lock()) {
225       auto real_now = time;
226       auto tw_now = wheels.currStdTime(1.msecs);
227       auto delay = (real_now - tw_now).hnsecs;
228       auto at = (dur + delay)/1.msecs;
229       auto timer = Timer(dg, nextTimerId.atomicOp!("+=")(1));
230       wheels.schedule(timer, at);
231       return timer;
232     }
233   }
234   void cancelTimer(Timer timer) @trusted shared {
235     with(lock()) {
236       wheels.cancel(timer);
237     }
238     timer.dg(TimerTrigger.cancel);
239   }
240   Nullable!Duration timeUntilNextEvent() @trusted shared {
241     with(lock()) {
242       return wheels.timeUntilNextEvent(1.msecs, time);
243     }
244   }
245   void advance(Duration dur) @trusted shared {
246     import std.range : retro;
247     import core.time : msecs;
248     with(lock()) {
249       time += dur.total!"hnsecs";
250       int incr = wheels.ticksToCatchUp(1.msecs, time);
251       if (incr > 0) {
252         wheels.advance(incr, expiredTimers);
253         // NOTE timingwheels keeps the timers in reverse order, so we iterate in reverse
254         foreach(t; expiredTimers.data.retro) {
255           t.dg(TimerTrigger.trigger);
256         }
257         expiredTimers.shrinkTo(0);
258       }
259     }
260   }
261 }
262 
263 auto withBaseScheduler(T, P)(auto ref T t, auto ref P p) {
264   static if (isScheduler!T)
265     return t;
266   else static if (isScheduler!P)
267     return ProxyScheduler!(T, P)(t,p);
268   else
269     static assert(false, "Neither "~T.stringof~" nor "~P.stringof~" are full schedulers. Chain the sender with a .withScheduler and ensure the Scheduler passes the isScheduler check.");
270 }
271 
272 private struct ProxyScheduler(T, P) {
273   import std.parallelism : TaskPool;
274   import core.time : Duration;
275   T front;
276   P back;
277   auto schedule() {
278     return front.schedule();
279   }
280   auto scheduleAfter(Duration run) {
281     import concurrency.operations : via;
282     return schedule().via(back.scheduleAfter(run));
283   }
284 }
285 
286 struct ScheduleAfter {
287   static assert (models!(typeof(this), isSender));
288   alias Value = void;
289   Duration duration;
290   auto connect(Receiver)(return Receiver receiver) @trusted scope return {
291     // ensure NRVO
292     auto op = receiver.getScheduler.scheduleAfter(duration).connect(receiver);
293     return op;
294   }
295 }
296 
297 struct Schedule {
298   static assert (models!(typeof(this), isSender));
299   alias Value = void;
300   auto connect(Receiver)(return Receiver receiver) @trusted scope return {
301     // ensure NRVO
302     auto op = receiver.getScheduler.schedule().connect(receiver);
303     return op;
304   }
305 }