1 module concurrency.operations.toshared; 2 3 import concurrency; 4 import concurrency.receiver; 5 import concurrency.sender; 6 import concurrency.stoptoken; 7 import concepts; 8 import std.traits; 9 import mir.algebraic : Algebraic, Nullable, match; 10 11 /// Wraps a Sender in a SharedSender. A SharedSender allows many receivers to connect to the same underlying Sender, forwarding the same termination call to each receiver. 12 /// The underlying Sender is connected and started only once. It can be explicitely `reset` so that it connects and starts the underlying Sender the next time it started. Calling `reset` when the underlying Sender hasn't completed is a no-op. 13 /// When the last receiver triggers its stoptoken while the underlying Sender is still running, the latter will be cancelled and one termination function of the former will be called after the latter is completed. (This is to ensure structured concurrency, otherwise tasks could be left running without anyone awaiting them). 14 /// If an receiver is connected after the underlying Sender has already been completed, that receiver will have one of its termination functions called immediately. 15 /// This operation is useful when you have multiple tasks that all depend on one shared task. It allows you to write the shared task as a regular Sender and simply apply a `.toShared`. 16 auto toShared(Sender, Scheduler)(Sender sender, Scheduler scheduler) { 17 return new SharedSender!(Sender, Scheduler)(sender, scheduler); 18 } 19 20 auto toShared(Sender)(Sender sender) { 21 static struct NullScheduler {} 22 return new SharedSender!(Sender, NullScheduler)(sender, NullScheduler()); 23 } 24 25 class SharedSender(Sender, Scheduler) if (models!(Sender, isSender)) { 26 import std.traits : ReturnType; 27 import concurrency.slist; 28 import concurrency.bitfield; 29 static assert(models!(typeof(this), isSender)); 30 alias Value = Sender.Value; 31 static if (is(Value == void)) { 32 static struct ValueRep{} 33 } else 34 alias ValueRep = Value; 35 static struct Done{} 36 alias InternalValue = Algebraic!(Exception, ValueRep, Done); 37 alias DG = void delegate(InternalValue) nothrow @safe shared; 38 static struct SharedSenderOp(Receiver) { 39 SharedSender parent; 40 Receiver receiver; 41 StopCallback cb; 42 void start() nothrow @trusted scope { 43 parent.add(&(cast(shared)this).onValue); 44 cb = receiver.getStopToken.onStop(&(cast(shared)this).onStop); 45 } 46 void onStop() nothrow @trusted shared { 47 with(unshared) { 48 /// If this is the last one connected, remove will return false, 49 /// stop the underlying sender and we will receive the setDone via 50 /// the onValue. 51 /// This is to ensure we always await the underlying sender for 52 /// completion. 53 if (parent.remove(&(cast(shared)this).onValue)) 54 receiver.setDone(); 55 } 56 } 57 void onValue(InternalValue value) nothrow @safe shared { 58 with(unshared) { 59 value.match!((ValueRep v){ 60 try { 61 static if (is(Value == void)) 62 receiver.setValue(); 63 else 64 receiver.setValue(v); 65 } catch (Exception e) { 66 /// TODO: dispose needs to be called in all cases, except 67 /// this onValue can sometimes be called immediately, 68 /// leaving no room to set cb.dispose... 69 cb.dispose(); 70 receiver.setError(e); 71 } 72 }, (Exception e){ 73 receiver.setError(e); 74 }, (Done d){ 75 receiver.setDone(); 76 }); 77 } 78 } 79 private auto ref unshared() @trusted nothrow shared { 80 return cast()this; 81 } 82 } 83 static class SharedSenderState : StopSource { 84 import std.traits : ReturnType; 85 alias Op = OpType!(Sender, SharedSenderReceiver); 86 SharedSender parent; 87 shared SList!DG dgs; 88 Nullable!InternalValue value; 89 Op op; 90 this(SharedSender parent) { 91 this.dgs = new shared SList!DG; 92 this.parent = parent; 93 } 94 } 95 static struct SharedSenderReceiver { 96 SharedSenderState state; 97 Scheduler scheduler; 98 static if (is(Sender.Value == void)) 99 void setValue() @safe { 100 state.value = InternalValue(ValueRep()); 101 process(); 102 } 103 else 104 void setValue(ValueRep v) @safe { 105 state.value = InternalValue(v); 106 process(); 107 } 108 void setDone() @safe nothrow { 109 state.value = InternalValue(Done()); 110 process(); 111 } 112 void setError(Exception e) @safe nothrow { 113 state.value = InternalValue(e); 114 process(); 115 } 116 private void process() @trusted { 117 with(state.parent.counter.lock(Flags.completed)) { 118 release(oldState & (~0x3)); // release early and remove all ticks 119 InternalValue v = state.value.get; 120 if (state.isStopRequested) 121 v = Done(); 122 foreach(dg; state.dgs[]) 123 dg(v); 124 } 125 } 126 StopToken getStopToken() @safe nothrow { 127 return StopToken(state); 128 } 129 Scheduler getScheduler() @safe nothrow scope { 130 return scheduler; 131 } 132 } 133 private { 134 Sender sender; 135 Scheduler scheduler; 136 SharedSenderState state; 137 enum Flags { 138 locked = 0x1, 139 completed = 0x2, 140 tick = 0x4 141 } 142 SharedBitField!Flags counter; 143 void add(DG dg) @safe nothrow { 144 with(counter.lock(0, Flags.tick)) { 145 if (was(Flags.completed)) { 146 InternalValue value = state.value.get; 147 release(Flags.tick); // release early 148 dg(value); 149 } else { 150 if ((oldState >> 2) == 0) { 151 auto localState = new SharedSenderState(this); 152 this.state = localState; 153 release(); // release early 154 localState.dgs.pushBack(dg); 155 localState.op = sender.connect(SharedSenderReceiver(localState, scheduler)); 156 localState.op.start(); 157 } else { 158 auto localState = state; 159 release(); // release early 160 localState.dgs.pushBack(dg); 161 } 162 } 163 } 164 } 165 /// returns false if it is the last 166 bool remove(DG dg) @safe nothrow { 167 with (counter.lock(0, 0, Flags.tick)) { 168 if (was(Flags.completed)) { 169 release(0-Flags.tick); // release early 170 return true; 171 } 172 if ((newState >> 2) == 0) { 173 auto localStopSource = state; 174 release(); // release early 175 localStopSource.stop(); 176 return false; 177 } else { 178 auto localReceiver = state; 179 release(); // release early 180 localReceiver.dgs.remove(dg); 181 return true; 182 } 183 } 184 } 185 } 186 bool isCompleted() @trusted { 187 import core.atomic : MemoryOrder; 188 return (counter.load!(MemoryOrder.acq) & Flags.completed) > 0; 189 } 190 void reset() @trusted { 191 with (counter.lock()) { 192 if (was(Flags.completed)) 193 release(Flags.completed); 194 } 195 } 196 this(Sender sender, Scheduler scheduler) { 197 this.sender = sender; 198 this.scheduler = scheduler; 199 } 200 auto connect(Receiver)(return Receiver receiver) @safe scope return { 201 // ensure NRVO 202 auto op = SharedSenderOp!Receiver(this, receiver); 203 return op; 204 } 205 }