1 module concurrency.scheduler;
2 
3 import concurrency.sender : SenderObjectBase;
4 import core.time : Duration;
5 import concepts;
6 import mir.algebraic : Nullable, nullable;
7 
8 void checkScheduler(T)() {
9   import concurrency.sender : checkSender;
10   import core.time : msecs;
11   T t = T.init;
12   alias Sender = typeof(t.schedule());
13   checkSender!Sender();
14   alias AfterSender = typeof(t.scheduleAfter(10.msecs));
15   checkSender!AfterSender();
16 }
17 enum isScheduler(T) = is(typeof(checkScheduler!T));
18 
19 /// polymorphic Scheduler
20 interface SchedulerObjectBase {
21   SenderObjectBase!void schedule() @safe;
22   SenderObjectBase!void scheduleAfter(Duration d) @safe;
23 }
24 
25 class SchedulerObject(S) : SchedulerObjectBase {
26   import concurrency.sender : toSenderObject;
27   S scheduler;
28   this(S scheduler) {
29     this.scheduler = scheduler;
30   }
31   SenderObjectBase!void schedule() @safe {
32     return scheduler.schedule().toSenderObject();
33   }
34   SenderObjectBase!void scheduleAfter(Duration d) @safe {
35     return scheduler.scheduleAfter(d).toSenderObject();
36   }
37 }
38 
39 SchedulerObjectBase toSchedulerObject(S)(S scheduler) {
40   return new SchedulerObject!(S)(scheduler);
41 }
42 
43 auto localThreadScheduler() {
44   import concurrency.thread : LocalThreadWorker;
45   import std.concurrency : thisTid;
46   return SchedulerAdapter!LocalThreadWorker(LocalThreadWorker(thisTid()));
47 }
48 
49 struct SchedulerAdapter(Worker) {
50   import concurrency.receiver : setValueOrError;
51   import concurrency.executor : VoidDelegate;
52   import core.time : Duration;
53   Worker worker;
54   auto schedule() {
55     static struct ScheduleOp(Receiver) {
56       Worker worker;
57       Receiver receiver;
58       void start() @trusted nothrow {
59         try {
60           worker.schedule(cast(VoidDelegate)()=>receiver.setValueOrError());
61         } catch (Exception e) {
62           receiver.setError(e);
63         }
64       }
65     }
66     static struct ScheduleSender {
67       alias Value = void;
68       Worker worker;
69       auto connect(Receiver)(return Receiver receiver) @safe scope return {
70         // ensure NRVO
71         auto op = ScheduleOp!(Receiver)(worker, receiver);
72         return op;
73       }
74     }
75     return ScheduleSender(worker);
76   }
77   auto scheduleAfter(Duration dur) {
78     return ScheduleAfterSender!(Worker)(worker, dur);
79   }
80 }
81 
82 struct ScheduleAfterOp(Worker, Receiver) {
83   import std.traits : ReturnType;
84   import concurrency.bitfield : SharedBitField;
85   import concurrency.stoptoken : StopCallback, onStop;
86   import concurrency.receiver : setValueOrError;
87 
88   enum Flags {
89     locked = 0x1,
90     terminated = 0x2
91   }
92   alias Timer = ReturnType!(Worker.addTimer);
93   Worker worker;
94   Duration dur;
95   Receiver receiver;
96   Timer timer;
97   StopCallback stopCb;
98   shared SharedBitField!Flags flags;
99   void start() @trusted nothrow {
100     with(flags.lock()) {
101       if (receiver.getStopToken().isStopRequested) {
102         receiver.setDone();
103         return;
104       }
105       stopCb = receiver.getStopToken().onStop(cast(void delegate() nothrow @safe shared)&stop);
106       try {
107         timer = worker.addTimer(() shared nothrow {
108             stopCb.dispose();
109             with(flags.update(Flags.terminated)) {
110               if ((oldState & Flags.terminated) == 0)
111                 receiver.setValueOrError();
112             }
113           }, dur);
114       } catch (Exception e) {
115         receiver.setError(e);
116       }
117     }
118   }
119   private void stop() @trusted nothrow {
120     with(flags.update(Flags.terminated)) {
121       if ((oldState & Flags.terminated) == 0) {
122         try { worker.cancelTimer(timer); } catch (Exception e) {} // TODO: what to do here?
123         receiver.setDone();
124       }
125     }
126   }
127 }
128 
129 struct ScheduleAfterSender(Worker) {
130   alias Value = void;
131   Worker worker;
132   Duration dur;
133   auto connect(Receiver)(return Receiver receiver) @safe return scope {
134     // ensure NRVO
135     auto op = ScheduleAfterOp!(Worker, Receiver)(worker, dur, receiver);
136     return op;
137   }
138 }
139 
140 struct ManualTimeScheduler {
141   shared ManualTimeWorker worker;
142   auto schedule() {
143     import core.time : msecs;
144     return scheduleAfter(0.msecs);
145   }
146   auto scheduleAfter(Duration dur) {
147     return ScheduleAfterSender!(shared ManualTimeWorker)(worker, dur);
148   }
149 }
150 
151 class ManualTimeWorker {
152   import concurrency.timingwheels : TimingWheels;
153   import concurrency.executor : VoidDelegate;
154   import core.sync.mutex : Mutex;
155   import core.time : msecs, hnsecs;
156   private {
157     TimingWheels!Timer wheels;
158     Mutex mutex;
159     size_t time = 1;
160     shared ulong nextTimerId;
161   }
162   static struct Timer {
163     VoidDelegate dg;
164     ulong id_;
165     ulong id() { return id_; }
166   }
167   auto lock() @trusted shared {
168     import concurrency.utils : SharedGuard;
169     return SharedGuard!(ManualTimeWorker).acquire(this, cast()mutex);
170   }
171   this() @trusted shared {
172     mutex = cast(shared)new Mutex();
173     (cast()wheels).init(time);
174   }
175   ManualTimeScheduler getScheduler() @safe shared {
176     return ManualTimeScheduler(this);
177   }
178   Timer addTimer(VoidDelegate dg, Duration dur) @trusted shared {
179     import core.atomic : atomicOp;
180     with(lock()) {
181       auto real_now = time;
182       auto tw_now = wheels.currStdTime(1.msecs);
183       auto delay = (real_now - tw_now).hnsecs;
184       auto at = (dur + delay)/1.msecs;
185       auto timer = Timer(dg, nextTimerId.atomicOp!("+=")(1));
186       wheels.schedule(timer, at);
187       return timer;
188     }
189   }
190   void cancelTimer(Timer timer) @trusted shared {
191     with(lock()) {
192       wheels.cancel(timer);
193     }
194   }
195   Nullable!Duration timeUntilNextEvent() @trusted shared {
196     with(lock()) {
197       return wheels.timeUntilNextEvent(1.msecs, time);
198     }
199   }
200   void advance(Duration dur) @trusted shared {
201     import core.time : msecs;
202     with(lock()) {
203       time += dur.total!"hnsecs";
204       int incr = wheels.ticksToCatchUp(1.msecs, time);
205       if (incr > 0) {
206         auto wr = wheels.advance(incr);
207         foreach(t; wr.timers) {
208           t.dg();
209         }
210       }
211     }
212   }
213 }
214 
215 T withBaseScheduler(T, P)(auto ref T t, auto ref P p) if (isScheduler!T && isScheduler!P) {
216   return t;
217 }