@@ -9,6 +9,7 @@ package topology
99import (
1010 "context"
1111 "crypto/tls"
12+ "encoding/binary"
1213 "errors"
1314 "fmt"
1415 "io"
@@ -79,9 +80,9 @@ type connection struct {
7980 driverConnectionID uint64
8081 generation uint64
8182
82- // awaitingResponse indicates that the server response was not completely
83+ // awaitRemainingBytes indicates the size of server response that was not completely
8384 // read before returning the connection to the pool.
84- awaitingResponse bool
85+ awaitRemainingBytes * int32
8586
8687 // oidcTokenGenID is the monotonic generation ID for OIDC tokens, used to invalidate
8788 // accessTokens in the OIDC authenticator cache.
@@ -115,12 +116,6 @@ func newConnection(addr address.Address, opts ...ConnectionOption) *connection {
115116 return c
116117}
117118
118- // DriverConnectionID returns the driver connection ID.
119- // TODO(GODRIVER-2824): change return type to int64.
120- func (c * connection ) DriverConnectionID () uint64 {
121- return c .driverConnectionID
122- }
123-
124119// setGenerationNumber sets the connection's generation number if a callback has been provided to do so in connection
125120// configuration.
126121func (c * connection ) setGenerationNumber () {
@@ -142,6 +137,39 @@ func (c *connection) hasGenerationNumber() bool {
142137 return c .desc .LoadBalanced ()
143138}
144139
140+ func configureTLS (ctx context.Context ,
141+ tlsConnSource tlsConnectionSource ,
142+ nc net.Conn ,
143+ addr address.Address ,
144+ config * tls.Config ,
145+ ocspOpts * ocsp.VerifyOptions ,
146+ ) (net.Conn , error ) {
147+ // Ensure config.ServerName is always set for SNI.
148+ if config .ServerName == "" {
149+ hostname := addr .String ()
150+ colonPos := strings .LastIndex (hostname , ":" )
151+ if colonPos == - 1 {
152+ colonPos = len (hostname )
153+ }
154+
155+ hostname = hostname [:colonPos ]
156+ config .ServerName = hostname
157+ }
158+
159+ client := tlsConnSource .Client (nc , config )
160+ if err := clientHandshake (ctx , client ); err != nil {
161+ return nil , err
162+ }
163+
164+ // Only do OCSP verification if TLS verification is requested.
165+ if ! config .InsecureSkipVerify {
166+ if ocspErr := ocsp .Verify (ctx , client .ConnectionState (), ocspOpts ); ocspErr != nil {
167+ return nil , ocspErr
168+ }
169+ }
170+ return client , nil
171+ }
172+
145173// connect handles the I/O for a connection. It will dial, configure TLS, and perform initialization
146174// handshakes. All errors returned by connect are considered "before the handshake completes" and
147175// must be handled by calling the appropriate SDAM handshake error handler.
@@ -317,6 +345,10 @@ func (c *connection) closeConnectContext() {
317345 }
318346}
319347
348+ func (c * connection ) cancellationListenerCallback () {
349+ _ = c .close ()
350+ }
351+
320352func transformNetworkError (ctx context.Context , originalError error , contextDeadlineUsed bool ) error {
321353 if originalError == nil {
322354 return nil
@@ -339,10 +371,6 @@ func transformNetworkError(ctx context.Context, originalError error, contextDead
339371 return originalError
340372}
341373
342- func (c * connection ) cancellationListenerCallback () {
343- _ = c .close ()
344- }
345-
346374func (c * connection ) writeWireMessage (ctx context.Context , wm []byte ) error {
347375 var err error
348376 if atomic .LoadInt64 (& c .state ) != connConnected {
@@ -423,15 +451,10 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {
423451
424452 dst , errMsg , err := c .read (ctx )
425453 if err != nil {
426- if nerr := net .Error (nil ); errors .As (err , & nerr ) && nerr .Timeout () && csot .IsTimeoutContext (ctx ) {
427- // If the error was a timeout error and CSOT is enabled, instead of
428- // closing the connection mark it as awaiting response so the pool
429- // can read the response before making it available to other
430- // operations.
431- c .awaitingResponse = true
432- } else {
433- // Otherwise, use the pre-CSOT behavior and close the connection
434- // because we don't know if there are other bytes left to read.
454+ if c .awaitRemainingBytes == nil {
455+ // If the connection was not marked as awaiting response, use the
456+ // pre-CSOT behavior and close the connection because we don't know
457+ // if there are other bytes left to read.
435458 c .close ()
436459 }
437460 message := errMsg
@@ -448,6 +471,26 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {
448471 return dst , nil
449472}
450473
474+ func (c * connection ) parseWmSizeBytes (wmSizeBytes [4 ]byte ) (int32 , error ) {
475+ // read the length as an int32
476+ size := int32 (binary .LittleEndian .Uint32 (wmSizeBytes [:]))
477+
478+ if size < 4 {
479+ return 0 , fmt .Errorf ("malformed message length: %d" , size )
480+ }
481+ // In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded
482+ // defaultMaxMessageSize instead.
483+ maxMessageSize := c .desc .MaxMessageSize
484+ if maxMessageSize == 0 {
485+ maxMessageSize = defaultMaxMessageSize
486+ }
487+ if uint32 (size ) > maxMessageSize {
488+ return 0 , errResponseTooLarge
489+ }
490+
491+ return size , nil
492+ }
493+
451494func (c * connection ) read (ctx context.Context ) (bytesRead []byte , errMsg string , err error ) {
452495 go c .cancellationListener .Listen (ctx , c .cancellationListenerCallback )
453496 defer func () {
@@ -461,36 +504,43 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string,
461504 }
462505 }()
463506
507+ isCSOTTimeout := func (err error ) bool {
508+ // If the error was a timeout error and CSOT is enabled, instead of
509+ // closing the connection mark it as awaiting response so the pool
510+ // can read the response before making it available to other
511+ // operations.
512+ nerr := net .Error (nil )
513+ return errors .As (err , & nerr ) && nerr .Timeout () && csot .IsTimeoutContext (ctx )
514+ }
515+
464516 // We use an array here because it only costs 4 bytes on the stack and means we'll only need to
465517 // reslice dst once instead of twice.
466518 var sizeBuf [4 ]byte
467519
468520 // We do a ReadFull into an array here instead of doing an opportunistic ReadAtLeast into dst
469521 // because there might be more than one wire message waiting to be read, for example when
470522 // reading messages from an exhaust cursor.
471- _ , err = io .ReadFull (c .nc , sizeBuf [:])
523+ n , err : = io .ReadFull (c .nc , sizeBuf [:])
472524 if err != nil {
525+ if l := int32 (n ); l == 0 && isCSOTTimeout (err ) {
526+ c .awaitRemainingBytes = & l
527+ }
473528 return nil , "incomplete read of message header" , err
474529 }
475-
476- // read the length as an int32
477- size := (int32 (sizeBuf [0 ])) | (int32 (sizeBuf [1 ]) << 8 ) | (int32 (sizeBuf [2 ]) << 16 ) | (int32 (sizeBuf [3 ]) << 24 )
478-
479- // In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded
480- // defaultMaxMessageSize instead.
481- maxMessageSize := c .desc .MaxMessageSize
482- if maxMessageSize == 0 {
483- maxMessageSize = defaultMaxMessageSize
484- }
485- if uint32 (size ) > maxMessageSize {
486- return nil , errResponseTooLarge .Error (), errResponseTooLarge
530+ size , err := c .parseWmSizeBytes (sizeBuf )
531+ if err != nil {
532+ return nil , err .Error (), err
487533 }
488534
489535 dst := make ([]byte , size )
490536 copy (dst , sizeBuf [:])
491537
492- _ , err = io .ReadFull (c .nc , dst [4 :])
538+ n , err = io .ReadFull (c .nc , dst [4 :])
493539 if err != nil {
540+ remainingBytes := size - 4 - int32 (n )
541+ if remainingBytes > 0 && isCSOTTimeout (err ) {
542+ c .awaitRemainingBytes = & remainingBytes
543+ }
494544 return dst , "incomplete read of full message" , err
495545 }
496546
@@ -537,10 +587,6 @@ func (c *connection) setCanStream(canStream bool) {
537587 c .canStream = canStream
538588}
539589
540- func (c initConnection ) supportsStreaming () bool {
541- return c .canStream
542- }
543-
544590func (c * connection ) setStreaming (streaming bool ) {
545591 c .currentlyStreaming = streaming
546592}
@@ -554,6 +600,12 @@ func (c *connection) setSocketTimeout(timeout time.Duration) {
554600 c .writeTimeout = timeout
555601}
556602
603+ // DriverConnectionID returns the driver connection ID.
604+ // TODO(GODRIVER-2824): change return type to int64.
605+ func (c * connection ) DriverConnectionID () uint64 {
606+ return c .driverConnectionID
607+ }
608+
557609func (c * connection ) ID () string {
558610 return c .id
559611}
@@ -562,6 +614,14 @@ func (c *connection) ServerConnectionID() *int64 {
562614 return c .serverConnectionID
563615}
564616
617+ func (c * connection ) OIDCTokenGenID () uint64 {
618+ return c .oidcTokenGenID
619+ }
620+
621+ func (c * connection ) SetOIDCTokenGenID (genID uint64 ) {
622+ c .oidcTokenGenID = genID
623+ }
624+
565625// initConnection is an adapter used during connection initialization. It has the minimum
566626// functionality necessary to implement the driver.Connection interface, which is required to pass a
567627// *connection to a Handshaker.
@@ -599,7 +659,7 @@ func (c initConnection) CurrentlyStreaming() bool {
599659 return c .getCurrentlyStreaming ()
600660}
601661func (c initConnection ) SupportsStreaming () bool {
602- return c .supportsStreaming ()
662+ return c .canStream
603663}
604664
605665// Connection implements the driver.Connection interface to allow reading and writing wire
@@ -833,39 +893,6 @@ func (c *Connection) DriverConnectionID() uint64 {
833893 return c .connection .DriverConnectionID ()
834894}
835895
836- func configureTLS (ctx context.Context ,
837- tlsConnSource tlsConnectionSource ,
838- nc net.Conn ,
839- addr address.Address ,
840- config * tls.Config ,
841- ocspOpts * ocsp.VerifyOptions ,
842- ) (net.Conn , error ) {
843- // Ensure config.ServerName is always set for SNI.
844- if config .ServerName == "" {
845- hostname := addr .String ()
846- colonPos := strings .LastIndex (hostname , ":" )
847- if colonPos == - 1 {
848- colonPos = len (hostname )
849- }
850-
851- hostname = hostname [:colonPos ]
852- config .ServerName = hostname
853- }
854-
855- client := tlsConnSource .Client (nc , config )
856- if err := clientHandshake (ctx , client ); err != nil {
857- return nil , err
858- }
859-
860- // Only do OCSP verification if TLS verification is requested.
861- if ! config .InsecureSkipVerify {
862- if ocspErr := ocsp .Verify (ctx , client .ConnectionState (), ocspOpts ); ocspErr != nil {
863- return nil , ocspErr
864- }
865- }
866- return client , nil
867- }
868-
869896// OIDCTokenGenID returns the OIDC token generation ID.
870897func (c * Connection ) OIDCTokenGenID () uint64 {
871898 return c .oidcTokenGenID
@@ -919,11 +946,3 @@ func (c *cancellListener) StopListening() bool {
919946 c .done <- struct {}{}
920947 return c .aborted
921948}
922-
923- func (c * connection ) OIDCTokenGenID () uint64 {
924- return c .oidcTokenGenID
925- }
926-
927- func (c * connection ) SetOIDCTokenGenID (genID uint64 ) {
928- c .oidcTokenGenID = genID
929- }
0 commit comments