Skip to content

Commit 4751fa9

Browse files
pzhan9facebook-github-bot
authored andcommitted
Close channel when NetRx is being stopped (meta-pytorch#2010)
Summary: When `NetRx` is stopped, we should let its `NetTx` half know as well, so it will exit too and stop reconnecting. Otherwise, those reconnections would all fail, and lead to a log spew. Reviewed By: shayne-fletcher Differential Revision: D87926632
1 parent fac29b9 commit 4751fa9

File tree

3 files changed

+77
-8
lines changed

3 files changed

+77
-8
lines changed

hyperactor/src/channel/net.rs

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,10 @@ enum Frame<M> {
108108
#[derive(Debug, Serialize, Deserialize, EnumAsInner)]
109109
enum NetRxResponse {
110110
Ack(u64),
111-
/// This channel is closed with the given reason. NetTx should stop reconnecting.
111+
/// This session is rejected with the given reason. NetTx should stop reconnecting.
112112
Reject(String),
113+
/// This channel is closed.
114+
Closed,
113115
}
114116

115117
fn serialize_response(response: NetRxResponse) -> Result<Bytes, bincode::Error> {
@@ -1625,6 +1627,9 @@ mod tests {
16251627
handle.await.unwrap().unwrap();
16261628
// mpsc is closed too and there should be no unread message left.
16271629
assert!(rx.recv().await.is_none());
1630+
// should send NetRxResponse::Closed before stopping server.
1631+
let bytes = reader.next().await.unwrap().unwrap();
1632+
assert!(deserialize_response(bytes).unwrap().is_closed());
16281633
// No more acks from server.
16291634
assert!(reader.next().await.unwrap().is_none());
16301635
};
@@ -1659,6 +1664,9 @@ mod tests {
16591664
handle.await.unwrap().unwrap();
16601665
// mpsc is closed too and there should be no unread message left.
16611666
assert!(rx.recv().await.is_none());
1667+
// should send NetRxResponse::Closed before stopping server.
1668+
let bytes = reader.next().await.unwrap().unwrap();
1669+
assert!(deserialize_response(bytes).unwrap().is_closed());
16621670
// No more acks from server.
16631671
assert!(reader.next().await.unwrap().is_none());
16641672
}
@@ -2398,4 +2406,41 @@ mod tests {
23982406
let bytes = reader.next().await.unwrap().unwrap();
23992407
assert!(deserialize_response(bytes).unwrap().is_reject());
24002408
}
2409+
2410+
#[async_timed_test(timeout_secs = 60)]
2411+
// TODO: OSS: called `Result::unwrap()` on an `Err` value: Listen(Tcp([::1]:0), Os { code: 99, kind: AddrNotAvailable, message: "Cannot assign requested address" })
2412+
#[cfg_attr(not(fbcode_build), ignore)]
2413+
async fn test_stop_net_tx_after_stopping_net_rx() {
2414+
hyperactor_telemetry::initialize_logging_for_test();
2415+
2416+
let config = config::global::lock();
2417+
let _guard =
2418+
config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_secs(300));
2419+
let (addr, mut rx) = tcp::serve::<u64>("[::1]:0".parse().unwrap()).unwrap();
2420+
let socket_addr = match addr {
2421+
ChannelAddr::Tcp(a) => a,
2422+
_ => panic!("unexpected channel type"),
2423+
};
2424+
let tx = tcp::dial::<u64>(socket_addr);
2425+
// NetTx will not establish a connection until it sends the 1st message.
2426+
// Without a live connection, NetTx cannot received the Closed message
2427+
// from NetRx. Therefore, we need to send a message to establish the
2428+
//connection.
2429+
tx.send(100).await.unwrap();
2430+
assert_eq!(rx.recv().await.unwrap(), 100);
2431+
// Drop rx will close the NetRx server.
2432+
rx.2.stop("testing");
2433+
assert!(rx.recv().await.is_err());
2434+
2435+
// NetTx will only read from the stream when it needs to send a message
2436+
// or wait for an ack. Therefore we need to send a message to trigger that.
2437+
tx.post(101);
2438+
let mut watcher = tx.status().clone();
2439+
// When NetRx exits, it should notify NetTx to exit as well.
2440+
let _ = watcher.wait_for(|val| *val == TxStatus::Closed).await;
2441+
// wait_for could return Err due to race between when watch's sender was
2442+
// dropped and when wait_for was called. So we still need to do an
2443+
// equality check.
2444+
assert_eq!(*watcher.borrow(), TxStatus::Closed);
2445+
}
24012446
}

hyperactor/src/channel/net/client.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -940,7 +940,19 @@ where
940940
);
941941
(State::Closing {
942942
deliveries: Deliveries{outbox, unacked},
943-
reason: format!("{log_id}: {error_msg}"),
943+
reason: error_msg,
944+
}, Conn::reconnect_with_default())
945+
}
946+
NetRxResponse::Closed => {
947+
let msg = "server closed the channel".to_string();
948+
tracing::info!(
949+
dest = %link.dest(),
950+
session_id = session_id,
951+
"{}", msg
952+
);
953+
(State::Closing {
954+
deliveries: Deliveries{outbox, unacked},
955+
reason: msg,
944956
}, Conn::reconnect_with_default())
945957
}
946958
}

hyperactor/src/channel/net/server.rs

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,11 @@ impl<S: AsyncRead + AsyncWrite + Send + 'static + Unpin> ServerConn<S> {
8989
) -> (Next, Result<(), anyhow::Error>) {
9090
#[derive(Debug)]
9191
enum RejectConn {
92-
Yes(String),
92+
/// Reject the connection due to the given error.
93+
EncounterError(String),
94+
/// The server is being closed.
95+
ServerClosing,
96+
/// Do not reject the connection.
9397
No,
9498
}
9599

@@ -170,7 +174,7 @@ impl<S: AsyncRead + AsyncWrite + Send + 'static + Unpin> ServerConn<S> {
170174
// Have a tick to abort select! call to make sure the ack for the last message can get the chance
171175
// to be sent as a result of time interval being reached.
172176
_ = RealClock.sleep_until(last_ack_time + ack_time_interval), if next.ack < next.seq => {},
173-
_ = cancel_token.cancelled() => break (next, Ok(()), RejectConn::No),
177+
_ = cancel_token.cancelled() => break (next, Ok(()), RejectConn::ServerClosing),
174178
bytes_result = self.reader.next() => {
175179
rcv_raw_frame_count += 1;
176180
// First handle transport-level I/O errors, and EOFs.
@@ -231,7 +235,7 @@ impl<S: AsyncRead + AsyncWrite + Send + 'static + Unpin> ServerConn<S> {
231235
break (
232236
next,
233237
Err(anyhow::anyhow!("{log_id}: unexpected init frame")),
234-
RejectConn::Yes("expect Frame::Message; got Frame::Int".to_string()),
238+
RejectConn::EncounterError("expect Frame::Message; got Frame::Int".to_string()),
235239
)
236240
},
237241
// Ignore retransmits.
@@ -260,7 +264,7 @@ impl<S: AsyncRead + AsyncWrite + Send + 'static + Unpin> ServerConn<S> {
260264
break (
261265
next,
262266
Err(anyhow::anyhow!(format!("{log_id}: {error_msg}"))),
263-
RejectConn::Yes(error_msg),
267+
RejectConn::EncounterError(error_msg),
264268
)
265269
}
266270
match self.send_with_buffer_metric(session_id, &tx, message).await {
@@ -391,12 +395,20 @@ impl<S: AsyncRead + AsyncWrite + Send + 'static + Unpin> ServerConn<S> {
391395
}
392396

393397
if self.write_state.is_idle()
394-
&& let RejectConn::Yes(reason) = reject_conn
398+
&& matches!(
399+
reject_conn,
400+
RejectConn::EncounterError(_) | RejectConn::ServerClosing
401+
)
395402
{
396403
let Ok(writer) = replace(&mut self.write_state, WriteState::Broken).into_idle() else {
397404
panic!("illegal state");
398405
};
399-
if let Ok(data) = serialize_response(NetRxResponse::Reject(reason)) {
406+
let rsp = match reject_conn {
407+
RejectConn::EncounterError(reason) => NetRxResponse::Reject(reason),
408+
RejectConn::ServerClosing => NetRxResponse::Closed,
409+
RejectConn::No => panic!("illegal state"),
410+
};
411+
if let Ok(data) = serialize_response(rsp) {
400412
match FrameWrite::new(
401413
writer,
402414
data,

0 commit comments

Comments
 (0)