diff --git a/src/client/pubsub.rs b/src/client/pubsub.rs index 4a91574..dd263a1 100644 --- a/src/client/pubsub.rs +++ b/src/client/pubsub.rs @@ -131,7 +131,7 @@ impl PubsubConnectionInner { } } - fn handle_message(&mut self, msg: resp::RespValue) -> Result { + fn handle_message(&mut self, msg: resp::RespValue) -> Result<(), error::Error> { let (message_type, topic, msg) = match msg { resp::RespValue::Array(mut messages) => match ( messages.pop(), @@ -193,38 +193,28 @@ impl PubsubConnectionInner { ))); } }, - b"unsubscribe" => { - match self.subscriptions.entry(topic) { - Entry::Occupied(entry) => { - entry.remove_entry(); - } - Entry::Vacant(vacant) => { - return Err(error::internal(format!( - "Unexpected unsubscribe message: {}", - vacant.key() - ))); - } + b"unsubscribe" => match self.subscriptions.entry(topic) { + Entry::Occupied(entry) => { + entry.remove_entry(); } - if self.subscriptions.is_empty() { - return Ok(false); + Entry::Vacant(vacant) => { + return Err(error::internal(format!( + "Unexpected unsubscribe message: {}", + vacant.key() + ))); } - } - b"punsubscribe" => { - match self.psubscriptions.entry(topic) { - Entry::Occupied(entry) => { - entry.remove_entry(); - } - Entry::Vacant(vacant) => { - return Err(error::internal(format!( - "Unexpected unsubscribe message: {}", - vacant.key() - ))); - } + }, + b"punsubscribe" => match self.psubscriptions.entry(topic) { + Entry::Occupied(entry) => { + entry.remove_entry(); } - if self.psubscriptions.is_empty() { - return Ok(false); + Entry::Vacant(vacant) => { + return Err(error::internal(format!( + "Unexpected unsubscribe message: {}", + vacant.key() + ))); } - } + }, b"message" => match self.subscriptions.get(&topic) { Some(sender) => { if let Err(error) = sender.unbounded_send(Ok(msg)) { @@ -263,39 +253,32 @@ impl PubsubConnectionInner { } } - Ok(true) + Ok(()) } /// Returns true, if there are still valid subscriptions at the end, or false if not, i.e. the whole thing can be dropped. - fn handle_messages(&mut self, cx: &mut Context) -> Result { + fn handle_messages(&mut self, cx: &mut Context) -> Result<(), error::Error> { loop { match self.connection.poll_next_unpin(cx) { - Poll::Pending => return Ok(true), + Poll::Pending => return Ok(()), Poll::Ready(None) => { - if self.subscriptions.is_empty() { - return Ok(false); - } else { - // This can only happen if the connection is closed server-side - for sub in self.subscriptions.values() { - sub.unbounded_send(Err(error::Error::Connection( - ConnectionReason::NotConnected, - ))) - .unwrap(); - } - for psub in self.psubscriptions.values() { - psub.unbounded_send(Err(error::Error::Connection( - ConnectionReason::NotConnected, - ))) - .unwrap(); - } - return Err(error::Error::Connection(ConnectionReason::NotConnected)); + // This can only happen if the connection is closed server-side + for sub in self.subscriptions.values() { + sub.unbounded_send(Err(error::Error::Connection( + ConnectionReason::NotConnected, + ))) + .unwrap(); + } + for psub in self.psubscriptions.values() { + psub.unbounded_send(Err(error::Error::Connection( + ConnectionReason::NotConnected, + ))) + .unwrap(); } + return Err(error::Error::Connection(ConnectionReason::NotConnected)); } Poll::Ready(Some(Ok(message))) => { - let message_result = self.handle_message(message)?; - if !message_result { - return Ok(false); - } + self.handle_message(message)?; } Poll::Ready(Some(Err(e))) => { for sub in self.subscriptions.values() { @@ -326,11 +309,11 @@ impl Future for PubsubConnectionInner { let this_self = self.get_mut(); this_self.handle_new_subs(cx)?; this_self.do_flush(cx)?; - let cont = this_self.handle_messages(cx)?; - if cont { - Poll::Pending - } else { + this_self.handle_messages(cx)?; + if this_self.out_rx.is_done() { Poll::Ready(Ok(())) + } else { + Poll::Pending } } } @@ -456,6 +439,8 @@ impl PubsubConnection { pub struct PubsubStream { topic: String, underlying: PubsubStreamInner, + // Note that, to keep the Future running, PubsubConnectionInner relies on PubsubStream to hold + // a reference to the connection. If that's ever changed remember to adapt the readiness check. con: PubsubConnection, } @@ -540,4 +525,41 @@ mod test { assert_eq!(result[1], "test-message-2".into()); assert_eq!(result[2], "test-message-3".into()); } + + #[tokio::test] + /// Regression test for https://github.com/benashford/redis-async-rs/issues/50 + async fn test_connection_remains_open_after_unsubscription() { + let addr = "127.0.0.1:6379".parse().unwrap(); + let pubsub = super::pubsub_connect(addr) + .await + .expect("Cannot connect to Redis"); + + let topic_messages = pubsub + .subscribe("test-topic") + .await + .expect("Cannot subscribe to topic"); + drop(topic_messages); + + pubsub + .subscribe("test-topic") + .await + .expect("Cannot subscribe to topic"); + } +} + +#[tokio::test] +async fn test_connection_is_closed_after_channel_is_dropped() { + let addr = "127.0.0.1:6379".parse().unwrap(); + let connection = connect_with_auth(&addr, None, None) + .await + .expect("Cannot connect to Redis"); + let (out_tx, out_rx) = mpsc::unbounded(); + let handle = tokio::spawn(async { + match PubsubConnectionInner::new(connection, out_rx).await { + Ok(_) => (), + Err(e) => log::error!("Pub/Sub error: {:?}", e), + } + }); + drop(out_tx); + handle.await.expect("Error waiting on the JoinHandle"); }