Skip to content

Commit 1b98d5a

Browse files
authored
task: add tokio_util::task::JoinQueue (#7590)
1 parent 6d1ae62 commit 1b98d5a

File tree

3 files changed

+575
-0
lines changed

3 files changed

+575
-0
lines changed

tokio-util/src/task/join_queue.rs

Lines changed: 349 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,349 @@
1+
use super::AbortOnDropHandle;
2+
use std::{
3+
collections::VecDeque,
4+
future::Future,
5+
pin::Pin,
6+
task::{Context, Poll},
7+
};
8+
use tokio::{
9+
runtime::Handle,
10+
task::{AbortHandle, Id, JoinError, JoinHandle},
11+
};
12+
13+
/// A FIFO queue for of tasks spawned on a Tokio runtime.
14+
///
15+
/// A [`JoinQueue`] can be used to await the completion of the tasks in FIFO
16+
/// order. That is, if tasks are spawned in the order A, B, C, then
17+
/// awaiting the next completed task will always return A first, then B,
18+
/// then C, regardless of the order in which the tasks actually complete.
19+
///
20+
/// All of the tasks must have the same return type `T`.
21+
///
22+
/// When the [`JoinQueue`] is dropped, all tasks in the [`JoinQueue`] are
23+
/// immediately aborted.
24+
#[derive(Debug)]
25+
pub struct JoinQueue<T>(VecDeque<AbortOnDropHandle<T>>);
26+
27+
impl<T> JoinQueue<T> {
28+
/// Create a new empty [`JoinQueue`].
29+
pub const fn new() -> Self {
30+
Self(VecDeque::new())
31+
}
32+
33+
/// Creates an empty [`JoinQueue`] with space for at least `capacity` tasks.
34+
pub fn with_capacity(capacity: usize) -> Self {
35+
Self(VecDeque::with_capacity(capacity))
36+
}
37+
38+
/// Returns the number of tasks currently in the [`JoinQueue`].
39+
///
40+
/// This includes both tasks that are currently running and tasks that have
41+
/// completed but not yet been removed from the queue because outputting of
42+
/// them waits for FIFO order.
43+
pub fn len(&self) -> usize {
44+
self.0.len()
45+
}
46+
47+
/// Returns whether the [`JoinQueue`] is empty.
48+
pub fn is_empty(&self) -> bool {
49+
self.0.is_empty()
50+
}
51+
52+
/// Spawn the provided task on the [`JoinQueue`], returning an [`AbortHandle`]
53+
/// that can be used to remotely cancel the task.
54+
///
55+
/// The provided future will start running in the background immediately
56+
/// when this method is called, even if you don't await anything on this
57+
/// [`JoinQueue`].
58+
///
59+
/// # Panics
60+
///
61+
/// This method panics if called outside of a Tokio runtime.
62+
///
63+
/// [`AbortHandle`]: tokio::task::AbortHandle
64+
#[track_caller]
65+
pub fn spawn<F>(&mut self, task: F) -> AbortHandle
66+
where
67+
F: Future<Output = T> + Send + 'static,
68+
T: Send + 'static,
69+
{
70+
self.push_back(tokio::spawn(task))
71+
}
72+
73+
/// Spawn the provided task on the provided runtime and store it in this
74+
/// [`JoinQueue`] returning an [`AbortHandle`] that can be used to remotely
75+
/// cancel the task.
76+
///
77+
/// The provided future will start running in the background immediately
78+
/// when this method is called, even if you don't await anything on this
79+
/// [`JoinQueue`].
80+
///
81+
/// [`AbortHandle`]: tokio::task::AbortHandle
82+
#[track_caller]
83+
pub fn spawn_on<F>(&mut self, task: F, handle: &Handle) -> AbortHandle
84+
where
85+
F: Future<Output = T> + Send + 'static,
86+
T: Send + 'static,
87+
{
88+
self.push_back(handle.spawn(task))
89+
}
90+
91+
/// Spawn the provided task on the current [`LocalSet`] and store it in this
92+
/// [`JoinQueue`], returning an [`AbortHandle`] that can be used to remotely
93+
/// cancel the task.
94+
///
95+
/// The provided future will start running in the background immediately
96+
/// when this method is called, even if you don't await anything on this
97+
/// [`JoinQueue`].
98+
///
99+
/// # Panics
100+
///
101+
/// This method panics if it is called outside of a `LocalSet`.
102+
///
103+
/// [`LocalSet`]: tokio::task::LocalSet
104+
/// [`AbortHandle`]: tokio::task::AbortHandle
105+
#[track_caller]
106+
pub fn spawn_local<F>(&mut self, task: F) -> AbortHandle
107+
where
108+
F: Future<Output = T> + 'static,
109+
T: 'static,
110+
{
111+
self.push_back(tokio::task::spawn_local(task))
112+
}
113+
114+
/// Spawn the blocking code on the blocking threadpool and store
115+
/// it in this [`JoinQueue`], returning an [`AbortHandle`] that can be
116+
/// used to remotely cancel the task.
117+
///
118+
/// # Panics
119+
///
120+
/// This method panics if called outside of a Tokio runtime.
121+
///
122+
/// [`AbortHandle`]: tokio::task::AbortHandle
123+
#[track_caller]
124+
pub fn spawn_blocking<F>(&mut self, f: F) -> AbortHandle
125+
where
126+
F: FnOnce() -> T + Send + 'static,
127+
T: Send + 'static,
128+
{
129+
self.push_back(tokio::task::spawn_blocking(f))
130+
}
131+
132+
/// Spawn the blocking code on the blocking threadpool of the
133+
/// provided runtime and store it in this [`JoinQueue`], returning an
134+
/// [`AbortHandle`] that can be used to remotely cancel the task.
135+
///
136+
/// [`AbortHandle`]: tokio::task::AbortHandle
137+
#[track_caller]
138+
pub fn spawn_blocking_on<F>(&mut self, f: F, handle: &Handle) -> AbortHandle
139+
where
140+
F: FnOnce() -> T + Send + 'static,
141+
T: Send + 'static,
142+
{
143+
self.push_back(handle.spawn_blocking(f))
144+
}
145+
146+
fn push_back(&mut self, jh: JoinHandle<T>) -> AbortHandle {
147+
let jh = AbortOnDropHandle::new(jh);
148+
let abort_handle = jh.abort_handle();
149+
self.0.push_back(jh);
150+
abort_handle
151+
}
152+
153+
/// Waits until the next task in FIFO order completes and returns its output.
154+
///
155+
/// Returns `None` if the queue is empty.
156+
///
157+
/// # Cancel Safety
158+
///
159+
/// This method is cancel safe. If `join_next` is used as the event in a `tokio::select!`
160+
/// statement and some other branch completes first, it is guaranteed that no tasks were
161+
/// removed from this [`JoinQueue`].
162+
pub async fn join_next(&mut self) -> Option<Result<T, JoinError>> {
163+
std::future::poll_fn(|cx| self.poll_join_next(cx)).await
164+
}
165+
166+
/// Waits until the next task in FIFO order completes and returns its output,
167+
/// along with the [task ID] of the completed task.
168+
///
169+
/// Returns `None` if the queue is empty.
170+
///
171+
/// When this method returns an error, then the id of the task that failed can be accessed
172+
/// using the [`JoinError::id`] method.
173+
///
174+
/// # Cancel Safety
175+
///
176+
/// This method is cancel safe. If `join_next_with_id` is used as the event in a `tokio::select!`
177+
/// statement and some other branch completes first, it is guaranteed that no tasks were
178+
/// removed from this [`JoinQueue`].
179+
///
180+
/// [task ID]: tokio::task::Id
181+
/// [`JoinError::id`]: fn@tokio::task::JoinError::id
182+
pub async fn join_next_with_id(&mut self) -> Option<Result<(Id, T), JoinError>> {
183+
std::future::poll_fn(|cx| self.poll_join_next_with_id(cx)).await
184+
}
185+
186+
/// Aborts all tasks and waits for them to finish shutting down.
187+
///
188+
/// Calling this method is equivalent to calling [`abort_all`] and then calling [`join_next`] in
189+
/// a loop until it returns `None`.
190+
///
191+
/// This method ignores any panics in the tasks shutting down. When this call returns, the
192+
/// [`JoinQueue`] will be empty.
193+
///
194+
/// [`abort_all`]: fn@Self::abort_all
195+
/// [`join_next`]: fn@Self::join_next
196+
pub async fn shutdown(&mut self) {
197+
self.abort_all();
198+
while self.join_next().await.is_some() {}
199+
}
200+
201+
/// Awaits the completion of all tasks in this [`JoinQueue`], returning a vector of their results.
202+
///
203+
/// The results will be stored in the order they were spawned, not the order they completed.
204+
/// This is a convenience method that is equivalent to calling [`join_next`] in
205+
/// a loop. If any tasks on the [`JoinQueue`] fail with an [`JoinError`], then this call
206+
/// to `join_all` will panic and all remaining tasks on the [`JoinQueue`] are
207+
/// cancelled. To handle errors in any other way, manually call [`join_next`]
208+
/// in a loop.
209+
///
210+
/// # Cancel Safety
211+
///
212+
/// This method is not cancel safe as it calls `join_next` in a loop. If you need
213+
/// cancel safety, manually call `join_next` in a loop with `Vec` accumulator.
214+
///
215+
/// [`join_next`]: fn@Self::join_next
216+
/// [`JoinError::id`]: fn@tokio::task::JoinError::id
217+
pub async fn join_all(mut self) -> Vec<T> {
218+
let mut output = Vec::with_capacity(self.len());
219+
220+
while let Some(res) = self.join_next().await {
221+
match res {
222+
Ok(t) => output.push(t),
223+
Err(err) if err.is_panic() => std::panic::resume_unwind(err.into_panic()),
224+
Err(err) => panic!("{err}"),
225+
}
226+
}
227+
output
228+
}
229+
230+
/// Aborts all tasks on this [`JoinQueue`].
231+
///
232+
/// This does not remove the tasks from the [`JoinQueue`]. To wait for the tasks to complete
233+
/// cancellation, you should call `join_next` in a loop until the [`JoinQueue`] is empty.
234+
pub fn abort_all(&mut self) {
235+
self.0.iter().for_each(|jh| jh.abort());
236+
}
237+
238+
/// Removes all tasks from this [`JoinQueue`] without aborting them.
239+
///
240+
/// The tasks removed by this call will continue to run in the background even if the [`JoinQueue`]
241+
/// is dropped.
242+
pub fn detach_all(&mut self) {
243+
self.0.drain(..).for_each(|jh| drop(jh.detach()));
244+
}
245+
246+
/// Polls for the next task in [`JoinQueue`] to complete.
247+
///
248+
/// If this returns `Poll::Ready(Some(_))`, then the task that completed is removed from the queue.
249+
///
250+
/// When the method returns `Poll::Pending`, the `Waker` in the provided `Context` is scheduled
251+
/// to receive a wakeup when a task in the [`JoinQueue`] completes. Note that on multiple calls to
252+
/// `poll_join_next`, only the `Waker` from the `Context` passed to the most recent call is
253+
/// scheduled to receive a wakeup.
254+
///
255+
/// # Returns
256+
///
257+
/// This function returns:
258+
///
259+
/// * `Poll::Pending` if the [`JoinQueue`] is not empty but there is no task whose output is
260+
/// available right now.
261+
/// * `Poll::Ready(Some(Ok(value)))` if the next task in this [`JoinQueue`] has completed.
262+
/// The `value` is the return value that task.
263+
/// * `Poll::Ready(Some(Err(err)))` if the next task in this [`JoinQueue`] has panicked or been
264+
/// aborted. The `err` is the `JoinError` from the panicked/aborted task.
265+
/// * `Poll::Ready(None)` if the [`JoinQueue`] is empty.
266+
pub fn poll_join_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<T, JoinError>>> {
267+
let jh = match self.0.front_mut() {
268+
None => return Poll::Ready(None),
269+
Some(jh) => jh,
270+
};
271+
if let Poll::Ready(res) = Pin::new(jh).poll(cx) {
272+
// Use `detach` to avoid calling `abort` on a task that has already completed.
273+
// Dropping `AbortOnDropHandle` would abort the task, but since it is finished,
274+
// we only need to drop the `JoinHandle` for cleanup.
275+
drop(self.0.pop_front().unwrap().detach());
276+
Poll::Ready(Some(res))
277+
} else {
278+
Poll::Pending
279+
}
280+
}
281+
282+
/// Polls for the next task in [`JoinQueue`] to complete.
283+
///
284+
/// If this returns `Poll::Ready(Some(_))`, then the task that completed is removed from the queue.
285+
///
286+
/// When the method returns `Poll::Pending`, the `Waker` in the provided `Context` is scheduled
287+
/// to receive a wakeup when a task in the [`JoinQueue`] completes. Note that on multiple calls to
288+
/// `poll_join_next`, only the `Waker` from the `Context` passed to the most recent call is
289+
/// scheduled to receive a wakeup.
290+
///
291+
/// # Returns
292+
///
293+
/// This function returns:
294+
///
295+
/// * `Poll::Pending` if the [`JoinQueue`] is not empty but there is no task whose output is
296+
/// available right now.
297+
/// * `Poll::Ready(Some(Ok((id, value))))` if the next task in this [`JoinQueue`] has completed.
298+
/// The `value` is the return value that task, and `id` is its [task ID].
299+
/// * `Poll::Ready(Some(Err(err)))` if the next task in this [`JoinQueue`] has panicked or been
300+
/// aborted. The `err` is the `JoinError` from the panicked/aborted task.
301+
/// * `Poll::Ready(None)` if the [`JoinQueue`] is empty.
302+
///
303+
/// [task ID]: tokio::task::Id
304+
pub fn poll_join_next_with_id(
305+
&mut self,
306+
cx: &mut Context<'_>,
307+
) -> Poll<Option<Result<(Id, T), JoinError>>> {
308+
let jh = match self.0.front_mut() {
309+
None => return Poll::Ready(None),
310+
Some(jh) => jh,
311+
};
312+
if let Poll::Ready(res) = Pin::new(jh).poll(cx) {
313+
// Use `detach` to avoid calling `abort` on a task that has already completed.
314+
// Dropping `AbortOnDropHandle` would abort the task, but since it is finished,
315+
// we only need to drop the `JoinHandle` for cleanup.
316+
let jh = self.0.pop_front().unwrap().detach();
317+
let id = jh.id();
318+
drop(jh);
319+
// If the task succeeded, add the task ID to the output. Otherwise, the
320+
// `JoinError` will already have the task's ID.
321+
Poll::Ready(Some(res.map(|output| (id, output))))
322+
} else {
323+
Poll::Pending
324+
}
325+
}
326+
}
327+
328+
impl<T> Default for JoinQueue<T> {
329+
fn default() -> Self {
330+
Self::new()
331+
}
332+
}
333+
334+
/// Collect an iterator of futures into a [`JoinQueue`].
335+
///
336+
/// This is equivalent to calling [`JoinQueue::spawn`] on each element of the iterator.
337+
impl<T, F> std::iter::FromIterator<F> for JoinQueue<T>
338+
where
339+
F: Future<Output = T> + Send + 'static,
340+
T: Send + 'static,
341+
{
342+
fn from_iter<I: IntoIterator<Item = F>>(iter: I) -> Self {
343+
let mut set = Self::new();
344+
iter.into_iter().for_each(|task| {
345+
set.spawn(task);
346+
});
347+
set
348+
}
349+
}

tokio-util/src/task/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ cfg_rt! {
1313

1414
mod abort_on_drop;
1515
pub use abort_on_drop::AbortOnDropHandle;
16+
17+
mod join_queue;
18+
pub use join_queue::JoinQueue;
1619
}
1720

1821
#[cfg(feature = "join-map")]

0 commit comments

Comments
 (0)