1 module concurrency.operations.withstopsource;
2 
3 import concurrency;
4 import concurrency.receiver;
5 import concurrency.sender;
6 import concurrency.stoptoken;
7 import concepts;
8 import std.traits;
9 
10 template withStopSource(Sender) {
11   auto withStopSource(Sender sender, StopSource stopSource) {
12     return SSSender!(Sender)(sender, stopSource);
13   }
14   auto withStopSource(Sender sender, shared StopSource stopSource) @trusted {
15     return SSSender!(Sender)(sender, cast()stopSource);
16   }
17 }
18 
19 private struct SSReceiver(Receiver, Value) {
20   private {
21     Receiver receiver;
22     StopSource stopSource;
23     StopSource combinedSource;
24     StopCallback[2] cbs;
25   }
26   static if (is(Value == void)) {
27     void setValue() @safe {
28       resetStopCallback();
29       receiver.setValueOrError();
30     }
31   } else {
32     void setValue(Value value) @safe {
33       resetStopCallback();
34       receiver.setValueOrError(value);
35     }
36   }
37   void setDone() @safe nothrow {
38     resetStopCallback();
39     receiver.setDone();
40   }
41   // TODO: would be good if we only emit this function in the Sender actually could call it
42   void setError(Throwable e) @safe nothrow {
43     resetStopCallback();
44     receiver.setError(e);
45   }
46   auto getStopToken() nothrow @trusted {
47     import core.atomic;
48     if (this.combinedSource is null) {
49       auto local = new StopSource();
50       if (cas(&this.combinedSource, cast(StopSource)null, local)) {
51         auto stop = cast(void delegate() shared nothrow @safe)&local.stop;
52         cbs[0] = receiver.getStopToken().onStop(stop);
53         cbs[1] = StopToken(stopSource).onStop(stop);
54         if (atomicLoad(this.combinedSource) is null) {
55           cbs[0].dispose();
56           cbs[1].dispose();
57         }
58       } else {
59         cbs[0].dispose();
60         cbs[1].dispose();
61       }
62     }
63     return StopToken(combinedSource);
64   }
65   mixin ForwardExtensionPoints!receiver;
66   private void resetStopCallback() {
67     import core.atomic;
68     if (atomicExchange(&this.combinedSource, cast(StopSource)null)) {
69       if (cbs[0]) cbs[0].dispose();
70       if (cbs[1]) cbs[1].dispose();
71     }
72   }
73 }
74 
75 struct SSSender(Sender) if (models!(Sender, isSender)) {
76   static assert(models!(typeof(this), isSender));
77   alias Value = Sender.Value;
78   Sender sender;
79   StopSource stopSource;
80   auto connect(Receiver)(return Receiver receiver) @safe scope return {
81     alias R = SSReceiver!(Receiver, Sender.Value);
82     // ensure NRVO
83     auto op = sender.connect(R(receiver, stopSource));
84     return op;
85   }
86 }