From 1d12c64e98707307352b9c3cc04bc10605f2b8a4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 9 Nov 2025 15:48:41 +0100 Subject: [PATCH 1/3] Align the asyncio and sync client and server modules. --- src/websockets/asyncio/client.py | 109 ++++++++++++++++--------------- src/websockets/asyncio/server.py | 37 ++++++----- src/websockets/sync/client.py | 35 +++++----- src/websockets/sync/server.py | 33 ++++++---- tests/asyncio/test_client.py | 13 ++++ 5 files changed, 128 insertions(+), 99 deletions(-) diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 05947f3a..da66c9c2 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -216,8 +216,8 @@ class connect: compression: The "permessage-deflate" extension is enabled by default. Set ``compression`` to :obj:`None` to disable it. See the :doc:`compression guide <../../topics/compression>` for details. - additional_headers (HeadersLike | None): Arbitrary HTTP headers to add - to the handshake request. + additional_headers: Arbitrary HTTP headers to add to the handshake + request. user_agent_header: Value of the ``User-Agent`` request header. It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to :obj:`None` removes the header. @@ -328,6 +328,9 @@ def __init__( **kwargs: Any, ) -> None: self.uri = uri + self.ws_uri = parse_uri(uri) + if not self.ws_uri.secure and kwargs.get("ssl") is not None: + raise ValueError("ssl argument is incompatible with a ws:// URI") if subprotocols is not None: validate_subprotocols(subprotocols) @@ -343,7 +346,7 @@ def __init__( if create_connection is None: create_connection = ClientConnection - def protocol_factory(uri: WebSocketURI) -> ClientConnection: + def factory(uri: WebSocketURI) -> ClientConnection: # This is a protocol in the Sans-I/O implementation of websockets. protocol = ClientProtocol( uri, @@ -365,20 +368,18 @@ def protocol_factory(uri: WebSocketURI) -> ClientConnection: return connection self.proxy = proxy - self.protocol_factory = protocol_factory + self.factory = factory self.additional_headers = additional_headers self.user_agent_header = user_agent_header self.process_exception = process_exception self.open_timeout = open_timeout self.logger = logger - self.connection_kwargs = kwargs + self.create_connection_kwargs = kwargs - async def create_connection(self) -> ClientConnection: - """Create TCP or Unix connection.""" + async def open_tcp_connection(self) -> ClientConnection: + """Create TCP or Unix connection to the server, possibly through a proxy.""" loop = asyncio.get_running_loop() - kwargs = self.connection_kwargs.copy() - - ws_uri = parse_uri(self.uri) + kwargs = self.create_connection_kwargs.copy() proxy = self.proxy if kwargs.get("unix", False): @@ -386,19 +387,16 @@ async def create_connection(self) -> ClientConnection: if kwargs.get("sock") is not None: proxy = None if proxy is True: - proxy = get_proxy(ws_uri) + proxy = get_proxy(self.ws_uri) def factory() -> ClientConnection: - return self.protocol_factory(ws_uri) + return self.factory(self.ws_uri) - if ws_uri.secure: + if self.ws_uri.secure: kwargs.setdefault("ssl", True) - kwargs.setdefault("server_hostname", ws_uri.host) if kwargs.get("ssl") is None: raise ValueError("ssl=None is incompatible with a wss:// URI") - else: - if kwargs.get("ssl") is not None: - raise ValueError("ssl argument is incompatible with a ws:// URI") + kwargs.setdefault("server_hostname", self.ws_uri.host) if kwargs.pop("unix", False): _, connection = await loop.create_unix_connection(factory, **kwargs) @@ -408,7 +406,7 @@ def factory() -> ClientConnection: # Connect to the server through the proxy. sock = await connect_socks_proxy( proxy_parsed, - ws_uri, + self.ws_uri, local_addr=kwargs.pop("local_addr", None), ) # Initialize WebSocket connection via the proxy. @@ -442,7 +440,7 @@ def factory() -> ClientConnection: # Connect to the server through the proxy. transport = await connect_http_proxy( proxy_parsed, - ws_uri, + self.ws_uri, user_agent_header=self.user_agent_header, **proxy_kwargs, ) @@ -459,18 +457,18 @@ def factory() -> ClientConnection: assert new_transport is not None # help mypy transport = new_transport connection.connection_made(transport) - else: - raise AssertionError("unsupported proxy") + else: # pragma: no cover + raise NotImplementedError(f"unsupported proxy: {proxy}") else: # Connect to the server directly. if kwargs.get("sock") is None: - kwargs.setdefault("host", ws_uri.host) - kwargs.setdefault("port", ws_uri.port) + kwargs.setdefault("host", self.ws_uri.host) + kwargs.setdefault("port", self.ws_uri.port) # Initialize WebSocket connection. _, connection = await loop.create_connection(factory, **kwargs) return connection - def process_redirect(self, exc: Exception) -> Exception | str: + def process_redirect(self, exc: Exception) -> Exception | tuple[str, WebSocketURI]: """ Determine whether a connection error is a redirect that can be followed. @@ -492,12 +490,12 @@ def process_redirect(self, exc: Exception) -> Exception | str: ): return exc - old_ws_uri = parse_uri(self.uri) + old_ws_uri = self.ws_uri new_uri = urllib.parse.urljoin(self.uri, exc.response.headers["Location"]) new_ws_uri = parse_uri(new_uri) # If connect() received a socket, it is closed and cannot be reused. - if self.connection_kwargs.get("sock") is not None: + if self.create_connection_kwargs.get("sock") is not None: return ValueError( f"cannot follow redirect to {new_uri} with a preexisting socket" ) @@ -513,7 +511,7 @@ def process_redirect(self, exc: Exception) -> Exception | str: or old_ws_uri.port != new_ws_uri.port ): # Cross-origin redirects on Unix sockets don't quite make sense. - if self.connection_kwargs.get("unix", False): + if self.create_connection_kwargs.get("unix", False): return ValueError( f"cannot follow cross-origin redirect to {new_uri} " f"with a Unix socket" @@ -521,15 +519,15 @@ def process_redirect(self, exc: Exception) -> Exception | str: # Cross-origin redirects when host and port are overridden are ill-defined. if ( - self.connection_kwargs.get("host") is not None - or self.connection_kwargs.get("port") is not None + self.create_connection_kwargs.get("host") is not None + or self.create_connection_kwargs.get("port") is not None ): return ValueError( f"cannot follow cross-origin redirect to {new_uri} " f"with an explicit host or port" ) - return new_uri + return new_uri, new_ws_uri # ... = await connect(...) @@ -541,14 +539,14 @@ async def __await_impl__(self) -> ClientConnection: try: async with asyncio_timeout(self.open_timeout): for _ in range(MAX_REDIRECTS): - self.connection = await self.create_connection() + connection = await self.open_tcp_connection() try: - await self.connection.handshake( + await connection.handshake( self.additional_headers, self.user_agent_header, ) except asyncio.CancelledError: - self.connection.transport.abort() + connection.transport.abort() raise except Exception as exc: # Always close the connection even though keep-alive is @@ -557,22 +555,23 @@ async def __await_impl__(self) -> ClientConnection: # protocol. In the current design of connect(), there is # no easy way to reuse the network connection that works # in every case nor to reinitialize the protocol. - self.connection.transport.abort() + connection.transport.abort() - uri_or_exc = self.process_redirect(exc) - # Response is a valid redirect; follow it. - if isinstance(uri_or_exc, str): - self.uri = uri_or_exc - continue + exc_or_uri = self.process_redirect(exc) # Response isn't a valid redirect; raise the exception. - if uri_or_exc is exc: - raise + if isinstance(exc_or_uri, Exception): + if exc_or_uri is exc: + raise + else: + raise exc_or_uri from exc + # Response is a valid redirect; follow it. else: - raise uri_or_exc from exc + self.uri, self.ws_uri = exc_or_uri + continue else: - self.connection.start_keepalive() - return self.connection + connection.start_keepalive() + return connection else: raise SecurityError(f"more than {MAX_REDIRECTS} redirects") @@ -587,7 +586,10 @@ async def __await_impl__(self) -> ClientConnection: # async with connect(...) as ...: ... async def __aenter__(self) -> ClientConnection: - return await self + if hasattr(self, "connection"): + raise RuntimeError("connect() isn't reentrant") + self.connection = await self + return self.connection async def __aexit__( self, @@ -595,7 +597,10 @@ async def __aexit__( exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: - await self.connection.close() + try: + await self.connection.close() + finally: + del self.connection # async for ... in connect(...): @@ -603,8 +608,8 @@ async def __aiter__(self) -> AsyncIterator[ClientConnection]: delays: Generator[float] | None = None while True: try: - async with self as protocol: - yield protocol + async with self as connection: + yield connection except Exception as exc: # Determine whether the exception is retryable or fatal. # The API of process_exception is "return an exception or None"; @@ -633,7 +638,6 @@ async def __aiter__(self) -> AsyncIterator[ClientConnection]: traceback.format_exception_only(exc)[0].strip(), ) await asyncio.sleep(delay) - continue else: # The connection succeeded. Reset backoff. @@ -777,8 +781,7 @@ def eof_received(self) -> None: def connection_lost(self, exc: Exception | None) -> None: self.reader.feed_eof() - if exc is not None: - self.response.set_exception(exc) + self.run_parser() async def connect_http_proxy( @@ -797,8 +800,8 @@ async def connect_http_proxy( try: # This raises exceptions if the connection to the proxy fails. await protocol.response - except Exception: - transport.close() + except (asyncio.CancelledError, Exception): + transport.abort() raise return transport diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index ef9bd807..018d891d 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -169,7 +169,7 @@ async def handshake( assert isinstance(response, Response) # help mypy self.response = response - if server_header: + if server_header is not None: self.response.headers["Server"] = server_header response = None @@ -231,12 +231,9 @@ class Server: This class mirrors the API of :class:`asyncio.Server`. - It keeps track of WebSocket connections in order to close them properly - when shutting down. - Args: handler: Connection handler. It receives the WebSocket connection, - which is a :class:`ServerConnection`, in argument. + which is a :class:`ServerConnection`. process_request: Intercept the request during the opening handshake. Return an HTTP response to force the response. Return :obj:`None` to continue normally. When you force an HTTP 101 Continue response, the @@ -310,7 +307,11 @@ def connections(self) -> set[ServerConnection]: It can be useful in combination with :func:`~broadcast`. """ - return {connection for connection in self.handlers if connection.state is OPEN} + return { + connection + for connection in self.handlers + if connection.protocol.state is OPEN + } def wrap(self, server: asyncio.Server) -> None: """ @@ -351,6 +352,8 @@ async def conn_handler(self, connection: ServerConnection) -> None: """ try: + # Apply open_timeout to the WebSocket handshake. + # Use ssl_handshake_timeout for the TLS handshake. async with asyncio_timeout(self.open_timeout): try: await connection.handshake( @@ -425,7 +428,7 @@ def close( ``code`` and ``reason`` can be customized, for example to use code 1012 (service restart). - * Wait until all connection handlers terminate. + * Wait until all connection handlers have returned. :meth:`close` is idempotent. @@ -452,6 +455,7 @@ async def _close( self.logger.info("server closing") # Stop accepting new connections. + # Reject OPENING connections with HTTP 503 -- see handshake(). self.server.close() # Wait until all accepted connections reach connection_made() and call @@ -459,15 +463,12 @@ async def _close( # details. This workaround can be removed when dropping Python < 3.11. await asyncio.sleep(0) - # After server.close(), handshake() closes OPENING connections with an - # HTTP 503 error. - + # Close OPEN connections. if close_connections: - # Close OPEN connections with code 1001 by default. close_tasks = [ asyncio.create_task(connection.close(code, reason)) for connection in self.handlers - if connection.protocol.state is not CONNECTING + if connection.protocol.state is OPEN ] # asyncio.wait doesn't accept an empty first argument. if close_tasks: @@ -476,7 +477,7 @@ async def _close( # Wait until all TCP connections are closed. await self.server.wait_closed() - # Wait until all connection handlers terminate. + # Wait until all connection handlers have returned. # asyncio.wait doesn't accept an empty first argument. if self.handlers: await asyncio.wait(self.handlers.values()) @@ -590,18 +591,18 @@ class serve: This coroutine returns a :class:`Server` whose API mirrors :class:`asyncio.Server`. Treat it as an asynchronous context manager to - ensure that the server will be closed:: + ensure that the server will be closed gracefully:: from websockets.asyncio.server import serve - def handler(websocket): + async def handler(websocket): ... - # set this future to exit the server - stop = asyncio.get_running_loop().create_future() + # set this event to exit the server + stop = asyncio.Event() async with serve(handler, host, port): - await stop + await stop.wait() Alternatively, call :meth:`~Server.serve_forever` to serve requests and cancel it to stop the server:: diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index b3fff44e..5d92bcc1 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -157,6 +157,7 @@ def connect( logger: LoggerLike | None = None, # Escape hatch for advanced customization create_connection: type[ClientConnection] | None = None, + # Other keyword arguments are passed to socket.create_connection **kwargs: Any, ) -> ClientConnection: """ @@ -190,8 +191,8 @@ def connect( compression: The "permessage-deflate" extension is enabled by default. Set ``compression`` to :obj:`None` to disable it. See the :doc:`compression guide <../../topics/compression>` for details. - additional_headers (HeadersLike | None): Arbitrary HTTP headers to add - to the handshake request. + additional_headers: Arbitrary HTTP headers to add to the handshake + request. user_agent_header: Value of the ``User-Agent`` request header. It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to :obj:`None` removes the header. @@ -230,6 +231,7 @@ def connect( Raises: InvalidURI: If ``uri`` isn't a valid WebSocket URI. + InvalidProxy: If ``proxy`` isn't a valid proxy. OSError: If the TCP connection fails. InvalidHandshake: If the opening handshake fails. TimeoutError: If the opening handshake times out. @@ -250,6 +252,17 @@ def connect( if not ws_uri.secure and ssl is not None: raise ValueError("ssl argument is incompatible with a ws:// URI") + if subprotocols is not None: + validate_subprotocols(subprotocols) + + if compression == "deflate": + extensions = enable_client_permessage_deflate(extensions) + elif compression is not None: + raise ValueError(f"unsupported compression: {compression}") + + if create_connection is None: + create_connection = ClientConnection + # Private APIs for unix_connect() unix: bool = kwargs.pop("unix", False) path: str | None = kwargs.pop("path", None) @@ -260,14 +273,6 @@ def connect( elif path is not None and sock is not None: raise ValueError("path and sock arguments are incompatible") - if subprotocols is not None: - validate_subprotocols(subprotocols) - - if compression == "deflate": - extensions = enable_client_permessage_deflate(extensions) - elif compression is not None: - raise ValueError(f"unsupported compression: {compression}") - if unix: proxy = None if sock is not None: @@ -280,9 +285,6 @@ def connect( # to avoid conflicting with the WebSocket timeout in handshake(). deadline = Deadline(open_timeout) - if create_connection is None: - create_connection = ClientConnection - try: # Connect socket @@ -320,8 +322,8 @@ def connect( server_hostname=proxy_server_hostname, **kwargs, ) - else: - raise AssertionError("unsupported proxy") + else: # pragma: no cover + raise NotImplementedError("unsupported proxy") else: kwargs.setdefault("timeout", deadline.timeout()) sock = socket.create_connection( @@ -539,7 +541,8 @@ def connect_http_proxy( # Send CONNECT request to the proxy and read response. - sock.sendall(prepare_connect_request(proxy, ws_uri, user_agent_header)) + request = prepare_connect_request(proxy, ws_uri, user_agent_header) + sock.sendall(request) try: read_connect_response(sock, deadline) except Exception: diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index ffd82fba..f3dfbdb6 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -32,7 +32,13 @@ from .utils import Deadline -__all__ = ["serve", "unix_serve", "ServerConnection", "Server", "basic_auth"] +__all__ = [ + "serve", + "unix_serve", + "ServerConnection", + "Server", + "basic_auth", +] class ServerConnection(Connection): @@ -154,7 +160,7 @@ def handshake( else: self.response = response - if server_header: + if server_header is not None: self.response.headers["Server"] = server_header response = None @@ -218,14 +224,12 @@ class Server: """ WebSocket server returned by :func:`serve`. - This class mirrors the API of :class:`~socketserver.BaseServer`, notably the - :meth:`~socketserver.BaseServer.serve_forever` and - :meth:`~socketserver.BaseServer.shutdown` methods, as well as the context - manager protocol. + This class mirrors partially the API of :class:`~socketserver.BaseServer`. - Args: - socket: Server socket listening for new connections. - handler: Handler for one connection. Receives the socket and address + + Args: + socket: Server socket accepting new connections. + handler: Handler for one connection. It receives the socket and address returned by :meth:`~socket.socket.accept`. logger: Logger for this server. It defaults to ``logging.getLogger("websockets.server")``. @@ -387,8 +391,8 @@ def serve( This function returns a :class:`Server` whose API mirrors :class:`~socketserver.BaseServer`. Treat it as a context manager to ensure - that it will be closed and call :meth:`~Server.serve_forever` to serve - requests:: + that it will be closed gracefully and call :meth:`~Server.serve_forever` to + serve requests:: from websockets.sync.server import serve @@ -605,7 +609,12 @@ def protocol_select_subprotocol( connection.recv_events_thread.join() return - assert connection.protocol.state is OPEN + if connection.protocol.state is not OPEN: + # process_request or process_response rejected the handshake. + connection.close_socket() + connection.recv_events_thread.join() + return + try: connection.start_keepalive() handler(connection) diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index a83074ae..9e18f52b 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -1000,3 +1000,16 @@ async def test_unsupported_compression(self): str(raised.exception), "unsupported compression: False", ) + + async def test_reentrancy(self): + """Client isn't reentrant.""" + async with serve(*args) as server: + connecter = connect(get_uri(server)) + async with connecter: + with self.assertRaises(RuntimeError) as raised: + async with connecter: + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "connect() isn't reentrant", + ) From f1310b90ac15de90c5023ed277bcbc928a1cedb3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 9 Nov 2025 15:49:46 +0100 Subject: [PATCH 2/3] Clean up the asyncio and sync client and server tests. --- tests/asyncio/test_client.py | 12 +++++++----- tests/asyncio/test_server.py | 4 ++-- tests/sync/test_client.py | 12 +++++++----- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index 9e18f52b..eff02623 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -75,7 +75,7 @@ async def test_existing_socket(self): """Client connects using a pre-existing socket.""" async with serve(*args) as server: with socket.create_connection(get_host_port(server)) as sock: - # Use a non-existing domain to ensure we connect to sock. + # Use a non-existing domain to ensure we connect via sock. async with connect("ws://invalid/", sock=sock) as client: self.assertEqual(client.protocol.state.name, "OPEN") @@ -341,7 +341,7 @@ def redirect(connection, request): async with serve(*args, process_request=redirect) as server: with socket.create_connection(get_host_port(server)) as sock: with self.assertRaises(ValueError) as raised: - # Use a non-existing domain to ensure we connect to sock. + # Use a non-existing domain to ensure we connect via sock. async with connect("ws://invalid/redirect", sock=sock): self.fail("did not raise") @@ -446,9 +446,11 @@ async def test_junk_handshake(self): """Client closes the connection when receiving non-HTTP response from server.""" async def junk(reader, writer): - await asyncio.sleep(MS) # wait for the client to send the handshake request + # Wait for the client to send the handshake request. + await asyncio.sleep(MS) writer.write(b"220 smtp.invalid ESMTP Postfix\r\n") - await reader.read(4096) # wait for the client to close the connection + # Wait for the client to close the connection. + await reader.read(4096) writer.close() server = await asyncio.start_server(junk, "localhost", 0) @@ -652,7 +654,7 @@ async def test_ignore_proxy_with_existing_socket(self): """Client connects using a pre-existing socket.""" async with serve(*args) as server: with socket.create_connection(get_host_port(server)) as sock: - # Use a non-existing domain to ensure we connect to sock. + # Use a non-existing domain to ensure we connect via sock. async with connect("ws://invalid/", sock=sock) as client: self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(0) diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index 00dcb301..fe225067 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -568,7 +568,7 @@ async def test_connection(self): async with serve(*args, ssl=SERVER_CONTEXT) as server: async with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: await self.assertEval(client, "ws.protocol.state.name", "OPEN") - await self.assertEval(client, SSL_OBJECT + ".version()[:3]", "TLS") + await self.assertEval(client, f"{SSL_OBJECT}.version()[:3]", "TLS") async def test_timeout_during_tls_handshake(self): """Server times out before receiving TLS handshake request from client.""" @@ -604,7 +604,7 @@ async def test_connection(self): async with unix_serve(handler, path, ssl=SERVER_CONTEXT): async with unix_connect(path, ssl=CLIENT_CONTEXT) as client: await self.assertEval(client, "ws.protocol.state.name", "OPEN") - await self.assertEval(client, SSL_OBJECT + ".version()[:3]", "TLS") + await self.assertEval(client, f"{SSL_OBJECT}.version()[:3]", "TLS") class ServerUsageErrorsTests(unittest.IsolatedAsyncioTestCase): diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 41534391..cc5949c9 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -44,7 +44,7 @@ def test_existing_socket(self): """Client connects using a pre-existing socket.""" with run_server() as server: with socket.create_connection(server.socket.getsockname()) as sock: - # Use a non-existing domain to ensure we connect to sock. + # Use a non-existing domain to ensure we connect via sock. with connect("ws://invalid/", sock=sock) as client: self.assertEqual(client.protocol.state.name, "OPEN") @@ -225,9 +225,11 @@ def test_junk_handshake(self): class JunkHandler(socketserver.BaseRequestHandler): def handle(self): - time.sleep(MS) # wait for the client to send the handshake request + # Wait for the client to send the handshake request. + time.sleep(MS) self.request.send(b"220 smtp.invalid ESMTP Postfix\r\n") - self.request.recv(4096) # wait for the client to close the connection + # Wait for the client to close the connection. + self.request.recv(4096) self.request.close() server = socketserver.TCPServer(("localhost", 0), JunkHandler) @@ -401,7 +403,7 @@ def test_ignore_proxy_with_existing_socket(self): """Client connects using a pre-existing socket.""" with run_server() as server: with socket.create_connection(server.socket.getsockname()) as sock: - # Use a non-existing domain to ensure we connect to sock. + # Use a non-existing domain to ensure we connect via sock. with connect("ws://invalid/", sock=sock) as client: self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(0) @@ -648,7 +650,7 @@ def test_proxy_ssl_without_https_proxy(self): connect( "ws://localhost/", proxy="http://localhost:8080", - proxy_ssl=True, + proxy_ssl=CLIENT_CONTEXT, ) self.assertEqual( str(raised.exception), From 92d59f900c2d62ba85648c2b436330ddaeead9e6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 9 Nov 2025 15:52:08 +0100 Subject: [PATCH 3/3] Add trio client and server. --- src/websockets/trio/client.py | 731 +++++++++++++++++++++++++++ src/websockets/trio/server.py | 649 ++++++++++++++++++++++++ tests/test_localhost.cnf | 3 +- tests/test_localhost.pem | 89 ++-- tests/trio/server.py | 63 +++ tests/trio/test_client.py | 927 ++++++++++++++++++++++++++++++++++ tests/trio/test_server.py | 831 ++++++++++++++++++++++++++++++ 7 files changed, 3248 insertions(+), 45 deletions(-) create mode 100644 src/websockets/trio/client.py create mode 100644 src/websockets/trio/server.py create mode 100644 tests/trio/server.py create mode 100644 tests/trio/test_client.py create mode 100644 tests/trio/test_server.py diff --git a/src/websockets/trio/client.py b/src/websockets/trio/client.py new file mode 100644 index 00000000..4af78af3 --- /dev/null +++ b/src/websockets/trio/client.py @@ -0,0 +1,731 @@ +from __future__ import annotations + +import logging +import os +import ssl as ssl_module +import sys +import traceback +import urllib.parse +from collections.abc import AsyncIterator, Generator, Sequence +from types import TracebackType +from typing import Any, Callable, Literal + +import trio + +from ..asyncio.client import process_exception +from ..client import ClientProtocol, backoff +from ..datastructures import HeadersLike +from ..exceptions import ( + InvalidProxyMessage, + InvalidProxyStatus, + InvalidStatus, + ProxyError, + SecurityError, +) +from ..extensions.base import ClientExtensionFactory +from ..extensions.permessage_deflate import enable_client_permessage_deflate +from ..headers import validate_subprotocols +from ..http11 import USER_AGENT, Response +from ..protocol import CONNECTING, Event +from ..proxy import Proxy, get_proxy, parse_proxy, prepare_connect_request +from ..streams import StreamReader +from ..typing import LoggerLike, Origin, Subprotocol +from ..uri import WebSocketURI, parse_uri +from .connection import Connection +from .utils import race_events + + +if sys.version_info[:2] < (3, 11): # pragma: no cover + from exceptiongroup import BaseExceptionGroup + + +__all__ = ["connect", "ClientConnection"] + +MAX_REDIRECTS = int(os.environ.get("WEBSOCKETS_MAX_REDIRECTS", "10")) + + +class ClientConnection(Connection): + """ + :mod:`trio` implementation of a WebSocket client connection. + + :class:`ClientConnection` provides :meth:`recv` and :meth:`send` coroutines + for receiving and sending messages. + + It supports asynchronous iteration to receive messages:: + + async for message in websocket: + await process(message) + + The iterator exits normally when the connection is closed with close code + 1000 (OK) or 1001 (going away) or without a close code. It raises a + :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is + closed with any other code. + + The ``ping_interval``, ``ping_timeout``, ``close_timeout``, and + ``max_queue`` arguments have the same meaning as in :func:`connect`. + + Args: + nursery: Trio nursery. + stream: Trio stream connected to a WebSocket server. + protocol: Sans-I/O connection. + + """ + + def __init__( + self, + nursery: trio.Nursery, + stream: trio.abc.Stream, + protocol: ClientProtocol, + *, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = 10, + max_queue: int | None | tuple[int | None, int | None] = 16, + ) -> None: + self.protocol: ClientProtocol + super().__init__( + nursery, + stream, + protocol, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + close_timeout=close_timeout, + max_queue=max_queue, + ) + self.response_rcvd = trio.Event() + + async def handshake( + self, + additional_headers: HeadersLike | None = None, + user_agent_header: str | None = USER_AGENT, + ) -> None: + """ + Perform the opening handshake. + + """ + async with self.send_context(expected_state=CONNECTING): + self.request = self.protocol.connect() + if additional_headers is not None: + self.request.headers.update(additional_headers) + if user_agent_header is not None: + self.request.headers.setdefault("User-Agent", user_agent_header) + self.protocol.send_request(self.request) + + await race_events(self.response_rcvd, self.stream_closed) + + # self.protocol.handshake_exc is set when the connection is lost before + # receiving a response, when the response cannot be parsed, or when the + # response fails the handshake. + + if self.protocol.handshake_exc is not None: + raise self.protocol.handshake_exc + + def process_event(self, event: Event) -> None: + """ + Process one incoming event. + + """ + # First event - handshake response. + if self.response is None: + assert isinstance(event, Response) + self.response = event + self.response_rcvd.set() + # Later events - frames. + else: + super().process_event(event) + + +# This is spelled in lower case because it's exposed as a callable in the API. +class connect: + """ + Connect to the WebSocket server at ``uri``. + + This coroutine returns a :class:`ClientConnection` instance, which you can + use to send and receive messages. + + :func:`connect` may be used as an asynchronous context manager:: + + from websockets.trio.client import connect + + async with connect(...) as websocket: + ... + + The connection is closed automatically when exiting the context. + + :func:`connect` can be used as an infinite asynchronous iterator to + reconnect automatically on errors:: + + async for websocket in connect(...): + try: + ... + except websockets.exceptions.ConnectionClosed: + continue + + If the connection fails with a transient error, it is retried with + exponential backoff. If it fails with a fatal error, the exception is + raised, breaking out of the loop. + + The connection is closed automatically after each iteration of the loop. + + Args: + uri: URI of the WebSocket server. + stream: Preexisting TCP stream. ``stream`` overrides the host and port + from ``uri``. You may call :func:`~trio.open_tcp_stream` to create a + suitable TCP stream. + ssl: Configuration for enabling TLS on the connection. + server_hostname: Host name for the TLS handshake. ``server_hostname`` + overrides the host name from ``uri``. + origin: Value of the ``Origin`` header, for servers that require it. + extensions: List of supported extensions, in order in which they + should be negotiated and run. + subprotocols: List of supported subprotocols, in order of decreasing + preference. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. + additional_headers: Arbitrary HTTP headers to add to the handshake + request. + user_agent_header: Value of the ``User-Agent`` request header. + It defaults to ``"Python/x.y.z websockets/X.Y"``. + Setting it to :obj:`None` removes the header. + proxy: If a proxy is configured, it is used by default. Set ``proxy`` + to :obj:`None` to disable the proxy or to the address of a proxy + to override the system configuration. See the :doc:`proxy docs + <../../topics/proxies>` for details. + proxy_ssl: Configuration for enabling TLS on the proxy connection. + proxy_server_hostname: Host name for the TLS handshake with the proxy. + ``proxy_server_hostname`` overrides the host name from ``proxy``. + process_exception: When reconnecting automatically, tell whether an + error is transient or fatal. The default behavior is defined by + :func:`process_exception`. Refer to its documentation for details. + open_timeout: Timeout for opening the connection in seconds. + :obj:`None` disables the timeout. + ping_interval: Interval between keepalive pings in seconds. + :obj:`None` disables keepalive. + ping_timeout: Timeout for keepalive pings in seconds. + :obj:`None` disables timeouts. + close_timeout: Timeout for closing the connection in seconds. + :obj:`None` disables the timeout. + max_size: Maximum size of incoming messages in bytes. + :obj:`None` disables the limit. You may pass a ``(max_message_size, + max_fragment_size)`` tuple to set different limits for messages and + fragments when you expect long messages sent in short fragments. + max_queue: High-water mark of the buffer where frames are received. + It defaults to 16 frames. The low-water mark defaults to ``max_queue + // 4``. You may pass a ``(high, low)`` tuple to set the high-water + and low-water marks. If you want to disable flow control entirely, + you may set it to ``None``, although that's a bad idea. + logger: Logger for this client. + It defaults to ``logging.getLogger("websockets.client")``. + See the :doc:`logging guide <../../topics/logging>` for details. + create_connection: Factory for the :class:`ClientConnection` managing + the connection. Set it to a wrapper or a subclass to customize + connection handling. + + Any other keyword arguments are passed to :func:`~trio.open_tcp_stream`. + + Raises: + InvalidURI: If ``uri`` isn't a valid WebSocket URI. + InvalidProxy: If ``proxy`` isn't a valid proxy. + OSError: If the TCP connection fails. + InvalidHandshake: If the opening handshake fails. + TimeoutError: If the opening handshake times out. + + """ + + # Arguments of type SSLContext don't render correctly in the documentation + # because of https://github.com/sphinx-doc/sphinx/issues/13838. + + def __init__( + self, + uri: str, + *, + # TCP/TLS + stream: trio.abc.Stream | None = None, + ssl: ssl_module.SSLContext | None = None, + server_hostname: str | None = None, + # WebSocket + origin: Origin | None = None, + extensions: Sequence[ClientExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + compression: str | None = "deflate", + # HTTP + additional_headers: HeadersLike | None = None, + user_agent_header: str | None = USER_AGENT, + proxy: str | Literal[True] | None = True, + proxy_ssl: ssl_module.SSLContext | None = None, + proxy_server_hostname: str | None = None, + process_exception: Callable[[Exception], Exception | None] = process_exception, + # Timeouts + open_timeout: float | None = 10, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = 10, + # Limits + max_size: int | None | tuple[int | None, int | None] = 2**20, + max_queue: int | None | tuple[int | None, int | None] = 16, + # Logging + logger: LoggerLike | None = None, + # Escape hatch for advanced customization + create_connection: type[ClientConnection] | None = None, + # Other keyword arguments are passed to trio.open_tcp_stream + **kwargs: Any, + ) -> None: + self.uri = uri + self.ws_uri = parse_uri(uri) + if not self.ws_uri.secure and ssl is not None: + raise ValueError("ssl argument is incompatible with a ws:// URI") + + if subprotocols is not None: + validate_subprotocols(subprotocols) + + if compression == "deflate": + extensions = enable_client_permessage_deflate(extensions) + elif compression is not None: + raise ValueError(f"unsupported compression: {compression}") + + if logger is None: + logger = logging.getLogger("websockets.client") + + if create_connection is None: + create_connection = ClientConnection + + self.stream = stream + self.ssl = ssl + self.server_hostname = server_hostname + self.proxy = proxy + self.proxy_ssl = proxy_ssl + self.proxy_server_hostname = proxy_server_hostname + self.additional_headers = additional_headers + self.user_agent_header = user_agent_header + self.process_exception = process_exception + self.open_timeout = open_timeout + self.logger = logger + self.create_connection = create_connection + self.open_tcp_stream_kwargs = kwargs + self.protocol_kwargs = dict( + origin=origin, + extensions=extensions, + subprotocols=subprotocols, + max_size=max_size, + logger=logger, + ) + self.connection_kwargs = dict( + ping_interval=ping_interval, + ping_timeout=ping_timeout, + close_timeout=close_timeout, + max_queue=max_queue, + ) + + async def open_tcp_stream(self) -> trio.abc.Stream: + """Open a TCP connection to the server, possibly through a proxy.""" + # TCP connection is already established. + if self.stream is not None: + return self.stream + + if self.proxy is True: + proxy = get_proxy(self.ws_uri) + else: + proxy = self.proxy + + # Connect to the server through a proxy. + if proxy is not None: + proxy_parsed = parse_proxy(proxy) + + if proxy_parsed.scheme[:5] == "socks": + return await connect_socks_proxy( + proxy_parsed, + self.ws_uri, + local_address=self.open_tcp_stream_kwargs.get("local_address"), + ) + + elif proxy_parsed.scheme[:4] == "http": + if proxy_parsed.scheme != "https" and self.proxy_ssl is not None: + raise ValueError( + "proxy_ssl argument is incompatible with an http:// proxy" + ) + return await connect_http_proxy( + proxy_parsed, + self.ws_uri, + user_agent_header=self.user_agent_header, + ssl=self.proxy_ssl, + server_hostname=self.proxy_server_hostname, + local_address=self.open_tcp_stream_kwargs.get("local_address"), + ) + + else: + raise NotImplementedError(f"unsupported proxy: {self.proxy}") + + # Connect to the server directly. + kwargs = self.open_tcp_stream_kwargs.copy() + kwargs.setdefault("host", self.ws_uri.host) + kwargs.setdefault("port", self.ws_uri.port) + return await trio.open_tcp_stream(**kwargs) + + async def enable_tls(self, stream: trio.abc.Stream) -> trio.abc.Stream: + """Enable TLS on the connection.""" + if self.ssl is None: + ssl = ssl_module.create_default_context() + else: + ssl = self.ssl + if self.server_hostname is None: + server_hostname = self.ws_uri.host + else: + server_hostname = self.server_hostname + ssl_stream = trio.SSLStream( + stream, + ssl, + server_hostname=server_hostname, + https_compatible=True, + ) + await ssl_stream.do_handshake() + return ssl_stream + + async def open_connection(self, nursery: trio.Nursery) -> ClientConnection: + """Create a WebSocket connection.""" + stream: trio.abc.Stream + stream = await self.open_tcp_stream() + + try: + if self.ws_uri.secure: + stream = await self.enable_tls(stream) + + protocol = ClientProtocol( + self.ws_uri, + **self.protocol_kwargs, # type: ignore + ) + + connection = self.create_connection( # default is ClientConnection + nursery, + stream, + protocol, + **self.connection_kwargs, # type: ignore + ) + + await connection.handshake( + self.additional_headers, + self.user_agent_header, + ) + + return connection + + except trio.Cancelled: + await trio.aclose_forcefully(stream) + # The nursery running this coroutine was canceled. + # The next checkpoint raises trio.Cancelled. + # aclose_forcefully() never returns. + raise AssertionError("nursery should be canceled") + except Exception: + # Always close the connection even though keep-alive is the default + # in HTTP/1.1 because the current implementation ties opening the + # TCP/TLS connection with initializing the WebSocket protocol. + await trio.aclose_forcefully(stream) + raise + + def process_redirect(self, exc: Exception) -> Exception | tuple[str, WebSocketURI]: + """ + Determine whether a connection error is a redirect that can be followed. + + Return the new URI if it's a valid redirect. Else, return an exception. + + """ + if not ( + isinstance(exc, InvalidStatus) + and exc.response.status_code + in [ + 300, # Multiple Choices + 301, # Moved Permanently + 302, # Found + 303, # See Other + 307, # Temporary Redirect + 308, # Permanent Redirect + ] + and "Location" in exc.response.headers + ): + return exc + + old_ws_uri = self.ws_uri + new_uri = urllib.parse.urljoin(self.uri, exc.response.headers["Location"]) + new_ws_uri = parse_uri(new_uri) + + # If connect() received a stream, it is closed and cannot be reused. + if self.stream is not None: + return ValueError( + f"cannot follow redirect to {new_uri} with a preexisting stream" + ) + + # TLS downgrade is forbidden. + if old_ws_uri.secure and not new_ws_uri.secure: + return SecurityError(f"cannot follow redirect to non-secure URI {new_uri}") + + # Apply restrictions to cross-origin redirects. + if ( + old_ws_uri.secure != new_ws_uri.secure + or old_ws_uri.host != new_ws_uri.host + or old_ws_uri.port != new_ws_uri.port + ): + # Cross-origin redirects when host and port are overridden are ill-defined. + if ( + self.open_tcp_stream_kwargs.get("host") is not None + or self.open_tcp_stream_kwargs.get("port") is not None + ): + return ValueError( + f"cannot follow cross-origin redirect to {new_uri} " + f"with an explicit host or port" + ) + + return new_uri, new_ws_uri + + async def connect(self, nursery: trio.Nursery) -> ClientConnection: + try: + with ( + trio.CancelScope() + if self.open_timeout is None + else trio.fail_after(self.open_timeout) + ): + for _ in range(MAX_REDIRECTS): + try: + connection = await self.open_connection(nursery) + except Exception as exc: + exc_or_uri = self.process_redirect(exc) + # Response isn't a valid redirect; raise the exception. + if isinstance(exc_or_uri, Exception): + if exc_or_uri is exc: + raise + else: + raise exc_or_uri from exc + # Response is a valid redirect; follow it. + else: + self.uri, self.ws_uri = exc_or_uri + continue + + else: + connection.start_keepalive() + return connection + else: + raise SecurityError(f"more than {MAX_REDIRECTS} redirects") + + except trio.TooSlowError as exc: + # Re-raise exception with an informative error message. + raise TimeoutError("timed out during opening handshake") from exc + + # Do not define __await__ for... = await nursery.start(connect, ...) + # because it doesn't look idiomatic in Trio. + + # async with connect(...) as ...: ... + + async def __aenter__(self) -> ClientConnection: + await self.__aenter_nursery__() + try: + self.connection = await self.connect(self.nursery) + return self.connection + except BaseException as exc: + await self.__aexit_nursery__(type(exc), exc, exc.__traceback__) + raise AssertionError("expected __aexit_nursery__ to re-raise the exception") + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + try: + await self.connection.aclose() + del self.connection + finally: + await self.__aexit_nursery__(exc_type, exc_value, traceback) + + async def __aenter_nursery__(self) -> None: + if hasattr(self, "nursery_manager"): # pragma: no cover + raise RuntimeError("connect() isn't reentrant") + self.nursery_manager = trio.open_nursery() + self.nursery = await self.nursery_manager.__aenter__() + + async def __aexit_nursery__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + # We need a nursery to start the recv_events and keepalive coroutines. + # They aren't expected to raise exceptions; instead they catch and log + # all unexpected errors. To keep the nursery an implementation detail, + # unwrap exceptions raised by user code -- per the second option here: + # https://trio.readthedocs.io/en/stable/reference-core.html#designing-for-multiple-errors + try: + await self.nursery_manager.__aexit__(exc_type, exc_value, traceback) + except BaseException as exc: + assert isinstance(exc, BaseExceptionGroup) + try: + trio._util.raise_single_exception_from_group(exc) + except trio._util.MultipleExceptionError: # pragma: no cover + raise AssertionError( + "unexpected multiple exceptions; please file a bug report" + ) from exc + finally: + del self.nursery_manager + + # async for ... in connect(...): + + async def __aiter__(self) -> AsyncIterator[ClientConnection]: + delays: Generator[float] | None = None + while True: + try: + async with self as connection: + yield connection + except Exception as exc: + # Determine whether the exception is retryable or fatal. + # The API of process_exception is "return an exception or None"; + # "raise an exception" is also supported because it's a frequent + # mistake. It isn't documented in order to keep the API simple. + try: + new_exc = self.process_exception(exc) + except Exception as raised_exc: + new_exc = raised_exc + + # The connection failed with a fatal error. + # Raise the exception and exit the loop. + if new_exc is exc: + raise + if new_exc is not None: + raise new_exc from exc + + # The connection failed with a retryable error. + # Start or continue backoff and reconnect. + if delays is None: + delays = backoff() + delay = next(delays) + self.logger.info( + "connect failed; reconnecting in %.1f seconds: %s", + delay, + traceback.format_exception_only(exc)[0].strip(), + ) + await trio.sleep(delay) + + else: + # The connection succeeded. Reset backoff. + delays = None + + +try: + from python_socks import ProxyType + from python_socks.async_.trio import Proxy as SocksProxy + +except ImportError: + + async def connect_socks_proxy( + proxy: Proxy, + ws_uri: WebSocketURI, + **kwargs: Any, + ) -> trio.abc.Stream: + raise ImportError("connecting through a SOCKS proxy requires python-socks") + +else: + SOCKS_PROXY_TYPES = { + "socks5h": ProxyType.SOCKS5, + "socks5": ProxyType.SOCKS5, + "socks4a": ProxyType.SOCKS4, + "socks4": ProxyType.SOCKS4, + } + + SOCKS_PROXY_RDNS = { + "socks5h": True, + "socks5": False, + "socks4a": True, + "socks4": False, + } + + async def connect_socks_proxy( + proxy: Proxy, + ws_uri: WebSocketURI, + **kwargs: Any, + ) -> trio.abc.Stream: + """Connect via a SOCKS proxy and return the socket.""" + socks_proxy = SocksProxy( + SOCKS_PROXY_TYPES[proxy.scheme], + proxy.host, + proxy.port, + proxy.username, + proxy.password, + SOCKS_PROXY_RDNS[proxy.scheme], + ) + # connect() is documented to raise OSError. + # socks_proxy.connect() re-raises trio.TooSlowError as ProxyTimeoutError. + # Wrap other exceptions in ProxyError, a subclass of InvalidHandshake. + try: + return trio.SocketStream( + await socks_proxy.connect(ws_uri.host, ws_uri.port, **kwargs) + ) + except OSError: + raise + except Exception as exc: + raise ProxyError("failed to connect to SOCKS proxy") from exc + + +async def read_connect_response(stream: trio.abc.Stream) -> Response: + reader = StreamReader() + parser = Response.parse( + reader.read_line, + reader.read_exact, + reader.read_to_eof, + proxy=True, + ) + try: + while True: + data = await stream.receive_some(4096) + if data: + reader.feed_data(data) + else: + reader.feed_eof() + next(parser) + except StopIteration as exc: + assert isinstance(exc.value, Response) # help mypy + response = exc.value + if 200 <= response.status_code < 300: + return response + else: + raise InvalidProxyStatus(response) + except Exception as exc: + raise InvalidProxyMessage( + "did not receive a valid HTTP response from proxy" + ) from exc + + +async def connect_http_proxy( + proxy: Proxy, + ws_uri: WebSocketURI, + *, + user_agent_header: str | None = None, + ssl: ssl_module.SSLContext | None = None, + server_hostname: str | None = None, + **kwargs: Any, +) -> trio.abc.Stream: + stream: trio.abc.Stream + stream = await trio.open_tcp_stream(proxy.host, proxy.port, **kwargs) + + try: + # Initialize TLS wrapper and perform TLS handshake + if proxy.scheme == "https": + if ssl is None: + ssl = ssl_module.create_default_context() + if server_hostname is None: + server_hostname = proxy.host + ssl_stream = trio.SSLStream( + stream, + ssl, + server_hostname=server_hostname, + https_compatible=True, + ) + await ssl_stream.do_handshake() + stream = ssl_stream + + # Send CONNECT request to the proxy and read response. + request = prepare_connect_request(proxy, ws_uri, user_agent_header) + await stream.send_all(request) + await read_connect_response(stream) + + except (trio.Cancelled, Exception): + await trio.aclose_forcefully(stream) + raise + + return stream diff --git a/src/websockets/trio/server.py b/src/websockets/trio/server.py new file mode 100644 index 00000000..b2cec4b8 --- /dev/null +++ b/src/websockets/trio/server.py @@ -0,0 +1,649 @@ +from __future__ import annotations + +import functools +import http +import logging +import re +import ssl as ssl_module +from collections.abc import Awaitable, Sequence +from types import TracebackType +from typing import Any, Callable, Mapping + +import trio +import trio.abc + +from ..asyncio.server import basic_auth +from ..extensions.base import ServerExtensionFactory +from ..extensions.permessage_deflate import enable_server_permessage_deflate +from ..frames import CloseCode +from ..headers import validate_subprotocols +from ..http11 import SERVER, Request, Response +from ..protocol import CONNECTING, OPEN, Event +from ..server import ServerProtocol +from ..typing import LoggerLike, Origin, StatusLike, Subprotocol +from .connection import Connection +from .utils import race_events + + +__all__ = [ + "serve", + "ServerConnection", + "Server", + "basic_auth", +] + + +class ServerConnection(Connection): + """ + :mod:`trio` implementation of a WebSocket server connection. + + :class:`ServerConnection` provides :meth:`recv` and :meth:`send` methods for + receiving and sending messages. + + It supports asynchronous iteration to receive messages:: + + async for message in websocket: + await process(message) + + The iterator exits normally when the connection is closed with close code + 1000 (OK) or 1001 (going away) or without a close code. It raises a + :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is + closed with any other code. + + The ``ping_interval``, ``ping_timeout``, ``close_timeout``, and + ``max_queue`` arguments have the same meaning as in :func:`serve`. + + Args: + nursery: Trio nursery. + stream: Trio stream connected to a WebSocket client. + protocol: Sans-I/O connection. + server: Server that manages this connection. + + """ + + def __init__( + self, + nursery: trio.Nursery, + stream: trio.abc.Stream, + protocol: ServerProtocol, + server: Server, + *, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = 10, + max_queue: int | None | tuple[int | None, int | None] = 16, + ) -> None: + self.protocol: ServerProtocol + super().__init__( + nursery, + stream, + protocol, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + close_timeout=close_timeout, + max_queue=max_queue, + ) + self.server = server + self.request_rcvd: trio.Event = trio.Event() + self.username: str # see basic_auth() + self.handler: Callable[[ServerConnection], Awaitable[None]] # see route() + self.handler_kwargs: Mapping[str, Any] # see route() + + def respond(self, status: StatusLike, text: str) -> Response: + """ + Create a plain text HTTP response. + + ``process_request`` and ``process_response`` may call this method to + return an HTTP response instead of performing the WebSocket opening + handshake. + + You can modify the response before returning it, for example by changing + HTTP headers. + + Args: + status: HTTP status code. + text: HTTP response body; it will be encoded to UTF-8. + + Returns: + HTTP response to send to the client. + + """ + return self.protocol.reject(status, text) + + async def handshake( + self, + process_request: ( + Callable[ + [ServerConnection, Request], + Awaitable[Response | None] | Response | None, + ] + | None + ) = None, + process_response: ( + Callable[ + [ServerConnection, Request, Response], + Awaitable[Response | None] | Response | None, + ] + | None + ) = None, + server_header: str | None = SERVER, + ) -> None: + """ + Perform the opening handshake. + + """ + await race_events(self.request_rcvd, self.stream_closed) + + if self.request is not None: + async with self.send_context(expected_state=CONNECTING): + response = None + + if process_request is not None: + try: + response = process_request(self, self.request) + if isinstance(response, Awaitable): + response = await response + except Exception as exc: + self.protocol.handshake_exc = exc + response = self.protocol.reject( + http.HTTPStatus.INTERNAL_SERVER_ERROR, + ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ), + ) + + if response is None: + if not self.server.closing: + self.response = self.protocol.accept(self.request) + else: + self.response = self.protocol.reject( + http.HTTPStatus.SERVICE_UNAVAILABLE, + "Server is shutting down.\n", + ) + else: + assert isinstance(response, Response) # help mypy + self.response = response + + if server_header is not None: + self.response.headers["Server"] = server_header + + response = None + + if process_response is not None: + try: + response = process_response(self, self.request, self.response) + if isinstance(response, Awaitable): + response = await response + except Exception as exc: + self.protocol.handshake_exc = exc + response = self.protocol.reject( + http.HTTPStatus.INTERNAL_SERVER_ERROR, + ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ), + ) + + if response is not None: + assert isinstance(response, Response) # help mypy + self.response = response + + self.protocol.send_response(self.response) + + # self.protocol.handshake_exc is set when the connection is lost before + # receiving a request, when the request cannot be parsed, or when the + # handshake fails, including when process_request or process_response + # raises an exception. + + # It isn't set when process_request or process_response sends an HTTP + # response that rejects the handshake. + + if self.protocol.handshake_exc is not None: + raise self.protocol.handshake_exc + + def process_event(self, event: Event) -> None: + """ + Process one incoming event. + + """ + # First event - handshake request. + if self.request is None: + assert isinstance(event, Request) + self.request = event + self.request_rcvd.set() + # Later events - frames. + else: + super().process_event(event) + + +class Server(trio.abc.AsyncResource): + """ + WebSocket server returned by :func:`serve`. + + Args: + open_listeners: Factory for Trio listeners accepting new connections. + stream_handler: Handler for one connection. It receives a Trio stream. + logger: Logger for this server. + It defaults to ``logging.getLogger("websockets.server")``. + See the :doc:`logging guide <../../topics/logging>` for details. + + """ + + def __init__( + self, + open_listeners: Callable[[], Awaitable[list[trio.SocketListener]]], + stream_handler: Callable[[trio.abc.Stream, Server], Awaitable[None]], + logger: LoggerLike | None = None, + ) -> None: + self.open_listeners = open_listeners + self.stream_handler = stream_handler + if logger is None: + logger = logging.getLogger("websockets.server") + self.logger = logger + + self.listeners: list[trio.SocketListener] = [] + """Trio listeners.""" + + self.closing = False + self.closed_waiters: dict[ServerConnection, trio.Event] = {} + + @property + def connections(self) -> set[ServerConnection]: + """ + Set of active connections. + + This property contains all connections that completed the opening + handshake successfully and didn't start the closing handshake yet. + + .. It can be useful in combination with :func:`~broadcast`. + + """ + return { + connection + for connection in self.closed_waiters + if connection.protocol.state is OPEN + } + + async def serve_forever( + self, + task_status: trio.TaskStatus[Server] = trio.TASK_STATUS_IGNORED, + ) -> None: + self.listeners = await self.open_listeners() # used in tests + # Running handlers in a dedicated nursery makes it possible to close + # listeners while handlers finish running. The nursery for listeners + # is created in trio.serve_listeners(). + async with trio.open_nursery() as self.handler_nursery: + # Wrap trio.serve_listeners() in another nursery to return the + # Server object in task_status instead of a list of listeners. + async with trio.open_nursery() as self.serve_nursery: + await self.serve_nursery.start( + functools.partial( + trio.serve_listeners, + functools.partial(self.stream_handler, server=self), # type: ignore + self.listeners, + handler_nursery=self.handler_nursery, + ) + ) + task_status.started(self) + + # Shutting down the server cleanly when serve_forever() is canceled would be + # the most idiomatic in Trio. However, that would require shielding too many + # asynchronous operations, including the TLS & WebSocket opening handshakes. + + async def aclose( + self, + close_connections: bool = True, + code: CloseCode | int = CloseCode.GOING_AWAY, + reason: str = "", + ) -> None: + """ + Close the server. + + * Close the TCP listeners. + * When ``close_connections`` is :obj:`True`, which is the default, + close existing connections. Specifically: + + * Reject opening WebSocket connections with an HTTP 503 (service + unavailable) error. This happens when the server accepted the TCP + connection but didn't complete the opening handshake before closing. + * Close open WebSocket connections with close code 1001 (going away). + ``code`` and ``reason`` can be customized, for example to use code + 1012 (service restart). + + * Wait until all connection handlers have returned. + + :meth:`aclose` is idempotent. + + """ + self.logger.info("server closing") + + # Stop accepting new connections. + self.serve_nursery.cancel_scope.cancel() + + # Reject OPENING connections with HTTP 503 -- see handshake(). + self.closing = True + + # Close OPEN connections. + if close_connections: + for connection in self.closed_waiters: + if connection.protocol.state is not OPEN: # pragma: no cover + continue + self.handler_nursery.start_soon(connection.aclose, code, reason) + + # Wait until all connection handlers have returned. + while self.closed_waiters: + await next(iter(self.closed_waiters.values())).wait() + + self.logger.info("server closed") + + async def __aenter__(self) -> Server: # pragma: no cover + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: # pragma: no cover + await self.aclose() + + +async def serve( + handler: Callable[[ServerConnection], Awaitable[None]], + port: int | None = None, + *, + # TCP/TLS + host: str | bytes | None = None, + backlog: int | None = None, + listeners: list[trio.SocketListener] | None = None, + ssl: ssl_module.SSLContext | None = None, + # WebSocket + origins: Sequence[Origin | re.Pattern[str] | None] | None = None, + extensions: Sequence[ServerExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + select_subprotocol: ( + Callable[ + [ServerConnection, Sequence[Subprotocol]], + Subprotocol | None, + ] + | None + ) = None, + compression: str | None = "deflate", + # HTTP + process_request: ( + Callable[ + [ServerConnection, Request], + Response | None, + ] + | None + ) = None, + process_response: ( + Callable[ + [ServerConnection, Request, Response], + Response | None, + ] + | None + ) = None, + server_header: str | None = SERVER, + # Timeouts + open_timeout: float | None = 10, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = 10, + # Limits + max_size: int | None | tuple[int | None, int | None] = 2**20, + max_queue: int | None | tuple[int | None, int | None] = 16, + # Logging + logger: LoggerLike | None = None, + # Escape hatch for advanced customization + create_connection: type[ServerConnection] | None = None, + # Trio + task_status: trio.TaskStatus[Server] = trio.TASK_STATUS_IGNORED, +) -> None: + """ + Create a WebSocket server listening on ``port``. + + Whenever a client connects, the server creates a :class:`ServerConnection`, + performs the opening handshake, and delegates to the ``handler`` coroutine. + + The handler receives the :class:`ServerConnection` instance, which you can + use to send and receive messages. + + Once the handler completes, either normally or with an exception, the server + performs the closing handshake and closes the connection. + + When using :func:`serve` with :meth:`nursery.start `, + you get back a :class:`Server` object. Call its :meth:`~Server.aclose` + method to stop the server gracefully:: + + from websockets.trio.server import serve + + async def handler(websocket): + ... + + # set this event to exit the server + stop = trio.Event() + + with trio.open_nursery() as nursery: + server = await nursery.start(serve, handler, port) + try: + await stop.wait() + finally: + await server.aclose() + + Args: + handler: Connection handler. It receives the WebSocket connection, + which is a :class:`ServerConnection`, in argument. + port: TCP port the server listens on. + See :func:`~trio.open_tcp_listeners` for details. + host: Network interfaces the server binds to. + See :func:`~trio.open_tcp_listeners` for details. + backlog: Listen backlog. See :func:`~trio.open_tcp_listeners` for + details. + listeners: Preexisting TCP listeners. ``listeners`` replaces ``port``, + ``host``, and ``backlog``. See :func:`trio.serve_listeners` for + details. + ssl: Configuration for enabling TLS on the connection. + origins: Acceptable values of the ``Origin`` header, for defending + against Cross-Site WebSocket Hijacking attacks. Values can be + :class:`str` to test for an exact match or regular expressions + compiled by :func:`re.compile` to test against a pattern. Include + :obj:`None` in the list if the lack of an origin is acceptable. + extensions: List of supported extensions, in order in which they + should be negotiated and run. + subprotocols: List of supported subprotocols, in order of decreasing + preference. + select_subprotocol: Callback for selecting a subprotocol among + those supported by the client and the server. It receives a + :class:`ServerConnection` (not a + :class:`~websockets.server.ServerProtocol`!) instance and a list of + subprotocols offered by the client. Other than the first argument, + it has the same behavior as the + :meth:`ServerProtocol.select_subprotocol + ` method. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. + process_request: Intercept the request during the opening handshake. + Return an HTTP response to force the response. Return :obj:`None` to + continue normally. When you force an HTTP 101 Continue response, the + handshake is successful. Else, the connection is aborted. + process_response: Intercept the response during the opening handshake. + Modify the response or return a new HTTP response to force the + response. Return :obj:`None` to continue normally. When you force an + HTTP 101 Continue response, the handshake is successful. Else, the + connection is aborted. + server_header: Value of the ``Server`` response header. + It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to + :obj:`None` removes the header. + open_timeout: Timeout for opening connections in seconds. + :obj:`None` disables the timeout. + ping_interval: Interval between keepalive pings in seconds. + :obj:`None` disables keepalive. + ping_timeout: Timeout for keepalive pings in seconds. + :obj:`None` disables timeouts. + close_timeout: Timeout for closing connections in seconds. + :obj:`None` disables the timeout. + max_size: Maximum size of incoming messages in bytes. + :obj:`None` disables the limit. You may pass a ``(max_message_size, + max_fragment_size)`` tuple to set different limits for messages and + fragments when you expect long messages sent in short fragments. + max_queue: High-water mark of the buffer where frames are received. + It defaults to 16 frames. The low-water mark defaults to ``max_queue + // 4``. You may pass a ``(high, low)`` tuple to set the high-water + and low-water marks. If you want to disable flow control entirely, + you may set it to ``None``, although that's a bad idea. + logger: Logger for this server. + It defaults to ``logging.getLogger("websockets.server")``. See the + :doc:`logging guide <../../topics/logging>` for details. + create_connection: Factory for the :class:`ServerConnection` managing + the connection. Set it to a wrapper or a subclass to customize + connection handling. + task_status: For compatibility with :meth:`nursery.start + `. + + """ + + # Process parameters + + if subprotocols is not None: + validate_subprotocols(subprotocols) + + if compression == "deflate": + extensions = enable_server_permessage_deflate(extensions) + elif compression is not None: + raise ValueError(f"unsupported compression: {compression}") + + if create_connection is None: + create_connection = ServerConnection + + # Create listeners + + if listeners is None: + if port is None: + raise ValueError("port is required when listeners is not provided") + + async def open_listeners() -> list[trio.SocketListener]: + return await trio.open_tcp_listeners(port, host=host, backlog=backlog) + else: + if port is not None: + raise ValueError("port is incompatible with listeners") + if host is not None: + raise ValueError("host is incompatible with listeners") + if backlog is not None: + raise ValueError("backlog is incompatible with listeners") + + async def open_listeners() -> list[trio.SocketListener]: + return listeners + + async def stream_handler(stream: trio.abc.Stream, server: Server) -> None: + async with trio.open_nursery() as nursery: + try: + # Apply open_timeout to the TLS and WebSocket handshake. + with ( + trio.CancelScope() + if open_timeout is None + else trio.move_on_after(open_timeout) + ): + # Enable TLS. + if ssl is not None: + # Wrap with SSLStream here rather than with TLSListener + # in order to include the TLS handshake within open_timeout. + stream = trio.SSLStream( + stream, + ssl, + server_side=True, + https_compatible=True, + ) + assert isinstance(stream, trio.SSLStream) # help mypy + try: + await stream.do_handshake() + except trio.BrokenResourceError: + return + + # Create a closure to give select_subprotocol access to connection. + protocol_select_subprotocol: ( + Callable[ + [ServerProtocol, Sequence[Subprotocol]], + Subprotocol | None, + ] + | None + ) = None + if select_subprotocol is not None: + + def protocol_select_subprotocol( + protocol: ServerProtocol, + subprotocols: Sequence[Subprotocol], + ) -> Subprotocol | None: + # mypy doesn't know that select_subprotocol is immutable. + assert select_subprotocol is not None + # Ensure this function is only used in the intended context. + assert protocol is connection.protocol + return select_subprotocol(connection, subprotocols) + + # Initialize WebSocket protocol. + protocol = ServerProtocol( + origins=origins, + extensions=extensions, + subprotocols=subprotocols, + select_subprotocol=protocol_select_subprotocol, + max_size=max_size, + logger=logger, + ) + + # Initialize WebSocket connection. + assert create_connection is not None # help mypy + connection = create_connection( + nursery, + stream, + protocol, + server, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + close_timeout=close_timeout, + max_queue=max_queue, + ) + + try: + await connection.handshake( + process_request, + process_response, + server_header, + ) + except trio.Cancelled: + # The nursery running this coroutine was canceled. + # The next checkpoint raises trio.Cancelled. + # aclose_forcefully() never returns. + await trio.aclose_forcefully(stream) + raise AssertionError("nursery should be canceled") + except Exception: + connection.logger.error( + "opening handshake failed", exc_info=True + ) + await trio.aclose_forcefully(stream) + return + + if connection.protocol.state is not OPEN: + # process_request or process_response rejected the handshake. + await connection.close_stream() + return + + try: + server.closed_waiters[connection] = trio.Event() + connection.start_keepalive() + await handler(connection) + except Exception: + connection.logger.error("connection handler failed", exc_info=True) + await connection.aclose(CloseCode.INTERNAL_ERROR) + else: + await connection.aclose() + finally: + server.closed_waiters.pop(connection).set() + + except Exception: # pragma: no cover + # Don't leak connections on unexpected errors. + await trio.aclose_forcefully(stream) + + server = Server(open_listeners, stream_handler, logger) + await server.serve_forever(task_status=task_status) diff --git a/tests/test_localhost.cnf b/tests/test_localhost.cnf index 4069e396..15d49228 100644 --- a/tests/test_localhost.cnf +++ b/tests/test_localhost.cnf @@ -24,4 +24,5 @@ subjectAltName = @san DNS.1 = localhost DNS.2 = overridden IP.3 = 127.0.0.1 -IP.4 = ::1 +IP.4 = 0.0.0.0 +IP.5 = ::1 diff --git a/tests/test_localhost.pem b/tests/test_localhost.pem index 8df63ec8..1f26df71 100644 --- a/tests/test_localhost.pem +++ b/tests/test_localhost.pem @@ -1,48 +1,49 @@ -----BEGIN PRIVATE KEY----- -MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDYOOQyq8yYtn5x -K3yRborFxTFse16JIVb4x/ZhZgGm49eARCi09fmczQxJdQpHz81Ij6z0xi7AUYH7 -9wS8T0Lh3uGFDDS1GzITUVPIqSUi0xim2T6XPzXFVQYI1D/OjUxlHm+3/up+WwbL -sBgBO/lDmzoa3ZN7kt9HQoGc/14oQz1Qsv1QTDQs69r+o7mmBJr/hf/g7S0Csyy3 -iC6aaq+yCUyzDbjXceTI7WJqbTGNnK0/DjdFD/SJS/uSDNEg0AH53eqcCSjm+Ei/ -UF8qR5Pu4sSsNwToOW2MVgjtHFazc+kG3rzD6+3Dp+t6x6uI/npyuudOMCmOtd6z -kX0UPQaNAgMBAAECggEAS4eMBztGC+5rusKTEAZKSY15l0h9HG/d/qdzJFDKsO6T -/8VPZu8pk6F48kwFHFK1hexSYWq9OAcA3fBK4jDZzybZJm2+F6l5U5AsMUMMqt6M -lPP8Tj8RXG433muuIkvvbL82DVLpvNu1Qv+vUvcNOpWFtY7DDv6eKjlMJ3h4/pzh -89MNt26VMCYOlq1NSjuZBzFohL2u9nsFehlOpcVsqNfNfcYCq9+5yoH8fWJP90Op -hqhvqUoGLN7DRKV1f+AWHSA4nmGgvVviV5PQgMhtk5exlN7kG+rDc3LbzhefS1Sp -Tat1qIgm8fK2n+Q/obQPjHOGOGuvE5cIF7E275ZKgQKBgQDt87BqALKWnbkbQnb7 -GS1h6LRcKyZhFbxnO2qbviBWSo15LEF8jPGV33Dj+T56hqufa/rUkbZiUbIR9yOX -dnOwpAVTo+ObAwZfGfHvrnufiIbHFqJBumaYLqjRZ7AC0QtS3G+kjS9dbllrr7ok -fO4JdfKRXzBJKrkQdCn8hR22rQKBgQDon0b49Dxs1EfdSDbDode2TSwE83fI3vmR -SKUkNY8ma6CRbomVRWijhBM458wJeuhpjPZOvjNMsnDzGwrtdAp2VfFlMIDnA8ZC -fEWIAAH2QYKXKGmkoXOcWB2QbvbI154zCm6zFGtzvRKOCGmTXuhFajO8VPwOyJVt -aSJA3bLrYQKBgQDJM2/tAfAAKRdW9GlUwqI8Ep9G+/l0yANJqtTnIemH7XwYhJJO -9YJlPszfB2aMBgliQNSUHy1/jyKpzDYdITyLlPUoFwEilnkxuud2yiuf5rpH51yF -hU6wyWtXvXv3tbkEdH42PmdZcjBMPQeBSN2hxEi6ISncBDL9tau26PwJ9QKBgQCs -cNYl2reoXTzgtpWSNDk6NL769JjJWTFcF6QD0YhKjOI8rNpkw00sWc3+EybXqDr9 -c7dq6+gPZQAB1vwkxi6zRkZqIqiLl+qygnjwtkC+EhYCg7y8g8q2DUPtO7TJcb0e -TQ9+xRZad8B3dZj93A8G1hF//OfU9bB/qL3xo+bsQQKBgC/9YJvgLIWA/UziLcB2 -29Ai0nbPkN5df7z4PifUHHSlbQJHKak8UKbMP+8S064Ul0F7g8UCjZMk2LzSbaNY -XU5+2j0sIOnGUFoSlvcpdowzYrD2LN5PkKBot7AOq/v7HlcOoR8J8RGWAMpCrHsI -a/u/dlZs+/K16RcavQwx8rag +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDKiNs9JHIq5I2c +GjupVn8QJ3oi+lSpEwdUu6aw/q1H9mVzv1dFtp7hT8kuhclNf1tlBBFiB+NWbRZc +uyBRq+mIIWfepcHRHpquxyopesD+CdeC0rogq3vry94FJNmN8257WZiraNl3v9ht +eBqTy0xYDsDtl8iYLfT4xPDfJVOMq0R6SQEljWi6jSbR3b74wiLpXoWjvx7KJahH +hd/p48meuq95tGfxDEb7r/h02RpZF5rq2zRqBOcO4nL5drWYBh1I4+RFp+AbCixX +MqWh1e0vl/wXiKwYTPIgqH2DIXxS3m8dn4O74zO0ktRqPkIXMyKAZQkdUNLngE7v +pNeDcQatAgMBAAECggEACRc/WtZvBt7YYu9IgP0btWBF9hoa0yOwA8P97FpQ8YkI +rpa0bVZrnjz2fkZNdwodLd43YBlKZe1ZbhxD1S1+uuYEY3TvpvWC7A78pPz86IEN +TPu/Jt1AMeo4d5vtLoS7fSYLBwl2H7OI03Y0ROeS8FJXfrKixdp2OmLmVcOAXDDj +Eq0Xs2tSXXPVZ8KKGMidKqvfxcVAhOZvJfHvkMJ+tS/FRAn7Qxc1tn7OTUOg+glr +sHdMwImfzDCbyhP5gZXL/MP35UqnKUBAGdJmfp3BkFxk0yGLhlCOefs1/a9PhVOt +Q83+kjWnuYeP3R4jB7fuWtEu0/gPZT/P1iJF4MIhjwKBgQDqPtT+7G7KMThGjdm6 +bu77VDsW10T5uDU55G3LvXHoFTZUnleSOtWrh2mdR3KVj5PdHDR4VSuA0d65S39n +LYVul82FMgjCWKL4odgssPcLD6SsybdF9xXSXJKtQ96eJjW0o7vMu0/CHrhF6whA +EmCeDcD81Bzvj8DbkSyHpIaolwKBgQDdWBn43eVBt8FStAXx3J49pMyw83AXyqNA +3taHTGjG9BnjgsRgQeYmZG82xpD/Yu6dYyzF+rI4iODkSzF1FN+j64ElDRJbAMvS +yThbAKAb+xegh0EQm43+kYG1sDavWT4pvzh6DCltN82eHwJ5utDuneiAB66DeAqY +ttXmw+fPWwKBgHYEoBWsE4mlUMAjWc5Xc+qGnpq8bNEQISkA0Ny0nv4aKdxqRp6z +K9IXEHwgcjeuNgZR3pG9/4QQuRFMW20lfzOgIfj4o3cfZ0SzbhHeOymEgShZHRCQ +E5t/7pqDNlch0y8my0i0GtQn3BnF98soNyuKrG/1gnqkR7uYIgJZP0sTAoGAGHLt +0353H04zzXXTHkcXN4nnjjgljos0gyraGXHINQmrfmToWhWNXXpEipFeXMdJwhq9 +TFUHsJT1+mGP4fXfShTuW/BYsbKh0POnBO5JwS14C6RE/JeiFJdv82i2caHy6tuT +Wm/Td5vtW2Tjehy3jVPl5ZZzoVP2H646bFYBWfcCgYEAkWJLFzvXsF9SW9Ku6cc0 +7Yhuoolad/AWCXe5Q3+k+icgOQFnMsOkuEPIlRHPgjaOnXMq76VyO4a66vK+ucgr +R3O8/h5QZiuxE3dfqXsDrGr/6W2kmDWWXXK9r5oJQ1J4ndj65ZaGcAuw/77hf5K8 +PnN3beykcf5xxuaPNpq0cbg= -----END PRIVATE KEY----- -----BEGIN CERTIFICATE----- -MIIDWTCCAkGgAwIBAgIJAOL9UKiOOxupMA0GCSqGSIb3DQEBCwUAMEwxCzAJBgNV -BAYTAkZSMQ4wDAYDVQQHDAVQYXJpczEZMBcGA1UECgwQQXltZXJpYyBBdWd1c3Rp -bjESMBAGA1UEAwwJbG9jYWxob3N0MCAXDTIyMTAxNTE5Mjg0MVoYDzIwNjQxMDE0 -MTkyODQxWjBMMQswCQYDVQQGEwJGUjEOMAwGA1UEBwwFUGFyaXMxGTAXBgNVBAoM -EEF5bWVyaWMgQXVndXN0aW4xEjAQBgNVBAMMCWxvY2FsaG9zdDCCASIwDQYJKoZI -hvcNAQEBBQADggEPADCCAQoCggEBANg45DKrzJi2fnErfJFuisXFMWx7XokhVvjH -9mFmAabj14BEKLT1+ZzNDEl1CkfPzUiPrPTGLsBRgfv3BLxPQuHe4YUMNLUbMhNR -U8ipJSLTGKbZPpc/NcVVBgjUP86NTGUeb7f+6n5bBsuwGAE7+UObOhrdk3uS30dC -gZz/XihDPVCy/VBMNCzr2v6juaYEmv+F/+DtLQKzLLeILppqr7IJTLMNuNdx5Mjt -YmptMY2crT8ON0UP9IlL+5IM0SDQAfnd6pwJKOb4SL9QXypHk+7ixKw3BOg5bYxW -CO0cVrNz6QbevMPr7cOn63rHq4j+enK6504wKY613rORfRQ9Bo0CAwEAAaM8MDow -OAYDVR0RBDEwL4IJbG9jYWxob3N0ggpvdmVycmlkZGVuhwR/AAABhxAAAAAAAAAA -AAAAAAAAAAABMA0GCSqGSIb3DQEBCwUAA4IBAQBPNDGDdl4wsCRlDuyCHBC8o+vW -Vb14thUw9Z6UrlsQRXLONxHOXbNAj1sYQACNwIWuNz36HXu5m8Xw/ID/bOhnIg+b -Y6l/JU/kZQYB7SV1aR3ZdbCK0gjfkE0POBHuKOjUFIOPBCtJ4tIBUX94zlgJrR9v -2rqJC3TIYrR7pVQumHZsI5GZEMpM5NxfreWwxcgltgxmGdm7elcizHfz7k5+szwh -4eZ/rxK9bw1q8BIvVBWelRvUR55mIrCjzfZp5ZObSYQTZlW7PzXBe5Jk+1w31YHM -RSBA2EpPhYlGNqPidi7bg7rnQcsc6+hE0OqzTL/hWxPm9Vbp9dj3HFTik1wa +MIIDiTCCAnGgAwIBAgIURQDnIfsMPAhuq9Uq1dka01Qoc9IwDQYJKoZIhvcNAQEL +BQAwTDELMAkGA1UEBhMCRlIxDjAMBgNVBAcMBVBhcmlzMRkwFwYDVQQKDBBBeW1l +cmljIEF1Z3VzdGluMRIwEAYDVQQDDAlsb2NhbGhvc3QwIBcNMjUwNTMxMjAxMDU1 +WhgPMjA2NzA1MzEyMDEwNTVaMEwxCzAJBgNVBAYTAkZSMQ4wDAYDVQQHDAVQYXJp +czEZMBcGA1UECgwQQXltZXJpYyBBdWd1c3RpbjESMBAGA1UEAwwJbG9jYWxob3N0 +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAyojbPSRyKuSNnBo7qVZ/ +ECd6IvpUqRMHVLumsP6tR/Zlc79XRbae4U/JLoXJTX9bZQQRYgfjVm0WXLsgUavp +iCFn3qXB0R6arscqKXrA/gnXgtK6IKt768veBSTZjfNue1mYq2jZd7/YbXgak8tM +WA7A7ZfImC30+MTw3yVTjKtEekkBJY1ouo0m0d2++MIi6V6Fo78eyiWoR4Xf6ePJ +nrqvebRn8QxG+6/4dNkaWRea6ts0agTnDuJy+Xa1mAYdSOPkRafgGwosVzKlodXt +L5f8F4isGEzyIKh9gyF8Ut5vHZ+Du+MztJLUaj5CFzMigGUJHVDS54BO76TXg3EG +rQIDAQABo2EwXzA+BgNVHREENzA1gglsb2NhbGhvc3SCCm92ZXJyaWRkZW6HBH8A +AAGHBAAAAACHEAAAAAAAAAAAAAAAAAAAAAEwHQYDVR0OBBYEFB7eswhXVVmG32UR +MGtc2vewZjM0MA0GCSqGSIb3DQEBCwUAA4IBAQBt9KGnnrtn15H9wz4fWHzPTGaO +laJQE5RnqlzyQ3aDLRtZIc/OA+0L6rW7+xiiN0v1irqCD/M0YGYGomy//3J444bT +SxciJQarZPtNRaLJx17geQOwbY5NpTsfEKmvhwCnMLx9Wy6kyHx0NyD3e1MJwH47 +QdJDmKCVF2R10AKGlnsp6zYaoOvoY48MvCBOnaZEVXPypta0N3XXrASsllw5QJSb +XXPIdNbwA22necSoa7PchMXIbyDXIhygf+tXVBAKvNaSNCzQPehTmepENYJPFEh/ +NJrYPB769uRPgZxIvivo1QjNik4ywcZlvEU6LC6JPUasUcGY6FTnipLL6lD0 -----END CERTIFICATE----- diff --git a/tests/trio/server.py b/tests/trio/server.py new file mode 100644 index 00000000..d2172af2 --- /dev/null +++ b/tests/trio/server.py @@ -0,0 +1,63 @@ +import contextlib +import functools +import socket +import urllib.parse + +import trio + +from websockets.trio.server import * + + +def get_host_port(listeners): + for listener in listeners: + if listener.socket.family == socket.AF_INET: # pragma: no branch + return listener.socket.getsockname() + raise AssertionError("expected at least one IPv4 socket") + + +def get_uri(server, secure=False): + protocol = "wss" if secure else "ws" + host, port = get_host_port(server.listeners) + return f"{protocol}://{host}:{port}" + + +async def handler(ws): + path = urllib.parse.urlparse(ws.request.path).path + if path == "/": + # The default path is an eval shell. + async for expr in ws: + value = eval(expr) + await ws.send(str(value)) + elif path == "/crash": + raise RuntimeError + elif path == "/no-op": + pass + elif path == "/delay": + delay = float(await ws.recv()) + await ws.aclose() + await trio.sleep(delay) + else: + raise AssertionError(f"unexpected path: {path}") + + +kwargs = {"handler": handler, "port": 0, "host": "localhost"} + + +@contextlib.asynccontextmanager +async def run_server(**overrides): + merged_kwargs = {**kwargs, **overrides} + async with trio.open_nursery() as nursery: + server = await nursery.start(functools.partial(serve, **merged_kwargs)) + try: + yield server + finally: + # Run all tasks to guarantee that any exceptions are raised. + # Otherwise, canceling the nursery could hide errors. + await trio.testing.wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + + +class EvalShellMixin: + async def assertEval(self, client, expr, value): + await client.send(expr) + self.assertEqual(await client.recv(), value) diff --git a/tests/trio/test_client.py b/tests/trio/test_client.py new file mode 100644 index 00000000..7448b5fd --- /dev/null +++ b/tests/trio/test_client.py @@ -0,0 +1,927 @@ +import contextlib +import http +import logging +import os +import socket +import ssl +import sys +import unittest +from unittest.mock import patch + +import trio + +from websockets.client import backoff +from websockets.exceptions import ( + InvalidHandshake, + InvalidMessage, + InvalidProxy, + InvalidProxyMessage, + InvalidStatus, + InvalidURI, + ProxyError, + SecurityError, +) +from websockets.extensions.permessage_deflate import PerMessageDeflate +from websockets.trio.client import * + +from ..proxy import ProxyMixin +from ..utils import CLIENT_CONTEXT, MS, SERVER_CONTEXT +from .server import get_host_port, get_uri, run_server +from .utils import IsolatedTrioTestCase + + +@contextlib.asynccontextmanager +async def short_backoff_delay(): + defaults = backoff.__defaults__ + backoff.__defaults__ = ( + defaults[0] * MS, + defaults[1] * MS, + defaults[2] * MS, + defaults[3], + ) + try: + yield + finally: + backoff.__defaults__ = defaults + + +@contextlib.asynccontextmanager +async def few_redirects(): + from websockets.trio import client + + max_redirects = client.MAX_REDIRECTS + client.MAX_REDIRECTS = 2 + try: + yield + finally: + client.MAX_REDIRECTS = max_redirects + + +class ClientTests(IsolatedTrioTestCase): + async def test_connection(self): + """Client connects to server.""" + async with run_server() as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + + async def test_explicit_host_port(self): + """Client connects using an explicit host / port.""" + async with run_server() as server: + host, port = get_host_port(server.listeners) + async with connect("ws://overridden/", host=host, port=port) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + + async def test_existing_stream(self): + """Client connects using a pre-existing stream.""" + async with run_server() as server: + stream = await trio.open_tcp_stream(*get_host_port(server.listeners)) + # Use a non-existing domain to ensure we connect via stream. + async with connect("ws://invalid/", stream=stream) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + + async def test_compression_is_enabled(self): + """Client enables compression by default.""" + async with run_server() as server: + async with connect(get_uri(server)) as client: + self.assertEqual( + [type(ext) for ext in client.protocol.extensions], + [PerMessageDeflate], + ) + + async def test_disable_compression(self): + """Client disables compression.""" + async with run_server() as server: + async with connect(get_uri(server), compression=None) as client: + self.assertEqual(client.protocol.extensions, []) + + async def test_additional_headers(self): + """Client can set additional headers with additional_headers.""" + async with run_server() as server: + async with connect( + get_uri(server), additional_headers={"Authorization": "Bearer ..."} + ) as client: + self.assertEqual(client.request.headers["Authorization"], "Bearer ...") + + async def test_override_user_agent(self): + """Client can override User-Agent header with user_agent_header.""" + async with run_server() as server: + async with connect(get_uri(server), user_agent_header="Smith") as client: + self.assertEqual(client.request.headers["User-Agent"], "Smith") + + async def test_remove_user_agent(self): + """Client can remove User-Agent header with user_agent_header.""" + async with run_server() as server: + async with connect(get_uri(server), user_agent_header=None) as client: + self.assertNotIn("User-Agent", client.request.headers) + + async def test_legacy_user_agent(self): + """Client can override User-Agent header with additional_headers.""" + async with run_server() as server: + async with connect( + get_uri(server), additional_headers={"User-Agent": "Smith"} + ) as client: + self.assertEqual(client.request.headers["User-Agent"], "Smith") + + async def test_keepalive_is_enabled(self): + """Client enables keepalive and measures latency by default.""" + async with run_server() as server: + async with connect(get_uri(server), ping_interval=MS) as client: + self.assertEqual(client.latency, 0) + await trio.sleep(2 * MS) + self.assertGreater(client.latency, 0) + + async def test_disable_keepalive(self): + """Client disables keepalive.""" + async with run_server() as server: + async with connect(get_uri(server), ping_interval=None) as client: + await trio.sleep(2 * MS) + self.assertEqual(client.latency, 0) + + async def test_logger(self): + """Client accepts a logger argument.""" + logger = logging.getLogger("test") + async with run_server() as server: + async with connect(get_uri(server), logger=logger) as client: + self.assertEqual(client.logger.name, logger.name) + + async def test_custom_connection_factory(self): + """Client runs ClientConnection factory provided in create_connection.""" + + def create_connection(*args, **kwargs): + client = ClientConnection(*args, **kwargs) + client.create_connection_ran = True + return client + + async with run_server() as server: + async with connect( + get_uri(server), create_connection=create_connection + ) as client: + self.assertTrue(client.create_connection_ran) + + @short_backoff_delay() + async def test_reconnect(self): + """Client reconnects to server.""" + iterations = 0 + successful = 0 + + async def process_request(connection, request): + nonlocal iterations + iterations += 1 + # Retriable errors + if iterations == 1: + await trio.sleep(3 * MS) + elif iterations == 2: + await connection.stream.aclose() + elif iterations == 3: + return connection.respond(http.HTTPStatus.SERVICE_UNAVAILABLE, "🚒") + # Fatal error + elif iterations == 6: + return connection.respond(http.HTTPStatus.PAYMENT_REQUIRED, "💸") + + async with run_server(process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async for client in connect(get_uri(server), open_timeout=3 * MS): + self.assertEqual(client.protocol.state.name, "OPEN") + successful += 1 + + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 402", + ) + self.assertEqual(iterations, 6) + self.assertEqual(successful, 2) + + @short_backoff_delay() + async def test_reconnect_with_custom_process_exception(self): + """Client runs process_exception to tell if errors are retryable or fatal.""" + iteration = 0 + + def process_request(connection, request): + nonlocal iteration + iteration += 1 + if iteration == 1: + return connection.respond(http.HTTPStatus.SERVICE_UNAVAILABLE, "🚒") + return connection.respond(http.HTTPStatus.IM_A_TEAPOT, "🫖") + + def process_exception(exc): + if isinstance(exc, InvalidStatus): + if 500 <= exc.response.status_code < 600: + return None + if exc.response.status_code == 418: + return Exception("🫖 💔 ☕️") + self.fail("unexpected exception") + + async with run_server(process_request=process_request) as server: + with self.assertRaises(Exception) as raised: + async for _ in connect( + get_uri(server), process_exception=process_exception + ): + self.fail("did not raise") + + self.assertEqual(iteration, 2) + self.assertEqual( + str(raised.exception), + "🫖 💔 ☕️", + ) + + @short_backoff_delay() + async def test_reconnect_with_custom_process_exception_raising_exception(self): + """Client supports raising an exception in process_exception.""" + + def process_request(connection, request): + return connection.respond(http.HTTPStatus.IM_A_TEAPOT, "🫖") + + def process_exception(exc): + if isinstance(exc, InvalidStatus) and exc.response.status_code == 418: + raise Exception("🫖 💔 ☕️") + self.fail("unexpected exception") + + async with run_server(process_request=process_request) as server: + with self.assertRaises(Exception) as raised: + async for _ in connect( + get_uri(server), process_exception=process_exception + ): + self.fail("did not raise") + + self.assertEqual( + str(raised.exception), + "🫖 💔 ☕️", + ) + + async def test_redirect(self): + """Client follows redirect.""" + + def redirect(connection, request): + if request.path == "/redirect": + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = "/" + return response + + async with run_server(process_request=redirect) as server: + async with connect(get_uri(server) + "/redirect") as client: + self.assertEqual(client.protocol.uri.path, "/") + + async def test_cross_origin_redirect(self): + """Client follows redirect to a secure URI on a different origin.""" + + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = get_uri(other_server) + return response + + async with run_server(process_request=redirect) as server: + async with run_server() as other_server: + async with connect(get_uri(server)): + self.assertFalse(server.connections) + self.assertTrue(other_server.connections) + + @few_redirects() + async def test_redirect_limit(self): + """Client stops following redirects after limit is reached.""" + + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = request.path + return response + + async with run_server(process_request=redirect) as server: + with self.assertRaises(SecurityError) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + + self.assertEqual( + str(raised.exception), + "more than 2 redirects", + ) + + async def test_redirect_with_explicit_host_port(self): + """Client follows redirect with an explicit host / port.""" + + def redirect(connection, request): + if request.path == "/redirect": + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = "/" + return response + + async with run_server(process_request=redirect) as server: + host, port = get_host_port(server.listeners) + async with connect( + "ws://overridden/redirect", host=host, port=port + ) as client: + self.assertEqual(client.protocol.uri.path, "/") + + async def test_cross_origin_redirect_with_explicit_host_port(self): + """Client doesn't follow cross-origin redirect with an explicit host / port.""" + + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = "ws://other/" + return response + + async with run_server(process_request=redirect) as server: + host, port = get_host_port(server.listeners) + with self.assertRaises(ValueError) as raised: + async with connect("ws://overridden/", host=host, port=port): + self.fail("did not raise") + + self.assertEqual( + str(raised.exception), + "cannot follow cross-origin redirect to ws://other/ " + "with an explicit host or port", + ) + + async def test_redirect_with_existing_stream(self): + """Client doesn't follow redirect when using a pre-existing stream.""" + + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = "/" + return response + + async with run_server(process_request=redirect) as server: + stream = await trio.open_tcp_stream(*get_host_port(server.listeners)) + with self.assertRaises(ValueError) as raised: + # Use a non-existing domain to ensure we connect via sock. + async with connect("ws://invalid/redirect", stream=stream): + self.fail("did not raise") + + self.assertEqual( + str(raised.exception), + "cannot follow redirect to ws://invalid/ with a preexisting stream", + ) + + async def test_invalid_uri(self): + """Client receives an invalid URI.""" + with self.assertRaises(InvalidURI): + async with connect("http://localhost"): # invalid scheme + self.fail("did not raise") + + async def test_tcp_connection_fails(self): + """Client fails to connect to server.""" + with self.assertRaises(OSError): + async with connect("ws://localhost:54321"): # invalid port + self.fail("did not raise") + + async def test_handshake_fails(self): + """Client connects to server but the handshake fails.""" + + def remove_accept_header(self, request, response): + del response.headers["Sec-WebSocket-Accept"] + + # The connection will be open for the server but failed for the client. + # Use a connection handler that exits immediately to avoid an exception. + async with run_server(process_response=remove_accept_header) as server: + with self.assertRaises(InvalidHandshake) as raised: + async with connect(get_uri(server) + "/no-op", close_timeout=MS): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "missing Sec-WebSocket-Accept header", + ) + + async def test_timeout_during_handshake(self): + """Client times out before receiving handshake response from server.""" + # Replace the WebSocket server with a TCP server that doesn't respond. + with socket.create_server(("localhost", 0)) as sock: + host, port = sock.getsockname() + with self.assertRaises(TimeoutError) as raised: + async with connect(f"ws://{host}:{port}", open_timeout=2 * MS): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "timed out during opening handshake", + ) + + async def test_connection_closed_during_handshake(self): + """Client reads EOF before receiving handshake response from server.""" + + async def close_connection(self, request): + await self.stream.aclose() + + async with run_server(process_request=close_connection) as server: + with self.assertRaises(InvalidMessage) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "did not receive a valid HTTP response", + ) + self.assertIsInstance(raised.exception.__cause__, EOFError) + self.assertEqual( + str(raised.exception.__cause__), + "connection closed while reading HTTP status line", + ) + + async def test_http_response(self): + """Client reads HTTP response.""" + + def http_response(connection, request): + return connection.respond(http.HTTPStatus.OK, "👌") + + async with run_server(process_request=http_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + + self.assertEqual(raised.exception.response.status_code, 200) + self.assertEqual(raised.exception.response.body.decode(), "👌") + + async def test_http_response_without_content_length(self): + """Client reads HTTP response without a Content-Length header.""" + + def http_response(connection, request): + response = connection.respond(http.HTTPStatus.OK, "👌") + del response.headers["Content-Length"] + return response + + async with run_server(process_request=http_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + + self.assertEqual(raised.exception.response.status_code, 200) + self.assertEqual(raised.exception.response.body.decode(), "👌") + + async def test_junk_handshake(self): + """Client closes the connection when receiving non-HTTP response from server.""" + + async def junk(stream): + # Wait for the client to send the handshake request. + await trio.testing.wait_all_tasks_blocked() + await stream.send_all(b"220 smtp.invalid ESMTP Postfix\r\n") + # Wait for the client to close the connection. + await stream.receive_some() + await stream.aclose() + + async with trio.open_nursery() as nursery: + try: + listeners = await nursery.start(trio.serve_tcp, junk, 0) + host, port = get_host_port(listeners) + with self.assertRaises(InvalidMessage) as raised: + async with connect(f"ws://{host}:{port}"): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "did not receive a valid HTTP response", + ) + self.assertIsInstance(raised.exception.__cause__, ValueError) + self.assertEqual( + str(raised.exception.__cause__), + "unsupported protocol; expected HTTP/1.1: " + "220 smtp.invalid ESMTP Postfix", + ) + finally: + nursery.cancel_scope.cancel() + + +class SecureClientTests(IsolatedTrioTestCase): + async def test_connection(self): + """Client connects to server securely.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + async with connect( + get_uri(server, secure=True), ssl=CLIENT_CONTEXT + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + ssl_object = client.stream._ssl_object + self.assertEqual(ssl_object.version()[:3], "TLS") + + async def test_set_server_hostname_implicitly(self): + """Client sets server_hostname to the host in the WebSocket URI.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + host, port = get_host_port(server.listeners) + async with connect( + "wss://overridden/", host=host, port=port, ssl=CLIENT_CONTEXT + ) as client: + ssl_object = client.stream._ssl_object + self.assertEqual(ssl_object.server_hostname, "overridden") + + async def test_set_server_hostname_explicitly(self): + """Client sets server_hostname to the value provided in argument.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + async with connect( + get_uri(server, secure=True), + ssl=CLIENT_CONTEXT, + server_hostname="overridden", + ) as client: + ssl_object = client.stream._ssl_object + self.assertEqual(ssl_object.server_hostname, "overridden") + + async def test_reject_invalid_server_certificate(self): + """Client rejects certificate where server certificate isn't trusted.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + with self.assertRaises(trio.BrokenResourceError) as raised: + # The test certificate is self-signed. + async with connect(get_uri(server, secure=True)): + self.fail("did not raise") + self.assertIsInstance( + raised.exception.__cause__, + ssl.SSLCertVerificationError, + ) + self.assertIn( + "certificate verify failed: self signed certificate", + str(raised.exception.__cause__).replace("-", " "), + ) + + async def test_reject_invalid_server_hostname(self): + """Client rejects certificate where server hostname doesn't match.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + with self.assertRaises(trio.BrokenResourceError) as raised: + # This hostname isn't included in the test certificate. + async with connect( + get_uri(server, secure=True), + ssl=CLIENT_CONTEXT, + server_hostname="invalid", + ): + self.fail("did not raise") + self.assertIsInstance( + raised.exception.__cause__, + ssl.SSLCertVerificationError, + ) + self.assertIn( + "certificate verify failed: Hostname mismatch", + str(raised.exception.__cause__), + ) + + async def test_cross_origin_redirect(self): + """Client follows redirect to a secure URI on a different origin.""" + + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = get_uri(other_server, secure=True) + return response + + async with run_server(ssl=SERVER_CONTEXT, process_request=redirect) as server: + async with run_server(ssl=SERVER_CONTEXT) as other_server: + async with connect(get_uri(server, secure=True), ssl=CLIENT_CONTEXT): + self.assertFalse(server.connections) + self.assertTrue(other_server.connections) + + async def test_redirect_to_insecure_uri(self): + """Client doesn't follow redirect from secure URI to non-secure URI.""" + + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = insecure_uri + return response + + async with run_server(ssl=SERVER_CONTEXT, process_request=redirect) as server: + with self.assertRaises(SecurityError) as raised: + secure_uri = get_uri(server, secure=True) + insecure_uri = secure_uri.replace("wss://", "ws://") + async with connect(secure_uri, ssl=CLIENT_CONTEXT): + self.fail("did not raise") + + self.assertEqual( + str(raised.exception), + f"cannot follow redirect to non-secure URI {insecure_uri}", + ) + + +@unittest.skipUnless("mitmproxy" in sys.modules, "mitmproxy not installed") +class SocksProxyClientTests(ProxyMixin, IsolatedTrioTestCase): + proxy_mode = "socks5@51080" + + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) + async def test_socks_proxy(self): + """Client connects to server through a SOCKS5 proxy.""" + async with run_server() as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) + + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) + async def test_secure_socks_proxy(self): + """Client connects to server securely through a SOCKS5 proxy.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + async with connect( + get_uri(server, secure=True), ssl=CLIENT_CONTEXT + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) + + @patch.dict(os.environ, {"socks_proxy": "http://hello:iloveyou@localhost:51080"}) + async def test_authenticated_socks_proxy(self): + """Client connects to server through an authenticated SOCKS5 proxy.""" + try: + self.proxy_options.update(proxyauth="hello:iloveyou") + async with run_server() as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + finally: + self.proxy_options.update(proxyauth=None) + self.assertNumFlows(1) + + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) + async def test_authenticated_socks_proxy_error(self): + """Client fails to authenticate to the SOCKS5 proxy.""" + from python_socks import ProxyError as SocksProxyError + + try: + self.proxy_options.update(proxyauth="any") + with self.assertRaises(ProxyError) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") + finally: + self.proxy_options.update(proxyauth=None) + self.assertEqual( + str(raised.exception), + "failed to connect to SOCKS proxy", + ) + self.assertIsInstance(raised.exception.__cause__, SocksProxyError) + self.assertNumFlows(0) + + @patch.dict(os.environ, {"socks_proxy": "http://localhost:61080"}) # bad port + async def test_socks_proxy_connection_failure(self): + """Client fails to connect to the SOCKS5 proxy.""" + from python_socks import ProxyConnectionError as SocksProxyConnectionError + + with self.assertRaises(OSError) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") + # Don't test str(raised.exception) because we don't control it. + self.assertIsInstance(raised.exception, SocksProxyConnectionError) + self.assertNumFlows(0) + + async def test_socks_proxy_connection_timeout(self): + """Client times out while connecting to the SOCKS5 proxy.""" + # Replace the proxy with a TCP server that doesn't respond. + with socket.create_server(("localhost", 0)) as sock: + host, port = sock.getsockname() + with patch.dict(os.environ, {"socks_proxy": f"http://{host}:{port}"}): + with self.assertRaises(TimeoutError) as raised: + async with connect("ws://example.com/", open_timeout=2 * MS): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "timed out during opening handshake", + ) + self.assertNumFlows(0) + + async def test_explicit_socks_proxy(self): + """Client connects to server through a SOCKS5 proxy set explicitly.""" + async with run_server() as server: + async with connect( + get_uri(server), + # Take this opportunity to test socks5 instead of socks5h. + proxy="socks5://localhost:51080", + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) + + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) + async def test_ignore_proxy_with_existing_stream(self): + """Cli ent connects using a pre-existing stream.""" + async with run_server() as server: + stream = await trio.open_tcp_stream(*get_host_port(server.listeners)) + # Use a non-existing domain to ensure we connect via stream. + async with connect("ws://invalid/", stream=stream) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(0) + + +@unittest.skipUnless("mitmproxy" in sys.modules, "mitmproxy not installed") +class HTTPProxyClientTests(ProxyMixin, IsolatedTrioTestCase): + proxy_mode = "regular@58080" + + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) + async def test_http_proxy(self): + """Client connects to server through an HTTP proxy.""" + async with run_server() as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) + + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) + async def test_secure_http_proxy(self): + """Client connects to server securely through an HTTP proxy.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + async with connect( + get_uri(server, secure=True), ssl=CLIENT_CONTEXT + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + ssl_object = client.stream._ssl_object + self.assertEqual(ssl_object.version()[:3], "TLS") + self.assertNumFlows(1) + + @patch.dict(os.environ, {"https_proxy": "http://hello:iloveyou@localhost:58080"}) + async def test_authenticated_http_proxy(self): + """Client connects to server through an authenticated HTTP proxy.""" + try: + self.proxy_options.update(proxyauth="hello:iloveyou") + async with run_server() as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + finally: + self.proxy_options.update(proxyauth=None) + self.assertNumFlows(1) + + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) + async def test_authenticated_http_proxy_error(self): + """Client fails to authenticate to the HTTP proxy.""" + try: + self.proxy_options.update(proxyauth="any") + with self.assertRaises(ProxyError) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") + finally: + self.proxy_options.update(proxyauth=None) + self.assertEqual( + str(raised.exception), + "proxy rejected connection: HTTP 407", + ) + self.assertNumFlows(0) + + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) + async def test_http_proxy_protocol_error(self): + """Client receives invalid data when connecting to the HTTP proxy.""" + try: + self.proxy_options.update(break_http_connect=True) + with self.assertRaises(InvalidProxyMessage) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") + finally: + self.proxy_options.update(break_http_connect=False) + self.assertEqual( + str(raised.exception), + "did not receive a valid HTTP response from proxy", + ) + self.assertNumFlows(0) + + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) + async def test_http_proxy_connection_error(self): + """Client receives no response when connecting to the HTTP proxy.""" + try: + self.proxy_options.update(close_http_connect=True) + with self.assertRaises(InvalidProxyMessage) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") + finally: + self.proxy_options.update(close_http_connect=False) + self.assertEqual( + str(raised.exception), + "did not receive a valid HTTP response from proxy", + ) + self.assertNumFlows(0) + + @patch.dict(os.environ, {"https_proxy": "http://localhost:48080"}) # bad port + async def test_http_proxy_connection_failure(self): + """Client fails to connect to the HTTP proxy.""" + with self.assertRaises(OSError): + async with connect("ws://example.com/"): + self.fail("did not raise") + # Don't test str(raised.exception) because we don't control it. + self.assertNumFlows(0) + + async def test_http_proxy_connection_timeout(self): + """Client times out while connecting to the HTTP proxy.""" + # Replace the proxy with a TCP server that doesn't respond. + with socket.create_server(("localhost", 0)) as sock: + host, port = sock.getsockname() + with patch.dict(os.environ, {"https_proxy": f"http://{host}:{port}"}): + with self.assertRaises(TimeoutError) as raised: + async with connect("ws://example.com/", open_timeout=2 * MS): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "timed out during opening handshake", + ) + + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) + async def test_https_proxy(self): + """Client connects to server through an HTTPS proxy.""" + async with run_server() as server: + async with connect( + get_uri(server), + proxy_ssl=self.proxy_context, + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) + + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) + async def test_secure_https_proxy(self): + """Client connects to server securely through an HTTPS proxy.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + async with connect( + get_uri(server, secure=True), + ssl=CLIENT_CONTEXT, + proxy_ssl=self.proxy_context, + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + ssl_object = client.stream._ssl_object + self.assertEqual(ssl_object.version()[:3], "TLS") + self.assertNumFlows(1) + + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) + async def test_https_server_hostname(self): + """Client sets server_hostname to the value of proxy_server_hostname.""" + async with run_server() as server: + # Pass an argument not prefixed with proxy_ for coverage. + kwargs = {"all_errors": True} if sys.version_info >= (3, 12) else {} + async with connect( + get_uri(server), + proxy_ssl=self.proxy_context, + proxy_server_hostname="overridden", + **kwargs, + ) as client: + ssl_object = client.stream._ssl_object + self.assertEqual(ssl_object.server_hostname, "overridden") + self.assertNumFlows(1) + + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) + async def test_https_proxy_invalid_proxy_certificate(self): + """Client rejects certificate when proxy certificate isn't trusted.""" + with self.assertRaises(trio.BrokenResourceError) as raised: + # The proxy certificate isn't trusted. + async with connect("wss://example.com/"): + self.fail("did not raise") + self.assertIsInstance(raised.exception.__cause__, ssl.SSLCertVerificationError) + self.assertIn( + "certificate verify failed: unable to get local issuer certificate", + str(raised.exception.__cause__), + ) + + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) + async def test_https_proxy_invalid_server_certificate(self): + """Client rejects certificate when proxy certificate isn't trusted.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + with self.assertRaises(trio.BrokenResourceError) as raised: + # The test certificate is self-signed. + async with connect( + get_uri(server, secure=True), proxy_ssl=self.proxy_context + ): + self.fail("did not raise") + self.assertIsInstance(raised.exception.__cause__, ssl.SSLCertVerificationError) + self.assertIn( + "certificate verify failed: self signed certificate", + str(raised.exception.__cause__).replace("-", " "), + ) + self.assertNumFlows(1) + + +class ClientUsageErrorsTests(IsolatedTrioTestCase): + async def test_ssl_without_secure_uri(self): + """Client rejects ssl when URI isn't secure.""" + with self.assertRaises(ValueError) as raised: + async with connect("ws://localhost/", ssl=CLIENT_CONTEXT): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "ssl argument is incompatible with a ws:// URI", + ) + + async def test_proxy_ssl_without_https_proxy(self): + """Client rejects proxy_ssl when proxy isn't HTTPS.""" + with self.assertRaises(ValueError) as raised: + async with connect( + "ws://localhost/", + proxy="http://localhost:8080", + proxy_ssl=True, + ): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "proxy_ssl argument is incompatible with an http:// proxy", + ) + + async def test_unsupported_proxy(self): + """Client rejects unsupported proxy.""" + with self.assertRaises(InvalidProxy) as raised: + async with connect("ws://example.com/", proxy="other://localhost:51080"): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "other://localhost:51080 isn't a valid proxy: scheme other isn't supported", + ) + + async def test_invalid_subprotocol(self): + """Client rejects single value of subprotocols.""" + with self.assertRaises(TypeError) as raised: + async with connect("ws://localhost/", subprotocols="chat"): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "subprotocols must be a list, not a str", + ) + + async def test_unsupported_compression(self): + """Client rejects incorrect value of compression.""" + with self.assertRaises(ValueError) as raised: + async with connect("ws://localhost/", compression=False): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "unsupported compression: False", + ) + + async def test_reentrancy(self): + """Client isn't reentrant.""" + async with run_server() as server: + connecter = connect(get_uri(server)) + async with connecter: + with self.assertRaises(RuntimeError) as raised: + async with connecter: + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "connect() isn't reentrant", + ) diff --git a/tests/trio/test_server.py b/tests/trio/test_server.py new file mode 100644 index 00000000..12dcafc7 --- /dev/null +++ b/tests/trio/test_server.py @@ -0,0 +1,831 @@ +import dataclasses +import hmac +import http +import logging + +import trio + +from websockets.exceptions import ( + ConnectionClosedError, + ConnectionClosedOK, + InvalidStatus, + NegotiationError, +) +from websockets.http11 import Request, Response +from websockets.trio.client import connect +from websockets.trio.server import * + +from ..utils import ( + CLIENT_CONTEXT, + MS, + SERVER_CONTEXT, +) +from .server import ( + EvalShellMixin, + get_host_port, + get_uri, + handler, + run_server, +) +from .utils import IsolatedTrioTestCase + + +class ServerTests(EvalShellMixin, IsolatedTrioTestCase): + async def test_connection(self): + """Server receives connection from client and the handshake succeeds.""" + async with run_server() as server: + async with connect(get_uri(server)) as client: + await self.assertEval(client, "ws.protocol.state.name", "OPEN") + + async def test_connection_handler_returns(self): + """Connection handler returns.""" + async with run_server() as server: + async with connect(get_uri(server) + "/no-op") as client: + with self.assertRaises(ConnectionClosedOK) as raised: + await client.recv() + self.assertEqual( + str(raised.exception), + "received 1000 (OK); then sent 1000 (OK)", + ) + + async def test_connection_handler_raises_exception(self): + """Connection handler raises an exception.""" + async with run_server() as server: + async with connect(get_uri(server) + "/crash") as client: + with self.assertRaises(ConnectionClosedError) as raised: + await client.recv() + self.assertEqual( + str(raised.exception), + "received 1011 (internal error); then sent 1011 (internal error)", + ) + + async def test_existing_listeners(self): + """Server receives connection using pre-existing listeners.""" + listeners = await trio.open_tcp_listeners(0, host="localhost") + host, port = get_host_port(listeners) + async with run_server(port=None, host=None, listeners=listeners): + async with connect(f"ws://{host}:{port}/") as client: # type: ignore + await self.assertEval(client, "ws.protocol.state.name", "OPEN") + + async def test_select_subprotocol(self): + """Server selects a subprotocol with the select_subprotocol callable.""" + + def select_subprotocol(ws, subprotocols): + ws.select_subprotocol_ran = True + assert "chat" in subprotocols + return "chat" + + async with run_server( + subprotocols=["chat"], select_subprotocol=select_subprotocol + ) as server: + async with connect(get_uri(server), subprotocols=["chat"]) as client: + await self.assertEval(client, "ws.select_subprotocol_ran", "True") + await self.assertEval(client, "ws.subprotocol", "chat") + + async def test_select_subprotocol_rejects_handshake(self): + """Server rejects handshake if select_subprotocol raises NegotiationError.""" + + def select_subprotocol(ws, subprotocols): + raise NegotiationError + + async with run_server(select_subprotocol=select_subprotocol) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 400", + ) + + async def test_select_subprotocol_raises_exception(self): + """Server returns an error if select_subprotocol raises an exception.""" + + def select_subprotocol(ws, subprotocols): + raise RuntimeError + + async with run_server(select_subprotocol=select_subprotocol) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + + async def test_compression_is_enabled(self): + """Server enables compression by default.""" + async with run_server() as server: + async with connect(get_uri(server)) as client: + await self.assertEval( + client, + "[type(ext).__name__ for ext in ws.protocol.extensions]", + "['PerMessageDeflate']", + ) + + async def test_disable_compression(self): + """Server disables compression.""" + async with run_server(compression=None) as server: + async with connect(get_uri(server)) as client: + await self.assertEval(client, "ws.protocol.extensions", "[]") + + async def test_process_request_returns_none(self): + """Server runs process_request and continues the handshake.""" + + def process_request(ws, request): + self.assertIsInstance(request, Request) + ws.process_request_ran = True + + async with run_server(process_request=process_request) as server: + async with connect(get_uri(server)) as client: + await self.assertEval(client, "ws.process_request_ran", "True") + + async def test_async_process_request_returns_none(self): + """Server runs async process_request and continues the handshake.""" + + async def process_request(ws, request): + self.assertIsInstance(request, Request) + ws.process_request_ran = True + + async with run_server(process_request=process_request) as server: + async with connect(get_uri(server)) as client: + await self.assertEval(client, "ws.process_request_ran", "True") + + async def test_process_request_returns_response(self): + """Server aborts handshake if process_request returns a response.""" + + def process_request(ws, request): + return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") + + async def handler(ws): + self.fail("handler must not run") + + with self.assertNoLogs("websockets", logging.ERROR): + async with run_server(process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) + + async def test_async_process_request_returns_response(self): + """Server aborts handshake if async process_request returns a response.""" + + async def process_request(ws, request): + return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") + + async def handler(ws): + self.fail("handler must not run") + + with self.assertNoLogs("websockets", logging.ERROR): + async with run_server(process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) + + async def test_process_request_raises_exception(self): + """Server returns an error if process_request raises an exception.""" + + def process_request(ws, request): + raise RuntimeError("BOOM") + + with self.assertLogs("websockets", logging.ERROR) as logs: + async with run_server(process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + self.assertEqual( + [record.getMessage() for record in logs.records], + ["opening handshake failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) + + async def test_async_process_request_raises_exception(self): + """Server returns an error if async process_request raises an exception.""" + + async def process_request(ws, request): + raise RuntimeError("BOOM") + + with self.assertLogs("websockets", logging.ERROR) as logs: + async with run_server(process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + self.assertEqual( + [record.getMessage() for record in logs.records], + ["opening handshake failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) + + async def test_process_response_returns_none(self): + """Server runs process_response but keeps the handshake response.""" + + def process_response(ws, request, response): + self.assertIsInstance(request, Request) + self.assertIsInstance(response, Response) + ws.process_response_ran = True + + async with run_server(process_response=process_response) as server: + async with connect(get_uri(server)) as client: + await self.assertEval(client, "ws.process_response_ran", "True") + + async def test_async_process_response_returns_none(self): + """Server runs async process_response but keeps the handshake response.""" + + async def process_response(ws, request, response): + self.assertIsInstance(request, Request) + self.assertIsInstance(response, Response) + ws.process_response_ran = True + + async with run_server(process_response=process_response) as server: + async with connect(get_uri(server)) as client: + await self.assertEval(client, "ws.process_response_ran", "True") + + async def test_process_response_modifies_response(self): + """Server runs process_response and modifies the handshake response.""" + + def process_response(ws, request, response): + response.headers["X-ProcessResponse"] = "OK" + + async with run_server(process_response=process_response) as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") + + async def test_async_process_response_modifies_response(self): + """Server runs async process_response and modifies the handshake response.""" + + async def process_response(ws, request, response): + response.headers["X-ProcessResponse"] = "OK" + + async with run_server(process_response=process_response) as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") + + async def test_process_response_replaces_response(self): + """Server runs process_response and replaces the handshake response.""" + + def process_response(ws, request, response): + headers = response.headers.copy() + headers["X-ProcessResponse"] = "OK" + return dataclasses.replace(response, headers=headers) + + async with run_server(process_response=process_response) as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") + + async def test_async_process_response_replaces_response(self): + """Server runs async process_response and replaces the handshake response.""" + + async def process_response(ws, request, response): + headers = response.headers.copy() + headers["X-ProcessResponse"] = "OK" + return dataclasses.replace(response, headers=headers) + + async with run_server(process_response=process_response) as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") + + async def test_process_response_raises_exception(self): + """Server returns an error if process_response raises an exception.""" + + def process_response(ws, request, response): + raise RuntimeError("BOOM") + + with self.assertLogs("websockets", logging.ERROR) as logs: + async with run_server(process_response=process_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + self.assertEqual( + [record.getMessage() for record in logs.records], + ["opening handshake failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) + + async def test_async_process_response_raises_exception(self): + """Server returns an error if async process_response raises an exception.""" + + async def process_response(ws, request, response): + raise RuntimeError("BOOM") + + with self.assertLogs("websockets", logging.ERROR) as logs: + async with run_server(process_response=process_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + self.assertEqual( + [record.getMessage() for record in logs.records], + ["opening handshake failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) + + async def test_override_server(self): + """Server can override Server header with server_header.""" + async with run_server(server_header="Neo") as server: + async with connect(get_uri(server)) as client: + await self.assertEval(client, "ws.response.headers['Server']", "Neo") + + async def test_remove_server(self): + """Server can remove Server header with server_header.""" + async with run_server(server_header=None) as server: + async with connect(get_uri(server)) as client: + await self.assertEval( + client, "'Server' in ws.response.headers", "False" + ) + + async def test_keepalive_is_enabled(self): + """Server enables keepalive and measures latency.""" + async with run_server(ping_interval=MS) as server: + async with connect(get_uri(server)) as client: + await client.send("ws.latency") + latency = eval(await client.recv()) + self.assertEqual(latency, 0) + await trio.sleep(2 * MS) + await client.send("ws.latency") + latency = eval(await client.recv()) + self.assertGreater(latency, 0) + + async def test_disable_keepalive(self): + """Server disables keepalive.""" + async with run_server(ping_interval=None) as server: + async with connect(get_uri(server)) as client: + await trio.sleep(2 * MS) + await client.send("ws.latency") + latency = eval(await client.recv()) + self.assertEqual(latency, 0) + + async def test_logger(self): + """Server accepts a logger argument.""" + logger = logging.getLogger("test") + async with run_server(logger=logger) as server: + self.assertEqual(server.logger.name, logger.name) + + async def test_custom_connection_factory(self): + """Server runs ServerConnection factory provided in create_connection.""" + + def create_connection(*args, **kwargs): + server = ServerConnection(*args, **kwargs) + server.create_connection_ran = True + return server + + async with run_server(create_connection=create_connection) as server: + async with connect(get_uri(server)) as client: + await self.assertEval(client, "ws.create_connection_ran", "True") + + async def test_connections(self): + """Server provides a connections property.""" + async with run_server() as server: + self.assertEqual(server.connections, set()) + async with connect(get_uri(server)) as client: + self.assertEqual(len(server.connections), 1) + ws_id = str(next(iter(server.connections)).id) + await self.assertEval(client, "ws.id", ws_id) + self.assertEqual(server.connections, set()) + + async def test_handshake_fails(self): + """Server receives connection from client but the handshake fails.""" + + def remove_key_header(self, request): + del request.headers["Sec-WebSocket-Key"] + + async with run_server(process_request=remove_key_header) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 400", + ) + + async def test_timeout_during_handshake(self): + """Server times out before receiving handshake request from client.""" + async with run_server(open_timeout=MS) as server: + stream = await trio.open_tcp_stream(*get_host_port(server.listeners)) + try: + self.assertEqual(await stream.receive_some(4096), b"") + finally: + await stream.aclose() + + async def test_connection_closed_during_handshake(self): + """Server reads EOF before receiving handshake request from client.""" + async with run_server() as server: + stream = await trio.open_tcp_stream(*get_host_port(server.listeners)) + await stream.aclose() + + async def test_junk_handshake(self): + """Server closes the connection when receiving non-HTTP request from client.""" + with self.assertLogs("websockets", logging.ERROR) as logs: + async with run_server() as server: + stream = await trio.open_tcp_stream(*get_host_port(server.listeners)) + await stream.send_all(b"HELO relay.invalid\r\n") + try: + # Wait for the server to close the connection. + self.assertEqual(await stream.receive_some(4096), b"") + finally: + await stream.aclose() + + self.assertEqual( + [record.getMessage() for record in logs.records], + ["opening handshake failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["did not receive a valid HTTP request"], + ) + self.assertEqual( + [str(record.exc_info[1].__cause__) for record in logs.records], + ["invalid HTTP request line: HELO relay.invalid"], + ) + + async def test_close_server_rejects_connecting_connections(self): + """Server rejects connecting connections with HTTP 503 when closing.""" + + async def process_request(ws, _request): + while not ws.server.closing: + await trio.sleep(0) # pragma: no cover + + async with run_server(process_request=process_request) as server: + + async def close_server(server): + await trio.sleep(MS) + await server.aclose() + + self.nursery.start_soon(close_server, server) + + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 503", + ) + + async def test_close_server_closes_open_connections(self): + """Server closes open connections with close code 1001 when closing.""" + async with run_server() as server: + async with connect(get_uri(server)) as client: + await server.aclose() + with self.assertRaises(ConnectionClosedOK) as raised: + await client.recv() + self.assertEqual( + str(raised.exception), + "received 1001 (going away); then sent 1001 (going away)", + ) + + async def test_close_server_closes_open_connections_with_code_and_reason(self): + """Server closes open connections with custom code and reason when closing.""" + async with run_server() as server: + async with connect(get_uri(server)) as client: + await server.aclose(code=1012, reason="restarting") + with self.assertRaises(ConnectionClosedError) as raised: + await client.recv() + self.assertEqual( + str(raised.exception), + "received 1012 (service restart) restarting; " + "then sent 1012 (service restart) restarting", + ) + + async def test_close_server_keeps_connections_open(self): + """Server waits for client to close open connections when closing.""" + + async with run_server() as server: + server_closed = trio.Event() + + async def close_server(): + await server.aclose(close_connections=False) + server_closed.set() + + async with connect(get_uri(server)) as client: + self.nursery.start_soon(close_server) + + # Server cannot receive new connections. + with self.assertRaises(OSError): + async with connect(get_uri(server)): + self.fail("did not raise") + + # The server waits for the client to close the connection. + with self.assertRaises(trio.TooSlowError): + with trio.fail_after(MS): + await server_closed.wait() + + # Once the client closes the connection, the server terminates. + await client.aclose() + with trio.fail_after(MS): + await server_closed.wait() + + async def test_close_server_keeps_handlers_running(self): + """Server waits for connection handlers to terminate.""" + async with run_server() as server: + server_closed = trio.Event() + + async def close_server(): + await server.aclose(close_connections=False) + server_closed.set() + + async with connect(get_uri(server) + "/delay") as client: + # Delay termination of connection handler. + await client.send(str(3 * MS)) + + self.nursery.start_soon(close_server) + + # The server waits for the connection handler to terminate. + with self.assertRaises(trio.TooSlowError): + with trio.fail_after(2 * MS): + await server_closed.wait() + + # Set a large timeout here, else the test becomes flaky. + with trio.fail_after(5 * MS): + await server_closed.wait() + + +SSL_OBJECT = "ws.stream._ssl_object" + + +class SecureServerTests(EvalShellMixin, IsolatedTrioTestCase): + async def test_connection(self): + """Server receives secure connection from client.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + async with connect( + get_uri(server, secure=True), ssl=CLIENT_CONTEXT + ) as client: + await self.assertEval(client, "ws.protocol.state.name", "OPEN") + await self.assertEval(client, f"{SSL_OBJECT}.version()[:3]", "TLS") + + async def test_timeout_during_tls_handshake(self): + """Server times out before receiving TLS handshake request from client.""" + async with run_server(ssl=SERVER_CONTEXT, open_timeout=MS) as server: + stream = await trio.open_tcp_stream(*get_host_port(server.listeners)) + try: + self.assertEqual(await stream.receive_some(4096), b"") + finally: + await stream.aclose() + + async def test_connection_closed_during_tls_handshake(self): + """Server reads EOF before receiving TLS handshake request from client.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + stream = await trio.open_tcp_stream(*get_host_port(server.listeners)) + await stream.aclose() + + +class ServerUsageErrorsTests(IsolatedTrioTestCase): + async def test_missing_port(self): + """Server requires port.""" + with self.assertRaises(ValueError) as raised: + await serve(handler, None) + self.assertEqual( + str(raised.exception), + "port is required when listeners is not provided", + ) + + async def test_port_and_listeners(self): + """Server rejects port when listeners is provided.""" + listeners = await trio.open_tcp_listeners(0) + try: + with self.assertRaises(ValueError) as raised: + await serve(handler, port=0, listeners=listeners) + self.assertEqual( + str(raised.exception), + "port is incompatible with listeners", + ) + finally: + for listener in listeners: + await listener.aclose() + + async def test_host_and_listeners(self): + """Server rejects host when listeners is provided.""" + listeners = await trio.open_tcp_listeners(0) + try: + with self.assertRaises(ValueError) as raised: + await serve(handler, host="localhost", listeners=listeners) + self.assertEqual( + str(raised.exception), + "host is incompatible with listeners", + ) + finally: + for listener in listeners: + await listener.aclose() + + async def test_backlog_and_listeners(self): + """Server rejects backlog when listeners is provided.""" + listeners = await trio.open_tcp_listeners(0) + try: + with self.assertRaises(ValueError) as raised: + await serve(handler, backlog=65535, listeners=listeners) + self.assertEqual( + str(raised.exception), + "backlog is incompatible with listeners", + ) + finally: + for listener in listeners: + await listener.aclose() + + async def test_invalid_subprotocol(self): + """Server rejects single value of subprotocols.""" + with self.assertRaises(TypeError) as raised: + await serve(handler, subprotocols="chat") + self.assertEqual( + str(raised.exception), + "subprotocols must be a list, not a str", + ) + + async def test_unsupported_compression(self): + """Server rejects incorrect value of compression.""" + with self.assertRaises(ValueError) as raised: + await serve(handler, compression=False) + self.assertEqual( + str(raised.exception), + "unsupported compression: False", + ) + + +class BasicAuthTests(EvalShellMixin, IsolatedTrioTestCase): + async def test_valid_authorization(self): + """basic_auth authenticates client with HTTP Basic Authentication.""" + async with run_server( + process_request=basic_auth(credentials=("hello", "iloveyou")), + ) as server: + async with connect( + get_uri(server), + additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, + ) as client: + await self.assertEval(client, "ws.username", "hello") + + async def test_missing_authorization(self): + """basic_auth rejects client without credentials.""" + async with run_server( + process_request=basic_auth(credentials=("hello", "iloveyou")), + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 401", + ) + + async def test_unsupported_authorization(self): + """basic_auth rejects client with unsupported credentials.""" + async with run_server( + process_request=basic_auth(credentials=("hello", "iloveyou")), + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect( + get_uri(server), + additional_headers={"Authorization": "Negotiate ..."}, + ): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 401", + ) + + async def test_authorization_with_unknown_username(self): + """basic_auth rejects client with unknown username.""" + async with run_server( + process_request=basic_auth(credentials=("hello", "iloveyou")), + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect( + get_uri(server), + additional_headers={"Authorization": "Basic YnllOnlvdWxvdmVtZQ=="}, + ): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 401", + ) + + async def test_authorization_with_incorrect_password(self): + """basic_auth rejects client with incorrect password.""" + async with run_server( + process_request=basic_auth(credentials=("hello", "changeme")), + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect( + get_uri(server), + additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, + ): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 401", + ) + + async def test_list_of_credentials(self): + """basic_auth accepts a list of hard coded credentials.""" + async with run_server( + process_request=basic_auth( + credentials=[ + ("hello", "iloveyou"), + ("bye", "youloveme"), + ] + ), + ) as server: + async with connect( + get_uri(server), + additional_headers={"Authorization": "Basic YnllOnlvdWxvdmVtZQ=="}, + ) as client: + await self.assertEval(client, "ws.username", "bye") + + async def test_check_credentials_function(self): + """basic_auth accepts a check_credentials function.""" + + def check_credentials(username, password): + return hmac.compare_digest(password, "iloveyou") + + async with run_server( + process_request=basic_auth(check_credentials=check_credentials), + ) as server: + async with connect( + get_uri(server), + additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, + ) as client: + await self.assertEval(client, "ws.username", "hello") + + async def test_check_credentials_coroutine(self): + """basic_auth accepts a check_credentials coroutine.""" + + async def check_credentials(username, password): + return hmac.compare_digest(password, "iloveyou") + + async with run_server( + process_request=basic_auth(check_credentials=check_credentials), + ) as server: + async with connect( + get_uri(server), + additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, + ) as client: + await self.assertEval(client, "ws.username", "hello") + + async def test_without_credentials_or_check_credentials(self): + """basic_auth requires either credentials or check_credentials.""" + with self.assertRaises(ValueError) as raised: + basic_auth() + self.assertEqual( + str(raised.exception), + "provide either credentials or check_credentials", + ) + + async def test_with_credentials_and_check_credentials(self): + """basic_auth requires only one of credentials and check_credentials.""" + with self.assertRaises(ValueError) as raised: + basic_auth( + credentials=("hello", "iloveyou"), + check_credentials=lambda: False, # pragma: no cover + ) + self.assertEqual( + str(raised.exception), + "provide either credentials or check_credentials", + ) + + async def test_bad_credentials(self): + """basic_auth receives an unsupported credentials argument.""" + with self.assertRaises(TypeError) as raised: + basic_auth(credentials=42) + self.assertEqual( + str(raised.exception), + "invalid credentials argument: 42", + ) + + async def test_bad_list_of_credentials(self): + """basic_auth receives an unsupported credentials argument.""" + with self.assertRaises(TypeError) as raised: + basic_auth(credentials=[42]) + self.assertEqual( + str(raised.exception), + "invalid credentials argument: [42]", + )