@@ -357,7 +357,7 @@ cdef class Protocol(BaseProtocol):
357357 if not use_tcps and (params._token is not None
358358 or params.access_token_callback is not None ):
359359 errors._raise_err(errors.ERR_ACCESS_TOKEN_REQUIRES_TCPS)
360- if description.use_tcp_fast_open:
360+ if not use_proxy and description.use_tcp_fast_open:
361361 sock = socket.socket(address.ip_family, socket.SOCK_STREAM)
362362 sock.sendto(connect_string.encode(), socket.MSG_FASTOPEN,
363363 connect_info)
@@ -506,6 +506,9 @@ cdef class Protocol(BaseProtocol):
506506
507507cdef class BaseAsyncProtocol(BaseProtocol):
508508
509+ cdef:
510+ object _proxy_waiter
511+
509512 def __init__ (self ):
510513 BaseProtocol.__init__ (self )
511514 self ._request_lock = asyncio.Lock()
@@ -727,9 +730,10 @@ cdef class BaseAsyncProtocol(BaseProtocol):
727730
728731 # complete connection through proxy, if applicable
729732 if use_proxy:
733+ self ._proxy_waiter = self ._read_buf._loop.create_future()
730734 data = f" CONNECT {host}:{port} HTTP/1.0\r \n \r \n "
731735 transport.write(data.encode())
732- reply = transport.read( 1024 )
736+ reply = await self ._proxy_waiter
733737 m = re.search(' HTTP/1.[01]\\ s+(\\ d+)\\ s+' , reply.decode())
734738 if m is None or m.groups()[0 ] != ' 200' :
735739 errors._raise_err(errors.ERR_PROXY_FAILURE,
@@ -913,12 +917,16 @@ cdef class BaseAsyncProtocol(BaseProtocol):
913917 cdef:
914918 bint notify_waiter = False
915919 Packet packet
916- packet = self ._transport.extract_packet(data)
917- while packet is not None :
918- self ._read_buf._process_packet(packet, & notify_waiter, False )
919- if notify_waiter:
920- self ._read_buf.notify_packet_received()
921- packet = self ._transport.extract_packet()
920+ if self ._proxy_waiter is not None :
921+ self ._proxy_waiter.set_result(data)
922+ self ._proxy_waiter = None
923+ else :
924+ packet = self ._transport.extract_packet(data)
925+ while packet is not None :
926+ self ._read_buf._process_packet(packet, & notify_waiter, False )
927+ if notify_waiter:
928+ self ._read_buf.notify_packet_received()
929+ packet = self ._transport.extract_packet()
922930
923931 async def end_pipeline(self , BaseThinConnImpl conn_impl, list messages,
924932 bint continue_on_error):
0 commit comments