1 module concurrency.scheduler;
2 
3 import concurrency.sender : SenderObjectBase;
4 import core.time : Duration;
5 import concepts;
6 import mir.algebraic : Nullable, nullable;
7 
8 void checkScheduler(T)() {
9   import concurrency.sender : checkSender;
10   import core.time : msecs;
11   T t = T.init;
12   alias Sender = typeof(t.schedule());
13   checkSender!Sender();
14   alias AfterSender = typeof(t.scheduleAfter(10.msecs));
15   checkSender!AfterSender();
16 }
17 enum isScheduler(T) = is(typeof(checkScheduler!T));
18 
19 /// polymorphic Scheduler
20 interface SchedulerObjectBase {
21   SenderObjectBase!void schedule() @safe;
22   SenderObjectBase!void scheduleAfter(Duration d) @safe;
23 }
24 
25 class SchedulerObject(S) : SchedulerObjectBase {
26   import concurrency.sender : toSenderObject;
27   S scheduler;
28   this(S scheduler) {
29     this.scheduler = scheduler;
30   }
31   SenderObjectBase!void schedule() @safe {
32     return scheduler.schedule().toSenderObject();
33   }
34   SenderObjectBase!void scheduleAfter(Duration d) @safe {
35     return scheduler.scheduleAfter(d).toSenderObject();
36   }
37 }
38 
39 SchedulerObjectBase toSchedulerObject(S)(S scheduler) {
40   return new SchedulerObject!(S)(scheduler);
41 }
42 
43 enum TimerTrigger {
44   trigger,
45   cancel
46 }
47 
48 alias TimerDelegate = void delegate(TimerTrigger) shared @safe;
49 
50 struct Timer {
51   TimerDelegate dg;
52   ulong id_;
53   ulong id() { return id_; }
54 }
55 
56 auto localThreadScheduler() {
57   import concurrency.thread : LocalThreadWorker, getLocalThreadExecutor;
58   return SchedulerAdapter!LocalThreadWorker(LocalThreadWorker(getLocalThreadExecutor));
59 }
60 
61 struct SchedulerAdapter(Worker) {
62   import concurrency.receiver : setValueOrError;
63   import concurrency.executor : VoidDelegate;
64   import core.time : Duration;
65   Worker worker;
66   auto schedule() {
67     static struct ScheduleOp(Receiver) {
68       Worker worker;
69       Receiver receiver;
70       void start() @trusted nothrow {
71         try {
72           worker.schedule(cast(VoidDelegate)()=>receiver.setValueOrError());
73         } catch (Exception e) {
74           receiver.setError(e);
75         }
76       }
77     }
78     static struct ScheduleSender {
79       alias Value = void;
80       Worker worker;
81       auto connect(Receiver)(return Receiver receiver) @safe scope return {
82         // ensure NRVO
83         auto op = ScheduleOp!(Receiver)(worker, receiver);
84         return op;
85       }
86     }
87     return ScheduleSender(worker);
88   }
89   auto scheduleAfter(Duration dur) {
90     return ScheduleAfterSender!(Worker)(worker, dur);
91   }
92 }
93 
94 struct ScheduleAfterOp(Worker, Receiver) {
95   import std.traits : ReturnType;
96   import concurrency.bitfield : SharedBitField;
97   import concurrency.stoptoken : StopCallback, onStop;
98   import concurrency.receiver : setValueOrError;
99 
100   enum Flags {
101     locked = 0x1,
102     terminated = 0x2
103   }
104   alias Timer = ReturnType!(Worker.addTimer);
105   Worker worker;
106   Duration dur;
107   Receiver receiver;
108   Timer timer;
109   StopCallback stopCb;
110   shared SharedBitField!Flags flags;
111   void start() @trusted nothrow {
112     with(flags.lock()) {
113       if (receiver.getStopToken().isStopRequested) {
114         receiver.setDone();
115         return;
116       }
117       stopCb = receiver.getStopToken().onStop(cast(void delegate() nothrow @safe shared)&stop);
118       try {
119         timer = worker.addTimer(cast(void delegate(TimerTrigger) @safe shared)&trigger, dur);
120       } catch (Exception e) {
121         receiver.setError(e);
122       }
123     }
124   }
125   private void trigger(TimerTrigger cause) @trusted nothrow {
126     stopCb.dispose();
127     final switch (cause) {
128     case TimerTrigger.cancel:
129       receiver.setDone();
130       break;
131     case TimerTrigger.trigger:
132       with(flags.update(Flags.terminated)) {
133         if ((oldState & Flags.terminated) == 0)
134           receiver.setValueOrError();
135       }
136       break;
137     }
138   }
139   private void stop() @trusted nothrow {
140     with(flags.update(Flags.terminated)) {
141       if ((oldState & Flags.terminated) == 0) {
142         try { worker.cancelTimer(timer); } catch (Exception e) {} // TODO: what to do here?
143       }
144     }
145   }
146 }
147 
148 struct ScheduleAfterSender(Worker) {
149   alias Value = void;
150   Worker worker;
151   Duration dur;
152   auto connect(Receiver)(return Receiver receiver) @safe return scope {
153     // ensure NRVO
154     auto op = ScheduleAfterOp!(Worker, Receiver)(worker, dur, receiver);
155     return op;
156   }
157 }
158 
159 struct ManualTimeScheduler {
160   shared ManualTimeWorker worker;
161   auto schedule() {
162     import core.time : msecs;
163     return scheduleAfter(0.msecs);
164   }
165   auto scheduleAfter(Duration dur) {
166     return ScheduleAfterSender!(shared ManualTimeWorker)(worker, dur);
167   }
168 }
169 
170 class ManualTimeWorker {
171   import concurrency.timingwheels : TimingWheels;
172   import concurrency.executor : VoidDelegate;
173   import core.sync.mutex : Mutex;
174   import core.time : msecs, hnsecs;
175   import std.array : Appender;
176   private {
177     TimingWheels!Timer wheels;
178     Appender!(Timer[]) expiredTimers;
179     Mutex mutex;
180     size_t time = 1;
181     shared ulong nextTimerId;
182   }
183   auto lock() @trusted shared {
184     import concurrency.utils : SharedGuard;
185     return SharedGuard!(ManualTimeWorker).acquire(this, cast()mutex);
186   }
187   this() @trusted shared {
188     mutex = cast(shared)new Mutex();
189     (cast()wheels).init(time);
190   }
191   ManualTimeScheduler getScheduler() @safe shared {
192     return ManualTimeScheduler(this);
193   }
194   Timer addTimer(TimerDelegate dg, Duration dur) @trusted shared {
195     import core.atomic : atomicOp;
196     with(lock()) {
197       auto real_now = time;
198       auto tw_now = wheels.currStdTime(1.msecs);
199       auto delay = (real_now - tw_now).hnsecs;
200       auto at = (dur + delay)/1.msecs;
201       auto timer = Timer(dg, nextTimerId.atomicOp!("+=")(1));
202       wheels.schedule(timer, at);
203       return timer;
204     }
205   }
206   void cancelTimer(Timer timer) @trusted shared {
207     with(lock()) {
208       wheels.cancel(timer);
209     }
210     timer.dg(TimerTrigger.cancel);
211   }
212   Nullable!Duration timeUntilNextEvent() @trusted shared {
213     with(lock()) {
214       return wheels.timeUntilNextEvent(1.msecs, time);
215     }
216   }
217   void advance(Duration dur) @trusted shared {
218     import std.range : retro;
219     import core.time : msecs;
220     with(lock()) {
221       time += dur.total!"hnsecs";
222       int incr = wheels.ticksToCatchUp(1.msecs, time);
223       if (incr > 0) {
224         wheels.advance(incr, expiredTimers);
225         // NOTE timingwheels keeps the timers in reverse order, so we iterate in reverse
226         foreach(t; expiredTimers.data.retro) {
227           t.dg(TimerTrigger.trigger);
228         }
229         expiredTimers.shrinkTo(0);
230       }
231     }
232   }
233 }
234 
235 T withBaseScheduler(T, P)(auto ref T t, auto ref P p) if (isScheduler!T && isScheduler!P) {
236   return t;
237 }