@@ -705,10 +705,15 @@ def __init__(self, proxy: CopilotProvider):
705705 self .stream_queue : Optional [asyncio .Queue ] = None
706706 self .processing_task : Optional [asyncio .Task ] = None
707707
708+ self .finish_stream = False
709+
710+ # For debugging only
711+ # self.data_sent = []
712+
708713 def connection_made (self , transport : asyncio .Transport ) -> None :
709714 """Handle successful connection to target"""
710715 self .transport = transport
711- logger .debug (f"Target transport peer : { transport .get_extra_info ('peername' )} " )
716+ logger .debug (f"Connection established to target : { transport .get_extra_info ('peername' )} " )
712717 self .proxy .target_transport = transport
713718
714719 def _ensure_output_processor (self ) -> None :
@@ -737,7 +742,7 @@ async def _process_stream(self):
737742 try :
738743
739744 async def stream_iterator ():
740- while True :
745+ while not self . stream_queue . empty () :
741746 incoming_record = await self .stream_queue .get ()
742747
743748 record_content = incoming_record .get ("content" , {})
@@ -750,6 +755,9 @@ async def stream_iterator():
750755 else :
751756 content = choice .get ("delta" , {}).get ("content" )
752757
758+ if choice .get ("finish_reason" , None ) == "stop" :
759+ self .finish_stream = True
760+
753761 streaming_choices .append (
754762 StreamingChoices (
755763 finish_reason = choice .get ("finish_reason" , None ),
@@ -771,22 +779,18 @@ async def stream_iterator():
771779 )
772780 yield mr
773781
774- async for record in self .output_pipeline_instance .process_stream (stream_iterator ()):
782+ async for record in self .output_pipeline_instance .process_stream (
783+ stream_iterator (), cleanup_sensitive = False
784+ ):
775785 chunk = record .model_dump_json (exclude_none = True , exclude_unset = True )
776786 sse_data = f"data: { chunk } \n \n " .encode ("utf-8" )
777787 chunk_size = hex (len (sse_data ))[2 :] + "\r \n "
778788 self ._proxy_transport_write (chunk_size .encode ())
779789 self ._proxy_transport_write (sse_data )
780790 self ._proxy_transport_write (b"\r \n " )
781791
782- sse_data = b"data: [DONE]\n \n "
783- # Add chunk size for DONE message too
784- chunk_size = hex (len (sse_data ))[2 :] + "\r \n "
785- self ._proxy_transport_write (chunk_size .encode ())
786- self ._proxy_transport_write (sse_data )
787- self ._proxy_transport_write (b"\r \n " )
788- # Now send the final zero chunk
789- self ._proxy_transport_write (b"0\r \n \r \n " )
792+ if self .finish_stream :
793+ self .finish_data ()
790794
791795 except asyncio .CancelledError :
792796 logger .debug ("Stream processing cancelled" )
@@ -795,12 +799,37 @@ async def stream_iterator():
795799 logger .error (f"Error processing stream: { e } " )
796800 finally :
797801 # Clean up
802+ self .stream_queue = None
798803 if self .processing_task and not self .processing_task .done ():
799804 self .processing_task .cancel ()
800- if self .proxy .context_tracking and self .proxy .context_tracking .sensitive :
801- self .proxy .context_tracking .sensitive .secure_cleanup ()
805+
806+ def finish_data (self ):
807+ logger .debug ("Finishing data stream" )
808+ sse_data = b"data: [DONE]\n \n "
809+ # Add chunk size for DONE message too
810+ chunk_size = hex (len (sse_data ))[2 :] + "\r \n "
811+ self ._proxy_transport_write (chunk_size .encode ())
812+ self ._proxy_transport_write (sse_data )
813+ self ._proxy_transport_write (b"\r \n " )
814+ # Now send the final zero chunk
815+ self ._proxy_transport_write (b"0\r \n \r \n " )
816+
817+ # For debugging only
818+ # print("===========START DATA SENT====================")
819+ # for data in self.data_sent:
820+ # print(data)
821+ # self.data_sent = []
822+ # print("===========START DATA SENT====================")
823+
824+ self .finish_stream = False
825+ self .headers_sent = False
802826
803827 def _process_chunk (self , chunk : bytes ):
828+ # For debugging only
829+ # print("===========START DATA RECVD====================")
830+ # print(chunk)
831+ # print("===========END DATA RECVD======================")
832+
804833 records = self .sse_processor .process_chunk (chunk )
805834
806835 for record in records :
@@ -812,13 +841,12 @@ def _process_chunk(self, chunk: bytes):
812841 self .stream_queue .put_nowait (record )
813842
814843 def _proxy_transport_write (self , data : bytes ):
844+ # For debugging only
845+ # self.data_sent.append(data)
815846 if not self .proxy .transport or self .proxy .transport .is_closing ():
816847 logger .error ("Proxy transport not available" )
817848 return
818849 self .proxy .transport .write (data )
819- # print("DEBUG =================================")
820- # print(data)
821- # print("DEBUG =================================")
822850
823851 def data_received (self , data : bytes ) -> None :
824852 """Handle data received from target"""
@@ -848,15 +876,13 @@ def data_received(self, data: bytes) -> None:
848876 logger .debug (f"Headers sent: { headers } " )
849877
850878 data = data [header_end + 4 :]
851- # print("DEBUG =================================")
852- # print(data)
853- # print("DEBUG =================================")
854879
855880 self ._process_chunk (data )
856881
857882 def connection_lost (self , exc : Optional [Exception ]) -> None :
858883 """Handle connection loss to target"""
859884
885+ logger .debug ("Lost connection to target" )
860886 if (
861887 not self .proxy ._closing
862888 and self .proxy .transport
0 commit comments