Skip to content

Commit 6833ad5

Browse files
authored
Merge pull request #30 from aumetra/minimize-api
Minimize API and add some documentation
2 parents a78e8c6 + d29076c commit 6833ad5

File tree

2 files changed

+151
-131
lines changed

2 files changed

+151
-131
lines changed

.clippy.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
doc-valid-idents = ["PostgreSQL"]

src/lib.rs

Lines changed: 150 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,36 @@
1-
use std::{
2-
convert::TryFrom,
3-
future::Future,
4-
io,
5-
pin::Pin,
6-
sync::Arc,
7-
task::{Context, Poll},
8-
};
9-
10-
use const_oid::db::{
11-
rfc5912::{
12-
ECDSA_WITH_SHA_256, ECDSA_WITH_SHA_384, ID_SHA_1, ID_SHA_256, ID_SHA_384, ID_SHA_512,
13-
SHA_1_WITH_RSA_ENCRYPTION, SHA_256_WITH_RSA_ENCRYPTION, SHA_384_WITH_RSA_ENCRYPTION,
14-
SHA_512_WITH_RSA_ENCRYPTION,
15-
},
16-
rfc8410::ID_ED_25519,
17-
};
18-
use ring::digest;
19-
use rustls::pki_types::ServerName;
20-
use rustls::ClientConfig;
21-
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
22-
use tokio_postgres::tls::{ChannelBinding, MakeTlsConnect, TlsConnect};
23-
use tokio_rustls::{client::TlsStream, TlsConnector};
24-
use x509_cert::{der::Decode, TbsCertificate};
1+
#![doc = include_str!("../README.md")]
2+
#![forbid(rust_2018_idioms)]
3+
#![deny(missing_docs, unsafe_code)]
4+
#![warn(clippy::all, clippy::pedantic)]
5+
6+
use std::{convert::TryFrom, sync::Arc};
7+
8+
use rustls::{pki_types::ServerName, ClientConfig};
9+
use tokio::io::{AsyncRead, AsyncWrite};
10+
use tokio_postgres::tls::MakeTlsConnect;
2511

2612
mod private {
27-
use super::*;
13+
use std::{
14+
future::Future,
15+
io,
16+
pin::Pin,
17+
task::{Context, Poll},
18+
};
19+
20+
use const_oid::db::{
21+
rfc5912::{
22+
ECDSA_WITH_SHA_256, ECDSA_WITH_SHA_384, ID_SHA_1, ID_SHA_256, ID_SHA_384, ID_SHA_512,
23+
SHA_1_WITH_RSA_ENCRYPTION, SHA_256_WITH_RSA_ENCRYPTION, SHA_384_WITH_RSA_ENCRYPTION,
24+
SHA_512_WITH_RSA_ENCRYPTION,
25+
},
26+
rfc8410::ID_ED_25519,
27+
};
28+
use ring::digest;
29+
use rustls::pki_types::ServerName;
30+
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
31+
use tokio_postgres::tls::{ChannelBinding, TlsConnect};
32+
use tokio_rustls::{client::TlsStream, TlsConnector};
33+
use x509_cert::{der::Decode, TbsCertificate};
2834

2935
pub struct TlsConnectFuture<S> {
3036
pub inner: tokio_rustls::Connect<S>,
@@ -36,20 +42,134 @@ mod private {
3642
{
3743
type Output = io::Result<RustlsStream<S>>;
3844

39-
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
45+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
4046
// SAFETY: If `self` is pinned, so is `inner`.
47+
#[allow(unsafe_code)]
4148
let fut = unsafe { self.map_unchecked_mut(|this| &mut this.inner) };
4249
fut.poll(cx).map_ok(RustlsStream)
4350
}
4451
}
52+
53+
pub struct RustlsConnect(pub RustlsConnectData);
54+
55+
pub struct RustlsConnectData {
56+
pub hostname: ServerName<'static>,
57+
pub connector: TlsConnector,
58+
}
59+
60+
impl<S> TlsConnect<S> for RustlsConnect
61+
where
62+
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
63+
{
64+
type Stream = RustlsStream<S>;
65+
type Error = io::Error;
66+
type Future = TlsConnectFuture<S>;
67+
68+
fn connect(self, stream: S) -> Self::Future {
69+
TlsConnectFuture {
70+
inner: self.0.connector.connect(self.0.hostname, stream),
71+
}
72+
}
73+
}
74+
75+
pub struct RustlsStream<S>(TlsStream<S>);
76+
77+
impl<S> RustlsStream<S> {
78+
pub fn project_stream(self: Pin<&mut Self>) -> Pin<&mut TlsStream<S>> {
79+
// SAFETY: When `Self` is pinned, so is the inner `TlsStream`.
80+
#[allow(unsafe_code)]
81+
unsafe {
82+
self.map_unchecked_mut(|this| &mut this.0)
83+
}
84+
}
85+
}
86+
87+
impl<S> tokio_postgres::tls::TlsStream for RustlsStream<S>
88+
where
89+
S: AsyncRead + AsyncWrite + Unpin,
90+
{
91+
fn channel_binding(&self) -> ChannelBinding {
92+
let (_, session) = self.0.get_ref();
93+
match session.peer_certificates() {
94+
Some(certs) if !certs.is_empty() => TbsCertificate::from_der(&certs[0])
95+
.ok()
96+
.and_then(|cert| {
97+
let digest = match cert.signature.oid {
98+
// Note: SHA1 is upgraded to SHA256 as per https://datatracker.ietf.org/doc/html/rfc5929#section-4.1
99+
ID_SHA_1
100+
| ID_SHA_256
101+
| SHA_1_WITH_RSA_ENCRYPTION
102+
| SHA_256_WITH_RSA_ENCRYPTION
103+
| ECDSA_WITH_SHA_256 => &digest::SHA256,
104+
ID_SHA_384 | SHA_384_WITH_RSA_ENCRYPTION | ECDSA_WITH_SHA_384 => {
105+
&digest::SHA384
106+
}
107+
ID_SHA_512 | SHA_512_WITH_RSA_ENCRYPTION | ID_ED_25519 => {
108+
&digest::SHA512
109+
}
110+
_ => return None,
111+
};
112+
113+
Some(digest)
114+
})
115+
.map_or_else(ChannelBinding::none, |algorithm| {
116+
let hash = digest::digest(algorithm, certs[0].as_ref());
117+
ChannelBinding::tls_server_end_point(hash.as_ref().into())
118+
}),
119+
_ => ChannelBinding::none(),
120+
}
121+
}
122+
}
123+
124+
impl<S> AsyncRead for RustlsStream<S>
125+
where
126+
S: AsyncRead + AsyncWrite + Unpin,
127+
{
128+
fn poll_read(
129+
self: Pin<&mut Self>,
130+
cx: &mut Context<'_>,
131+
buf: &mut ReadBuf<'_>,
132+
) -> Poll<tokio::io::Result<()>> {
133+
self.project_stream().poll_read(cx, buf)
134+
}
135+
}
136+
137+
impl<S> AsyncWrite for RustlsStream<S>
138+
where
139+
S: AsyncRead + AsyncWrite + Unpin,
140+
{
141+
fn poll_write(
142+
self: Pin<&mut Self>,
143+
cx: &mut Context<'_>,
144+
buf: &[u8],
145+
) -> Poll<tokio::io::Result<usize>> {
146+
self.project_stream().poll_write(cx, buf)
147+
}
148+
149+
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>> {
150+
self.project_stream().poll_flush(cx)
151+
}
152+
153+
fn poll_shutdown(
154+
self: Pin<&mut Self>,
155+
cx: &mut Context<'_>,
156+
) -> Poll<tokio::io::Result<()>> {
157+
self.project_stream().poll_shutdown(cx)
158+
}
159+
}
45160
}
46161

162+
/// A `MakeTlsConnect` implementation using `rustls`.
163+
///
164+
/// That way you can connect to PostgreSQL using `rustls` as the TLS stack.
47165
#[derive(Clone)]
48166
pub struct MakeRustlsConnect {
49167
config: Arc<ClientConfig>,
50168
}
51169

52170
impl MakeRustlsConnect {
171+
/// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`.
172+
#[must_use]
53173
pub fn new(config: ClientConfig) -> Self {
54174
Self {
55175
config: Arc::new(config),
@@ -61,121 +181,20 @@ impl<S> MakeTlsConnect<S> for MakeRustlsConnect
61181
where
62182
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
63183
{
64-
type Stream = RustlsStream<S>;
65-
type TlsConnect = RustlsConnect;
184+
type Stream = private::RustlsStream<S>;
185+
type TlsConnect = private::RustlsConnect;
66186
type Error = rustls::pki_types::InvalidDnsNameError;
67187

68-
fn make_tls_connect(&mut self, hostname: &str) -> Result<RustlsConnect, Self::Error> {
188+
fn make_tls_connect(&mut self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
69189
ServerName::try_from(hostname).map(|dns_name| {
70-
RustlsConnect(RustlsConnectData {
190+
private::RustlsConnect(private::RustlsConnectData {
71191
hostname: dns_name.to_owned(),
72192
connector: Arc::clone(&self.config).into(),
73193
})
74194
})
75195
}
76196
}
77197

78-
pub struct RustlsConnect(RustlsConnectData);
79-
80-
struct RustlsConnectData {
81-
hostname: ServerName<'static>,
82-
connector: TlsConnector,
83-
}
84-
85-
impl<S> TlsConnect<S> for RustlsConnect
86-
where
87-
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
88-
{
89-
type Stream = RustlsStream<S>;
90-
type Error = io::Error;
91-
type Future = private::TlsConnectFuture<S>;
92-
93-
fn connect(self, stream: S) -> Self::Future {
94-
private::TlsConnectFuture {
95-
inner: self.0.connector.connect(self.0.hostname, stream),
96-
}
97-
}
98-
}
99-
100-
pub struct RustlsStream<S>(TlsStream<S>);
101-
102-
impl<S> RustlsStream<S> {
103-
pub fn project_stream(self: Pin<&mut Self>) -> Pin<&mut TlsStream<S>> {
104-
// SAFETY: When `Self` is pinned, so is the inner `TlsStream`.
105-
unsafe { self.map_unchecked_mut(|this| &mut this.0) }
106-
}
107-
}
108-
109-
impl<S> tokio_postgres::tls::TlsStream for RustlsStream<S>
110-
where
111-
S: AsyncRead + AsyncWrite + Unpin,
112-
{
113-
fn channel_binding(&self) -> ChannelBinding {
114-
let (_, session) = self.0.get_ref();
115-
match session.peer_certificates() {
116-
Some(certs) if !certs.is_empty() => TbsCertificate::from_der(&certs[0])
117-
.ok()
118-
.and_then(|cert| {
119-
let digest = match cert.signature.oid {
120-
// Note: SHA1 is upgraded to SHA256 as per https://datatracker.ietf.org/doc/html/rfc5929#section-4.1
121-
ID_SHA_1
122-
| ID_SHA_256
123-
| SHA_1_WITH_RSA_ENCRYPTION
124-
| SHA_256_WITH_RSA_ENCRYPTION
125-
| ECDSA_WITH_SHA_256 => &digest::SHA256,
126-
ID_SHA_384 | SHA_384_WITH_RSA_ENCRYPTION | ECDSA_WITH_SHA_384 => {
127-
&digest::SHA384
128-
}
129-
ID_SHA_512 | SHA_512_WITH_RSA_ENCRYPTION | ID_ED_25519 => &digest::SHA512,
130-
_ => return None,
131-
};
132-
133-
Some(digest)
134-
})
135-
.map(|algorithm| {
136-
let hash = digest::digest(algorithm, certs[0].as_ref());
137-
ChannelBinding::tls_server_end_point(hash.as_ref().into())
138-
})
139-
.unwrap_or(ChannelBinding::none()),
140-
_ => ChannelBinding::none(),
141-
}
142-
}
143-
}
144-
145-
impl<S> AsyncRead for RustlsStream<S>
146-
where
147-
S: AsyncRead + AsyncWrite + Unpin,
148-
{
149-
fn poll_read(
150-
self: Pin<&mut Self>,
151-
cx: &mut Context,
152-
buf: &mut ReadBuf<'_>,
153-
) -> Poll<tokio::io::Result<()>> {
154-
self.project_stream().poll_read(cx, buf)
155-
}
156-
}
157-
158-
impl<S> AsyncWrite for RustlsStream<S>
159-
where
160-
S: AsyncRead + AsyncWrite + Unpin,
161-
{
162-
fn poll_write(
163-
self: Pin<&mut Self>,
164-
cx: &mut Context,
165-
buf: &[u8],
166-
) -> Poll<tokio::io::Result<usize>> {
167-
self.project_stream().poll_write(cx, buf)
168-
}
169-
170-
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<tokio::io::Result<()>> {
171-
self.project_stream().poll_flush(cx)
172-
}
173-
174-
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<tokio::io::Result<()>> {
175-
self.project_stream().poll_shutdown(cx)
176-
}
177-
}
178-
179198
#[cfg(test)]
180199
mod tests {
181200
use super::*;

0 commit comments

Comments
 (0)