|
1 | 1 | #![forbid(unsafe_op_in_unsafe_fn)] |
2 | 2 |
|
| 3 | +use crate::mem::DropGuard; |
3 | 4 | use crate::pin::Pin; |
4 | | -use crate::ptr; |
5 | | -use crate::sync::atomic::Ordering::Relaxed; |
6 | | -use crate::sync::atomic::{Atomic, AtomicUsize}; |
7 | 5 | use crate::sys::pal::sync as pal; |
8 | 6 | use crate::sys::sync::{Mutex, OnceBox}; |
9 | 7 | use crate::time::{Duration, Instant}; |
10 | 8 |
|
| 9 | +struct StateGuard<'a> { |
| 10 | + mutex: Pin<&'a pal::Mutex>, |
| 11 | +} |
| 12 | + |
| 13 | +impl<'a> Drop for StateGuard<'a> { |
| 14 | + fn drop(&mut self) { |
| 15 | + unsafe { self.mutex.unlock() }; |
| 16 | + } |
| 17 | +} |
| 18 | + |
| 19 | +struct State { |
| 20 | + mutex: pal::Mutex, |
| 21 | + condvar: pal::Condvar, |
| 22 | +} |
| 23 | + |
| 24 | +impl State { |
| 25 | + fn condvar(self: Pin<&Self>) -> Pin<&pal::Condvar> { |
| 26 | + unsafe { self.map_unchecked(|this| &this.condvar) } |
| 27 | + } |
| 28 | + |
| 29 | + fn condvar_mut(self: Pin<&mut Self>) -> Pin<&mut pal::Condvar> { |
| 30 | + unsafe { self.map_unchecked_mut(|this| &mut this.condvar) } |
| 31 | + } |
| 32 | + |
| 33 | + /// Locks the `mutex` field and returns a [`StateGuard`] that unlocks the |
| 34 | + /// mutex when it is dropped. |
| 35 | + /// |
| 36 | + /// # Safety |
| 37 | + /// |
| 38 | + /// * The `mutex` field must not be locked by this thread. |
| 39 | + /// * Dismissing the guard leads to undefined behaviour when this `State` |
| 40 | + /// is dropped, as it is undefined behaviour to destroy a locked mutex. |
| 41 | + unsafe fn lock(self: Pin<&Self>) -> StateGuard<'_> { |
| 42 | + let mutex = unsafe { self.map_unchecked(|this| &this.mutex) }; |
| 43 | + unsafe { mutex.lock() }; |
| 44 | + StateGuard { mutex } |
| 45 | + } |
| 46 | +} |
| 47 | + |
11 | 48 | pub struct Condvar { |
12 | | - cvar: OnceBox<pal::Condvar>, |
13 | | - mutex: Atomic<usize>, |
| 49 | + state: OnceBox<State>, |
14 | 50 | } |
15 | 51 |
|
16 | 52 | impl Condvar { |
17 | 53 | pub const fn new() -> Condvar { |
18 | | - Condvar { cvar: OnceBox::new(), mutex: AtomicUsize::new(0) } |
| 54 | + Condvar { state: OnceBox::new() } |
19 | 55 | } |
20 | 56 |
|
21 | 57 | #[inline] |
22 | | - fn get(&self) -> Pin<&pal::Condvar> { |
23 | | - self.cvar.get_or_init(|| { |
24 | | - let mut cvar = Box::pin(pal::Condvar::new()); |
| 58 | + fn state(&self) -> Pin<&State> { |
| 59 | + self.state.get_or_init(|| { |
| 60 | + let mut state = |
| 61 | + Box::pin(State { mutex: pal::Mutex::new(), condvar: pal::Condvar::new() }); |
| 62 | + |
25 | 63 | // SAFETY: we only call `init` once per `pal::Condvar`, namely here. |
26 | | - unsafe { cvar.as_mut().init() }; |
27 | | - cvar |
| 64 | + unsafe { state.as_mut().condvar_mut().init() }; |
| 65 | + state |
28 | 66 | }) |
29 | 67 | } |
30 | 68 |
|
31 | | - #[inline] |
32 | | - fn verify(&self, mutex: Pin<&pal::Mutex>) { |
33 | | - let addr = ptr::from_ref::<pal::Mutex>(&mutex).addr(); |
34 | | - // Relaxed is okay here because we never read through `self.mutex`, and only use it to |
35 | | - // compare addresses. |
36 | | - match self.mutex.compare_exchange(0, addr, Relaxed, Relaxed) { |
37 | | - Ok(_) => {} // Stored the address |
38 | | - Err(n) if n == addr => {} // Lost a race to store the same address |
39 | | - _ => panic!("attempted to use a condition variable with two mutexes"), |
40 | | - } |
41 | | - } |
42 | | - |
43 | | - #[inline] |
44 | 69 | pub fn notify_one(&self) { |
| 70 | + let state = self.state(); |
| 71 | + // Notifications might be sent right after a mutex used with `wait` or |
| 72 | + // `wait_timeout` is unlocked. Waiting until the state mutex is |
| 73 | + // available ensures that the thread unlocking the mutex is enqueued |
| 74 | + // on the inner condition variable, as the mutex is only unlocked |
| 75 | + // with the state mutex held. |
| 76 | + // |
| 77 | + // Releasing the state mutex before issuing the notification stops |
| 78 | + // the awakened threads from having to wait on this thread unlocking |
| 79 | + // the mutex. |
| 80 | + // |
| 81 | + // SAFETY: |
| 82 | + // The functions in this module are never called recursively, so the |
| 83 | + // state mutex cannot be currently locked by this thread. |
| 84 | + drop(unsafe { state.lock() }); |
45 | 85 | // SAFETY: we called `init` above. |
46 | | - unsafe { self.get().notify_one() } |
| 86 | + unsafe { state.condvar().notify_one() } |
47 | 87 | } |
48 | 88 |
|
49 | | - #[inline] |
50 | 89 | pub fn notify_all(&self) { |
| 90 | + let state = self.state(); |
| 91 | + // Notifications might be sent right after a mutex used with `wait` or |
| 92 | + // `wait_timeout` is unlocked. Waiting until the state mutex is |
| 93 | + // available ensures that the thread unlocking the mutex is enqueued |
| 94 | + // on the inner condition variable, as the mutex is only unlocked |
| 95 | + // with the state mutex held. |
| 96 | + // |
| 97 | + // Releasing the state mutex before issuing the notification stops |
| 98 | + // the awakened threads from having to wait on this thread unlocking |
| 99 | + // the mutex. |
| 100 | + // |
| 101 | + // SAFETY: |
| 102 | + // The functions in this module are never called recursively, so the |
| 103 | + // state mutex cannot be currently locked by this thread. |
| 104 | + drop(unsafe { state.lock() }); |
51 | 105 | // SAFETY: we called `init` above. |
52 | | - unsafe { self.get().notify_all() } |
| 106 | + unsafe { state.condvar().notify_all() } |
53 | 107 | } |
54 | 108 |
|
55 | | - #[inline] |
56 | 109 | pub unsafe fn wait(&self, mutex: &Mutex) { |
57 | | - // SAFETY: the caller guarantees that the lock is owned, thus the mutex |
58 | | - // must have been initialized already. |
59 | | - let mutex = unsafe { mutex.pal.get_unchecked() }; |
60 | | - self.verify(mutex); |
61 | | - // SAFETY: we called `init` above, we verified that this condition |
62 | | - // variable is only used with `mutex` and the caller guarantees that |
63 | | - // `mutex` is locked by the current thread. |
64 | | - unsafe { self.get().wait(mutex) } |
| 110 | + let state = self.state(); |
| 111 | + |
| 112 | + // Ensure that the mutex is locked when this function returns or panics. |
| 113 | + // The relocking must occur after the state lock is unlocked to prevent |
| 114 | + // deadlocks, hence we scope the relock guard before the state lock guard. |
| 115 | + let relock; |
| 116 | + |
| 117 | + // Lock the state mutex before unlocking `mutex` to ensure that |
| 118 | + // notifications occurring before this thread is enqueued on the |
| 119 | + // condvar are not missed. |
| 120 | + // |
| 121 | + // SAFETY: |
| 122 | + // The functions in this module are never called recursively, so the |
| 123 | + // state mutex cannot be currently locked by this thread. |
| 124 | + let guard = unsafe { state.lock() }; |
| 125 | + |
| 126 | + // SAFETY: |
| 127 | + // The caller must guarantee that `mutex` is currently locked by this |
| 128 | + // thread. |
| 129 | + unsafe { mutex.unlock() }; |
| 130 | + relock = DropGuard::new(mutex, |mutex| mutex.lock()); |
| 131 | + |
| 132 | + // SAFETY: |
| 133 | + // * `init` was called above |
| 134 | + // * the condition variable is only ever used with the state mutex |
| 135 | + // * the state mutex was locked above |
| 136 | + unsafe { state.condvar().wait(guard.mutex) }; |
65 | 137 | } |
66 | 138 |
|
67 | 139 | pub unsafe fn wait_timeout(&self, mutex: &Mutex, dur: Duration) -> bool { |
68 | | - // SAFETY: the caller guarantees that the lock is owned, thus the mutex |
69 | | - // must have been initialized already. |
70 | | - let mutex = unsafe { mutex.pal.get_unchecked() }; |
71 | | - self.verify(mutex); |
| 140 | + let state = self.state(); |
| 141 | + |
| 142 | + // Ensure that the mutex is locked when this function returns or panics. |
| 143 | + // The relocking must occur after the state lock is unlocked to prevent |
| 144 | + // deadlocks, hence we scope the relock guard before the state lock guard. |
| 145 | + let relock; |
| 146 | + |
| 147 | + // Lock the state mutex before unlocking `mutex` to ensure that |
| 148 | + // notifications occurring before this thread is enqueued on the |
| 149 | + // condvar are not missed. |
| 150 | + // |
| 151 | + // SAFETY: |
| 152 | + // The functions in this module are never called recursively, so the |
| 153 | + // state mutex cannot be currently locked by this thread. |
| 154 | + let guard = unsafe { state.lock() }; |
| 155 | + |
| 156 | + // SAFETY: |
| 157 | + // The caller must guarantee that `mutex` is currently locked by this |
| 158 | + // thread. |
| 159 | + unsafe { mutex.unlock() }; |
| 160 | + relock = DropGuard::new(mutex, |mutex| mutex.lock()); |
72 | 161 |
|
73 | 162 | if pal::Condvar::PRECISE_TIMEOUT { |
74 | | - // SAFETY: we called `init` above, we verified that this condition |
75 | | - // variable is only used with `mutex` and the caller guarantees that |
76 | | - // `mutex` is locked by the current thread. |
77 | | - unsafe { self.get().wait_timeout(mutex, dur) } |
| 163 | + // SAFETY: |
| 164 | + // * `init` was called above |
| 165 | + // * the condition variable is only ever used with the state mutex |
| 166 | + // * the state mutex was locked above |
| 167 | + unsafe { state.condvar().wait_timeout(guard.mutex, dur) } |
78 | 168 | } else { |
79 | 169 | // Timeout reports are not reliable, so do the check ourselves. |
80 | 170 | let now = Instant::now(); |
81 | | - // SAFETY: we called `init` above, we verified that this condition |
82 | | - // variable is only used with `mutex` and the caller guarantees that |
83 | | - // `mutex` is locked by the current thread. |
84 | | - let woken = unsafe { self.get().wait_timeout(mutex, dur) }; |
| 171 | + // SAFETY: |
| 172 | + // * `init` was called above |
| 173 | + // * the condition variable is only ever used with the state mutex |
| 174 | + // * the state mutex was locked above |
| 175 | + let woken = unsafe { state.condvar().wait_timeout(guard.mutex, dur) }; |
85 | 176 | woken || now.elapsed() < dur |
86 | 177 | } |
87 | 178 | } |
|
0 commit comments