1 module concurrency.scheduler;
2 
3 import concurrency.sender : SenderObjectBase, isSender;
4 import core.time : Duration;
5 import concepts;
6 import mir.algebraic : Nullable, nullable;
7 
8 void checkScheduler(T)() {
9   import concurrency.sender : checkSender;
10   import core.time : msecs;
11   T t = T.init;
12   alias Sender = typeof(t.schedule());
13   checkSender!Sender();
14   alias AfterSender = typeof(t.scheduleAfter(10.msecs));
15   checkSender!AfterSender();
16 }
17 enum isScheduler(T) = is(typeof(checkScheduler!T));
18 
19 /// polymorphic Scheduler
20 interface SchedulerObjectBase {
21   SenderObjectBase!void schedule() @safe;
22   SenderObjectBase!void scheduleAfter(Duration d) @safe;
23 }
24 
25 class SchedulerObject(S) : SchedulerObjectBase {
26   import concurrency.sender : toSenderObject;
27   S scheduler;
28   this(S scheduler) {
29     this.scheduler = scheduler;
30   }
31   SenderObjectBase!void schedule() @safe {
32     return scheduler.schedule().toSenderObject();
33   }
34   SenderObjectBase!void scheduleAfter(Duration d) @safe {
35     return scheduler.scheduleAfter(d).toSenderObject();
36   }
37 }
38 
39 SchedulerObjectBase toSchedulerObject(S)(S scheduler) {
40   return new SchedulerObject!(S)(scheduler);
41 }
42 
43 struct NullScheduler {}
44 
45 enum TimerTrigger {
46   trigger,
47   cancel
48 }
49 
50 alias TimerDelegate = void delegate(TimerTrigger) shared @safe;
51 
52 struct Timer {
53   TimerDelegate dg;
54   ulong id_;
55   ulong id() @safe nothrow @nogc { return id_; }
56 }
57 
58 auto localThreadScheduler() {
59   import concurrency.thread : LocalThreadWorker, getLocalThreadExecutor;
60   return SchedulerAdapter!LocalThreadWorker(LocalThreadWorker(getLocalThreadExecutor));
61 }
62 
63 alias LocalThreadScheduler = typeof(localThreadScheduler());
64 
65 struct SchedulerAdapter(Worker) {
66   import concurrency.receiver : setValueOrError;
67   import concurrency.executor : VoidDelegate;
68   import core.time : Duration;
69   Worker worker;
70   auto schedule() {
71     static struct ScheduleOp(Receiver) {
72       Worker worker;
73       Receiver receiver;
74       @disable this(ref return scope typeof(this) rhs);
75       @disable this(this);
76       void start() @trusted nothrow {
77         try {
78           worker.schedule(cast(VoidDelegate)()=>receiver.setValueOrError());
79         } catch (Exception e) {
80           receiver.setError(e);
81         }
82       }
83     }
84     static struct ScheduleSender {
85       alias Value = void;
86       Worker worker;
87       auto connect(Receiver)(return Receiver receiver) @safe scope return {
88         // ensure NRVO
89         auto op = ScheduleOp!(Receiver)(worker, receiver);
90         return op;
91       }
92     }
93     return ScheduleSender(worker);
94   }
95   auto schedule() shared @trusted {
96     return (cast()this).schedule();
97   }
98   auto scheduleAfter(Duration dur) @safe {
99     return ScheduleAfterSender!(Worker)(worker, dur);
100   }
101   auto scheduleAfter(Duration dur) shared @trusted {
102     return (cast()this).scheduleAfter(dur);
103   }
104 }
105 
106 struct ScheduleAfterOp(Worker, Receiver) {
107   import std.traits : ReturnType;
108   import concurrency.bitfield : SharedBitField;
109   import concurrency.stoptoken : StopCallback, onStop;
110   import concurrency.receiver : setValueOrError;
111 
112   enum Flags {
113     locked = 0x0,
114     stop = 0x1,
115     triggered = 0x2,
116     setup = 0x4,
117   }
118   alias Timer = ReturnType!(Worker.addTimer);
119   Worker worker;
120   Duration dur;
121   Receiver receiver;
122   Timer timer;
123   StopCallback stopCb;
124   shared SharedBitField!Flags flags;
125   @disable this(ref return scope typeof(this) rhs);
126   @disable this(this);
127   void start() @trusted scope nothrow {
128     if (receiver.getStopToken().isStopRequested) {
129       receiver.setDone();
130       return;
131     }
132 
133     stopCb = receiver.getStopToken().onStop(cast(void delegate() nothrow @safe shared)&stop);
134 
135     try {
136       timer = worker.addTimer(cast(void delegate(TimerTrigger) @safe shared)&trigger, dur);
137     } catch (Exception e) {
138       receiver.setError(e);
139       return;
140     }
141 
142     with (flags.add(Flags.setup)) {
143       if (has(Flags.stop)) {
144         try { worker.cancelTimer(timer); } catch (Exception e) {} // TODO: what to do here?
145       }
146       if (has(Flags.triggered)) {
147         receiver.setValueOrError();
148       }
149     }
150   }
151   private void trigger(TimerTrigger cause) @trusted nothrow {
152     with (flags.add(Flags.triggered)) {
153       if (!has(Flags.setup))
154         return;
155       stopCb.dispose();
156       final switch (cause) {
157       case TimerTrigger.cancel:
158         receiver.setDone();
159         break;
160       case TimerTrigger.trigger:
161         receiver.setValueOrError();
162         break;
163       }
164     }
165   }
166   private void stop() @trusted nothrow {
167     with (flags.add(Flags.stop)) {
168       if (!has(Flags.setup)) {
169         return;
170       }
171       if (!has(Flags.triggered)) {
172         try { worker.cancelTimer(timer); } catch (Exception e) {} // TODO: what to do here?
173       }
174     }
175   }
176 }
177 
178 struct ScheduleAfterSender(Worker) {
179   alias Value = void;
180   Worker worker;
181   Duration dur;
182   auto connect(Receiver)(return Receiver receiver) @safe return scope {
183     // ensure NRVO
184     auto op = ScheduleAfterOp!(Worker, Receiver)(worker, dur, receiver);
185     return op;
186   }
187 }
188 
189 struct ManualTimeScheduler {
190   shared ManualTimeWorker worker;
191   auto schedule() {
192     import core.time : msecs;
193     return scheduleAfter(0.msecs);
194   }
195   auto scheduleAfter(Duration dur) {
196     return ScheduleAfterSender!(shared ManualTimeWorker)(worker, dur);
197   }
198 }
199 
200 class ManualTimeWorker {
201   import concurrency.timingwheels : TimingWheels;
202   import concurrency.executor : VoidDelegate;
203   import core.sync.mutex : Mutex;
204   import core.sync.condition : Condition;
205   import core.time : msecs, hnsecs;
206   import std.array : Appender;
207   private {
208     TimingWheels!Timer wheels;
209     Appender!(Timer[]) expiredTimers;
210     Condition condition;
211     size_t time = 1;
212     shared ulong nextTimerId;
213   }
214   auto lock() @trusted shared {
215     import concurrency.utils : SharedGuard;
216     return SharedGuard!(ManualTimeWorker).acquire(this, cast()condition.mutex);
217   }
218   this() @trusted shared {
219     condition = cast(shared)new Condition(new Mutex());
220     (cast()wheels).init(time);
221   }
222   ManualTimeScheduler getScheduler() @safe shared {
223     return ManualTimeScheduler(this);
224   }
225   Timer addTimer(TimerDelegate dg, Duration dur) @trusted shared {
226     import core.atomic : atomicOp;
227     with(lock()) {
228       auto real_now = time;
229       auto tw_now = wheels.currStdTime(1.msecs);
230       auto delay = (real_now - tw_now).hnsecs;
231       auto at = (dur + delay)/1.msecs;
232       auto timer = Timer(dg, nextTimerId.atomicOp!("+=")(1));
233       wheels.schedule(timer, at);
234       condition.notifyAll();
235       return timer;
236     }
237   }
238   void wait() @trusted shared {
239     with(lock()) {
240       condition.wait();
241     }
242   }
243   void cancelTimer(Timer timer) @trusted shared {
244     with(lock()) {
245       wheels.cancel(timer);
246     }
247     timer.dg(TimerTrigger.cancel);
248   }
249   Nullable!Duration timeUntilNextEvent() @trusted shared {
250     with(lock()) {
251       return wheels.timeUntilNextEvent(1.msecs, time);
252     }
253   }
254   void advance(Duration dur) @trusted shared {
255     import std.range : retro;
256     import core.time : msecs;
257     with(lock()) {
258       time += dur.total!"hnsecs";
259       int incr = wheels.ticksToCatchUp(1.msecs, time);
260       if (incr > 0) {
261         wheels.advance(incr, expiredTimers);
262         // NOTE timingwheels keeps the timers in reverse order, so we iterate in reverse
263         foreach(t; expiredTimers.data.retro) {
264           t.dg(TimerTrigger.trigger);
265         }
266         expiredTimers.shrinkTo(0);
267       }
268     }
269   }
270 }
271 
272 auto withBaseScheduler(T, P)(auto ref T t, auto ref P p) {
273   static if (isScheduler!T)
274     return t;
275   else static if (isScheduler!P)
276     return ProxyScheduler!(T, P)(t,p);
277   else
278     static assert(false, "Neither "~T.stringof~" nor "~P.stringof~" are full schedulers. Chain the sender with a .withScheduler and ensure the Scheduler passes the isScheduler check.");
279 }
280 
281 private struct ProxyScheduler(T, P) {
282   import std.parallelism : TaskPool;
283   import core.time : Duration;
284   T front;
285   P back;
286   auto schedule() {
287     return front.schedule();
288   }
289   auto scheduleAfter(Duration run) {
290     import concurrency.operations : via;
291     return schedule().via(back.scheduleAfter(run));
292   }
293 }
294 
295 struct ScheduleAfter {
296   static assert (models!(typeof(this), isSender));
297   alias Value = void;
298   Duration duration;
299   auto connect(Receiver)(return Receiver receiver) @trusted scope return {
300     // ensure NRVO
301     auto op = receiver.getScheduler.scheduleAfter(duration).connect(receiver);
302     return op;
303   }
304 }
305 
306 struct Schedule {
307   static assert (models!(typeof(this), isSender));
308   alias Value = void;
309   auto connect(Receiver)(return Receiver receiver) @trusted scope return {
310     // ensure NRVO
311     auto op = receiver.getScheduler.schedule().connect(receiver);
312     return op;
313   }
314 }