1- use std:: {
2- convert:: TryFrom ,
3- future:: Future ,
4- io,
5- pin:: Pin ,
6- sync:: Arc ,
7- task:: { Context , Poll } ,
8- } ;
9-
10- use const_oid:: db:: {
11- rfc5912:: {
12- ECDSA_WITH_SHA_256 , ECDSA_WITH_SHA_384 , ID_SHA_1 , ID_SHA_256 , ID_SHA_384 , ID_SHA_512 ,
13- SHA_1_WITH_RSA_ENCRYPTION , SHA_256_WITH_RSA_ENCRYPTION , SHA_384_WITH_RSA_ENCRYPTION ,
14- SHA_512_WITH_RSA_ENCRYPTION ,
15- } ,
16- rfc8410:: ID_ED_25519 ,
17- } ;
18- use ring:: digest;
19- use rustls:: pki_types:: ServerName ;
20- use rustls:: ClientConfig ;
21- use tokio:: io:: { AsyncRead , AsyncWrite , ReadBuf } ;
22- use tokio_postgres:: tls:: { ChannelBinding , MakeTlsConnect , TlsConnect } ;
23- use tokio_rustls:: { client:: TlsStream , TlsConnector } ;
24- use x509_cert:: { der:: Decode , TbsCertificate } ;
1+ #![ doc = include_str ! ( "../README.md" ) ]
2+ #![ forbid( rust_2018_idioms) ]
3+ #![ deny( missing_docs, unsafe_code) ]
4+ #![ warn( clippy:: all, clippy:: pedantic) ]
5+
6+ use std:: { convert:: TryFrom , sync:: Arc } ;
7+
8+ use rustls:: { pki_types:: ServerName , ClientConfig } ;
9+ use tokio:: io:: { AsyncRead , AsyncWrite } ;
10+ use tokio_postgres:: tls:: MakeTlsConnect ;
2511
2612mod private {
27- use super :: * ;
13+ use std:: {
14+ future:: Future ,
15+ io,
16+ pin:: Pin ,
17+ task:: { Context , Poll } ,
18+ } ;
19+
20+ use const_oid:: db:: {
21+ rfc5912:: {
22+ ECDSA_WITH_SHA_256 , ECDSA_WITH_SHA_384 , ID_SHA_1 , ID_SHA_256 , ID_SHA_384 , ID_SHA_512 ,
23+ SHA_1_WITH_RSA_ENCRYPTION , SHA_256_WITH_RSA_ENCRYPTION , SHA_384_WITH_RSA_ENCRYPTION ,
24+ SHA_512_WITH_RSA_ENCRYPTION ,
25+ } ,
26+ rfc8410:: ID_ED_25519 ,
27+ } ;
28+ use ring:: digest;
29+ use rustls:: pki_types:: ServerName ;
30+ use tokio:: io:: { AsyncRead , AsyncWrite , ReadBuf } ;
31+ use tokio_postgres:: tls:: { ChannelBinding , TlsConnect } ;
32+ use tokio_rustls:: { client:: TlsStream , TlsConnector } ;
33+ use x509_cert:: { der:: Decode , TbsCertificate } ;
2834
2935 pub struct TlsConnectFuture < S > {
3036 pub inner : tokio_rustls:: Connect < S > ,
@@ -36,20 +42,134 @@ mod private {
3642 {
3743 type Output = io:: Result < RustlsStream < S > > ;
3844
39- fn poll ( self : Pin < & mut Self > , cx : & mut Context ) -> Poll < Self :: Output > {
45+ fn poll ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Self :: Output > {
4046 // SAFETY: If `self` is pinned, so is `inner`.
47+ #[ allow( unsafe_code) ]
4148 let fut = unsafe { self . map_unchecked_mut ( |this| & mut this. inner ) } ;
4249 fut. poll ( cx) . map_ok ( RustlsStream )
4350 }
4451 }
52+
53+ pub struct RustlsConnect ( pub RustlsConnectData ) ;
54+
55+ pub struct RustlsConnectData {
56+ pub hostname : ServerName < ' static > ,
57+ pub connector : TlsConnector ,
58+ }
59+
60+ impl < S > TlsConnect < S > for RustlsConnect
61+ where
62+ S : AsyncRead + AsyncWrite + Unpin + Send + ' static ,
63+ {
64+ type Stream = RustlsStream < S > ;
65+ type Error = io:: Error ;
66+ type Future = TlsConnectFuture < S > ;
67+
68+ fn connect ( self , stream : S ) -> Self :: Future {
69+ TlsConnectFuture {
70+ inner : self . 0 . connector . connect ( self . 0 . hostname , stream) ,
71+ }
72+ }
73+ }
74+
75+ pub struct RustlsStream < S > ( TlsStream < S > ) ;
76+
77+ impl < S > RustlsStream < S > {
78+ pub fn project_stream ( self : Pin < & mut Self > ) -> Pin < & mut TlsStream < S > > {
79+ // SAFETY: When `Self` is pinned, so is the inner `TlsStream`.
80+ #[ allow( unsafe_code) ]
81+ unsafe {
82+ self . map_unchecked_mut ( |this| & mut this. 0 )
83+ }
84+ }
85+ }
86+
87+ impl < S > tokio_postgres:: tls:: TlsStream for RustlsStream < S >
88+ where
89+ S : AsyncRead + AsyncWrite + Unpin ,
90+ {
91+ fn channel_binding ( & self ) -> ChannelBinding {
92+ let ( _, session) = self . 0 . get_ref ( ) ;
93+ match session. peer_certificates ( ) {
94+ Some ( certs) if !certs. is_empty ( ) => TbsCertificate :: from_der ( & certs[ 0 ] )
95+ . ok ( )
96+ . and_then ( |cert| {
97+ let digest = match cert. signature . oid {
98+ // Note: SHA1 is upgraded to SHA256 as per https://datatracker.ietf.org/doc/html/rfc5929#section-4.1
99+ ID_SHA_1
100+ | ID_SHA_256
101+ | SHA_1_WITH_RSA_ENCRYPTION
102+ | SHA_256_WITH_RSA_ENCRYPTION
103+ | ECDSA_WITH_SHA_256 => & digest:: SHA256 ,
104+ ID_SHA_384 | SHA_384_WITH_RSA_ENCRYPTION | ECDSA_WITH_SHA_384 => {
105+ & digest:: SHA384
106+ }
107+ ID_SHA_512 | SHA_512_WITH_RSA_ENCRYPTION | ID_ED_25519 => {
108+ & digest:: SHA512
109+ }
110+ _ => return None ,
111+ } ;
112+
113+ Some ( digest)
114+ } )
115+ . map_or_else ( ChannelBinding :: none, |algorithm| {
116+ let hash = digest:: digest ( algorithm, certs[ 0 ] . as_ref ( ) ) ;
117+ ChannelBinding :: tls_server_end_point ( hash. as_ref ( ) . into ( ) )
118+ } ) ,
119+ _ => ChannelBinding :: none ( ) ,
120+ }
121+ }
122+ }
123+
124+ impl < S > AsyncRead for RustlsStream < S >
125+ where
126+ S : AsyncRead + AsyncWrite + Unpin ,
127+ {
128+ fn poll_read (
129+ self : Pin < & mut Self > ,
130+ cx : & mut Context < ' _ > ,
131+ buf : & mut ReadBuf < ' _ > ,
132+ ) -> Poll < tokio:: io:: Result < ( ) > > {
133+ self . project_stream ( ) . poll_read ( cx, buf)
134+ }
135+ }
136+
137+ impl < S > AsyncWrite for RustlsStream < S >
138+ where
139+ S : AsyncRead + AsyncWrite + Unpin ,
140+ {
141+ fn poll_write (
142+ self : Pin < & mut Self > ,
143+ cx : & mut Context < ' _ > ,
144+ buf : & [ u8 ] ,
145+ ) -> Poll < tokio:: io:: Result < usize > > {
146+ self . project_stream ( ) . poll_write ( cx, buf)
147+ }
148+
149+ fn poll_flush ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < tokio:: io:: Result < ( ) > > {
150+ self . project_stream ( ) . poll_flush ( cx)
151+ }
152+
153+ fn poll_shutdown (
154+ self : Pin < & mut Self > ,
155+ cx : & mut Context < ' _ > ,
156+ ) -> Poll < tokio:: io:: Result < ( ) > > {
157+ self . project_stream ( ) . poll_shutdown ( cx)
158+ }
159+ }
45160}
46161
162+ /// A `MakeTlsConnect` implementation using `rustls`.
163+ ///
164+ /// That way you can connect to PostgreSQL using `rustls` as the TLS stack.
47165#[ derive( Clone ) ]
48166pub struct MakeRustlsConnect {
49167 config : Arc < ClientConfig > ,
50168}
51169
52170impl MakeRustlsConnect {
171+ /// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`.
172+ #[ must_use]
53173 pub fn new ( config : ClientConfig ) -> Self {
54174 Self {
55175 config : Arc :: new ( config) ,
@@ -61,121 +181,20 @@ impl<S> MakeTlsConnect<S> for MakeRustlsConnect
61181where
62182 S : AsyncRead + AsyncWrite + Unpin + Send + ' static ,
63183{
64- type Stream = RustlsStream < S > ;
65- type TlsConnect = RustlsConnect ;
184+ type Stream = private :: RustlsStream < S > ;
185+ type TlsConnect = private :: RustlsConnect ;
66186 type Error = rustls:: pki_types:: InvalidDnsNameError ;
67187
68- fn make_tls_connect ( & mut self , hostname : & str ) -> Result < RustlsConnect , Self :: Error > {
188+ fn make_tls_connect ( & mut self , hostname : & str ) -> Result < Self :: TlsConnect , Self :: Error > {
69189 ServerName :: try_from ( hostname) . map ( |dns_name| {
70- RustlsConnect ( RustlsConnectData {
190+ private :: RustlsConnect ( private :: RustlsConnectData {
71191 hostname : dns_name. to_owned ( ) ,
72192 connector : Arc :: clone ( & self . config ) . into ( ) ,
73193 } )
74194 } )
75195 }
76196}
77197
78- pub struct RustlsConnect ( RustlsConnectData ) ;
79-
80- struct RustlsConnectData {
81- hostname : ServerName < ' static > ,
82- connector : TlsConnector ,
83- }
84-
85- impl < S > TlsConnect < S > for RustlsConnect
86- where
87- S : AsyncRead + AsyncWrite + Unpin + Send + ' static ,
88- {
89- type Stream = RustlsStream < S > ;
90- type Error = io:: Error ;
91- type Future = private:: TlsConnectFuture < S > ;
92-
93- fn connect ( self , stream : S ) -> Self :: Future {
94- private:: TlsConnectFuture {
95- inner : self . 0 . connector . connect ( self . 0 . hostname , stream) ,
96- }
97- }
98- }
99-
100- pub struct RustlsStream < S > ( TlsStream < S > ) ;
101-
102- impl < S > RustlsStream < S > {
103- pub fn project_stream ( self : Pin < & mut Self > ) -> Pin < & mut TlsStream < S > > {
104- // SAFETY: When `Self` is pinned, so is the inner `TlsStream`.
105- unsafe { self . map_unchecked_mut ( |this| & mut this. 0 ) }
106- }
107- }
108-
109- impl < S > tokio_postgres:: tls:: TlsStream for RustlsStream < S >
110- where
111- S : AsyncRead + AsyncWrite + Unpin ,
112- {
113- fn channel_binding ( & self ) -> ChannelBinding {
114- let ( _, session) = self . 0 . get_ref ( ) ;
115- match session. peer_certificates ( ) {
116- Some ( certs) if !certs. is_empty ( ) => TbsCertificate :: from_der ( & certs[ 0 ] )
117- . ok ( )
118- . and_then ( |cert| {
119- let digest = match cert. signature . oid {
120- // Note: SHA1 is upgraded to SHA256 as per https://datatracker.ietf.org/doc/html/rfc5929#section-4.1
121- ID_SHA_1
122- | ID_SHA_256
123- | SHA_1_WITH_RSA_ENCRYPTION
124- | SHA_256_WITH_RSA_ENCRYPTION
125- | ECDSA_WITH_SHA_256 => & digest:: SHA256 ,
126- ID_SHA_384 | SHA_384_WITH_RSA_ENCRYPTION | ECDSA_WITH_SHA_384 => {
127- & digest:: SHA384
128- }
129- ID_SHA_512 | SHA_512_WITH_RSA_ENCRYPTION | ID_ED_25519 => & digest:: SHA512 ,
130- _ => return None ,
131- } ;
132-
133- Some ( digest)
134- } )
135- . map ( |algorithm| {
136- let hash = digest:: digest ( algorithm, certs[ 0 ] . as_ref ( ) ) ;
137- ChannelBinding :: tls_server_end_point ( hash. as_ref ( ) . into ( ) )
138- } )
139- . unwrap_or ( ChannelBinding :: none ( ) ) ,
140- _ => ChannelBinding :: none ( ) ,
141- }
142- }
143- }
144-
145- impl < S > AsyncRead for RustlsStream < S >
146- where
147- S : AsyncRead + AsyncWrite + Unpin ,
148- {
149- fn poll_read (
150- self : Pin < & mut Self > ,
151- cx : & mut Context ,
152- buf : & mut ReadBuf < ' _ > ,
153- ) -> Poll < tokio:: io:: Result < ( ) > > {
154- self . project_stream ( ) . poll_read ( cx, buf)
155- }
156- }
157-
158- impl < S > AsyncWrite for RustlsStream < S >
159- where
160- S : AsyncRead + AsyncWrite + Unpin ,
161- {
162- fn poll_write (
163- self : Pin < & mut Self > ,
164- cx : & mut Context ,
165- buf : & [ u8 ] ,
166- ) -> Poll < tokio:: io:: Result < usize > > {
167- self . project_stream ( ) . poll_write ( cx, buf)
168- }
169-
170- fn poll_flush ( self : Pin < & mut Self > , cx : & mut Context ) -> Poll < tokio:: io:: Result < ( ) > > {
171- self . project_stream ( ) . poll_flush ( cx)
172- }
173-
174- fn poll_shutdown ( self : Pin < & mut Self > , cx : & mut Context ) -> Poll < tokio:: io:: Result < ( ) > > {
175- self . project_stream ( ) . poll_shutdown ( cx)
176- }
177- }
178-
179198#[ cfg( test) ]
180199mod tests {
181200 use super :: * ;
0 commit comments