Skip to content

Commit a78e8c6

Browse files
authored
Merge pull request #28 from aumetra/remove-allocations
Remove allocations
2 parents 1015b8d + 0a99a3c commit a78e8c6

File tree

1 file changed

+41
-17
lines changed

1 file changed

+41
-17
lines changed

src/lib.rs

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,27 @@ use tokio_postgres::tls::{ChannelBinding, MakeTlsConnect, TlsConnect};
2323
use tokio_rustls::{client::TlsStream, TlsConnector};
2424
use x509_cert::{der::Decode, TbsCertificate};
2525

26+
mod private {
27+
use super::*;
28+
29+
pub struct TlsConnectFuture<S> {
30+
pub inner: tokio_rustls::Connect<S>,
31+
}
32+
33+
impl<S> Future for TlsConnectFuture<S>
34+
where
35+
S: AsyncRead + AsyncWrite + Unpin,
36+
{
37+
type Output = io::Result<RustlsStream<S>>;
38+
39+
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
40+
// SAFETY: If `self` is pinned, so is `inner`.
41+
let fut = unsafe { self.map_unchecked_mut(|this| &mut this.inner) };
42+
fut.poll(cx).map_ok(RustlsStream)
43+
}
44+
}
45+
}
46+
2647
#[derive(Clone)]
2748
pub struct MakeRustlsConnect {
2849
config: Arc<ClientConfig>,
@@ -67,20 +88,23 @@ where
6788
{
6889
type Stream = RustlsStream<S>;
6990
type Error = io::Error;
70-
type Future = Pin<Box<dyn Future<Output = io::Result<RustlsStream<S>>> + Send>>;
91+
type Future = private::TlsConnectFuture<S>;
7192

7293
fn connect(self, stream: S) -> Self::Future {
73-
Box::pin(async move {
74-
self.0
75-
.connector
76-
.connect(self.0.hostname, stream)
77-
.await
78-
.map(|s| RustlsStream(Box::pin(s)))
79-
})
94+
private::TlsConnectFuture {
95+
inner: self.0.connector.connect(self.0.hostname, stream),
96+
}
8097
}
8198
}
8299

83-
pub struct RustlsStream<S>(Pin<Box<TlsStream<S>>>);
100+
pub struct RustlsStream<S>(TlsStream<S>);
101+
102+
impl<S> RustlsStream<S> {
103+
pub fn project_stream(self: Pin<&mut Self>) -> Pin<&mut TlsStream<S>> {
104+
// SAFETY: When `Self` is pinned, so is the inner `TlsStream`.
105+
unsafe { self.map_unchecked_mut(|this| &mut this.0) }
106+
}
107+
}
84108

85109
impl<S> tokio_postgres::tls::TlsStream for RustlsStream<S>
86110
where
@@ -123,11 +147,11 @@ where
123147
S: AsyncRead + AsyncWrite + Unpin,
124148
{
125149
fn poll_read(
126-
mut self: Pin<&mut Self>,
150+
self: Pin<&mut Self>,
127151
cx: &mut Context,
128152
buf: &mut ReadBuf<'_>,
129153
) -> Poll<tokio::io::Result<()>> {
130-
self.0.as_mut().poll_read(cx, buf)
154+
self.project_stream().poll_read(cx, buf)
131155
}
132156
}
133157

@@ -136,19 +160,19 @@ where
136160
S: AsyncRead + AsyncWrite + Unpin,
137161
{
138162
fn poll_write(
139-
mut self: Pin<&mut Self>,
163+
self: Pin<&mut Self>,
140164
cx: &mut Context,
141165
buf: &[u8],
142166
) -> Poll<tokio::io::Result<usize>> {
143-
self.0.as_mut().poll_write(cx, buf)
167+
self.project_stream().poll_write(cx, buf)
144168
}
145169

146-
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<tokio::io::Result<()>> {
147-
self.0.as_mut().poll_flush(cx)
170+
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<tokio::io::Result<()>> {
171+
self.project_stream().poll_flush(cx)
148172
}
149173

150-
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<tokio::io::Result<()>> {
151-
self.0.as_mut().poll_shutdown(cx)
174+
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<tokio::io::Result<()>> {
175+
self.project_stream().poll_shutdown(cx)
152176
}
153177
}
154178

0 commit comments

Comments
 (0)