1 module concurrency.operations.retrywhen;
2 
3 import concurrency;
4 import concurrency.operations.via;
5 import concurrency.receiver;
6 import concurrency.sender;
7 import concurrency.stoptoken;
8 import concepts;
9 import std.traits;
10 
11 enum isRetryWhenLogic(T) = models!(typeof(T.init.failure(Exception.init)), isSender);
12 
13 auto retryWhen(Sender, Logic)(Sender sender, Logic logic) if (isRetryWhenLogic!Logic) {
14   return RetryWhenSender!(Sender, Logic)(sender, logic);
15 }
16 
17 private struct TriggerReceiver(Sender, Receiver, Logic) {
18   alias Value = void;
19   private RetryWhenOp!(Sender, Receiver, Logic)* op;
20   void setValue() @safe {
21     op.sourceOp = op.sender.connect(SourceReceiver!(Sender, Receiver, Logic)(op));
22     op.sourceOp.start();
23   }
24   void setDone() @safe nothrow {
25     op.receiver.setDone();
26   }
27   void setError(Throwable t) @safe nothrow {
28     op.receiver.setError(t);
29   }
30   private auto receiver() {
31     return op.receiver;
32   }
33   mixin ForwardExtensionPoints!(receiver);
34 }
35 
36 private struct SourceReceiver(Sender, Receiver, Logic) {
37   alias Value = Sender.Value;
38   private RetryWhenOp!(Sender, Receiver, Logic)* op;
39   static if (is(Value == void)) {
40     void setValue() @safe {
41       op.receiver.setValueOrError();
42     }
43   } else {
44     void setValue(Value value) @safe {
45       op.receiver.setValueOrError(value);
46     }
47   }
48   void setDone() @safe nothrow {
49     op.receiver.setDone();
50   }
51   void setError(Throwable t) @trusted nothrow {
52     if (auto ex = cast(Exception) t) {
53       try {
54         op.triggerOp = op.logic.failure(ex).connect(TriggerReceiver!(Sender, Receiver, Logic)(op));
55         op.triggerOp.start();
56       } catch (Throwable t2) {
57         op.receiver.setError(t2);
58       }
59       return;
60     }
61     op.receiver.setError(t);
62   }
63   private auto receiver() {
64     return op.receiver;
65   }
66   mixin ForwardExtensionPoints!(receiver);
67 }
68 
69 private struct RetryWhenOp(Sender, Receiver, Logic) {
70   import std.traits : ReturnType;
71   alias SourceOp = OpType!(Sender, SourceReceiver!(Sender, Receiver, Logic));
72   alias TriggerOp = OpType!(ReturnType!(Logic.failure), TriggerReceiver!(Sender, Receiver, Logic));
73   Sender sender;
74   Receiver receiver;
75   Logic logic;
76   // TODO: this could probably be a Variant to safe some space
77   SourceOp sourceOp;
78   TriggerOp triggerOp;
79   @disable this(ref return scope typeof(this) rhs);
80   @disable this(this);
81   this(return Sender sender, Receiver receiver, Logic logic) @trusted scope {
82     this.sender = sender;
83     this.receiver = receiver;
84     this.logic = logic;
85     sourceOp = this.sender.connect(SourceReceiver!(Sender, Receiver, Logic)(&this));
86   }
87   void start() @trusted nothrow scope {
88     sourceOp.start();
89   }
90 }
91 
92 struct RetryWhenSender(Sender, Logic) if (isRetryWhenLogic!Logic) {
93   static assert(models!(typeof(this), isSender));
94   alias Value = Sender.Value;
95   Sender sender;
96   Logic logic;
97   auto connect(Receiver)(return Receiver receiver) @safe scope return {
98     // ensure NRVO
99     auto op = RetryWhenOp!(Sender, Receiver, Logic)(sender, receiver, logic);
100     return op;
101   }
102 }