@@ -481,7 +481,7 @@ def _resolve_pending(self, exc: Optional[Exception] = None) -> None:
481481class KMSBuffer :
482482 buffer : memoryview
483483 start_index : int
484- length : int
484+ end_index : int
485485
486486
487487class PyMongoKMSProtocol (PyMongoBaseProtocol ):
@@ -524,11 +524,21 @@ def get_buffer(self, sizehint: int) -> memoryview:
524524 If any data does not fit into the returned buffer, this method will be called again until
525525 either no data remains or an empty buffer is returned.
526526 """
527- sizehint = max (sizehint , 1024 )
528- buffer = KMSBuffer (memoryview (bytearray (sizehint )), 0 , 0 )
527+ # Reuse the active buffer if it has space.
528+ if len (self ._buffers ):
529+ buffer = self ._buffers [- 1 ]
530+ if len (buffer .buffer ) - buffer .end_index > sizehint :
531+ return buffer .buffer [buffer .end_index :]
532+ # Allocate a bit more than the max response size for an AWS KMS response.
533+ buffer = KMSBuffer (memoryview (bytearray (16384 )), 0 , 0 )
529534 self ._buffers .append (buffer )
530535 return buffer .buffer
531536
537+ def _resolve_pending (self , exc : Optional [Exception ] = None ) -> None :
538+ while self ._pending_listeners :
539+ fut = self ._pending_listeners .popleft ()
540+ fut .set_result (b"" )
541+
532542 def buffer_updated (self , nbytes : int ) -> None :
533543 """Called when the buffer was updated with the received data"""
534544 # Wrote 0 bytes into a non-empty buffer, signal connection closed
@@ -540,9 +550,7 @@ def buffer_updated(self, nbytes: int) -> None:
540550 self ._bytes_ready += nbytes
541551
542552 # Update the length of the current buffer.
543- current_buffer = self ._buffers .pop ()
544- current_buffer .length += nbytes
545- self ._buffers .append (current_buffer )
553+ self ._buffers [- 1 ].end_index += nbytes
546554
547555 if not len (self ._pending_reads ):
548556 return
@@ -564,7 +572,7 @@ def _read(self, bytes_needed: int) -> memoryview:
564572 out_index = 0
565573 while n_remaining > 0 :
566574 buffer = self ._buffers .popleft ()
567- buffer_remaining = buffer .length - buffer .start_index
575+ buffer_remaining = buffer .end_index - buffer .start_index
568576 # if we didn't exhaust the buffer, read the partial data and return the buffer.
569577 if buffer_remaining > n_remaining :
570578 output_buf [out_index : n_remaining + out_index ] = buffer .buffer [
@@ -576,10 +584,14 @@ def _read(self, bytes_needed: int) -> memoryview:
576584 # otherwise exhaust the buffer.
577585 else :
578586 output_buf [out_index : out_index + buffer_remaining ] = buffer .buffer [
579- buffer .start_index : buffer .length
587+ buffer .start_index : buffer .end_index
580588 ]
581589 out_index += buffer_remaining
582590 n_remaining -= buffer_remaining
591+ # if this is the only buffer, add it back to the queue.
592+ if not len (self ._buffers ):
593+ buffer .start_index = buffer .end_index
594+ self ._buffers .appendleft (buffer )
583595 return memoryview (output_buf )
584596
585597
0 commit comments