@@ -151,6 +151,20 @@ def __init__(self, loop: asyncio.AbstractEventLoop):
151151 self ._closing = False
152152 self .pipeline_factory = PipelineFactory (SecretsManager ())
153153 self .context_tracking : Optional [PipelineContext ] = None
154+ self .idle_timeout = 10
155+ self .idle_timer = None
156+
157+ def _reset_idle_timer (self ) -> None :
158+ if self .idle_timer :
159+ self .idle_timer .cancel ()
160+ self .idle_timer = asyncio .get_event_loop ().call_later (
161+ self .idle_timeout , self ._handle_idle_timeout
162+ )
163+
164+ def _handle_idle_timeout (self ) -> None :
165+ logger .warning ("Idle timeout reached, closing connection" )
166+ if self .transport and not self .transport .is_closing ():
167+ self .transport .close ()
154168
155169 def _select_pipeline (self , method : str , path : str ) -> Optional [CopilotPipeline ]:
156170 if method == "POST" and path == "v1/engines/copilot-codex/completions" :
@@ -215,6 +229,7 @@ def connection_made(self, transport: asyncio.Transport) -> None:
215229 self .transport = transport
216230 self .peername = transport .get_extra_info ("peername" )
217231 logger .debug (f"Client connected from { self .peername } " )
232+ self ._reset_idle_timer ()
218233
219234 def get_headers_dict (self ) -> Dict [str , str ]:
220235 """Convert raw headers to dictionary format"""
@@ -350,8 +365,10 @@ async def _forward_data_to_target(self, data: bytes) -> None:
350365 pipeline_output = pipeline_output .reconstruct ()
351366 self .target_transport .write (pipeline_output )
352367
368+
353369 def data_received (self , data : bytes ) -> None :
354370 """Handle received data from client"""
371+ self ._reset_idle_timer ()
355372 try :
356373 if not self ._check_buffer_size (data ):
357374 self .send_error_response (413 , b"Request body too large" )
@@ -556,6 +573,7 @@ async def connect_to_target(self) -> None:
556573 logger .error (f"Error during TLS handshake: { e } " )
557574 self .send_error_response (502 , b"TLS handshake failed" )
558575
576+
559577 def send_error_response (self , status : int , message : bytes ) -> None :
560578 """Send error response to client"""
561579 if self ._closing :
@@ -593,6 +611,37 @@ def connection_lost(self, exc: Optional[Exception]) -> None:
593611 self .buffer .clear ()
594612 self .ssl_context = None
595613
614+ if self .idle_timer :
615+ self .idle_timer .cancel ()
616+
617+ def eof_received (self ) -> None :
618+ print ("in eof received" )
619+ """Handle connection loss"""
620+ if self ._closing :
621+ return
622+
623+ self ._closing = True
624+ logger .debug (f"EOF received from { self .peername } " )
625+
626+ # Close target transport if it exists and isn't already closing
627+ if self .target_transport and not self .target_transport .is_closing ():
628+ try :
629+ self .target_transport .close ()
630+ except Exception as e :
631+ logger .error (f"Error closing target transport when EOF: { e } " )
632+
633+ # Clear references to help with cleanup
634+ self .transport = None
635+ self .target_transport = None
636+ self .buffer .clear ()
637+ self .ssl_context = None
638+
639+ def pause_writing (self ) -> None :
640+ print ("Transport buffer full, pausing writing" )
641+
642+ def resume_writing (self ) -> None :
643+ print ("Transport buffer ready, resuming writing" )
644+
596645 @classmethod
597646 async def create_proxy_server (
598647 cls , host : str , port : int , ssl_context : Optional [ssl .SSLContext ] = None
0 commit comments