1 module concurrency.operations.then;
2 
3 import concurrency;
4 import concurrency.receiver;
5 import concurrency.sender;
6 import concurrency.stoptoken;
7 import concepts;
8 import std.traits;
9 
10 auto then(Sender, Fun)(Sender sender, Fun fun) {
11   static assert (hasFunctionAttributes!(Fun, "shared"), "Function must be shared");
12 
13   return ThenSender!(Sender, Fun)(sender, fun);
14 }
15 
16 private struct ThenReceiver(Receiver, Value, Fun) {
17   Receiver receiver;
18   Fun fun;
19   static if (is(Value == void)) {
20     void setValue() @safe {
21       static if (is(ReturnType!Fun == void)) {
22         fun();
23         receiver.setValue();
24       } else
25         receiver.setValue(fun());
26     }
27   } else {
28     void setValue(Value value) @safe {
29       static if (is(ReturnType!Fun == void)) {
30         fun(value);
31         receiver.setValue();
32       } else
33         receiver.setValue(fun(value));
34     }
35   }
36   void setDone() @safe nothrow {
37     receiver.setDone();
38   }
39   void setError(Exception e) @safe nothrow {
40     receiver.setError(e);
41   }
42   mixin ForwardExtensionPoints!receiver;
43 }
44 
45 struct ThenSender(Sender, Fun) if (models!(Sender, isSender)) {
46   import std.traits : ReturnType;
47   static assert(models!(typeof(this), isSender));
48   alias Value = ReturnType!fun;
49   Sender sender;
50   Fun fun;
51   auto connect(Receiver)(return Receiver receiver) @safe scope return {
52     alias R = ThenReceiver!(Receiver, Sender.Value, Fun);
53     // ensure NRVO
54     auto op = sender.connect(R(receiver, fun));
55     return op;
56   }
57 }