Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions tokio-util/src/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,15 @@ macro_rules! cfg_rt {
}
}

macro_rules! cfg_not_rt {
($($item:item)*) => {
$(
#[cfg(not(feature = "rt"))]
$item
)*
}
}

macro_rules! cfg_time {
($($item:item)*) => {
$(
Expand Down
1 change: 1 addition & 0 deletions tokio-util/src/io/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ mod copy_to_bytes;
mod inspect;
mod read_buf;
mod reader_stream;
pub mod simplex;
mod sink_writer;
mod stream_reader;

Expand Down
322 changes: 322 additions & 0 deletions tokio-util/src/io/simplex.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,322 @@
//! Unidirectional byte-oriented channel.

use crate::util::poll_proceed_and_make_progress;

use bytes::Buf;
use bytes::BytesMut;
use futures_core::ready;
use std::io::Error as IoError;
use std::io::ErrorKind as IoErrorKind;
use std::io::IoSlice;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

type IoResult<T> = Result<T, IoError>;

const CLOSED_ERROR_MSG: &str = "simplex has been closed";

#[derive(Debug)]
struct Inner {
/// `poll_write` will return [`Poll::Pending`] if the backpressure boundary is reached
backpressure_boundary: usize,

/// either [`Sender`] or [`Receiver`] is closed
is_closed: bool,

/// Waker used to wake the [`Receiver`]
receiver_waker: Option<Waker>,

/// Waker used to wake the [`Sender`]
sender_waker: Option<Waker>,

/// Buffer used to read and write data
buf: BytesMut,
}

impl Inner {
fn with_capacity(capacity: usize) -> Self {
Self {
backpressure_boundary: capacity,
is_closed: false,
receiver_waker: None,
sender_waker: None,
buf: BytesMut::with_capacity(capacity),
}
}

fn register_receiver_waker(&mut self, waker: &Waker) {
match self.receiver_waker.as_mut() {
Some(old) if old.will_wake(waker) => {}
Some(old) => old.clone_from(waker),
None => self.receiver_waker = Some(waker.clone()),
}
}

fn register_sender_waker(&mut self, waker: &Waker) {
match self.sender_waker.as_mut() {
Some(old) if old.will_wake(waker) => {}
Some(old) => old.clone_from(waker),
None => self.sender_waker = Some(waker.clone()),
}
}

fn take_receiver_waker(&mut self) -> Option<Waker> {
self.receiver_waker.take()
}

fn take_sender_waker(&mut self) -> Option<Waker> {
self.sender_waker.take()
}

fn is_closed(&self) -> bool {
self.is_closed
}

fn close_receiver(&mut self) -> Option<Waker> {
self.is_closed = true;
self.take_sender_waker()
}

fn close_sender(&mut self) -> Option<Waker> {
self.is_closed = true;
self.take_receiver_waker()
}
}

/// Receiver of the simplex channel.
///
/// You can still read the remaining data from the buffer
/// even if the write half has been dropped.
/// See [`Sender::poll_shutdown`] and [`Sender::drop`] for more details.
#[derive(Debug)]
pub struct Receiver {
inner: Arc<Mutex<Inner>>,
}

impl Drop for Receiver {
/// This also wakes up the [`Sender`].
fn drop(&mut self) {
let maybe_waker = {
let mut inner = self.inner.lock().unwrap();
inner.close_receiver()
};

if let Some(waker) = maybe_waker {
waker.wake();
}
}
}

impl AsyncRead for Receiver {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<IoResult<()>> {
let mut inner = self.inner.lock().unwrap();

let to_read = buf.remaining().min(inner.buf.remaining());
if to_read == 0 {
if inner.is_closed() || buf.remaining() == 0 {
return Poll::Ready(Ok(()));
}

inner.register_receiver_waker(cx.waker());
let maybe_waker = inner.take_sender_waker();
drop(inner); // unlock before waking up
if let Some(waker) = maybe_waker {
waker.wake();
}
return Poll::Pending;
}

ready!(poll_proceed_and_make_progress(cx));

buf.put_slice(&inner.buf[..to_read]);
inner.buf.advance(to_read);
let waker = inner.take_sender_waker();
drop(inner); // unlock before waking up
if let Some(waker) = waker {
waker.wake();
}
Poll::Ready(Ok(()))
}
}

/// Sender of the simplex channel.
///
/// ## Shutdown
///
/// See [`Sender::poll_shutdown`].
#[derive(Debug)]
pub struct Sender {
inner: Arc<Mutex<Inner>>,
}

impl Drop for Sender {
/// This also wakes up the [`Receiver`].
fn drop(&mut self) {
let maybe_waker = {
let mut inner = self.inner.lock().unwrap();
inner.close_sender()
};

if let Some(waker) = maybe_waker {
waker.wake();
}
}
}

impl AsyncWrite for Sender {
/// # Errors
///
/// This method will return [`IoErrorKind::BrokenPipe`]
/// if the channel has been closed.
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<IoResult<usize>> {
let mut inner = self.inner.lock().unwrap();

if inner.is_closed() {
return Poll::Ready(Err(IoError::new(IoErrorKind::BrokenPipe, CLOSED_ERROR_MSG)));
}

let free = inner
.backpressure_boundary
.checked_sub(inner.buf.len())
.expect("backpressure boundary overflow");
let to_write = buf.len().min(free);
if to_write == 0 {
if buf.is_empty() {
return Poll::Ready(Ok(0));
}

inner.register_sender_waker(cx.waker());
let waker = inner.take_receiver_waker();
drop(inner); // unlock before waking up
if let Some(waker) = waker {
waker.wake();
}
return Poll::Pending;
}

// this is to avoid starving other tasks
ready!(poll_proceed_and_make_progress(cx));

inner.buf.extend_from_slice(&buf[..to_write]);
let waker = inner.take_receiver_waker();
drop(inner); // unlock before waking up
if let Some(waker) = waker {
waker.wake();
}
Poll::Ready(Ok(to_write))
}

/// # Errors
///
/// This method will return [`IoErrorKind::BrokenPipe`]
/// if the channel has been closed.
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<IoResult<()>> {
let inner = self.inner.lock().unwrap();
if inner.is_closed() {
Poll::Ready(Err(IoError::new(IoErrorKind::BrokenPipe, CLOSED_ERROR_MSG)))
} else {
Poll::Ready(Ok(()))
}
}

/// After returns [`Poll::Ready`], all the following call to
/// [`Sender::poll_write`] and [`Sender::poll_flush`]
/// will return error.
///
/// The [`Receiver`] can still be used to read remaining data
/// until all bytes have been consumed.
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<IoResult<()>> {
let maybe_waker = {
let mut inner = self.inner.lock().unwrap();
inner.close_sender()
};

if let Some(waker) = maybe_waker {
waker.wake();
}

Poll::Ready(Ok(()))
}

fn is_write_vectored(&self) -> bool {
true
}

fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize, IoError>> {
let mut inner = self.inner.lock().unwrap();
if inner.is_closed() {
return Poll::Ready(Err(IoError::new(IoErrorKind::BrokenPipe, CLOSED_ERROR_MSG)));
}

let free = inner
.backpressure_boundary
.checked_sub(inner.buf.len())
.expect("backpressure boundary overflow");
if free == 0 {
inner.register_sender_waker(cx.waker());
let maybe_waker = inner.take_receiver_waker();
drop(inner); // unlock before waking up
if let Some(waker) = maybe_waker {
waker.wake();
}
return Poll::Pending;
}

ready!(poll_proceed_and_make_progress(cx));

let mut rem = free;
for buf in bufs {
if rem == 0 {
break;
}

let to_write = buf.len().min(rem);
if to_write == 0 {
assert_ne!(rem, 0);
assert_eq!(buf.len(), 0);
continue;
}

inner.buf.extend_from_slice(&buf[..to_write]);
rem -= to_write;
}

let waker = inner.take_receiver_waker();
drop(inner); // unlock before waking up
if let Some(waker) = waker {
waker.wake();
}

Poll::Ready(Ok(free - rem))
}
}

/// Create a simplex channel.
///
/// The `capacity` parameter specifies the maximum number of bytes that can be
/// stored in the channel without making the [`Sender::poll_write`]
/// return [`Poll::Pending`].
///
/// # Panics
///
/// This function will panic if `capacity` is zero.
pub fn new(capacity: usize) -> (Sender, Receiver) {
assert_ne!(capacity, 0, "capacity must be greater than zero");

let inner = Arc::new(Mutex::new(Inner::with_capacity(capacity)));
let tx = Sender {
inner: Arc::clone(&inner),
};
let rx = Receiver { inner };
(tx, rx)
}
21 changes: 21 additions & 0 deletions tokio-util/src/util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,24 @@ pub(crate) use maybe_dangling::MaybeDangling;
#[cfg(any(feature = "io", feature = "codec"))]
#[cfg_attr(not(feature = "io"), allow(unreachable_pub))]
pub use poll_buf::{poll_read_buf, poll_write_buf};

cfg_rt! {
use std::task::{Context, Poll};
use tokio::task::coop::poll_proceed;
use futures_core::ready;

#[cfg_attr(not(feature = "io"), allow(unused))]
pub(crate) fn poll_proceed_and_make_progress(cx: &mut Context<'_>) -> Poll<()> {
ready!(poll_proceed(cx)).made_progress();
Poll::Ready(())
}
}

cfg_not_rt! {
use std::task::{Context, Poll};

#[cfg_attr(not(feature = "io"), allow(unused))]
pub(crate) fn poll_proceed_and_make_progress(_cx: &mut Context<'_>) -> Poll<()> {
Poll::Ready(())
}
}
Loading
Loading