1 module concurrency.bitfield;
2 
3 import core.atomic : atomicOp, atomicLoad, MemoryOrder;
4 
5 struct Guard(Flags) {
6   private SharedBitField!(Flags)* obj;
7   size_t oldState, newState;
8   ~this() {
9     release();
10   }
11   void release(size_t sub = 0) {
12     if (obj !is null) {
13       // TODO: want to use atomicFetchSub but (proper) support is only recent
14       // obj.store.atomicFetchSub!(MemoryOrder.rel)(sub | Flags.locked);
15       obj.store.atomicOp!"-="(sub | Flags.locked);
16     }
17     obj = null;
18   }
19   bool was(Flags flags) {
20     return (oldState & flags) == flags;
21   }
22 }
23 
24 shared struct SharedBitField(Flags) {
25   static assert(__traits(compiles, Flags.locked), "Must have a 'locked' flag");
26   private shared size_t store;
27   static if (Flags.locked > 0) {
28     Guard!Flags lock(size_t or = 0, size_t add = 0, size_t sub = 0) return scope @safe @nogc nothrow {
29         return Guard!Flags(&this, update(Flags.locked | or, add, sub).expand);
30       }
31   }
32   auto update(size_t or, size_t add = 0, size_t sub = 0) nothrow {
33     import concurrency.utils : spin_yield, casWeak;
34     import std.typecons : tuple;
35     size_t oldState, newState;
36     do {
37       goto load_state;
38       do {
39         spin_yield();
40       load_state:
41         oldState = store.atomicLoad!(MemoryOrder.acq);
42       } while ((oldState & Flags.locked) > 0);
43       newState = (oldState + add - sub) | or;
44     } while (!casWeak!(MemoryOrder.acq, MemoryOrder.acq)(&store, oldState, newState));
45     return tuple!("oldState", "newState")(oldState, newState);
46   }
47   auto add(size_t add) nothrow {
48     return Result!Flags(store.atomicOp!"+="(add));
49   }
50   auto sub(size_t sub) nothrow {
51     return Result!Flags(store.atomicOp!"-="(sub));
52   }
53   size_t load(MemoryOrder ms)() {
54     return store.atomicLoad!ms;
55   }
56 }
57 
58 struct Result(Flags) {
59   size_t state;
60   alias state this;
61   bool has(Flags flags) {
62     return (state & flags) == flags;
63   }
64 }