Skip to content

Commit c3226ac

Browse files
pzhan9facebook-github-bot
authored andcommitted
Close channel when NetRx is being stopped (#2010)
Summary: When `NetRx` is stopped, we should let its `NetTx` half know as well, so it will exit too and stop reconnecting. Otherwise, those reconnections would all fail, and lead to a log spew. Reviewed By: shayne-fletcher Differential Revision: D87926632
1 parent 75afa82 commit c3226ac

File tree

3 files changed

+77
-8
lines changed

3 files changed

+77
-8
lines changed

hyperactor/src/channel/net.rs

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,10 @@ enum Frame<M> {
106106
#[derive(Debug, Serialize, Deserialize, EnumAsInner)]
107107
enum NetRxResponse {
108108
Ack(u64),
109-
/// This channel is closed with the given reason. NetTx should stop reconnecting.
109+
/// This session is rejected with the given reason. NetTx should stop reconnecting.
110110
Reject(String),
111+
/// This channel is closed.
112+
Closed,
111113
}
112114

113115
fn serialize_response(response: NetRxResponse) -> Result<Bytes, bincode::Error> {
@@ -1612,6 +1614,9 @@ mod tests {
16121614
handle.await.unwrap().unwrap();
16131615
// mpsc is closed too and there should be no unread message left.
16141616
assert!(rx.recv().await.is_none());
1617+
// should send NetRxResponse::Closed before stopping server.
1618+
let bytes = reader.next().await.unwrap().unwrap();
1619+
assert!(deserialize_response(bytes).unwrap().is_closed());
16151620
// No more acks from server.
16161621
assert!(reader.next().await.unwrap().is_none());
16171622
};
@@ -1646,6 +1651,9 @@ mod tests {
16461651
handle.await.unwrap().unwrap();
16471652
// mpsc is closed too and there should be no unread message left.
16481653
assert!(rx.recv().await.is_none());
1654+
// should send NetRxResponse::Closed before stopping server.
1655+
let bytes = reader.next().await.unwrap().unwrap();
1656+
assert!(deserialize_response(bytes).unwrap().is_closed());
16491657
// No more acks from server.
16501658
assert!(reader.next().await.unwrap().is_none());
16511659
}
@@ -2385,4 +2393,41 @@ mod tests {
23852393
let bytes = reader.next().await.unwrap().unwrap();
23862394
assert!(deserialize_response(bytes).unwrap().is_reject());
23872395
}
2396+
2397+
#[async_timed_test(timeout_secs = 60)]
2398+
// TODO: OSS: called `Result::unwrap()` on an `Err` value: Listen(Tcp([::1]:0), Os { code: 99, kind: AddrNotAvailable, message: "Cannot assign requested address" })
2399+
#[cfg_attr(not(fbcode_build), ignore)]
2400+
async fn test_stop_net_tx_after_stopping_net_rx() {
2401+
hyperactor_telemetry::initialize_logging_for_test();
2402+
2403+
let config = config::global::lock();
2404+
let _guard =
2405+
config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_secs(300));
2406+
let (addr, mut rx) = tcp::serve::<u64>("[::1]:0".parse().unwrap()).unwrap();
2407+
let socket_addr = match addr {
2408+
ChannelAddr::Tcp(a) => a,
2409+
_ => panic!("unexpected channel type"),
2410+
};
2411+
let tx = tcp::dial::<u64>(socket_addr);
2412+
// NetTx will not establish a connection until it sends the 1st message.
2413+
// Without a live connection, NetTx cannot received the Closed message
2414+
// from NetRx. Therefore, we need to send a message to establish the
2415+
//connection.
2416+
tx.send(100).await.unwrap();
2417+
assert_eq!(rx.recv().await.unwrap(), 100);
2418+
// Drop rx will close the NetRx server.
2419+
rx.2.stop("testing");
2420+
assert!(rx.recv().await.is_err());
2421+
2422+
// NetTx will only read from the stream when it needs to send a message
2423+
// or wait for an ack. Therefore we need to send a message to trigger that.
2424+
tx.post(101);
2425+
let mut watcher = tx.status().clone();
2426+
// When NetRx exits, it should notify NetTx to exit as well.
2427+
let _ = watcher.wait_for(|val| *val == TxStatus::Closed).await;
2428+
// wait_for could return Err due to race between when watch's sender was
2429+
// dropped and when wait_for was called. So we still need to do an
2430+
// equality check.
2431+
assert_eq!(*watcher.borrow(), TxStatus::Closed);
2432+
}
23882433
}

hyperactor/src/channel/net/client.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -943,7 +943,19 @@ where
943943
);
944944
(State::Closing {
945945
deliveries: Deliveries{outbox, unacked},
946-
reason: format!("{log_id}: {error_msg}"),
946+
reason: error_msg,
947+
}, Conn::reconnect_with_default())
948+
}
949+
NetRxResponse::Closed => {
950+
let msg = "server closed the channel".to_string();
951+
tracing::info!(
952+
dest = %link.dest(),
953+
session_id = session_id,
954+
"{}", msg
955+
);
956+
(State::Closing {
957+
deliveries: Deliveries{outbox, unacked},
958+
reason: msg,
947959
}, Conn::reconnect_with_default())
948960
}
949961
}

hyperactor/src/channel/net/server.rs

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,11 @@ impl<S: AsyncRead + AsyncWrite + Send + 'static + Unpin> ServerConn<S> {
8989
) -> (Next, Result<(), anyhow::Error>) {
9090
#[derive(Debug)]
9191
enum RejectConn {
92-
Yes(String),
92+
/// Reject the connection due to the given error.
93+
EncounterError(String),
94+
/// The server is being closed.
95+
ServerClosing,
96+
/// Do not reject the connection.
9397
No,
9498
}
9599

@@ -170,7 +174,7 @@ impl<S: AsyncRead + AsyncWrite + Send + 'static + Unpin> ServerConn<S> {
170174
// Have a tick to abort select! call to make sure the ack for the last message can get the chance
171175
// to be sent as a result of time interval being reached.
172176
_ = RealClock.sleep_until(last_ack_time + ack_time_interval), if next.ack < next.seq => {},
173-
_ = cancel_token.cancelled() => break (next, Ok(()), RejectConn::No),
177+
_ = cancel_token.cancelled() => break (next, Ok(()), RejectConn::ServerClosing),
174178
bytes_result = self.reader.next() => {
175179
rcv_raw_frame_count += 1;
176180
// First handle transport-level I/O errors, and EOFs.
@@ -231,7 +235,7 @@ impl<S: AsyncRead + AsyncWrite + Send + 'static + Unpin> ServerConn<S> {
231235
break (
232236
next,
233237
Err(anyhow::anyhow!("{log_id}: unexpected init frame")),
234-
RejectConn::Yes("expect Frame::Message; got Frame::Int".to_string()),
238+
RejectConn::EncounterError("expect Frame::Message; got Frame::Int".to_string()),
235239
)
236240
},
237241
// Ignore retransmits.
@@ -260,7 +264,7 @@ impl<S: AsyncRead + AsyncWrite + Send + 'static + Unpin> ServerConn<S> {
260264
break (
261265
next,
262266
Err(anyhow::anyhow!(format!("{log_id}: {error_msg}"))),
263-
RejectConn::Yes(error_msg),
267+
RejectConn::EncounterError(error_msg),
264268
)
265269
}
266270
match self.send_with_buffer_metric(session_id, &tx, message).await {
@@ -391,12 +395,20 @@ impl<S: AsyncRead + AsyncWrite + Send + 'static + Unpin> ServerConn<S> {
391395
}
392396

393397
if self.write_state.is_idle()
394-
&& let RejectConn::Yes(reason) = reject_conn
398+
&& matches!(
399+
reject_conn,
400+
RejectConn::EncounterError(_) | RejectConn::ServerClosing
401+
)
395402
{
396403
let Ok(writer) = replace(&mut self.write_state, WriteState::Broken).into_idle() else {
397404
panic!("illegal state");
398405
};
399-
if let Ok(data) = serialize_response(NetRxResponse::Reject(reason)) {
406+
let rsp = match reject_conn {
407+
RejectConn::EncounterError(reason) => NetRxResponse::Reject(reason),
408+
RejectConn::ServerClosing => NetRxResponse::Closed,
409+
RejectConn::No => panic!("illegal state"),
410+
};
411+
if let Ok(data) = serialize_response(rsp) {
400412
match FrameWrite::new(
401413
writer,
402414
data,

0 commit comments

Comments
 (0)