1 module concurrency.operations.withstopsource;
2 
3 import concurrency;
4 import concurrency.receiver;
5 import concurrency.sender;
6 import concurrency.stoptoken;
7 import concepts;
8 import std.traits;
9 
10 template withStopSource(Sender) {
11   auto withStopSource(Sender sender, StopSource stopSource) {
12     return SSSender!(Sender)(sender, stopSource);
13   }
14   auto withStopSource(Sender sender, shared StopSource stopSource) @trusted {
15     return SSSender!(Sender)(sender, cast()stopSource);
16   }
17 }
18 
19 private struct SSReceiver(Receiver, Value) {
20   private {
21     Receiver receiver;
22     StopSource stopSource;
23     StopSource combinedSource;
24     StopCallback[2] cbs;
25   }
26   static if (is(Value == void)) {
27     void setValue() @safe {
28       resetStopCallback();
29       receiver.setValueOrError();
30     }
31   } else {
32     void setValue(Value value) @safe {
33       resetStopCallback();
34       receiver.setValueOrError(value);
35     }
36   }
37   void setDone() @safe nothrow {
38     resetStopCallback();
39     receiver.setDone();
40   }
41   // TODO: would be good if we only emit this function in the Sender actually could call it
42   void setError(Exception e) @safe nothrow {
43     resetStopCallback();
44     receiver.setError(e);
45   }
46   auto getStopToken() nothrow @trusted {
47     import core.atomic;
48     if (this.combinedSource is null) {
49       auto local = new StopSource();
50       auto sharedStopSource = cast(shared)local;
51       StopSource emptyStopSource = null;
52       if (cas(&this.combinedSource, emptyStopSource, local)) {
53         cbs[0] = receiver.getStopToken().onStop(() shared => cast(void)sharedStopSource.stop());
54         cbs[1] = StopToken(stopSource).onStop(() shared => cast(void)sharedStopSource.stop());
55         if (atomicLoad(this.combinedSource) is null) {
56           cbs[0].dispose();
57           cbs[1].dispose();
58         }
59       } else {
60         cbs[0].dispose();
61         cbs[1].dispose();
62       }
63     }
64     return StopToken(combinedSource);
65   }
66   mixin ForwardExtensionPoints!receiver;
67   private void resetStopCallback() {
68     import core.atomic;
69     if (atomicExchange(&this.combinedSource, cast(StopSource)null)) {
70       if (cbs[0]) cbs[0].dispose();
71       if (cbs[1]) cbs[1].dispose();
72     }
73   }
74 }
75 
76 struct SSSender(Sender) if (models!(Sender, isSender)) {
77   static assert(models!(typeof(this), isSender));
78   alias Value = Sender.Value;
79   Sender sender;
80   StopSource stopSource;
81   auto connect(Receiver)(return Receiver receiver) @safe scope return {
82     alias R = SSReceiver!(Receiver, Sender.Value);
83     // ensure NRVO
84     auto op = sender.connect(R(receiver, stopSource));
85     return op;
86   }
87 }