1 module concurrency.operations.toshared; 2 3 import concurrency; 4 import concurrency.receiver; 5 import concurrency.sender; 6 import concurrency.stoptoken; 7 import concepts; 8 import std.traits; 9 import mir.algebraic : Algebraic, Nullable, match; 10 11 /// Wraps a Sender in a SharedSender. A SharedSender allows many receivers to connect to the same underlying Sender, forwarding the same termination call to each receiver. 12 /// The underlying Sender is connected and started only once. It can be explicitely `reset` so that it connects and starts the underlying Sender the next time it started. Calling `reset` when the underlying Sender hasn't completed is a no-op. 13 /// When the last receiver triggers its stoptoken while the underlying Sender is still running, the latter will be cancelled and one termination function of the former will be called after the latter is completed. (This is to ensure structured concurrency, otherwise tasks could be left running without anyone awaiting them). 14 /// If an receiver is connected after the underlying Sender has already been completed, that receiver will have one of its termination functions called immediately. 15 /// This operation is useful when you have multiple tasks that all depend on one shared task. It allows you to write the shared task as a regular Sender and simply apply a `.toShared`. 16 auto toShared(Sender, Scheduler)(Sender sender, Scheduler scheduler) { 17 return new SharedSender!(Sender, Scheduler)(sender, scheduler); 18 } 19 20 auto toShared(Sender)(Sender sender) { 21 return new SharedSender!(Sender, NullScheduler)(sender, NullScheduler()); 22 } 23 24 private struct NullScheduler {} 25 private struct Done{} 26 private struct ValueRep{} 27 28 class SharedSender(Sender, Scheduler) if (models!(Sender, isSender)) { 29 import std.traits : ReturnType; 30 import concurrency.slist; 31 import concurrency.bitfield; 32 static assert(models!(typeof(this), isSender)); 33 alias Value = Sender.Value; 34 static if (!is(Value == void)) 35 alias ValueRep = Value; 36 alias InternalValue = Algebraic!(Exception, ValueRep, Done); 37 alias DG = void delegate(InternalValue) nothrow @safe shared; 38 static struct SharedSenderOp(Receiver) { 39 SharedSender parent; 40 Receiver receiver; 41 StopCallback cb; 42 void start() nothrow @trusted scope { 43 parent.add(&(cast(shared)this).onValue); 44 cb = receiver.getStopToken.onStop(&(cast(shared)this).onStop); 45 } 46 void onStop() nothrow @trusted shared { 47 with(unshared) { 48 /// If this is the last one connected, remove will return false, 49 /// stop the underlying sender and we will receive the setDone via 50 /// the onValue. 51 /// This is to ensure we always await the underlying sender for 52 /// completion. 53 if (parent.remove(&(cast(shared)this).onValue)) 54 receiver.setDone(); 55 } 56 } 57 void onValue(InternalValue value) nothrow @safe shared { 58 with(unshared) { 59 value.match!((ValueRep v){ 60 try { 61 static if (is(Value == void)) 62 receiver.setValue(); 63 else 64 receiver.setValue(v); 65 } catch (Exception e) { 66 /// TODO: dispose needs to be called in all cases, except 67 /// this onValue can sometimes be called immediately, 68 /// leaving no room to set cb.dispose... 69 cb.dispose(); 70 receiver.setError(e); 71 } 72 }, (Exception e){ 73 receiver.setError(e); 74 }, (Done d){ 75 receiver.setDone(); 76 }); 77 } 78 } 79 private auto ref unshared() @trusted nothrow shared { 80 return cast()this; 81 } 82 } 83 static class SharedSenderState : StopSource { 84 import std.traits : ReturnType; 85 alias Op = OpType!(Sender, SharedSenderReceiver); 86 SharedSender parent; 87 shared SList!DG dgs; 88 Nullable!InternalValue value; 89 Op op; 90 this(SharedSender parent) { 91 this.dgs = new shared SList!DG; 92 this.parent = parent; 93 } 94 } 95 static struct SharedSenderReceiver { 96 SharedSenderState state; 97 Scheduler scheduler; 98 static if (is(Sender.Value == void)) 99 void setValue() @safe { 100 state.value = InternalValue(ValueRep()); 101 process(); 102 } 103 else 104 void setValue(ValueRep v) @safe { 105 state.value = InternalValue(v); 106 process(); 107 } 108 void setDone() @safe nothrow { 109 state.value = InternalValue(Done()); 110 process(); 111 } 112 void setError(Exception e) @safe nothrow { 113 state.value = InternalValue(e); 114 process(); 115 } 116 private void process() @trusted { 117 state.parent.process(); 118 } 119 StopToken getStopToken() @safe nothrow { 120 return StopToken(state); 121 } 122 Scheduler getScheduler() @safe nothrow scope { 123 return scheduler; 124 } 125 } 126 private { 127 Sender sender; 128 Scheduler scheduler; 129 SharedSenderState state; 130 enum Flags { 131 locked = 0x1, 132 completed = 0x2, 133 tick = 0x4 134 } 135 SharedBitField!Flags counter; 136 void add(DG dg) @safe nothrow { 137 with(counter.lock(0, Flags.tick)) { 138 if (was(Flags.completed)) { 139 InternalValue value = state.value.get; 140 release(Flags.tick); // release early 141 dg(value); 142 } else { 143 if ((oldState >> 2) == 0) { 144 auto localState = new SharedSenderState(this); 145 this.state = localState; 146 release(); // release early 147 localState.dgs.pushBack(dg); 148 try { 149 localState.op = sender.connect(SharedSenderReceiver(localState, scheduler)); 150 } catch (Exception e) { 151 state.value = InternalValue(e); 152 process(); 153 } 154 localState.op.start(); 155 } else { 156 auto localState = state; 157 release(); // release early 158 localState.dgs.pushBack(dg); 159 } 160 } 161 } 162 } 163 /// returns false if it is the last 164 bool remove(DG dg) @safe nothrow { 165 with (counter.lock(0, 0, Flags.tick)) { 166 if (was(Flags.completed)) { 167 release(0-Flags.tick); // release early 168 return true; 169 } 170 if ((newState >> 2) == 0) { 171 auto localStopSource = state; 172 release(); // release early 173 localStopSource.stop(); 174 return false; 175 } else { 176 auto localReceiver = state; 177 release(); // release early 178 localReceiver.dgs.remove(dg); 179 return true; 180 } 181 } 182 } 183 } 184 private void process() { 185 with(counter.lock(Flags.completed)) { 186 release(oldState & (~0x3)); // release early and remove all ticks 187 InternalValue v = state.value.get; 188 if (state.isStopRequested) 189 v = Done(); 190 foreach(dg; state.dgs[]) 191 dg(v); 192 } 193 } 194 bool isCompleted() @trusted { 195 import core.atomic : MemoryOrder; 196 return (counter.load!(MemoryOrder.acq) & Flags.completed) > 0; 197 } 198 void reset() @trusted { 199 with (counter.lock()) { 200 if (was(Flags.completed)) 201 release(Flags.completed); 202 } 203 } 204 this(Sender sender, Scheduler scheduler) { 205 this.sender = sender; 206 this.scheduler = scheduler; 207 } 208 auto connect(Receiver)(return Receiver receiver) @safe scope return { 209 // ensure NRVO 210 auto op = SharedSenderOp!Receiver(this, receiver); 211 return op; 212 } 213 }