From 13464b268f4ec5e8634c350b4c9927cd93903cac Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 27 May 2025 22:41:19 +0200 Subject: [PATCH 1/3] Add trio client and server. Also uniformize code & tests with other implementations. --- docs/conf.py | 1 + docs/index.rst | 9 + docs/project/changelog.rst | 8 + docs/reference/asyncio/server.rst | 2 +- docs/reference/features.rst | 257 ++++----- docs/reference/index.rst | 11 + docs/reference/trio/client.rst | 63 ++ docs/reference/trio/common.rst | 54 ++ docs/reference/trio/server.rst | 84 +++ example/asyncio/client.py | 1 - example/asyncio/server.py | 0 example/sync/client.py | 0 example/sync/server.py | 0 example/trio/client.py | 21 + example/trio/echo.py | 15 + example/trio/hello.py | 17 + example/trio/server.py | 20 + pyproject.toml | 1 + src/websockets/asyncio/client.py | 107 ++-- src/websockets/asyncio/server.py | 37 +- src/websockets/sync/client.py | 35 +- src/websockets/sync/server.py | 35 +- src/websockets/trio/client.py | 728 +++++++++++++++++++++++ src/websockets/trio/server.py | 649 +++++++++++++++++++++ tests/asyncio/test_client.py | 25 +- tests/asyncio/test_server.py | 4 +- tests/sync/test_client.py | 12 +- tests/test_localhost.cnf | 3 +- tests/test_localhost.pem | 89 +-- tests/trio/server.py | 63 ++ tests/trio/test_client.py | 927 ++++++++++++++++++++++++++++++ tests/trio/test_server.py | 831 ++++++++++++++++++++++++++ 32 files changed, 3823 insertions(+), 286 deletions(-) create mode 100644 docs/reference/trio/client.rst create mode 100644 docs/reference/trio/common.rst create mode 100644 docs/reference/trio/server.rst mode change 100644 => 100755 example/asyncio/client.py mode change 100644 => 100755 example/asyncio/server.py mode change 100644 => 100755 example/sync/client.py mode change 100644 => 100755 example/sync/server.py create mode 100755 example/trio/client.py create mode 100755 example/trio/echo.py create mode 100755 example/trio/hello.py create mode 100755 example/trio/server.py create mode 100644 src/websockets/trio/client.py create mode 100644 src/websockets/trio/server.py create mode 100644 tests/trio/server.py create mode 100644 tests/trio/test_client.py create mode 100644 tests/trio/test_server.py diff --git a/docs/conf.py b/docs/conf.py index 798d595db..0b1f64edc 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -85,6 +85,7 @@ intersphinx_mapping = { "python": ("https://docs.python.org/3", None), "sesame": ("https://django-sesame.readthedocs.io/en/stable/", None), + "trio": ("https://trio.readthedocs.io/en/stable/", None), "werkzeug": ("https://werkzeug.palletsprojects.com/en/stable/", None), } diff --git a/docs/index.rst b/docs/index.rst index 738258688..0774b91b8 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -70,6 +70,10 @@ Here's an echo server and corresponding client. .. literalinclude:: ../example/sync/echo.py +.. tab:: trio + + .. literalinclude:: ../example/trio/echo.py + .. tab:: asyncio :new-set: @@ -79,6 +83,11 @@ Here's an echo server and corresponding client. .. literalinclude:: ../example/sync/hello.py +.. tab:: trio + + .. literalinclude:: ../example/trio/hello.py + + Don't worry about the opening and closing handshakes, pings and pongs, or any other behavior described in the WebSocket specification. websockets takes care of this under the hood so you can focus on your application! diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index f6d7abb76..c43c20f71 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -43,6 +43,14 @@ Backwards-incompatible changes New features ............ +.. admonition:: websockets 16.0 introduces a :mod:`trio` implementation. + :class: important + + It is an alternative to the :mod:`asyncio` implementation. + + See :func:`websockets.trio.client.connect` and + :func:`websockets.trio.server.serve` for details. + * Validated compatibility with Python 3.14. Improvements diff --git a/docs/reference/asyncio/server.rst b/docs/reference/asyncio/server.rst index a245929ef..e8d80902b 100644 --- a/docs/reference/asyncio/server.rst +++ b/docs/reference/asyncio/server.rst @@ -46,7 +46,7 @@ Running a server .. automethod:: serve_forever - .. autoattribute:: sockets + .. autoproperty:: sockets Using a connection ------------------ diff --git a/docs/reference/features.rst b/docs/reference/features.rst index e5f6e0de0..dbeb6b43d 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -16,6 +16,7 @@ Feature support matrices summarize which implementations support which features. .. |aio| replace:: :mod:`asyncio` (new) .. |sync| replace:: :mod:`threading` +.. |trio| replace:: :mod:`trio` .. |sans| replace:: `Sans-I/O`_ .. |leg| replace:: :mod:`asyncio` (legacy) .. _Sans-I/O: https://sans-io.readthedocs.io/ @@ -26,68 +27,68 @@ Both sides .. table:: :class: support-matrix-table - +------------------------------------+--------+--------+--------+--------+ - | | |aio| | |sync| | |sans| | |leg| | - +====================================+========+========+========+========+ - | Perform the opening handshake | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Enforce opening timeout | ✅ | ✅ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Send a message | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Broadcast a message | ✅ | ❌ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Receive a message | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Iterate over received messages | ✅ | ✅ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Send a fragmented message | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Receive a fragmented message frame | ✅ | ✅ | — | ❌ | - | by frame | | | | | - +------------------------------------+--------+--------+--------+--------+ - | Receive a fragmented message after | ✅ | ✅ | — | ✅ | - | reassembly | | | | | - +------------------------------------+--------+--------+--------+--------+ - | Force sending a message as Text or | ✅ | ✅ | — | ❌ | - | Binary | | | | | - +------------------------------------+--------+--------+--------+--------+ - | Force receiving a message as | ✅ | ✅ | — | ❌ | - | :class:`bytes` or :class:`str` | | | | | - +------------------------------------+--------+--------+--------+--------+ - | Send a ping | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Respond to pings automatically | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Send a pong | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Keepalive | ✅ | ✅ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Heartbeat | ✅ | ✅ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Measure latency | ✅ | ✅ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Perform the closing handshake | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Enforce closing timeout | ✅ | ✅ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Report close codes and reasons | ✅ | ✅ | ✅ | ❌ | - | from both sides | | | | | - +------------------------------------+--------+--------+--------+--------+ - | Compress messages (:rfc:`7692`) | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Tune memory usage for compression | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Negotiate extensions | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Implement custom extensions | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Negotiate a subprotocol | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Enforce security limits | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Log events | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ + +------------------------------------+--------+--------+--------+--------+--------+ + | | |aio| | |sync| | |trio| | |sans| | |leg| | + +====================================+========+========+========+========+========+ + | Perform the opening handshake | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Enforce opening timeout | ✅ | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Send a message | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Broadcast a message | ✅ | ❌ | ❌ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Receive a message | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Iterate over received messages | ✅ | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Send a fragmented message | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Receive a fragmented message frame | ✅ | ✅ | ✅ | — | ❌ | + | by frame | | | | | | + +------------------------------------+--------+--------+--------+--------+--------+ + | Receive a fragmented message after | ✅ | ✅ | ✅ | — | ✅ | + | reassembly | | | | | | + +------------------------------------+--------+--------+--------+--------+--------+ + | Force sending a message as Text or | ✅ | ✅ | ✅ | — | ❌ | + | Binary | | | | | | + +------------------------------------+--------+--------+--------+--------+--------+ + | Force receiving a message as | ✅ | ✅ | ✅ | — | ❌ | + | :class:`bytes` or :class:`str` | | | | | | + +------------------------------------+--------+--------+--------+--------+--------+ + | Send a ping | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Respond to pings automatically | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Send a pong | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Keepalive | ✅ | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Heartbeat | ✅ | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Measure latency | ✅ | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Perform the closing handshake | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Enforce closing timeout | ✅ | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Report close codes and reasons | ✅ | ✅ | ✅ | ✅ | ❌ | + | from both sides | | | | | | + +------------------------------------+--------+--------+--------+--------+--------+ + | Compress messages (:rfc:`7692`) | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Tune memory usage for compression | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Negotiate extensions | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Implement custom extensions | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Negotiate a subprotocol | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Enforce security limits | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Log events | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ Server ------ @@ -95,39 +96,39 @@ Server .. table:: :class: support-matrix-table - +------------------------------------+--------+--------+--------+--------+ - | | |aio| | |sync| | |sans| | |leg| | - +====================================+========+========+========+========+ - | Listen on a TCP socket | ✅ | ✅ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Listen on a Unix socket | ✅ | ✅ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Listen using a preexisting socket | ✅ | ✅ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Encrypt connection with TLS | ✅ | ✅ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Close server on context exit | ✅ | ✅ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Close connection on handler exit | ✅ | ✅ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Shut down server gracefully | ✅ | ✅ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Check ``Origin`` header | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Customize subprotocol selection | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Configure ``Server`` header | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Alter opening handshake request | ✅ | ✅ | ✅ | ❌ | - +------------------------------------+--------+--------+--------+--------+ - | Alter opening handshake response | ✅ | ✅ | ✅ | ❌ | - +------------------------------------+--------+--------+--------+--------+ - | Force an HTTP response | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Perform HTTP Basic Authentication | ✅ | ✅ | ❌ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Dispatch connections to handlers | ✅ | ✅ | — | ❌ | - +------------------------------------+--------+--------+--------+--------+ + +------------------------------------+--------+--------+--------+--------+--------+ + | | |aio| | |sync| | |trio| | |sans| | |leg| | + +====================================+========+========+========+========+========+ + | Listen on a TCP socket | ✅ | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Listen on a Unix socket | ✅ | ✅ | ❌ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Listen using a preexisting socket | ✅ | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Encrypt connection with TLS | ✅ | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Close server on context exit | ✅ | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Close connection on handler exit | ✅ | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Shut down server gracefully | ✅ | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Check ``Origin`` header | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Customize subprotocol selection | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Configure ``Server`` header | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Alter opening handshake request | ✅ | ✅ | ✅ | ✅ | ❌ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Alter opening handshake response | ✅ | ✅ | ✅ | ✅ | ❌ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Force an HTTP response | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Perform HTTP Basic Authentication | ✅ | ✅ | ✅ | ❌ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Route connections to handlers | ✅ | ✅ | ❌ | — | ❌ | + +------------------------------------+--------+--------+--------+--------+--------+ Client ------ @@ -135,39 +136,39 @@ Client .. table:: :class: support-matrix-table - +------------------------------------+--------+--------+--------+--------+ - | | |aio| | |sync| | |sans| | |leg| | - +====================================+========+========+========+========+ - | Connect to a TCP socket | ✅ | ✅ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Connect to a Unix socket | ✅ | ✅ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Connect using a preexisting socket | ✅ | ✅ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Encrypt connection with TLS | ✅ | ✅ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Close connection on context exit | ✅ | ✅ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Reconnect automatically | ✅ | ❌ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Configure ``Origin`` header | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Configure ``User-Agent`` header | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Modify opening handshake request | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Modify opening handshake response | ✅ | ✅ | ✅ | ❌ | - +------------------------------------+--------+--------+--------+--------+ - | Connect to non-ASCII IRIs | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Follow HTTP redirects | ✅ | ❌ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Perform HTTP Basic Authentication | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Connect via HTTP proxy | ✅ | ✅ | — | ❌ | - +------------------------------------+--------+--------+--------+--------+ - | Connect via SOCKS5 proxy | ✅ | ✅ | — | ❌ | - +------------------------------------+--------+--------+--------+--------+ + +------------------------------------+--------+--------+--------+--------+--------+ + | | |aio| | |sync| | |trio| | |sans| | |leg| | + +====================================+========+========+========+========+========+ + | Connect to a TCP socket | ✅ | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Connect to a Unix socket | ✅ | ✅ | ❌ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Connect using a preexisting socket | ✅ | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Encrypt connection with TLS | ✅ | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Close connection on context exit | ✅ | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Reconnect automatically | ✅ | ❌ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Configure ``Origin`` header | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Configure ``User-Agent`` header | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Modify opening handshake request | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Modify opening handshake response | ✅ | ✅ | ✅ | ✅ | ❌ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Connect to non-ASCII IRIs | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Follow HTTP redirects | ✅ | ❌ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Perform HTTP Basic Authentication | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Connect via HTTP proxy | ✅ | ✅ | ✅ | — | ❌ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Connect via SOCKS5 proxy | ✅ | ✅ | ✅ | — | ❌ | + +------------------------------------+--------+--------+--------+--------+--------+ Known limitations ----------------- diff --git a/docs/reference/index.rst b/docs/reference/index.rst index cc9542c24..64a393d53 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -37,6 +37,17 @@ This alternative implementation can be a good choice for clients. sync/server sync/client +:mod:`trio` +------------ + +This is another option for servers that handle many clients concurrently. + +.. toctree:: + :titlesonly: + + trio/server + trio/client + `Sans-I/O`_ ----------- diff --git a/docs/reference/trio/client.rst b/docs/reference/trio/client.rst new file mode 100644 index 000000000..cf5643c55 --- /dev/null +++ b/docs/reference/trio/client.rst @@ -0,0 +1,63 @@ +Client (:mod:`trio`) +======================= + +.. automodule:: websockets.trio.client + +Opening a connection +-------------------- + +.. autofunction:: connect + :async: + +.. autofunction:: process_exception + +Using a connection +------------------ + +.. autoclass:: ClientConnection + + .. automethod:: __aiter__ + + .. automethod:: recv + + .. automethod:: recv_streaming + + .. automethod:: send + + .. automethod:: aclose + + .. automethod:: wait_closed + + .. automethod:: ping + + .. automethod:: pong + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + .. autoattribute:: latency + + .. autoproperty:: state + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: request + + .. autoattribute:: response + + .. autoproperty:: subprotocol + + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason diff --git a/docs/reference/trio/common.rst b/docs/reference/trio/common.rst new file mode 100644 index 000000000..a1c68e0eb --- /dev/null +++ b/docs/reference/trio/common.rst @@ -0,0 +1,54 @@ +:orphan: + +Both sides (:mod:`trio`) +=========================== + +.. automodule:: websockets.trio.connection + +.. autoclass:: Connection + + .. automethod:: __aiter__ + + .. automethod:: recv + + .. automethod:: recv_streaming + + .. automethod:: send + + .. automethod:: aclose + + .. automethod:: wait_closed + + .. automethod:: ping + + .. automethod:: pong + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + .. autoattribute:: latency + + .. autoproperty:: state + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: request + + .. autoattribute:: response + + .. autoproperty:: subprotocol + + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason diff --git a/docs/reference/trio/server.rst b/docs/reference/trio/server.rst new file mode 100644 index 000000000..e3d92ed45 --- /dev/null +++ b/docs/reference/trio/server.rst @@ -0,0 +1,84 @@ +Server (:mod:`trio`) +======================= + +.. automodule:: websockets.trio.server + +Creating a server +----------------- + +.. autofunction:: serve + :async: + +.. currentmodule:: websockets.trio.server + +Running a server +---------------- + +.. autoclass:: Server + + .. autoattribute:: connections + + .. automethod:: aclose + + .. autoattribute:: listeners + +Using a connection +------------------ + +.. autoclass:: ServerConnection + + .. automethod:: __aiter__ + + .. automethod:: recv + + .. automethod:: recv_streaming + + .. automethod:: send + + .. automethod:: aclose + + .. automethod:: wait_closed + + .. automethod:: ping + + .. automethod:: pong + + .. automethod:: respond + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + .. autoattribute:: latency + + .. autoproperty:: state + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: request + + .. autoattribute:: response + + .. autoproperty:: subprotocol + + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason + +HTTP Basic Authentication +------------------------- + +websockets supports HTTP Basic Authentication according to +:rfc:`7235` and :rfc:`7617`. + +.. autofunction:: basic_auth diff --git a/example/asyncio/client.py b/example/asyncio/client.py old mode 100644 new mode 100755 index e3562642d..4d40f97c4 --- a/example/asyncio/client.py +++ b/example/asyncio/client.py @@ -3,7 +3,6 @@ """Client example using the asyncio API.""" import asyncio - from websockets.asyncio.client import connect diff --git a/example/asyncio/server.py b/example/asyncio/server.py old mode 100644 new mode 100755 diff --git a/example/sync/client.py b/example/sync/client.py old mode 100644 new mode 100755 diff --git a/example/sync/server.py b/example/sync/server.py old mode 100644 new mode 100755 diff --git a/example/trio/client.py b/example/trio/client.py new file mode 100755 index 000000000..8bb5d9759 --- /dev/null +++ b/example/trio/client.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python + +"""Client example using the trio API.""" + +import trio +from websockets.trio.client import connect + + +async def hello(): + async with connect("ws://localhost:8765") as websocket: + name = input("What's your name? ") + + await websocket.send(name) + print(f">>> {name}") + + greeting = await websocket.recv() + print(f"<<< {greeting}") + + +if __name__ == "__main__": + trio.run(hello) diff --git a/example/trio/echo.py b/example/trio/echo.py new file mode 100755 index 000000000..e995b767e --- /dev/null +++ b/example/trio/echo.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python + +"""Echo server using the trio API.""" + +import trio +from websockets.trio.server import serve + + +async def echo(websocket): + async for message in websocket: + await websocket.send(message) + + +if __name__ == "__main__": + trio.run(serve, echo, 8765) diff --git a/example/trio/hello.py b/example/trio/hello.py new file mode 100755 index 000000000..1accba49e --- /dev/null +++ b/example/trio/hello.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python + +"""Client using the trio API.""" + +import trio +from websockets.trio.client import connect + + +async def hello(): + async with connect("ws://localhost:8765") as websocket: + await websocket.send("Hello world!") + message = await websocket.recv() + print(message) + + +if __name__ == "__main__": + trio.run(hello) diff --git a/example/trio/server.py b/example/trio/server.py new file mode 100755 index 000000000..78a5ab7bd --- /dev/null +++ b/example/trio/server.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python + +"""Server example using the trio API.""" + +import trio +from websockets.trio.server import serve + + +async def hello(websocket): + name = await websocket.recv() + print(f"<<< {name}") + + greeting = f"Hello {name}!" + + await websocket.send(greeting) + print(f">>> {greeting}") + + +if __name__ == "__main__": + trio.run(serve, hello, 8765) diff --git a/pyproject.toml b/pyproject.toml index fb9b11c83..c77abb1f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,7 @@ exclude_lines = [ "if sys.platform != \"win32\":", "if TYPE_CHECKING:", "raise AssertionError", + "raise NotImplementedError", "self.fail\\(\".*\"\\)", "@overload", "@unittest.skip", diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 05947f3a0..f8ac2e1ca 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -216,8 +216,8 @@ class connect: compression: The "permessage-deflate" extension is enabled by default. Set ``compression`` to :obj:`None` to disable it. See the :doc:`compression guide <../../topics/compression>` for details. - additional_headers (HeadersLike | None): Arbitrary HTTP headers to add - to the handshake request. + additional_headers: Arbitrary HTTP headers to add to the handshake + request. user_agent_header: Value of the ``User-Agent`` request header. It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to :obj:`None` removes the header. @@ -328,6 +328,9 @@ def __init__( **kwargs: Any, ) -> None: self.uri = uri + self.ws_uri = parse_uri(uri) + if not self.ws_uri.secure and kwargs.get("ssl") is not None: + raise ValueError("ssl argument is incompatible with a ws:// URI") if subprotocols is not None: validate_subprotocols(subprotocols) @@ -343,7 +346,7 @@ def __init__( if create_connection is None: create_connection = ClientConnection - def protocol_factory(uri: WebSocketURI) -> ClientConnection: + def factory(uri: WebSocketURI) -> ClientConnection: # This is a protocol in the Sans-I/O implementation of websockets. protocol = ClientProtocol( uri, @@ -365,20 +368,18 @@ def protocol_factory(uri: WebSocketURI) -> ClientConnection: return connection self.proxy = proxy - self.protocol_factory = protocol_factory + self.factory = factory self.additional_headers = additional_headers self.user_agent_header = user_agent_header self.process_exception = process_exception self.open_timeout = open_timeout self.logger = logger - self.connection_kwargs = kwargs + self.create_connection_kwargs = kwargs - async def create_connection(self) -> ClientConnection: - """Create TCP or Unix connection.""" + async def open_tcp_connection(self) -> ClientConnection: + """Create TCP or Unix connection to the server, possibly through a proxy.""" loop = asyncio.get_running_loop() - kwargs = self.connection_kwargs.copy() - - ws_uri = parse_uri(self.uri) + kwargs = self.create_connection_kwargs.copy() proxy = self.proxy if kwargs.get("unix", False): @@ -386,19 +387,16 @@ async def create_connection(self) -> ClientConnection: if kwargs.get("sock") is not None: proxy = None if proxy is True: - proxy = get_proxy(ws_uri) + proxy = get_proxy(self.ws_uri) def factory() -> ClientConnection: - return self.protocol_factory(ws_uri) + return self.factory(self.ws_uri) - if ws_uri.secure: + if self.ws_uri.secure: kwargs.setdefault("ssl", True) - kwargs.setdefault("server_hostname", ws_uri.host) if kwargs.get("ssl") is None: raise ValueError("ssl=None is incompatible with a wss:// URI") - else: - if kwargs.get("ssl") is not None: - raise ValueError("ssl argument is incompatible with a ws:// URI") + kwargs.setdefault("server_hostname", self.ws_uri.host) if kwargs.pop("unix", False): _, connection = await loop.create_unix_connection(factory, **kwargs) @@ -408,7 +406,7 @@ def factory() -> ClientConnection: # Connect to the server through the proxy. sock = await connect_socks_proxy( proxy_parsed, - ws_uri, + self.ws_uri, local_addr=kwargs.pop("local_addr", None), ) # Initialize WebSocket connection via the proxy. @@ -442,7 +440,7 @@ def factory() -> ClientConnection: # Connect to the server through the proxy. transport = await connect_http_proxy( proxy_parsed, - ws_uri, + self.ws_uri, user_agent_header=self.user_agent_header, **proxy_kwargs, ) @@ -460,17 +458,17 @@ def factory() -> ClientConnection: transport = new_transport connection.connection_made(transport) else: - raise AssertionError("unsupported proxy") + raise NotImplementedError(f"unsupported proxy: {proxy}") else: # Connect to the server directly. if kwargs.get("sock") is None: - kwargs.setdefault("host", ws_uri.host) - kwargs.setdefault("port", ws_uri.port) + kwargs.setdefault("host", self.ws_uri.host) + kwargs.setdefault("port", self.ws_uri.port) # Initialize WebSocket connection. _, connection = await loop.create_connection(factory, **kwargs) return connection - def process_redirect(self, exc: Exception) -> Exception | str: + def process_redirect(self, exc: Exception) -> Exception | tuple[str, WebSocketURI]: """ Determine whether a connection error is a redirect that can be followed. @@ -492,12 +490,12 @@ def process_redirect(self, exc: Exception) -> Exception | str: ): return exc - old_ws_uri = parse_uri(self.uri) + old_ws_uri = self.ws_uri new_uri = urllib.parse.urljoin(self.uri, exc.response.headers["Location"]) new_ws_uri = parse_uri(new_uri) # If connect() received a socket, it is closed and cannot be reused. - if self.connection_kwargs.get("sock") is not None: + if self.create_connection_kwargs.get("sock") is not None: return ValueError( f"cannot follow redirect to {new_uri} with a preexisting socket" ) @@ -513,7 +511,7 @@ def process_redirect(self, exc: Exception) -> Exception | str: or old_ws_uri.port != new_ws_uri.port ): # Cross-origin redirects on Unix sockets don't quite make sense. - if self.connection_kwargs.get("unix", False): + if self.create_connection_kwargs.get("unix", False): return ValueError( f"cannot follow cross-origin redirect to {new_uri} " f"with a Unix socket" @@ -521,15 +519,15 @@ def process_redirect(self, exc: Exception) -> Exception | str: # Cross-origin redirects when host and port are overridden are ill-defined. if ( - self.connection_kwargs.get("host") is not None - or self.connection_kwargs.get("port") is not None + self.create_connection_kwargs.get("host") is not None + or self.create_connection_kwargs.get("port") is not None ): return ValueError( f"cannot follow cross-origin redirect to {new_uri} " f"with an explicit host or port" ) - return new_uri + return new_uri, new_ws_uri # ... = await connect(...) @@ -541,14 +539,14 @@ async def __await_impl__(self) -> ClientConnection: try: async with asyncio_timeout(self.open_timeout): for _ in range(MAX_REDIRECTS): - self.connection = await self.create_connection() + connection = await self.open_tcp_connection() try: - await self.connection.handshake( + await connection.handshake( self.additional_headers, self.user_agent_header, ) except asyncio.CancelledError: - self.connection.transport.abort() + connection.transport.abort() raise except Exception as exc: # Always close the connection even though keep-alive is @@ -557,22 +555,23 @@ async def __await_impl__(self) -> ClientConnection: # protocol. In the current design of connect(), there is # no easy way to reuse the network connection that works # in every case nor to reinitialize the protocol. - self.connection.transport.abort() + connection.transport.abort() - uri_or_exc = self.process_redirect(exc) - # Response is a valid redirect; follow it. - if isinstance(uri_or_exc, str): - self.uri = uri_or_exc - continue + exc_or_uri = self.process_redirect(exc) # Response isn't a valid redirect; raise the exception. - if uri_or_exc is exc: - raise + if isinstance(exc_or_uri, Exception): + if exc_or_uri is exc: + raise + else: + raise exc_or_uri from exc + # Response is a valid redirect; follow it. else: - raise uri_or_exc from exc + self.uri, self.ws_uri = exc_or_uri + continue else: - self.connection.start_keepalive() - return self.connection + connection.start_keepalive() + return connection else: raise SecurityError(f"more than {MAX_REDIRECTS} redirects") @@ -587,7 +586,10 @@ async def __await_impl__(self) -> ClientConnection: # async with connect(...) as ...: ... async def __aenter__(self) -> ClientConnection: - return await self + if hasattr(self, "connection"): + raise RuntimeError("connect() isn't reentrant") + self.connection = await self + return self.connection async def __aexit__( self, @@ -595,7 +597,10 @@ async def __aexit__( exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: - await self.connection.close() + try: + await self.connection.close() + finally: + del self.connection # async for ... in connect(...): @@ -603,8 +608,8 @@ async def __aiter__(self) -> AsyncIterator[ClientConnection]: delays: Generator[float] | None = None while True: try: - async with self as protocol: - yield protocol + async with self as connection: + yield connection except Exception as exc: # Determine whether the exception is retryable or fatal. # The API of process_exception is "return an exception or None"; @@ -633,7 +638,6 @@ async def __aiter__(self) -> AsyncIterator[ClientConnection]: traceback.format_exception_only(exc)[0].strip(), ) await asyncio.sleep(delay) - continue else: # The connection succeeded. Reset backoff. @@ -777,8 +781,7 @@ def eof_received(self) -> None: def connection_lost(self, exc: Exception | None) -> None: self.reader.feed_eof() - if exc is not None: - self.response.set_exception(exc) + self.run_parser() async def connect_http_proxy( @@ -797,8 +800,8 @@ async def connect_http_proxy( try: # This raises exceptions if the connection to the proxy fails. await protocol.response - except Exception: - transport.close() + except (asyncio.CancelledError, Exception): + transport.abort() raise return transport diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index ef9bd807f..018d891d1 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -169,7 +169,7 @@ async def handshake( assert isinstance(response, Response) # help mypy self.response = response - if server_header: + if server_header is not None: self.response.headers["Server"] = server_header response = None @@ -231,12 +231,9 @@ class Server: This class mirrors the API of :class:`asyncio.Server`. - It keeps track of WebSocket connections in order to close them properly - when shutting down. - Args: handler: Connection handler. It receives the WebSocket connection, - which is a :class:`ServerConnection`, in argument. + which is a :class:`ServerConnection`. process_request: Intercept the request during the opening handshake. Return an HTTP response to force the response. Return :obj:`None` to continue normally. When you force an HTTP 101 Continue response, the @@ -310,7 +307,11 @@ def connections(self) -> set[ServerConnection]: It can be useful in combination with :func:`~broadcast`. """ - return {connection for connection in self.handlers if connection.state is OPEN} + return { + connection + for connection in self.handlers + if connection.protocol.state is OPEN + } def wrap(self, server: asyncio.Server) -> None: """ @@ -351,6 +352,8 @@ async def conn_handler(self, connection: ServerConnection) -> None: """ try: + # Apply open_timeout to the WebSocket handshake. + # Use ssl_handshake_timeout for the TLS handshake. async with asyncio_timeout(self.open_timeout): try: await connection.handshake( @@ -425,7 +428,7 @@ def close( ``code`` and ``reason`` can be customized, for example to use code 1012 (service restart). - * Wait until all connection handlers terminate. + * Wait until all connection handlers have returned. :meth:`close` is idempotent. @@ -452,6 +455,7 @@ async def _close( self.logger.info("server closing") # Stop accepting new connections. + # Reject OPENING connections with HTTP 503 -- see handshake(). self.server.close() # Wait until all accepted connections reach connection_made() and call @@ -459,15 +463,12 @@ async def _close( # details. This workaround can be removed when dropping Python < 3.11. await asyncio.sleep(0) - # After server.close(), handshake() closes OPENING connections with an - # HTTP 503 error. - + # Close OPEN connections. if close_connections: - # Close OPEN connections with code 1001 by default. close_tasks = [ asyncio.create_task(connection.close(code, reason)) for connection in self.handlers - if connection.protocol.state is not CONNECTING + if connection.protocol.state is OPEN ] # asyncio.wait doesn't accept an empty first argument. if close_tasks: @@ -476,7 +477,7 @@ async def _close( # Wait until all TCP connections are closed. await self.server.wait_closed() - # Wait until all connection handlers terminate. + # Wait until all connection handlers have returned. # asyncio.wait doesn't accept an empty first argument. if self.handlers: await asyncio.wait(self.handlers.values()) @@ -590,18 +591,18 @@ class serve: This coroutine returns a :class:`Server` whose API mirrors :class:`asyncio.Server`. Treat it as an asynchronous context manager to - ensure that the server will be closed:: + ensure that the server will be closed gracefully:: from websockets.asyncio.server import serve - def handler(websocket): + async def handler(websocket): ... - # set this future to exit the server - stop = asyncio.get_running_loop().create_future() + # set this event to exit the server + stop = asyncio.Event() async with serve(handler, host, port): - await stop + await stop.wait() Alternatively, call :meth:`~Server.serve_forever` to serve requests and cancel it to stop the server:: diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index b3fff44ee..3dca571e5 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -63,7 +63,6 @@ def __init__( max_queue: int | None | tuple[int | None, int | None] = 16, ) -> None: self.protocol: ClientProtocol - self.response_rcvd = threading.Event() super().__init__( socket, protocol, @@ -72,6 +71,7 @@ def __init__( close_timeout=close_timeout, max_queue=max_queue, ) + self.response_rcvd = threading.Event() def handshake( self, @@ -157,6 +157,7 @@ def connect( logger: LoggerLike | None = None, # Escape hatch for advanced customization create_connection: type[ClientConnection] | None = None, + # Other keyword arguments are passed to socket.create_connection **kwargs: Any, ) -> ClientConnection: """ @@ -190,8 +191,8 @@ def connect( compression: The "permessage-deflate" extension is enabled by default. Set ``compression`` to :obj:`None` to disable it. See the :doc:`compression guide <../../topics/compression>` for details. - additional_headers (HeadersLike | None): Arbitrary HTTP headers to add - to the handshake request. + additional_headers: Arbitrary HTTP headers to add to the handshake + request. user_agent_header: Value of the ``User-Agent`` request header. It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to :obj:`None` removes the header. @@ -230,6 +231,7 @@ def connect( Raises: InvalidURI: If ``uri`` isn't a valid WebSocket URI. + InvalidProxy: If ``proxy`` isn't a valid proxy. OSError: If the TCP connection fails. InvalidHandshake: If the opening handshake fails. TimeoutError: If the opening handshake times out. @@ -250,6 +252,17 @@ def connect( if not ws_uri.secure and ssl is not None: raise ValueError("ssl argument is incompatible with a ws:// URI") + if subprotocols is not None: + validate_subprotocols(subprotocols) + + if compression == "deflate": + extensions = enable_client_permessage_deflate(extensions) + elif compression is not None: + raise ValueError(f"unsupported compression: {compression}") + + if create_connection is None: + create_connection = ClientConnection + # Private APIs for unix_connect() unix: bool = kwargs.pop("unix", False) path: str | None = kwargs.pop("path", None) @@ -260,14 +273,6 @@ def connect( elif path is not None and sock is not None: raise ValueError("path and sock arguments are incompatible") - if subprotocols is not None: - validate_subprotocols(subprotocols) - - if compression == "deflate": - extensions = enable_client_permessage_deflate(extensions) - elif compression is not None: - raise ValueError(f"unsupported compression: {compression}") - if unix: proxy = None if sock is not None: @@ -280,9 +285,6 @@ def connect( # to avoid conflicting with the WebSocket timeout in handshake(). deadline = Deadline(open_timeout) - if create_connection is None: - create_connection = ClientConnection - try: # Connect socket @@ -321,7 +323,7 @@ def connect( **kwargs, ) else: - raise AssertionError("unsupported proxy") + raise NotImplementedError("unsupported proxy") else: kwargs.setdefault("timeout", deadline.timeout()) sock = socket.create_connection( @@ -539,7 +541,8 @@ def connect_http_proxy( # Send CONNECT request to the proxy and read response. - sock.sendall(prepare_connect_request(proxy, ws_uri, user_agent_header)) + request = prepare_connect_request(proxy, ws_uri, user_agent_header) + sock.sendall(request) try: read_connect_response(sock, deadline) except Exception: diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index ffd82fbad..cf6b14c76 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -32,7 +32,13 @@ from .utils import Deadline -__all__ = ["serve", "unix_serve", "ServerConnection", "Server", "basic_auth"] +__all__ = [ + "serve", + "unix_serve", + "ServerConnection", + "Server", + "basic_auth", +] class ServerConnection(Connection): @@ -72,7 +78,6 @@ def __init__( max_queue: int | None | tuple[int | None, int | None] = 16, ) -> None: self.protocol: ServerProtocol - self.request_rcvd = threading.Event() super().__init__( socket, protocol, @@ -81,6 +86,7 @@ def __init__( close_timeout=close_timeout, max_queue=max_queue, ) + self.request_rcvd = threading.Event() self.username: str # see basic_auth() self.handler: Callable[[ServerConnection], None] # see route() self.handler_kwargs: Mapping[str, Any] # see route() @@ -154,7 +160,7 @@ def handshake( else: self.response = response - if server_header: + if server_header is not None: self.response.headers["Server"] = server_header response = None @@ -218,14 +224,12 @@ class Server: """ WebSocket server returned by :func:`serve`. - This class mirrors the API of :class:`~socketserver.BaseServer`, notably the - :meth:`~socketserver.BaseServer.serve_forever` and - :meth:`~socketserver.BaseServer.shutdown` methods, as well as the context - manager protocol. + This class mirrors partially the API of :class:`~socketserver.BaseServer`. - Args: - socket: Server socket listening for new connections. - handler: Handler for one connection. Receives the socket and address + + Args: + socket: Server socket accepting new connections. + handler: Handler for one connection. It receives the socket and address returned by :meth:`~socket.socket.accept`. logger: Logger for this server. It defaults to ``logging.getLogger("websockets.server")``. @@ -387,8 +391,8 @@ def serve( This function returns a :class:`Server` whose API mirrors :class:`~socketserver.BaseServer`. Treat it as a context manager to ensure - that it will be closed and call :meth:`~Server.serve_forever` to serve - requests:: + that it will be closed gracefully and call :meth:`~Server.serve_forever` to + serve requests:: from websockets.sync.server import serve @@ -605,7 +609,12 @@ def protocol_select_subprotocol( connection.recv_events_thread.join() return - assert connection.protocol.state is OPEN + if connection.protocol.state is not OPEN: + # process_request or process_response rejected the handshake. + connection.close_socket() + connection.recv_events_thread.join() + return + try: connection.start_keepalive() handler(connection) diff --git a/src/websockets/trio/client.py b/src/websockets/trio/client.py new file mode 100644 index 000000000..fa4b36b73 --- /dev/null +++ b/src/websockets/trio/client.py @@ -0,0 +1,728 @@ +from __future__ import annotations + +import logging +import os +import ssl as ssl_module +import sys +import traceback +import urllib.parse +from collections.abc import AsyncIterator, Generator, Sequence +from types import TracebackType +from typing import Any, Callable, Literal + +import trio + +from ..asyncio.client import process_exception +from ..client import ClientProtocol, backoff +from ..datastructures import HeadersLike +from ..exceptions import ( + InvalidProxyMessage, + InvalidProxyStatus, + InvalidStatus, + ProxyError, + SecurityError, +) +from ..extensions.base import ClientExtensionFactory +from ..extensions.permessage_deflate import enable_client_permessage_deflate +from ..headers import validate_subprotocols +from ..http11 import USER_AGENT, Response +from ..protocol import CONNECTING, Event +from ..proxy import Proxy, get_proxy, parse_proxy, prepare_connect_request +from ..streams import StreamReader +from ..typing import LoggerLike, Origin, Subprotocol +from ..uri import WebSocketURI, parse_uri +from .connection import Connection +from .utils import race_events + + +if sys.version_info[:2] < (3, 11): # pragma: no cover + from exceptiongroup import BaseExceptionGroup + + +__all__ = ["connect", "ClientConnection"] + +MAX_REDIRECTS = int(os.environ.get("WEBSOCKETS_MAX_REDIRECTS", "10")) + + +class ClientConnection(Connection): + """ + :mod:`trio` implementation of a WebSocket client connection. + + :class:`ClientConnection` provides :meth:`recv` and :meth:`send` coroutines + for receiving and sending messages. + + It supports asynchronous iteration to receive messages:: + + async for message in websocket: + await process(message) + + The iterator exits normally when the connection is closed with close code + 1000 (OK) or 1001 (going away) or without a close code. It raises a + :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is + closed with any other code. + + The ``ping_interval``, ``ping_timeout``, ``close_timeout``, and + ``max_queue`` arguments have the same meaning as in :func:`connect`. + + Args: + nursery: Trio nursery. + stream: Trio stream connected to a WebSocket server. + protocol: Sans-I/O connection. + + """ + + def __init__( + self, + nursery: trio.Nursery, + stream: trio.abc.Stream, + protocol: ClientProtocol, + *, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = 10, + max_queue: int | None | tuple[int | None, int | None] = 16, + ) -> None: + self.protocol: ClientProtocol + super().__init__( + nursery, + stream, + protocol, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + close_timeout=close_timeout, + max_queue=max_queue, + ) + self.response_rcvd = trio.Event() + + async def handshake( + self, + additional_headers: HeadersLike | None = None, + user_agent_header: str | None = USER_AGENT, + ) -> None: + """ + Perform the opening handshake. + + """ + async with self.send_context(expected_state=CONNECTING): + self.request = self.protocol.connect() + if additional_headers is not None: + self.request.headers.update(additional_headers) + if user_agent_header is not None: + self.request.headers.setdefault("User-Agent", user_agent_header) + self.protocol.send_request(self.request) + + await race_events(self.response_rcvd, self.stream_closed) + + # self.protocol.handshake_exc is set when the connection is lost before + # receiving a response, when the response cannot be parsed, or when the + # response fails the handshake. + + if self.protocol.handshake_exc is not None: + raise self.protocol.handshake_exc + + def process_event(self, event: Event) -> None: + """ + Process one incoming event. + + """ + # First event - handshake response. + if self.response is None: + assert isinstance(event, Response) + self.response = event + self.response_rcvd.set() + # Later events - frames. + else: + super().process_event(event) + + +# This is spelled in lower case because it's exposed as a callable in the API. +class connect: + """ + Connect to the WebSocket server at ``uri``. + + This coroutine returns a :class:`ClientConnection` instance, which you can + use to send and receive messages. + + :func:`connect` may be used as an asynchronous context manager:: + + from websockets.trio.client import connect + + async with connect(...) as websocket: + ... + + The connection is closed automatically when exiting the context. + + :func:`connect` can be used as an infinite asynchronous iterator to + reconnect automatically on errors:: + + async for websocket in connect(...): + try: + ... + except websockets.exceptions.ConnectionClosed: + continue + + If the connection fails with a transient error, it is retried with + exponential backoff. If it fails with a fatal error, the exception is + raised, breaking out of the loop. + + The connection is closed automatically after each iteration of the loop. + + Args: + uri: URI of the WebSocket server. + stream: Preexisting TCP stream. ``stream`` overrides the host and port + from ``uri``. You may call :func:`~trio.open_tcp_stream` to create a + suitable TCP stream. + ssl: Configuration for enabling TLS on the connection. + server_hostname: Host name for the TLS handshake. ``server_hostname`` + overrides the host name from ``uri``. + origin: Value of the ``Origin`` header, for servers that require it. + extensions: List of supported extensions, in order in which they + should be negotiated and run. + subprotocols: List of supported subprotocols, in order of decreasing + preference. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. + additional_headers: Arbitrary HTTP headers to add to the handshake + request. + user_agent_header: Value of the ``User-Agent`` request header. + It defaults to ``"Python/x.y.z websockets/X.Y"``. + Setting it to :obj:`None` removes the header. + proxy: If a proxy is configured, it is used by default. Set ``proxy`` + to :obj:`None` to disable the proxy or to the address of a proxy + to override the system configuration. See the :doc:`proxy docs + <../../topics/proxies>` for details. + proxy_ssl: Configuration for enabling TLS on the proxy connection. + proxy_server_hostname: Host name for the TLS handshake with the proxy. + ``proxy_server_hostname`` overrides the host name from ``proxy``. + process_exception: When reconnecting automatically, tell whether an + error is transient or fatal. The default behavior is defined by + :func:`process_exception`. Refer to its documentation for details. + open_timeout: Timeout for opening the connection in seconds. + :obj:`None` disables the timeout. + ping_interval: Interval between keepalive pings in seconds. + :obj:`None` disables keepalive. + ping_timeout: Timeout for keepalive pings in seconds. + :obj:`None` disables timeouts. + close_timeout: Timeout for closing the connection in seconds. + :obj:`None` disables the timeout. + max_size: Maximum size of incoming messages in bytes. + :obj:`None` disables the limit. You may pass a ``(max_message_size, + max_fragment_size)`` tuple to set different limits for messages and + fragments when you expect long messages sent in short fragments. + max_queue: High-water mark of the buffer where frames are received. + It defaults to 16 frames. The low-water mark defaults to ``max_queue + // 4``. You may pass a ``(high, low)`` tuple to set the high-water + and low-water marks. If you want to disable flow control entirely, + you may set it to ``None``, although that's a bad idea. + logger: Logger for this client. + It defaults to ``logging.getLogger("websockets.client")``. + See the :doc:`logging guide <../../topics/logging>` for details. + create_connection: Factory for the :class:`ClientConnection` managing + the connection. Set it to a wrapper or a subclass to customize + connection handling. + + Any other keyword arguments are passed to :func:`~trio.open_tcp_stream`. + + Raises: + InvalidURI: If ``uri`` isn't a valid WebSocket URI. + InvalidProxy: If ``proxy`` isn't a valid proxy. + OSError: If the TCP connection fails. + InvalidHandshake: If the opening handshake fails. + TimeoutError: If the opening handshake times out. + + """ + + def __init__( + self, + uri: str, + *, + # TCP/TLS + stream: trio.abc.Stream | None = None, + ssl: ssl_module.SSLContext | None = None, + server_hostname: str | None = None, + # WebSocket + origin: Origin | None = None, + extensions: Sequence[ClientExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + compression: str | None = "deflate", + # HTTP + additional_headers: HeadersLike | None = None, + user_agent_header: str | None = USER_AGENT, + proxy: str | Literal[True] | None = True, + proxy_ssl: ssl_module.SSLContext | None = None, + proxy_server_hostname: str | None = None, + process_exception: Callable[[Exception], Exception | None] = process_exception, + # Timeouts + open_timeout: float | None = 10, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = 10, + # Limits + max_size: int | None | tuple[int | None, int | None] = 2**20, + max_queue: int | None | tuple[int | None, int | None] = 16, + # Logging + logger: LoggerLike | None = None, + # Escape hatch for advanced customization + create_connection: type[ClientConnection] | None = None, + # Other keyword arguments are passed to trio.open_tcp_stream + **kwargs: Any, + ) -> None: + self.uri = uri + self.ws_uri = parse_uri(uri) + if not self.ws_uri.secure and ssl is not None: + raise ValueError("ssl argument is incompatible with a ws:// URI") + + if subprotocols is not None: + validate_subprotocols(subprotocols) + + if compression == "deflate": + extensions = enable_client_permessage_deflate(extensions) + elif compression is not None: + raise ValueError(f"unsupported compression: {compression}") + + if logger is None: + logger = logging.getLogger("websockets.client") + + if create_connection is None: + create_connection = ClientConnection + + self.stream = stream + self.ssl = ssl + self.server_hostname = server_hostname + self.proxy = proxy + self.proxy_ssl = proxy_ssl + self.proxy_server_hostname = proxy_server_hostname + self.additional_headers = additional_headers + self.user_agent_header = user_agent_header + self.process_exception = process_exception + self.open_timeout = open_timeout + self.logger = logger + self.create_connection = create_connection + self.open_tcp_stream_kwargs = kwargs + self.protocol_kwargs = dict( + origin=origin, + extensions=extensions, + subprotocols=subprotocols, + max_size=max_size, + logger=logger, + ) + self.connection_kwargs = dict( + ping_interval=ping_interval, + ping_timeout=ping_timeout, + close_timeout=close_timeout, + max_queue=max_queue, + ) + + async def open_tcp_stream(self) -> trio.abc.Stream: + """Open a TCP connection to the server, possibly through a proxy.""" + # TCP connection is already established. + if self.stream is not None: + return self.stream + + if self.proxy is True: + proxy = get_proxy(self.ws_uri) + else: + proxy = self.proxy + + # Connect to the server through a proxy. + if proxy is not None: + proxy_parsed = parse_proxy(proxy) + + if proxy_parsed.scheme[:5] == "socks": + return await connect_socks_proxy( + proxy_parsed, + self.ws_uri, + local_address=self.open_tcp_stream_kwargs.get("local_address"), + ) + + elif proxy_parsed.scheme[:4] == "http": + if proxy_parsed.scheme != "https" and self.proxy_ssl is not None: + raise ValueError( + "proxy_ssl argument is incompatible with an http:// proxy" + ) + return await connect_http_proxy( + proxy_parsed, + self.ws_uri, + user_agent_header=self.user_agent_header, + ssl=self.proxy_ssl, + server_hostname=self.proxy_server_hostname, + local_address=self.open_tcp_stream_kwargs.get("local_address"), + ) + + else: + raise NotImplementedError(f"unsupported proxy: {self.proxy}") + + # Connect to the server directly. + kwargs = self.open_tcp_stream_kwargs.copy() + kwargs.setdefault("host", self.ws_uri.host) + kwargs.setdefault("port", self.ws_uri.port) + return await trio.open_tcp_stream(**kwargs) + + async def enable_tls(self, stream: trio.abc.Stream) -> trio.abc.Stream: + """Enable TLS on the connection.""" + if self.ssl is None: + ssl = ssl_module.create_default_context() + else: + ssl = self.ssl + if self.server_hostname is None: + server_hostname = self.ws_uri.host + else: + server_hostname = self.server_hostname + ssl_stream = trio.SSLStream( + stream, + ssl, + server_hostname=server_hostname, + https_compatible=True, + ) + await ssl_stream.do_handshake() + return ssl_stream + + async def open_connection(self, nursery: trio.Nursery) -> ClientConnection: + """Create a WebSocket connection.""" + stream: trio.abc.Stream + stream = await self.open_tcp_stream() + + try: + if self.ws_uri.secure: + stream = await self.enable_tls(stream) + + protocol = ClientProtocol( + self.ws_uri, + **self.protocol_kwargs, # type: ignore + ) + + connection = self.create_connection( # default is ClientConnection + nursery, + stream, + protocol, + **self.connection_kwargs, # type: ignore + ) + + await connection.handshake( + self.additional_headers, + self.user_agent_header, + ) + + return connection + + except trio.Cancelled: + await trio.aclose_forcefully(stream) + # The nursery running this coroutine was canceled. + # The next checkpoint raises trio.Cancelled. + # aclose_forcefully() never returns. + raise AssertionError("nursery should be canceled") + except Exception: + # Always close the connection even though keep-alive is the default + # in HTTP/1.1 because the current implementation ties opening the + # TCP/TLS connection with initializing the WebSocket protocol. + await trio.aclose_forcefully(stream) + raise + + def process_redirect(self, exc: Exception) -> Exception | tuple[str, WebSocketURI]: + """ + Determine whether a connection error is a redirect that can be followed. + + Return the new URI if it's a valid redirect. Else, return an exception. + + """ + if not ( + isinstance(exc, InvalidStatus) + and exc.response.status_code + in [ + 300, # Multiple Choices + 301, # Moved Permanently + 302, # Found + 303, # See Other + 307, # Temporary Redirect + 308, # Permanent Redirect + ] + and "Location" in exc.response.headers + ): + return exc + + old_ws_uri = self.ws_uri + new_uri = urllib.parse.urljoin(self.uri, exc.response.headers["Location"]) + new_ws_uri = parse_uri(new_uri) + + # If connect() received a stream, it is closed and cannot be reused. + if self.stream is not None: + return ValueError( + f"cannot follow redirect to {new_uri} with a preexisting stream" + ) + + # TLS downgrade is forbidden. + if old_ws_uri.secure and not new_ws_uri.secure: + return SecurityError(f"cannot follow redirect to non-secure URI {new_uri}") + + # Apply restrictions to cross-origin redirects. + if ( + old_ws_uri.secure != new_ws_uri.secure + or old_ws_uri.host != new_ws_uri.host + or old_ws_uri.port != new_ws_uri.port + ): + # Cross-origin redirects when host and port are overridden are ill-defined. + if ( + self.open_tcp_stream_kwargs.get("host") is not None + or self.open_tcp_stream_kwargs.get("port") is not None + ): + return ValueError( + f"cannot follow cross-origin redirect to {new_uri} " + f"with an explicit host or port" + ) + + return new_uri, new_ws_uri + + async def connect(self, nursery: trio.Nursery) -> ClientConnection: + try: + with ( + trio.CancelScope() + if self.open_timeout is None + else trio.fail_after(self.open_timeout) + ): + for _ in range(MAX_REDIRECTS): + try: + connection = await self.open_connection(nursery) + except Exception as exc: + exc_or_uri = self.process_redirect(exc) + # Response isn't a valid redirect; raise the exception. + if isinstance(exc_or_uri, Exception): + if exc_or_uri is exc: + raise + else: + raise exc_or_uri from exc + # Response is a valid redirect; follow it. + else: + self.uri, self.ws_uri = exc_or_uri + continue + + else: + connection.start_keepalive() + return connection + else: + raise SecurityError(f"more than {MAX_REDIRECTS} redirects") + + except trio.TooSlowError as exc: + # Re-raise exception with an informative error message. + raise TimeoutError("timed out during opening handshake") from exc + + # Do not define __await__ for... = await nursery.start(connect, ...) + # because it doesn't look idiomatic in Trio. + + # async with connect(...) as ...: ... + + async def __aenter__(self) -> ClientConnection: + await self.__aenter_nursery__() + try: + self.connection = await self.connect(self.nursery) + return self.connection + except BaseException as exc: + await self.__aexit_nursery__(type(exc), exc, exc.__traceback__) + raise AssertionError("expected __aexit_nursery__ to re-raise the exception") + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + try: + await self.connection.aclose() + del self.connection + finally: + await self.__aexit_nursery__(exc_type, exc_value, traceback) + + async def __aenter_nursery__(self) -> None: + if hasattr(self, "nursery_manager"): # pragma: no cover + raise RuntimeError("connect() isn't reentrant") + self.nursery_manager = trio.open_nursery() + self.nursery = await self.nursery_manager.__aenter__() + + async def __aexit_nursery__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + # We need a nursery to start the recv_events and keepalive coroutines. + # They aren't expected to raise exceptions; instead they catch and log + # all unexpected errors. To keep the nursery an implementation detail, + # unwrap exceptions raised by user code -- per the second option here: + # https://trio.readthedocs.io/en/stable/reference-core.html#designing-for-multiple-errors + try: + await self.nursery_manager.__aexit__(exc_type, exc_value, traceback) + except BaseException as exc: + assert isinstance(exc, BaseExceptionGroup) + try: + trio._util.raise_single_exception_from_group(exc) + except trio._util.MultipleExceptionError: # pragma: no cover + raise AssertionError( + "unexpected multiple exceptions; please file a bug report" + ) from exc + finally: + del self.nursery_manager + + # async for ... in connect(...): + + async def __aiter__(self) -> AsyncIterator[ClientConnection]: + delays: Generator[float] | None = None + while True: + try: + async with self as connection: + yield connection + except Exception as exc: + # Determine whether the exception is retryable or fatal. + # The API of process_exception is "return an exception or None"; + # "raise an exception" is also supported because it's a frequent + # mistake. It isn't documented in order to keep the API simple. + try: + new_exc = self.process_exception(exc) + except Exception as raised_exc: + new_exc = raised_exc + + # The connection failed with a fatal error. + # Raise the exception and exit the loop. + if new_exc is exc: + raise + if new_exc is not None: + raise new_exc from exc + + # The connection failed with a retryable error. + # Start or continue backoff and reconnect. + if delays is None: + delays = backoff() + delay = next(delays) + self.logger.info( + "connect failed; reconnecting in %.1f seconds: %s", + delay, + traceback.format_exception_only(exc)[0].strip(), + ) + await trio.sleep(delay) + + else: + # The connection succeeded. Reset backoff. + delays = None + + +try: + from python_socks import ProxyType + from python_socks.async_.trio import Proxy as SocksProxy + +except ImportError: + + async def connect_socks_proxy( + proxy: Proxy, + ws_uri: WebSocketURI, + **kwargs: Any, + ) -> trio.abc.Stream: + raise ImportError("connecting through a SOCKS proxy requires python-socks") + +else: + SOCKS_PROXY_TYPES = { + "socks5h": ProxyType.SOCKS5, + "socks5": ProxyType.SOCKS5, + "socks4a": ProxyType.SOCKS4, + "socks4": ProxyType.SOCKS4, + } + + SOCKS_PROXY_RDNS = { + "socks5h": True, + "socks5": False, + "socks4a": True, + "socks4": False, + } + + async def connect_socks_proxy( + proxy: Proxy, + ws_uri: WebSocketURI, + **kwargs: Any, + ) -> trio.abc.Stream: + """Connect via a SOCKS proxy and return the socket.""" + socks_proxy = SocksProxy( + SOCKS_PROXY_TYPES[proxy.scheme], + proxy.host, + proxy.port, + proxy.username, + proxy.password, + SOCKS_PROXY_RDNS[proxy.scheme], + ) + # connect() is documented to raise OSError. + # socks_proxy.connect() re-raises trio.TooSlowError as ProxyTimeoutError. + # Wrap other exceptions in ProxyError, a subclass of InvalidHandshake. + try: + return trio.SocketStream( + await socks_proxy.connect(ws_uri.host, ws_uri.port, **kwargs) + ) + except OSError: + raise + except Exception as exc: + raise ProxyError("failed to connect to SOCKS proxy") from exc + + +async def read_connect_response(stream: trio.abc.Stream) -> Response: + reader = StreamReader() + parser = Response.parse( + reader.read_line, + reader.read_exact, + reader.read_to_eof, + proxy=True, + ) + try: + while True: + data = await stream.receive_some(4096) + if data: + reader.feed_data(data) + else: + reader.feed_eof() + next(parser) + except StopIteration as exc: + assert isinstance(exc.value, Response) # help mypy + response = exc.value + if 200 <= response.status_code < 300: + return response + else: + raise InvalidProxyStatus(response) + except Exception as exc: + raise InvalidProxyMessage( + "did not receive a valid HTTP response from proxy" + ) from exc + + +async def connect_http_proxy( + proxy: Proxy, + ws_uri: WebSocketURI, + *, + user_agent_header: str | None = None, + ssl: ssl_module.SSLContext | None = None, + server_hostname: str | None = None, + **kwargs: Any, +) -> trio.abc.Stream: + stream: trio.abc.Stream + stream = await trio.open_tcp_stream(proxy.host, proxy.port, **kwargs) + + try: + # Initialize TLS wrapper and perform TLS handshake + if proxy.scheme == "https": + if ssl is None: + ssl = ssl_module.create_default_context() + if server_hostname is None: + server_hostname = proxy.host + ssl_stream = trio.SSLStream( + stream, + ssl, + server_hostname=server_hostname, + https_compatible=True, + ) + await ssl_stream.do_handshake() + stream = ssl_stream + + # Send CONNECT request to the proxy and read response. + request = prepare_connect_request(proxy, ws_uri, user_agent_header) + await stream.send_all(request) + await read_connect_response(stream) + + except (trio.Cancelled, Exception): + await trio.aclose_forcefully(stream) + raise + + return stream diff --git a/src/websockets/trio/server.py b/src/websockets/trio/server.py new file mode 100644 index 000000000..7571ca623 --- /dev/null +++ b/src/websockets/trio/server.py @@ -0,0 +1,649 @@ +from __future__ import annotations + +import functools +import http +import logging +import re +import ssl as ssl_module +from collections.abc import Awaitable, Sequence +from types import TracebackType +from typing import Any, Callable, Mapping + +import trio +import trio.abc + +from ..asyncio.server import basic_auth +from ..extensions.base import ServerExtensionFactory +from ..extensions.permessage_deflate import enable_server_permessage_deflate +from ..frames import CloseCode +from ..headers import validate_subprotocols +from ..http11 import SERVER, Request, Response +from ..protocol import CONNECTING, OPEN, Event +from ..server import ServerProtocol +from ..typing import LoggerLike, Origin, StatusLike, Subprotocol +from .connection import Connection +from .utils import race_events + + +__all__ = [ + "serve", + "ServerConnection", + "Server", + "basic_auth", +] + + +class ServerConnection(Connection): + """ + :mod:`trio` implementation of a WebSocket server connection. + + :class:`ServerConnection` provides :meth:`recv` and :meth:`send` methods for + receiving and sending messages. + + It supports asynchronous iteration to receive messages:: + + async for message in websocket: + await process(message) + + The iterator exits normally when the connection is closed with close code + 1000 (OK) or 1001 (going away) or without a close code. It raises a + :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is + closed with any other code. + + The ``ping_interval``, ``ping_timeout``, ``close_timeout``, and + ``max_queue`` arguments have the same meaning as in :func:`serve`. + + Args: + nursery: Trio nursery. + stream: Trio stream connected to a WebSocket client. + protocol: Sans-I/O connection. + server: Server that manages this connection. + + """ + + def __init__( + self, + nursery: trio.Nursery, + stream: trio.abc.Stream, + protocol: ServerProtocol, + server: Server, + *, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = 10, + max_queue: int | None | tuple[int | None, int | None] = 16, + ) -> None: + self.protocol: ServerProtocol + super().__init__( + nursery, + stream, + protocol, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + close_timeout=close_timeout, + max_queue=max_queue, + ) + self.server = server + self.request_rcvd: trio.Event = trio.Event() + self.username: str # see basic_auth() + self.handler: Callable[[ServerConnection], Awaitable[None]] # see route() + self.handler_kwargs: Mapping[str, Any] # see route() + + def respond(self, status: StatusLike, text: str) -> Response: + """ + Create a plain text HTTP response. + + ``process_request`` and ``process_response`` may call this method to + return an HTTP response instead of performing the WebSocket opening + handshake. + + You can modify the response before returning it, for example by changing + HTTP headers. + + Args: + status: HTTP status code. + text: HTTP response body; it will be encoded to UTF-8. + + Returns: + HTTP response to send to the client. + + """ + return self.protocol.reject(status, text) + + async def handshake( + self, + process_request: ( + Callable[ + [ServerConnection, Request], + Awaitable[Response | None] | Response | None, + ] + | None + ) = None, + process_response: ( + Callable[ + [ServerConnection, Request, Response], + Awaitable[Response | None] | Response | None, + ] + | None + ) = None, + server_header: str | None = SERVER, + ) -> None: + """ + Perform the opening handshake. + + """ + await race_events(self.request_rcvd, self.stream_closed) + + if self.request is not None: + async with self.send_context(expected_state=CONNECTING): + response = None + + if process_request is not None: + try: + response = process_request(self, self.request) + if isinstance(response, Awaitable): + response = await response + except Exception as exc: + self.protocol.handshake_exc = exc + response = self.protocol.reject( + http.HTTPStatus.INTERNAL_SERVER_ERROR, + ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ), + ) + + if response is None: + if not self.server.closing: + self.response = self.protocol.accept(self.request) + else: + self.response = self.protocol.reject( + http.HTTPStatus.SERVICE_UNAVAILABLE, + "Server is shutting down.\n", + ) + else: + assert isinstance(response, Response) # help mypy + self.response = response + + if server_header is not None: + self.response.headers["Server"] = server_header + + response = None + + if process_response is not None: + try: + response = process_response(self, self.request, self.response) + if isinstance(response, Awaitable): + response = await response + except Exception as exc: + self.protocol.handshake_exc = exc + response = self.protocol.reject( + http.HTTPStatus.INTERNAL_SERVER_ERROR, + ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ), + ) + + if response is not None: + assert isinstance(response, Response) # help mypy + self.response = response + + self.protocol.send_response(self.response) + + # self.protocol.handshake_exc is set when the connection is lost before + # receiving a request, when the request cannot be parsed, or when the + # handshake fails, including when process_request or process_response + # raises an exception. + + # It isn't set when process_request or process_response sends an HTTP + # response that rejects the handshake. + + if self.protocol.handshake_exc is not None: + raise self.protocol.handshake_exc + + def process_event(self, event: Event) -> None: + """ + Process one incoming event. + + """ + # First event - handshake request. + if self.request is None: + assert isinstance(event, Request) + self.request = event + self.request_rcvd.set() + # Later events - frames. + else: + super().process_event(event) + + +class Server(trio.abc.AsyncResource): + """ + WebSocket server returned by :func:`serve`. + + Args: + open_listeners: Factory for Trio listeners accepting new connections. + stream_handler: Handler for one connection. It receives a Trio stream. + logger: Logger for this server. + It defaults to ``logging.getLogger("websockets.server")``. + See the :doc:`logging guide <../../topics/logging>` for details. + + """ + + def __init__( + self, + open_listeners: Callable[[], Awaitable[list[trio.SocketListener]]], + stream_handler: Callable[[trio.abc.Stream, Server], Awaitable[None]], + logger: LoggerLike | None = None, + ) -> None: + self.open_listeners = open_listeners + self.stream_handler = stream_handler + if logger is None: + logger = logging.getLogger("websockets.server") + self.logger = logger + + self.listeners: list[trio.SocketListener] = [] + """Trio listeners.""" + + self.closing = False + self.closed_waiters: dict[ServerConnection, trio.Event] = {} + + @property + def connections(self) -> set[ServerConnection]: + """ + Set of active connections. + + This property contains all connections that completed the opening + handshake successfully and didn't start the closing handshake yet. + + .. It can be useful in combination with :func:`~broadcast`. + + """ + return { + connection + for connection in self.closed_waiters + if connection.protocol.state is OPEN + } + + async def serve_forever( + self, + task_status: trio.TaskStatus[Server] = trio.TASK_STATUS_IGNORED, + ) -> None: + self.listeners = await self.open_listeners() # used in tests + # Running handlers in a dedicated nursery makes it possible to close + # listeners while handlers finish running. The nursery for listeners + # is created in trio.serve_listeners(). + async with trio.open_nursery() as self.handler_nursery: + # Wrap trio.serve_listeners() in another nursery to return the + # Server object in task_status instead of a list of listeners. + async with trio.open_nursery() as self.serve_nursery: + await self.serve_nursery.start( + functools.partial( + trio.serve_listeners, + functools.partial(self.stream_handler, server=self), # type: ignore + self.listeners, + handler_nursery=self.handler_nursery, + ) + ) + task_status.started(self) + + # Shutting down the server cleanly when serve_forever() is canceled would be + # the most idiomatic in Trio. However, that would require shielding too many + # asynchronous operations, including the TLS & WebSocket opening handshakes. + + async def aclose( + self, + close_connections: bool = True, + code: CloseCode | int = CloseCode.GOING_AWAY, + reason: str = "", + ) -> None: + """ + Close the server. + + * Close the TCP listeners. + * When ``close_connections`` is :obj:`True`, which is the default, + close existing connections. Specifically: + + * Reject opening WebSocket connections with an HTTP 503 (service + unavailable) error. This happens when the server accepted the TCP + connection but didn't complete the opening handshake before closing. + * Close open WebSocket connections with close code 1001 (going away). + ``code`` and ``reason`` can be customized, for example to use code + 1012 (service restart). + + * Wait until all connection handlers have returned. + + :meth:`aclose` is idempotent. + + """ + self.logger.info("server closing") + + # Stop accepting new connections. + self.serve_nursery.cancel_scope.cancel() + + # Reject OPENING connections with HTTP 503 -- see handshake(). + self.closing = True + + # Close OPEN connections. + if close_connections: + for connection in self.closed_waiters: + if connection.protocol.state is not OPEN: # pragma: no cover + continue + self.handler_nursery.start_soon(connection.aclose, code, reason) + + # Wait until all connection handlers have returned. + while self.closed_waiters: + await next(iter(self.closed_waiters.values())).wait() + + self.logger.info("server closed") + + async def __aenter__(self) -> Server: # pragma: no cover + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: # pragma: no cover + await self.aclose() + + +async def serve( + handler: Callable[[ServerConnection], Awaitable[None]], + port: int | None = None, + *, + # TCP/TLS + host: str | bytes | None = None, + backlog: int | None = None, + listeners: list[trio.SocketListener] | None = None, + ssl: ssl_module.SSLContext | None = None, + # WebSocket + origins: Sequence[Origin | re.Pattern[str] | None] | None = None, + extensions: Sequence[ServerExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + select_subprotocol: ( + Callable[ + [ServerConnection, Sequence[Subprotocol]], + Subprotocol | None, + ] + | None + ) = None, + compression: str | None = "deflate", + # HTTP + process_request: ( + Callable[ + [ServerConnection, Request], + Response | None, + ] + | None + ) = None, + process_response: ( + Callable[ + [ServerConnection, Request, Response], + Response | None, + ] + | None + ) = None, + server_header: str | None = SERVER, + # Timeouts + open_timeout: float | None = 10, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = 10, + # Limits + max_size: int | None | tuple[int | None, int | None] = 2**20, + max_queue: int | None | tuple[int | None, int | None] = 16, + # Logging + logger: LoggerLike | None = None, + # Escape hatch for advanced customization + create_connection: type[ServerConnection] | None = None, + # Trio + task_status: trio.TaskStatus[Server] = trio.TASK_STATUS_IGNORED, +) -> None: + """ + Create a WebSocket server listening on ``port``. + + Whenever a client connects, the server creates a :class:`ServerConnection`, + performs the opening handshake, and delegates to the ``handler`` coroutine. + + The handler receives the :class:`ServerConnection` instance, which you can + use to send and receive messages. + + Once the handler completes, either normally or with an exception, the server + performs the closing handshake and closes the connection. + + When using :func:`serve` with :meth:`nursery.start `, + you get back a :class:`Server` object. Call its :meth:`~Server.aclose` + method to stop the server gracefully:: + + from websockets.trio.server import serve + + async def handler(websocket): + ... + + # set this event to exit the server + stop = trio.Event() + + with trio.open_nursery() as nursery: + server = await nursery.start(serve, handler, port) + try: + await stop.wait() + finally: + await server.aclose() + + Args: + handler: Connection handler. It receives the WebSocket connection, + which is a :class:`ServerConnection`, in argument. + port: TCP port the server listens on. + See :func:`~trio.open_tcp_listeners` for details. + host: Network interfaces the server binds to. + See :func:`~trio.open_tcp_listeners` for details. + backlog: Listen backlog. See :func:`~trio.open_tcp_listeners` for + details. + listeners: Preexisting TCP listeners. ``listeners`` replaces ``port``, + ``host``, and ``backlog``. See :func:`trio.serve_listeners` for + details. + ssl: Configuration for enabling TLS on the connection. + origins: Acceptable values of the ``Origin`` header, for defending + against Cross-Site WebSocket Hijacking attacks. Values can be + :class:`str` to test for an exact match or regular expressions + compiled by :func:`re.compile` to test against a pattern. Include + :obj:`None` in the list if the lack of an origin is acceptable. + extensions: List of supported extensions, in order in which they + should be negotiated and run. + subprotocols: List of supported subprotocols, in order of decreasing + preference. + select_subprotocol: Callback for selecting a subprotocol among + those supported by the client and the server. It receives a + :class:`ServerConnection` (not a + :class:`~websockets.server.ServerProtocol`!) instance and a list of + subprotocols offered by the client. Other than the first argument, + it has the same behavior as the + :meth:`ServerProtocol.select_subprotocol + ` method. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. + process_request: Intercept the request during the opening handshake. + Return an HTTP response to force the response. Return :obj:`None` to + continue normally. When you force an HTTP 101 Continue response, the + handshake is successful. Else, the connection is aborted. + process_response: Intercept the response during the opening handshake. + Modify the response or return a new HTTP response to force the + response. Return :obj:`None` to continue normally. When you force an + HTTP 101 Continue response, the handshake is successful. Else, the + connection is aborted. + server_header: Value of the ``Server`` response header. + It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to + :obj:`None` removes the header. + open_timeout: Timeout for opening connections in seconds. + :obj:`None` disables the timeout. + ping_interval: Interval between keepalive pings in seconds. + :obj:`None` disables keepalive. + ping_timeout: Timeout for keepalive pings in seconds. + :obj:`None` disables timeouts. + close_timeout: Timeout for closing connections in seconds. + :obj:`None` disables the timeout. + max_size: Maximum size of incoming messages in bytes. + :obj:`None` disables the limit. You may pass a ``(max_message_size, + max_fragment_size)`` tuple to set different limits for messages and + fragments when you expect long messages sent in short fragments. + max_queue: High-water mark of the buffer where frames are received. + It defaults to 16 frames. The low-water mark defaults to ``max_queue + // 4``. You may pass a ``(high, low)`` tuple to set the high-water + and low-water marks. If you want to disable flow control entirely, + you may set it to ``None``, although that's a bad idea. + logger: Logger for this server. + It defaults to ``logging.getLogger("websockets.server")``. See the + :doc:`logging guide <../../topics/logging>` for details. + create_connection: Factory for the :class:`ServerConnection` managing + the connection. Set it to a wrapper or a subclass to customize + connection handling. + task_status: For compatibility with :meth:`nursery.start + `. + + """ + + # Process parameters + + if subprotocols is not None: + validate_subprotocols(subprotocols) + + if compression == "deflate": + extensions = enable_server_permessage_deflate(extensions) + elif compression is not None: + raise ValueError(f"unsupported compression: {compression}") + + if create_connection is None: + create_connection = ServerConnection + + # Create listeners + + if listeners is None: + if port is None: + raise ValueError("port is required when listeners is not provided") + + async def open_listeners() -> list[trio.SocketListener]: + return await trio.open_tcp_listeners(port, host=host, backlog=backlog) + else: + if port is not None: + raise ValueError("port is incompatible with listeners") + if host is not None: + raise ValueError("host is incompatible with listeners") + if backlog is not None: + raise ValueError("backlog is incompatible with listeners") + + async def open_listeners() -> list[trio.SocketListener]: + return listeners + + async def stream_handler(stream: trio.abc.Stream, server: Server) -> None: + async with trio.open_nursery() as nursery: + try: + # Apply open_timeout to the TLS and WebSocket handshake. + with ( + trio.CancelScope() + if open_timeout is None + else trio.move_on_after(open_timeout) + ): + # Enable TLS. + if ssl is not None: + # Wrap with SSLStream here rather than with TLSListener + # in order to include the TLS handshake within open_timeout. + stream = trio.SSLStream( + stream, + ssl, + server_side=True, + https_compatible=True, + ) + assert isinstance(stream, trio.SSLStream) # help mypy + try: + await stream.do_handshake() + except trio.BrokenResourceError: + return + + # Create a closure to give select_subprotocol access to connection. + protocol_select_subprotocol: ( + Callable[ + [ServerProtocol, Sequence[Subprotocol]], + Subprotocol | None, + ] + | None + ) = None + if select_subprotocol is not None: + + def protocol_select_subprotocol( + protocol: ServerProtocol, + subprotocols: Sequence[Subprotocol], + ) -> Subprotocol | None: + # mypy doesn't know that select_subprotocol is immutable. + assert select_subprotocol is not None + # Ensure this function is only used in the intended context. + assert protocol is connection.protocol + return select_subprotocol(connection, subprotocols) + + # Initialize WebSocket protocol. + protocol = ServerProtocol( + origins=origins, + extensions=extensions, + subprotocols=subprotocols, + select_subprotocol=protocol_select_subprotocol, + max_size=max_size, + logger=logger, + ) + + # Initialize WebSocket connection. + assert create_connection is not None # help mypy + connection = create_connection( + nursery, + stream, + protocol, + server, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + close_timeout=close_timeout, + max_queue=max_queue, + ) + + try: + await connection.handshake( + process_request, + process_response, + server_header, + ) + except trio.Cancelled: + # The nursery running this coroutine was canceled. + # The next checkpoint raises trio.Cancelled. + # aclose_forcefully() never returns. + await trio.aclose_forcefully(stream) + raise AssertionError("nursery should be canceled") + except Exception: + connection.logger.error( + "opening handshake failed", exc_info=True + ) + await trio.aclose_forcefully(stream) + return + + if connection.protocol.state is not OPEN: + # process_request or process_response rejected the handshake. + await connection.close_stream() + return + + try: + server.waiters[connection] = trio.Event() + connection.start_keepalive() + await handler(connection) + except Exception: + connection.logger.error("connection handler failed", exc_info=True) + await connection.aclose(CloseCode.INTERNAL_ERROR) + else: + await connection.aclose() + finally: + server.waiters.pop(connection).set() + + except Exception: # pragma: no cover + # Don't leak connections on unexpected errors. + await trio.aclose_forcefully(stream) + + server = Server(open_listeners, stream_handler, logger) + await server.serve_forever(task_status=task_status) diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index a83074ae8..eff026230 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -75,7 +75,7 @@ async def test_existing_socket(self): """Client connects using a pre-existing socket.""" async with serve(*args) as server: with socket.create_connection(get_host_port(server)) as sock: - # Use a non-existing domain to ensure we connect to sock. + # Use a non-existing domain to ensure we connect via sock. async with connect("ws://invalid/", sock=sock) as client: self.assertEqual(client.protocol.state.name, "OPEN") @@ -341,7 +341,7 @@ def redirect(connection, request): async with serve(*args, process_request=redirect) as server: with socket.create_connection(get_host_port(server)) as sock: with self.assertRaises(ValueError) as raised: - # Use a non-existing domain to ensure we connect to sock. + # Use a non-existing domain to ensure we connect via sock. async with connect("ws://invalid/redirect", sock=sock): self.fail("did not raise") @@ -446,9 +446,11 @@ async def test_junk_handshake(self): """Client closes the connection when receiving non-HTTP response from server.""" async def junk(reader, writer): - await asyncio.sleep(MS) # wait for the client to send the handshake request + # Wait for the client to send the handshake request. + await asyncio.sleep(MS) writer.write(b"220 smtp.invalid ESMTP Postfix\r\n") - await reader.read(4096) # wait for the client to close the connection + # Wait for the client to close the connection. + await reader.read(4096) writer.close() server = await asyncio.start_server(junk, "localhost", 0) @@ -652,7 +654,7 @@ async def test_ignore_proxy_with_existing_socket(self): """Client connects using a pre-existing socket.""" async with serve(*args) as server: with socket.create_connection(get_host_port(server)) as sock: - # Use a non-existing domain to ensure we connect to sock. + # Use a non-existing domain to ensure we connect via sock. async with connect("ws://invalid/", sock=sock) as client: self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(0) @@ -1000,3 +1002,16 @@ async def test_unsupported_compression(self): str(raised.exception), "unsupported compression: False", ) + + async def test_reentrancy(self): + """Client isn't reentrant.""" + async with serve(*args) as server: + connecter = connect(get_uri(server)) + async with connecter: + with self.assertRaises(RuntimeError) as raised: + async with connecter: + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "connect() isn't reentrant", + ) diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index 00dcb3010..fe225067c 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -568,7 +568,7 @@ async def test_connection(self): async with serve(*args, ssl=SERVER_CONTEXT) as server: async with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: await self.assertEval(client, "ws.protocol.state.name", "OPEN") - await self.assertEval(client, SSL_OBJECT + ".version()[:3]", "TLS") + await self.assertEval(client, f"{SSL_OBJECT}.version()[:3]", "TLS") async def test_timeout_during_tls_handshake(self): """Server times out before receiving TLS handshake request from client.""" @@ -604,7 +604,7 @@ async def test_connection(self): async with unix_serve(handler, path, ssl=SERVER_CONTEXT): async with unix_connect(path, ssl=CLIENT_CONTEXT) as client: await self.assertEval(client, "ws.protocol.state.name", "OPEN") - await self.assertEval(client, SSL_OBJECT + ".version()[:3]", "TLS") + await self.assertEval(client, f"{SSL_OBJECT}.version()[:3]", "TLS") class ServerUsageErrorsTests(unittest.IsolatedAsyncioTestCase): diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 415343911..cc5949c93 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -44,7 +44,7 @@ def test_existing_socket(self): """Client connects using a pre-existing socket.""" with run_server() as server: with socket.create_connection(server.socket.getsockname()) as sock: - # Use a non-existing domain to ensure we connect to sock. + # Use a non-existing domain to ensure we connect via sock. with connect("ws://invalid/", sock=sock) as client: self.assertEqual(client.protocol.state.name, "OPEN") @@ -225,9 +225,11 @@ def test_junk_handshake(self): class JunkHandler(socketserver.BaseRequestHandler): def handle(self): - time.sleep(MS) # wait for the client to send the handshake request + # Wait for the client to send the handshake request. + time.sleep(MS) self.request.send(b"220 smtp.invalid ESMTP Postfix\r\n") - self.request.recv(4096) # wait for the client to close the connection + # Wait for the client to close the connection. + self.request.recv(4096) self.request.close() server = socketserver.TCPServer(("localhost", 0), JunkHandler) @@ -401,7 +403,7 @@ def test_ignore_proxy_with_existing_socket(self): """Client connects using a pre-existing socket.""" with run_server() as server: with socket.create_connection(server.socket.getsockname()) as sock: - # Use a non-existing domain to ensure we connect to sock. + # Use a non-existing domain to ensure we connect via sock. with connect("ws://invalid/", sock=sock) as client: self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(0) @@ -648,7 +650,7 @@ def test_proxy_ssl_without_https_proxy(self): connect( "ws://localhost/", proxy="http://localhost:8080", - proxy_ssl=True, + proxy_ssl=CLIENT_CONTEXT, ) self.assertEqual( str(raised.exception), diff --git a/tests/test_localhost.cnf b/tests/test_localhost.cnf index 4069e3967..15d49228c 100644 --- a/tests/test_localhost.cnf +++ b/tests/test_localhost.cnf @@ -24,4 +24,5 @@ subjectAltName = @san DNS.1 = localhost DNS.2 = overridden IP.3 = 127.0.0.1 -IP.4 = ::1 +IP.4 = 0.0.0.0 +IP.5 = ::1 diff --git a/tests/test_localhost.pem b/tests/test_localhost.pem index 8df63ec8f..1f26df715 100644 --- a/tests/test_localhost.pem +++ b/tests/test_localhost.pem @@ -1,48 +1,49 @@ -----BEGIN PRIVATE KEY----- -MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDYOOQyq8yYtn5x -K3yRborFxTFse16JIVb4x/ZhZgGm49eARCi09fmczQxJdQpHz81Ij6z0xi7AUYH7 -9wS8T0Lh3uGFDDS1GzITUVPIqSUi0xim2T6XPzXFVQYI1D/OjUxlHm+3/up+WwbL -sBgBO/lDmzoa3ZN7kt9HQoGc/14oQz1Qsv1QTDQs69r+o7mmBJr/hf/g7S0Csyy3 -iC6aaq+yCUyzDbjXceTI7WJqbTGNnK0/DjdFD/SJS/uSDNEg0AH53eqcCSjm+Ei/ -UF8qR5Pu4sSsNwToOW2MVgjtHFazc+kG3rzD6+3Dp+t6x6uI/npyuudOMCmOtd6z -kX0UPQaNAgMBAAECggEAS4eMBztGC+5rusKTEAZKSY15l0h9HG/d/qdzJFDKsO6T -/8VPZu8pk6F48kwFHFK1hexSYWq9OAcA3fBK4jDZzybZJm2+F6l5U5AsMUMMqt6M -lPP8Tj8RXG433muuIkvvbL82DVLpvNu1Qv+vUvcNOpWFtY7DDv6eKjlMJ3h4/pzh -89MNt26VMCYOlq1NSjuZBzFohL2u9nsFehlOpcVsqNfNfcYCq9+5yoH8fWJP90Op -hqhvqUoGLN7DRKV1f+AWHSA4nmGgvVviV5PQgMhtk5exlN7kG+rDc3LbzhefS1Sp -Tat1qIgm8fK2n+Q/obQPjHOGOGuvE5cIF7E275ZKgQKBgQDt87BqALKWnbkbQnb7 -GS1h6LRcKyZhFbxnO2qbviBWSo15LEF8jPGV33Dj+T56hqufa/rUkbZiUbIR9yOX -dnOwpAVTo+ObAwZfGfHvrnufiIbHFqJBumaYLqjRZ7AC0QtS3G+kjS9dbllrr7ok -fO4JdfKRXzBJKrkQdCn8hR22rQKBgQDon0b49Dxs1EfdSDbDode2TSwE83fI3vmR -SKUkNY8ma6CRbomVRWijhBM458wJeuhpjPZOvjNMsnDzGwrtdAp2VfFlMIDnA8ZC -fEWIAAH2QYKXKGmkoXOcWB2QbvbI154zCm6zFGtzvRKOCGmTXuhFajO8VPwOyJVt -aSJA3bLrYQKBgQDJM2/tAfAAKRdW9GlUwqI8Ep9G+/l0yANJqtTnIemH7XwYhJJO -9YJlPszfB2aMBgliQNSUHy1/jyKpzDYdITyLlPUoFwEilnkxuud2yiuf5rpH51yF -hU6wyWtXvXv3tbkEdH42PmdZcjBMPQeBSN2hxEi6ISncBDL9tau26PwJ9QKBgQCs -cNYl2reoXTzgtpWSNDk6NL769JjJWTFcF6QD0YhKjOI8rNpkw00sWc3+EybXqDr9 -c7dq6+gPZQAB1vwkxi6zRkZqIqiLl+qygnjwtkC+EhYCg7y8g8q2DUPtO7TJcb0e -TQ9+xRZad8B3dZj93A8G1hF//OfU9bB/qL3xo+bsQQKBgC/9YJvgLIWA/UziLcB2 -29Ai0nbPkN5df7z4PifUHHSlbQJHKak8UKbMP+8S064Ul0F7g8UCjZMk2LzSbaNY -XU5+2j0sIOnGUFoSlvcpdowzYrD2LN5PkKBot7AOq/v7HlcOoR8J8RGWAMpCrHsI -a/u/dlZs+/K16RcavQwx8rag +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDKiNs9JHIq5I2c +GjupVn8QJ3oi+lSpEwdUu6aw/q1H9mVzv1dFtp7hT8kuhclNf1tlBBFiB+NWbRZc +uyBRq+mIIWfepcHRHpquxyopesD+CdeC0rogq3vry94FJNmN8257WZiraNl3v9ht +eBqTy0xYDsDtl8iYLfT4xPDfJVOMq0R6SQEljWi6jSbR3b74wiLpXoWjvx7KJahH +hd/p48meuq95tGfxDEb7r/h02RpZF5rq2zRqBOcO4nL5drWYBh1I4+RFp+AbCixX +MqWh1e0vl/wXiKwYTPIgqH2DIXxS3m8dn4O74zO0ktRqPkIXMyKAZQkdUNLngE7v +pNeDcQatAgMBAAECggEACRc/WtZvBt7YYu9IgP0btWBF9hoa0yOwA8P97FpQ8YkI +rpa0bVZrnjz2fkZNdwodLd43YBlKZe1ZbhxD1S1+uuYEY3TvpvWC7A78pPz86IEN +TPu/Jt1AMeo4d5vtLoS7fSYLBwl2H7OI03Y0ROeS8FJXfrKixdp2OmLmVcOAXDDj +Eq0Xs2tSXXPVZ8KKGMidKqvfxcVAhOZvJfHvkMJ+tS/FRAn7Qxc1tn7OTUOg+glr +sHdMwImfzDCbyhP5gZXL/MP35UqnKUBAGdJmfp3BkFxk0yGLhlCOefs1/a9PhVOt +Q83+kjWnuYeP3R4jB7fuWtEu0/gPZT/P1iJF4MIhjwKBgQDqPtT+7G7KMThGjdm6 +bu77VDsW10T5uDU55G3LvXHoFTZUnleSOtWrh2mdR3KVj5PdHDR4VSuA0d65S39n +LYVul82FMgjCWKL4odgssPcLD6SsybdF9xXSXJKtQ96eJjW0o7vMu0/CHrhF6whA +EmCeDcD81Bzvj8DbkSyHpIaolwKBgQDdWBn43eVBt8FStAXx3J49pMyw83AXyqNA +3taHTGjG9BnjgsRgQeYmZG82xpD/Yu6dYyzF+rI4iODkSzF1FN+j64ElDRJbAMvS +yThbAKAb+xegh0EQm43+kYG1sDavWT4pvzh6DCltN82eHwJ5utDuneiAB66DeAqY +ttXmw+fPWwKBgHYEoBWsE4mlUMAjWc5Xc+qGnpq8bNEQISkA0Ny0nv4aKdxqRp6z +K9IXEHwgcjeuNgZR3pG9/4QQuRFMW20lfzOgIfj4o3cfZ0SzbhHeOymEgShZHRCQ +E5t/7pqDNlch0y8my0i0GtQn3BnF98soNyuKrG/1gnqkR7uYIgJZP0sTAoGAGHLt +0353H04zzXXTHkcXN4nnjjgljos0gyraGXHINQmrfmToWhWNXXpEipFeXMdJwhq9 +TFUHsJT1+mGP4fXfShTuW/BYsbKh0POnBO5JwS14C6RE/JeiFJdv82i2caHy6tuT +Wm/Td5vtW2Tjehy3jVPl5ZZzoVP2H646bFYBWfcCgYEAkWJLFzvXsF9SW9Ku6cc0 +7Yhuoolad/AWCXe5Q3+k+icgOQFnMsOkuEPIlRHPgjaOnXMq76VyO4a66vK+ucgr +R3O8/h5QZiuxE3dfqXsDrGr/6W2kmDWWXXK9r5oJQ1J4ndj65ZaGcAuw/77hf5K8 +PnN3beykcf5xxuaPNpq0cbg= -----END PRIVATE KEY----- -----BEGIN CERTIFICATE----- -MIIDWTCCAkGgAwIBAgIJAOL9UKiOOxupMA0GCSqGSIb3DQEBCwUAMEwxCzAJBgNV -BAYTAkZSMQ4wDAYDVQQHDAVQYXJpczEZMBcGA1UECgwQQXltZXJpYyBBdWd1c3Rp -bjESMBAGA1UEAwwJbG9jYWxob3N0MCAXDTIyMTAxNTE5Mjg0MVoYDzIwNjQxMDE0 -MTkyODQxWjBMMQswCQYDVQQGEwJGUjEOMAwGA1UEBwwFUGFyaXMxGTAXBgNVBAoM -EEF5bWVyaWMgQXVndXN0aW4xEjAQBgNVBAMMCWxvY2FsaG9zdDCCASIwDQYJKoZI -hvcNAQEBBQADggEPADCCAQoCggEBANg45DKrzJi2fnErfJFuisXFMWx7XokhVvjH -9mFmAabj14BEKLT1+ZzNDEl1CkfPzUiPrPTGLsBRgfv3BLxPQuHe4YUMNLUbMhNR -U8ipJSLTGKbZPpc/NcVVBgjUP86NTGUeb7f+6n5bBsuwGAE7+UObOhrdk3uS30dC -gZz/XihDPVCy/VBMNCzr2v6juaYEmv+F/+DtLQKzLLeILppqr7IJTLMNuNdx5Mjt -YmptMY2crT8ON0UP9IlL+5IM0SDQAfnd6pwJKOb4SL9QXypHk+7ixKw3BOg5bYxW -CO0cVrNz6QbevMPr7cOn63rHq4j+enK6504wKY613rORfRQ9Bo0CAwEAAaM8MDow -OAYDVR0RBDEwL4IJbG9jYWxob3N0ggpvdmVycmlkZGVuhwR/AAABhxAAAAAAAAAA -AAAAAAAAAAABMA0GCSqGSIb3DQEBCwUAA4IBAQBPNDGDdl4wsCRlDuyCHBC8o+vW -Vb14thUw9Z6UrlsQRXLONxHOXbNAj1sYQACNwIWuNz36HXu5m8Xw/ID/bOhnIg+b -Y6l/JU/kZQYB7SV1aR3ZdbCK0gjfkE0POBHuKOjUFIOPBCtJ4tIBUX94zlgJrR9v -2rqJC3TIYrR7pVQumHZsI5GZEMpM5NxfreWwxcgltgxmGdm7elcizHfz7k5+szwh -4eZ/rxK9bw1q8BIvVBWelRvUR55mIrCjzfZp5ZObSYQTZlW7PzXBe5Jk+1w31YHM -RSBA2EpPhYlGNqPidi7bg7rnQcsc6+hE0OqzTL/hWxPm9Vbp9dj3HFTik1wa +MIIDiTCCAnGgAwIBAgIURQDnIfsMPAhuq9Uq1dka01Qoc9IwDQYJKoZIhvcNAQEL +BQAwTDELMAkGA1UEBhMCRlIxDjAMBgNVBAcMBVBhcmlzMRkwFwYDVQQKDBBBeW1l +cmljIEF1Z3VzdGluMRIwEAYDVQQDDAlsb2NhbGhvc3QwIBcNMjUwNTMxMjAxMDU1 +WhgPMjA2NzA1MzEyMDEwNTVaMEwxCzAJBgNVBAYTAkZSMQ4wDAYDVQQHDAVQYXJp +czEZMBcGA1UECgwQQXltZXJpYyBBdWd1c3RpbjESMBAGA1UEAwwJbG9jYWxob3N0 +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAyojbPSRyKuSNnBo7qVZ/ +ECd6IvpUqRMHVLumsP6tR/Zlc79XRbae4U/JLoXJTX9bZQQRYgfjVm0WXLsgUavp +iCFn3qXB0R6arscqKXrA/gnXgtK6IKt768veBSTZjfNue1mYq2jZd7/YbXgak8tM +WA7A7ZfImC30+MTw3yVTjKtEekkBJY1ouo0m0d2++MIi6V6Fo78eyiWoR4Xf6ePJ +nrqvebRn8QxG+6/4dNkaWRea6ts0agTnDuJy+Xa1mAYdSOPkRafgGwosVzKlodXt +L5f8F4isGEzyIKh9gyF8Ut5vHZ+Du+MztJLUaj5CFzMigGUJHVDS54BO76TXg3EG +rQIDAQABo2EwXzA+BgNVHREENzA1gglsb2NhbGhvc3SCCm92ZXJyaWRkZW6HBH8A +AAGHBAAAAACHEAAAAAAAAAAAAAAAAAAAAAEwHQYDVR0OBBYEFB7eswhXVVmG32UR +MGtc2vewZjM0MA0GCSqGSIb3DQEBCwUAA4IBAQBt9KGnnrtn15H9wz4fWHzPTGaO +laJQE5RnqlzyQ3aDLRtZIc/OA+0L6rW7+xiiN0v1irqCD/M0YGYGomy//3J444bT +SxciJQarZPtNRaLJx17geQOwbY5NpTsfEKmvhwCnMLx9Wy6kyHx0NyD3e1MJwH47 +QdJDmKCVF2R10AKGlnsp6zYaoOvoY48MvCBOnaZEVXPypta0N3XXrASsllw5QJSb +XXPIdNbwA22necSoa7PchMXIbyDXIhygf+tXVBAKvNaSNCzQPehTmepENYJPFEh/ +NJrYPB769uRPgZxIvivo1QjNik4ywcZlvEU6LC6JPUasUcGY6FTnipLL6lD0 -----END CERTIFICATE----- diff --git a/tests/trio/server.py b/tests/trio/server.py new file mode 100644 index 000000000..d2172af21 --- /dev/null +++ b/tests/trio/server.py @@ -0,0 +1,63 @@ +import contextlib +import functools +import socket +import urllib.parse + +import trio + +from websockets.trio.server import * + + +def get_host_port(listeners): + for listener in listeners: + if listener.socket.family == socket.AF_INET: # pragma: no branch + return listener.socket.getsockname() + raise AssertionError("expected at least one IPv4 socket") + + +def get_uri(server, secure=False): + protocol = "wss" if secure else "ws" + host, port = get_host_port(server.listeners) + return f"{protocol}://{host}:{port}" + + +async def handler(ws): + path = urllib.parse.urlparse(ws.request.path).path + if path == "/": + # The default path is an eval shell. + async for expr in ws: + value = eval(expr) + await ws.send(str(value)) + elif path == "/crash": + raise RuntimeError + elif path == "/no-op": + pass + elif path == "/delay": + delay = float(await ws.recv()) + await ws.aclose() + await trio.sleep(delay) + else: + raise AssertionError(f"unexpected path: {path}") + + +kwargs = {"handler": handler, "port": 0, "host": "localhost"} + + +@contextlib.asynccontextmanager +async def run_server(**overrides): + merged_kwargs = {**kwargs, **overrides} + async with trio.open_nursery() as nursery: + server = await nursery.start(functools.partial(serve, **merged_kwargs)) + try: + yield server + finally: + # Run all tasks to guarantee that any exceptions are raised. + # Otherwise, canceling the nursery could hide errors. + await trio.testing.wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + + +class EvalShellMixin: + async def assertEval(self, client, expr, value): + await client.send(expr) + self.assertEqual(await client.recv(), value) diff --git a/tests/trio/test_client.py b/tests/trio/test_client.py new file mode 100644 index 000000000..7448b5fd1 --- /dev/null +++ b/tests/trio/test_client.py @@ -0,0 +1,927 @@ +import contextlib +import http +import logging +import os +import socket +import ssl +import sys +import unittest +from unittest.mock import patch + +import trio + +from websockets.client import backoff +from websockets.exceptions import ( + InvalidHandshake, + InvalidMessage, + InvalidProxy, + InvalidProxyMessage, + InvalidStatus, + InvalidURI, + ProxyError, + SecurityError, +) +from websockets.extensions.permessage_deflate import PerMessageDeflate +from websockets.trio.client import * + +from ..proxy import ProxyMixin +from ..utils import CLIENT_CONTEXT, MS, SERVER_CONTEXT +from .server import get_host_port, get_uri, run_server +from .utils import IsolatedTrioTestCase + + +@contextlib.asynccontextmanager +async def short_backoff_delay(): + defaults = backoff.__defaults__ + backoff.__defaults__ = ( + defaults[0] * MS, + defaults[1] * MS, + defaults[2] * MS, + defaults[3], + ) + try: + yield + finally: + backoff.__defaults__ = defaults + + +@contextlib.asynccontextmanager +async def few_redirects(): + from websockets.trio import client + + max_redirects = client.MAX_REDIRECTS + client.MAX_REDIRECTS = 2 + try: + yield + finally: + client.MAX_REDIRECTS = max_redirects + + +class ClientTests(IsolatedTrioTestCase): + async def test_connection(self): + """Client connects to server.""" + async with run_server() as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + + async def test_explicit_host_port(self): + """Client connects using an explicit host / port.""" + async with run_server() as server: + host, port = get_host_port(server.listeners) + async with connect("ws://overridden/", host=host, port=port) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + + async def test_existing_stream(self): + """Client connects using a pre-existing stream.""" + async with run_server() as server: + stream = await trio.open_tcp_stream(*get_host_port(server.listeners)) + # Use a non-existing domain to ensure we connect via stream. + async with connect("ws://invalid/", stream=stream) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + + async def test_compression_is_enabled(self): + """Client enables compression by default.""" + async with run_server() as server: + async with connect(get_uri(server)) as client: + self.assertEqual( + [type(ext) for ext in client.protocol.extensions], + [PerMessageDeflate], + ) + + async def test_disable_compression(self): + """Client disables compression.""" + async with run_server() as server: + async with connect(get_uri(server), compression=None) as client: + self.assertEqual(client.protocol.extensions, []) + + async def test_additional_headers(self): + """Client can set additional headers with additional_headers.""" + async with run_server() as server: + async with connect( + get_uri(server), additional_headers={"Authorization": "Bearer ..."} + ) as client: + self.assertEqual(client.request.headers["Authorization"], "Bearer ...") + + async def test_override_user_agent(self): + """Client can override User-Agent header with user_agent_header.""" + async with run_server() as server: + async with connect(get_uri(server), user_agent_header="Smith") as client: + self.assertEqual(client.request.headers["User-Agent"], "Smith") + + async def test_remove_user_agent(self): + """Client can remove User-Agent header with user_agent_header.""" + async with run_server() as server: + async with connect(get_uri(server), user_agent_header=None) as client: + self.assertNotIn("User-Agent", client.request.headers) + + async def test_legacy_user_agent(self): + """Client can override User-Agent header with additional_headers.""" + async with run_server() as server: + async with connect( + get_uri(server), additional_headers={"User-Agent": "Smith"} + ) as client: + self.assertEqual(client.request.headers["User-Agent"], "Smith") + + async def test_keepalive_is_enabled(self): + """Client enables keepalive and measures latency by default.""" + async with run_server() as server: + async with connect(get_uri(server), ping_interval=MS) as client: + self.assertEqual(client.latency, 0) + await trio.sleep(2 * MS) + self.assertGreater(client.latency, 0) + + async def test_disable_keepalive(self): + """Client disables keepalive.""" + async with run_server() as server: + async with connect(get_uri(server), ping_interval=None) as client: + await trio.sleep(2 * MS) + self.assertEqual(client.latency, 0) + + async def test_logger(self): + """Client accepts a logger argument.""" + logger = logging.getLogger("test") + async with run_server() as server: + async with connect(get_uri(server), logger=logger) as client: + self.assertEqual(client.logger.name, logger.name) + + async def test_custom_connection_factory(self): + """Client runs ClientConnection factory provided in create_connection.""" + + def create_connection(*args, **kwargs): + client = ClientConnection(*args, **kwargs) + client.create_connection_ran = True + return client + + async with run_server() as server: + async with connect( + get_uri(server), create_connection=create_connection + ) as client: + self.assertTrue(client.create_connection_ran) + + @short_backoff_delay() + async def test_reconnect(self): + """Client reconnects to server.""" + iterations = 0 + successful = 0 + + async def process_request(connection, request): + nonlocal iterations + iterations += 1 + # Retriable errors + if iterations == 1: + await trio.sleep(3 * MS) + elif iterations == 2: + await connection.stream.aclose() + elif iterations == 3: + return connection.respond(http.HTTPStatus.SERVICE_UNAVAILABLE, "🚒") + # Fatal error + elif iterations == 6: + return connection.respond(http.HTTPStatus.PAYMENT_REQUIRED, "💸") + + async with run_server(process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async for client in connect(get_uri(server), open_timeout=3 * MS): + self.assertEqual(client.protocol.state.name, "OPEN") + successful += 1 + + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 402", + ) + self.assertEqual(iterations, 6) + self.assertEqual(successful, 2) + + @short_backoff_delay() + async def test_reconnect_with_custom_process_exception(self): + """Client runs process_exception to tell if errors are retryable or fatal.""" + iteration = 0 + + def process_request(connection, request): + nonlocal iteration + iteration += 1 + if iteration == 1: + return connection.respond(http.HTTPStatus.SERVICE_UNAVAILABLE, "🚒") + return connection.respond(http.HTTPStatus.IM_A_TEAPOT, "🫖") + + def process_exception(exc): + if isinstance(exc, InvalidStatus): + if 500 <= exc.response.status_code < 600: + return None + if exc.response.status_code == 418: + return Exception("🫖 💔 ☕️") + self.fail("unexpected exception") + + async with run_server(process_request=process_request) as server: + with self.assertRaises(Exception) as raised: + async for _ in connect( + get_uri(server), process_exception=process_exception + ): + self.fail("did not raise") + + self.assertEqual(iteration, 2) + self.assertEqual( + str(raised.exception), + "🫖 💔 ☕️", + ) + + @short_backoff_delay() + async def test_reconnect_with_custom_process_exception_raising_exception(self): + """Client supports raising an exception in process_exception.""" + + def process_request(connection, request): + return connection.respond(http.HTTPStatus.IM_A_TEAPOT, "🫖") + + def process_exception(exc): + if isinstance(exc, InvalidStatus) and exc.response.status_code == 418: + raise Exception("🫖 💔 ☕️") + self.fail("unexpected exception") + + async with run_server(process_request=process_request) as server: + with self.assertRaises(Exception) as raised: + async for _ in connect( + get_uri(server), process_exception=process_exception + ): + self.fail("did not raise") + + self.assertEqual( + str(raised.exception), + "🫖 💔 ☕️", + ) + + async def test_redirect(self): + """Client follows redirect.""" + + def redirect(connection, request): + if request.path == "/redirect": + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = "/" + return response + + async with run_server(process_request=redirect) as server: + async with connect(get_uri(server) + "/redirect") as client: + self.assertEqual(client.protocol.uri.path, "/") + + async def test_cross_origin_redirect(self): + """Client follows redirect to a secure URI on a different origin.""" + + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = get_uri(other_server) + return response + + async with run_server(process_request=redirect) as server: + async with run_server() as other_server: + async with connect(get_uri(server)): + self.assertFalse(server.connections) + self.assertTrue(other_server.connections) + + @few_redirects() + async def test_redirect_limit(self): + """Client stops following redirects after limit is reached.""" + + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = request.path + return response + + async with run_server(process_request=redirect) as server: + with self.assertRaises(SecurityError) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + + self.assertEqual( + str(raised.exception), + "more than 2 redirects", + ) + + async def test_redirect_with_explicit_host_port(self): + """Client follows redirect with an explicit host / port.""" + + def redirect(connection, request): + if request.path == "/redirect": + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = "/" + return response + + async with run_server(process_request=redirect) as server: + host, port = get_host_port(server.listeners) + async with connect( + "ws://overridden/redirect", host=host, port=port + ) as client: + self.assertEqual(client.protocol.uri.path, "/") + + async def test_cross_origin_redirect_with_explicit_host_port(self): + """Client doesn't follow cross-origin redirect with an explicit host / port.""" + + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = "ws://other/" + return response + + async with run_server(process_request=redirect) as server: + host, port = get_host_port(server.listeners) + with self.assertRaises(ValueError) as raised: + async with connect("ws://overridden/", host=host, port=port): + self.fail("did not raise") + + self.assertEqual( + str(raised.exception), + "cannot follow cross-origin redirect to ws://other/ " + "with an explicit host or port", + ) + + async def test_redirect_with_existing_stream(self): + """Client doesn't follow redirect when using a pre-existing stream.""" + + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = "/" + return response + + async with run_server(process_request=redirect) as server: + stream = await trio.open_tcp_stream(*get_host_port(server.listeners)) + with self.assertRaises(ValueError) as raised: + # Use a non-existing domain to ensure we connect via sock. + async with connect("ws://invalid/redirect", stream=stream): + self.fail("did not raise") + + self.assertEqual( + str(raised.exception), + "cannot follow redirect to ws://invalid/ with a preexisting stream", + ) + + async def test_invalid_uri(self): + """Client receives an invalid URI.""" + with self.assertRaises(InvalidURI): + async with connect("http://localhost"): # invalid scheme + self.fail("did not raise") + + async def test_tcp_connection_fails(self): + """Client fails to connect to server.""" + with self.assertRaises(OSError): + async with connect("ws://localhost:54321"): # invalid port + self.fail("did not raise") + + async def test_handshake_fails(self): + """Client connects to server but the handshake fails.""" + + def remove_accept_header(self, request, response): + del response.headers["Sec-WebSocket-Accept"] + + # The connection will be open for the server but failed for the client. + # Use a connection handler that exits immediately to avoid an exception. + async with run_server(process_response=remove_accept_header) as server: + with self.assertRaises(InvalidHandshake) as raised: + async with connect(get_uri(server) + "/no-op", close_timeout=MS): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "missing Sec-WebSocket-Accept header", + ) + + async def test_timeout_during_handshake(self): + """Client times out before receiving handshake response from server.""" + # Replace the WebSocket server with a TCP server that doesn't respond. + with socket.create_server(("localhost", 0)) as sock: + host, port = sock.getsockname() + with self.assertRaises(TimeoutError) as raised: + async with connect(f"ws://{host}:{port}", open_timeout=2 * MS): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "timed out during opening handshake", + ) + + async def test_connection_closed_during_handshake(self): + """Client reads EOF before receiving handshake response from server.""" + + async def close_connection(self, request): + await self.stream.aclose() + + async with run_server(process_request=close_connection) as server: + with self.assertRaises(InvalidMessage) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "did not receive a valid HTTP response", + ) + self.assertIsInstance(raised.exception.__cause__, EOFError) + self.assertEqual( + str(raised.exception.__cause__), + "connection closed while reading HTTP status line", + ) + + async def test_http_response(self): + """Client reads HTTP response.""" + + def http_response(connection, request): + return connection.respond(http.HTTPStatus.OK, "👌") + + async with run_server(process_request=http_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + + self.assertEqual(raised.exception.response.status_code, 200) + self.assertEqual(raised.exception.response.body.decode(), "👌") + + async def test_http_response_without_content_length(self): + """Client reads HTTP response without a Content-Length header.""" + + def http_response(connection, request): + response = connection.respond(http.HTTPStatus.OK, "👌") + del response.headers["Content-Length"] + return response + + async with run_server(process_request=http_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + + self.assertEqual(raised.exception.response.status_code, 200) + self.assertEqual(raised.exception.response.body.decode(), "👌") + + async def test_junk_handshake(self): + """Client closes the connection when receiving non-HTTP response from server.""" + + async def junk(stream): + # Wait for the client to send the handshake request. + await trio.testing.wait_all_tasks_blocked() + await stream.send_all(b"220 smtp.invalid ESMTP Postfix\r\n") + # Wait for the client to close the connection. + await stream.receive_some() + await stream.aclose() + + async with trio.open_nursery() as nursery: + try: + listeners = await nursery.start(trio.serve_tcp, junk, 0) + host, port = get_host_port(listeners) + with self.assertRaises(InvalidMessage) as raised: + async with connect(f"ws://{host}:{port}"): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "did not receive a valid HTTP response", + ) + self.assertIsInstance(raised.exception.__cause__, ValueError) + self.assertEqual( + str(raised.exception.__cause__), + "unsupported protocol; expected HTTP/1.1: " + "220 smtp.invalid ESMTP Postfix", + ) + finally: + nursery.cancel_scope.cancel() + + +class SecureClientTests(IsolatedTrioTestCase): + async def test_connection(self): + """Client connects to server securely.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + async with connect( + get_uri(server, secure=True), ssl=CLIENT_CONTEXT + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + ssl_object = client.stream._ssl_object + self.assertEqual(ssl_object.version()[:3], "TLS") + + async def test_set_server_hostname_implicitly(self): + """Client sets server_hostname to the host in the WebSocket URI.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + host, port = get_host_port(server.listeners) + async with connect( + "wss://overridden/", host=host, port=port, ssl=CLIENT_CONTEXT + ) as client: + ssl_object = client.stream._ssl_object + self.assertEqual(ssl_object.server_hostname, "overridden") + + async def test_set_server_hostname_explicitly(self): + """Client sets server_hostname to the value provided in argument.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + async with connect( + get_uri(server, secure=True), + ssl=CLIENT_CONTEXT, + server_hostname="overridden", + ) as client: + ssl_object = client.stream._ssl_object + self.assertEqual(ssl_object.server_hostname, "overridden") + + async def test_reject_invalid_server_certificate(self): + """Client rejects certificate where server certificate isn't trusted.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + with self.assertRaises(trio.BrokenResourceError) as raised: + # The test certificate is self-signed. + async with connect(get_uri(server, secure=True)): + self.fail("did not raise") + self.assertIsInstance( + raised.exception.__cause__, + ssl.SSLCertVerificationError, + ) + self.assertIn( + "certificate verify failed: self signed certificate", + str(raised.exception.__cause__).replace("-", " "), + ) + + async def test_reject_invalid_server_hostname(self): + """Client rejects certificate where server hostname doesn't match.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + with self.assertRaises(trio.BrokenResourceError) as raised: + # This hostname isn't included in the test certificate. + async with connect( + get_uri(server, secure=True), + ssl=CLIENT_CONTEXT, + server_hostname="invalid", + ): + self.fail("did not raise") + self.assertIsInstance( + raised.exception.__cause__, + ssl.SSLCertVerificationError, + ) + self.assertIn( + "certificate verify failed: Hostname mismatch", + str(raised.exception.__cause__), + ) + + async def test_cross_origin_redirect(self): + """Client follows redirect to a secure URI on a different origin.""" + + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = get_uri(other_server, secure=True) + return response + + async with run_server(ssl=SERVER_CONTEXT, process_request=redirect) as server: + async with run_server(ssl=SERVER_CONTEXT) as other_server: + async with connect(get_uri(server, secure=True), ssl=CLIENT_CONTEXT): + self.assertFalse(server.connections) + self.assertTrue(other_server.connections) + + async def test_redirect_to_insecure_uri(self): + """Client doesn't follow redirect from secure URI to non-secure URI.""" + + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = insecure_uri + return response + + async with run_server(ssl=SERVER_CONTEXT, process_request=redirect) as server: + with self.assertRaises(SecurityError) as raised: + secure_uri = get_uri(server, secure=True) + insecure_uri = secure_uri.replace("wss://", "ws://") + async with connect(secure_uri, ssl=CLIENT_CONTEXT): + self.fail("did not raise") + + self.assertEqual( + str(raised.exception), + f"cannot follow redirect to non-secure URI {insecure_uri}", + ) + + +@unittest.skipUnless("mitmproxy" in sys.modules, "mitmproxy not installed") +class SocksProxyClientTests(ProxyMixin, IsolatedTrioTestCase): + proxy_mode = "socks5@51080" + + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) + async def test_socks_proxy(self): + """Client connects to server through a SOCKS5 proxy.""" + async with run_server() as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) + + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) + async def test_secure_socks_proxy(self): + """Client connects to server securely through a SOCKS5 proxy.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + async with connect( + get_uri(server, secure=True), ssl=CLIENT_CONTEXT + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) + + @patch.dict(os.environ, {"socks_proxy": "http://hello:iloveyou@localhost:51080"}) + async def test_authenticated_socks_proxy(self): + """Client connects to server through an authenticated SOCKS5 proxy.""" + try: + self.proxy_options.update(proxyauth="hello:iloveyou") + async with run_server() as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + finally: + self.proxy_options.update(proxyauth=None) + self.assertNumFlows(1) + + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) + async def test_authenticated_socks_proxy_error(self): + """Client fails to authenticate to the SOCKS5 proxy.""" + from python_socks import ProxyError as SocksProxyError + + try: + self.proxy_options.update(proxyauth="any") + with self.assertRaises(ProxyError) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") + finally: + self.proxy_options.update(proxyauth=None) + self.assertEqual( + str(raised.exception), + "failed to connect to SOCKS proxy", + ) + self.assertIsInstance(raised.exception.__cause__, SocksProxyError) + self.assertNumFlows(0) + + @patch.dict(os.environ, {"socks_proxy": "http://localhost:61080"}) # bad port + async def test_socks_proxy_connection_failure(self): + """Client fails to connect to the SOCKS5 proxy.""" + from python_socks import ProxyConnectionError as SocksProxyConnectionError + + with self.assertRaises(OSError) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") + # Don't test str(raised.exception) because we don't control it. + self.assertIsInstance(raised.exception, SocksProxyConnectionError) + self.assertNumFlows(0) + + async def test_socks_proxy_connection_timeout(self): + """Client times out while connecting to the SOCKS5 proxy.""" + # Replace the proxy with a TCP server that doesn't respond. + with socket.create_server(("localhost", 0)) as sock: + host, port = sock.getsockname() + with patch.dict(os.environ, {"socks_proxy": f"http://{host}:{port}"}): + with self.assertRaises(TimeoutError) as raised: + async with connect("ws://example.com/", open_timeout=2 * MS): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "timed out during opening handshake", + ) + self.assertNumFlows(0) + + async def test_explicit_socks_proxy(self): + """Client connects to server through a SOCKS5 proxy set explicitly.""" + async with run_server() as server: + async with connect( + get_uri(server), + # Take this opportunity to test socks5 instead of socks5h. + proxy="socks5://localhost:51080", + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) + + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) + async def test_ignore_proxy_with_existing_stream(self): + """Cli ent connects using a pre-existing stream.""" + async with run_server() as server: + stream = await trio.open_tcp_stream(*get_host_port(server.listeners)) + # Use a non-existing domain to ensure we connect via stream. + async with connect("ws://invalid/", stream=stream) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(0) + + +@unittest.skipUnless("mitmproxy" in sys.modules, "mitmproxy not installed") +class HTTPProxyClientTests(ProxyMixin, IsolatedTrioTestCase): + proxy_mode = "regular@58080" + + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) + async def test_http_proxy(self): + """Client connects to server through an HTTP proxy.""" + async with run_server() as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) + + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) + async def test_secure_http_proxy(self): + """Client connects to server securely through an HTTP proxy.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + async with connect( + get_uri(server, secure=True), ssl=CLIENT_CONTEXT + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + ssl_object = client.stream._ssl_object + self.assertEqual(ssl_object.version()[:3], "TLS") + self.assertNumFlows(1) + + @patch.dict(os.environ, {"https_proxy": "http://hello:iloveyou@localhost:58080"}) + async def test_authenticated_http_proxy(self): + """Client connects to server through an authenticated HTTP proxy.""" + try: + self.proxy_options.update(proxyauth="hello:iloveyou") + async with run_server() as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + finally: + self.proxy_options.update(proxyauth=None) + self.assertNumFlows(1) + + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) + async def test_authenticated_http_proxy_error(self): + """Client fails to authenticate to the HTTP proxy.""" + try: + self.proxy_options.update(proxyauth="any") + with self.assertRaises(ProxyError) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") + finally: + self.proxy_options.update(proxyauth=None) + self.assertEqual( + str(raised.exception), + "proxy rejected connection: HTTP 407", + ) + self.assertNumFlows(0) + + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) + async def test_http_proxy_protocol_error(self): + """Client receives invalid data when connecting to the HTTP proxy.""" + try: + self.proxy_options.update(break_http_connect=True) + with self.assertRaises(InvalidProxyMessage) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") + finally: + self.proxy_options.update(break_http_connect=False) + self.assertEqual( + str(raised.exception), + "did not receive a valid HTTP response from proxy", + ) + self.assertNumFlows(0) + + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) + async def test_http_proxy_connection_error(self): + """Client receives no response when connecting to the HTTP proxy.""" + try: + self.proxy_options.update(close_http_connect=True) + with self.assertRaises(InvalidProxyMessage) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") + finally: + self.proxy_options.update(close_http_connect=False) + self.assertEqual( + str(raised.exception), + "did not receive a valid HTTP response from proxy", + ) + self.assertNumFlows(0) + + @patch.dict(os.environ, {"https_proxy": "http://localhost:48080"}) # bad port + async def test_http_proxy_connection_failure(self): + """Client fails to connect to the HTTP proxy.""" + with self.assertRaises(OSError): + async with connect("ws://example.com/"): + self.fail("did not raise") + # Don't test str(raised.exception) because we don't control it. + self.assertNumFlows(0) + + async def test_http_proxy_connection_timeout(self): + """Client times out while connecting to the HTTP proxy.""" + # Replace the proxy with a TCP server that doesn't respond. + with socket.create_server(("localhost", 0)) as sock: + host, port = sock.getsockname() + with patch.dict(os.environ, {"https_proxy": f"http://{host}:{port}"}): + with self.assertRaises(TimeoutError) as raised: + async with connect("ws://example.com/", open_timeout=2 * MS): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "timed out during opening handshake", + ) + + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) + async def test_https_proxy(self): + """Client connects to server through an HTTPS proxy.""" + async with run_server() as server: + async with connect( + get_uri(server), + proxy_ssl=self.proxy_context, + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) + + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) + async def test_secure_https_proxy(self): + """Client connects to server securely through an HTTPS proxy.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + async with connect( + get_uri(server, secure=True), + ssl=CLIENT_CONTEXT, + proxy_ssl=self.proxy_context, + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + ssl_object = client.stream._ssl_object + self.assertEqual(ssl_object.version()[:3], "TLS") + self.assertNumFlows(1) + + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) + async def test_https_server_hostname(self): + """Client sets server_hostname to the value of proxy_server_hostname.""" + async with run_server() as server: + # Pass an argument not prefixed with proxy_ for coverage. + kwargs = {"all_errors": True} if sys.version_info >= (3, 12) else {} + async with connect( + get_uri(server), + proxy_ssl=self.proxy_context, + proxy_server_hostname="overridden", + **kwargs, + ) as client: + ssl_object = client.stream._ssl_object + self.assertEqual(ssl_object.server_hostname, "overridden") + self.assertNumFlows(1) + + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) + async def test_https_proxy_invalid_proxy_certificate(self): + """Client rejects certificate when proxy certificate isn't trusted.""" + with self.assertRaises(trio.BrokenResourceError) as raised: + # The proxy certificate isn't trusted. + async with connect("wss://example.com/"): + self.fail("did not raise") + self.assertIsInstance(raised.exception.__cause__, ssl.SSLCertVerificationError) + self.assertIn( + "certificate verify failed: unable to get local issuer certificate", + str(raised.exception.__cause__), + ) + + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) + async def test_https_proxy_invalid_server_certificate(self): + """Client rejects certificate when proxy certificate isn't trusted.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + with self.assertRaises(trio.BrokenResourceError) as raised: + # The test certificate is self-signed. + async with connect( + get_uri(server, secure=True), proxy_ssl=self.proxy_context + ): + self.fail("did not raise") + self.assertIsInstance(raised.exception.__cause__, ssl.SSLCertVerificationError) + self.assertIn( + "certificate verify failed: self signed certificate", + str(raised.exception.__cause__).replace("-", " "), + ) + self.assertNumFlows(1) + + +class ClientUsageErrorsTests(IsolatedTrioTestCase): + async def test_ssl_without_secure_uri(self): + """Client rejects ssl when URI isn't secure.""" + with self.assertRaises(ValueError) as raised: + async with connect("ws://localhost/", ssl=CLIENT_CONTEXT): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "ssl argument is incompatible with a ws:// URI", + ) + + async def test_proxy_ssl_without_https_proxy(self): + """Client rejects proxy_ssl when proxy isn't HTTPS.""" + with self.assertRaises(ValueError) as raised: + async with connect( + "ws://localhost/", + proxy="http://localhost:8080", + proxy_ssl=True, + ): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "proxy_ssl argument is incompatible with an http:// proxy", + ) + + async def test_unsupported_proxy(self): + """Client rejects unsupported proxy.""" + with self.assertRaises(InvalidProxy) as raised: + async with connect("ws://example.com/", proxy="other://localhost:51080"): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "other://localhost:51080 isn't a valid proxy: scheme other isn't supported", + ) + + async def test_invalid_subprotocol(self): + """Client rejects single value of subprotocols.""" + with self.assertRaises(TypeError) as raised: + async with connect("ws://localhost/", subprotocols="chat"): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "subprotocols must be a list, not a str", + ) + + async def test_unsupported_compression(self): + """Client rejects incorrect value of compression.""" + with self.assertRaises(ValueError) as raised: + async with connect("ws://localhost/", compression=False): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "unsupported compression: False", + ) + + async def test_reentrancy(self): + """Client isn't reentrant.""" + async with run_server() as server: + connecter = connect(get_uri(server)) + async with connecter: + with self.assertRaises(RuntimeError) as raised: + async with connecter: + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "connect() isn't reentrant", + ) diff --git a/tests/trio/test_server.py b/tests/trio/test_server.py new file mode 100644 index 000000000..12dcafc7d --- /dev/null +++ b/tests/trio/test_server.py @@ -0,0 +1,831 @@ +import dataclasses +import hmac +import http +import logging + +import trio + +from websockets.exceptions import ( + ConnectionClosedError, + ConnectionClosedOK, + InvalidStatus, + NegotiationError, +) +from websockets.http11 import Request, Response +from websockets.trio.client import connect +from websockets.trio.server import * + +from ..utils import ( + CLIENT_CONTEXT, + MS, + SERVER_CONTEXT, +) +from .server import ( + EvalShellMixin, + get_host_port, + get_uri, + handler, + run_server, +) +from .utils import IsolatedTrioTestCase + + +class ServerTests(EvalShellMixin, IsolatedTrioTestCase): + async def test_connection(self): + """Server receives connection from client and the handshake succeeds.""" + async with run_server() as server: + async with connect(get_uri(server)) as client: + await self.assertEval(client, "ws.protocol.state.name", "OPEN") + + async def test_connection_handler_returns(self): + """Connection handler returns.""" + async with run_server() as server: + async with connect(get_uri(server) + "/no-op") as client: + with self.assertRaises(ConnectionClosedOK) as raised: + await client.recv() + self.assertEqual( + str(raised.exception), + "received 1000 (OK); then sent 1000 (OK)", + ) + + async def test_connection_handler_raises_exception(self): + """Connection handler raises an exception.""" + async with run_server() as server: + async with connect(get_uri(server) + "/crash") as client: + with self.assertRaises(ConnectionClosedError) as raised: + await client.recv() + self.assertEqual( + str(raised.exception), + "received 1011 (internal error); then sent 1011 (internal error)", + ) + + async def test_existing_listeners(self): + """Server receives connection using pre-existing listeners.""" + listeners = await trio.open_tcp_listeners(0, host="localhost") + host, port = get_host_port(listeners) + async with run_server(port=None, host=None, listeners=listeners): + async with connect(f"ws://{host}:{port}/") as client: # type: ignore + await self.assertEval(client, "ws.protocol.state.name", "OPEN") + + async def test_select_subprotocol(self): + """Server selects a subprotocol with the select_subprotocol callable.""" + + def select_subprotocol(ws, subprotocols): + ws.select_subprotocol_ran = True + assert "chat" in subprotocols + return "chat" + + async with run_server( + subprotocols=["chat"], select_subprotocol=select_subprotocol + ) as server: + async with connect(get_uri(server), subprotocols=["chat"]) as client: + await self.assertEval(client, "ws.select_subprotocol_ran", "True") + await self.assertEval(client, "ws.subprotocol", "chat") + + async def test_select_subprotocol_rejects_handshake(self): + """Server rejects handshake if select_subprotocol raises NegotiationError.""" + + def select_subprotocol(ws, subprotocols): + raise NegotiationError + + async with run_server(select_subprotocol=select_subprotocol) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 400", + ) + + async def test_select_subprotocol_raises_exception(self): + """Server returns an error if select_subprotocol raises an exception.""" + + def select_subprotocol(ws, subprotocols): + raise RuntimeError + + async with run_server(select_subprotocol=select_subprotocol) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + + async def test_compression_is_enabled(self): + """Server enables compression by default.""" + async with run_server() as server: + async with connect(get_uri(server)) as client: + await self.assertEval( + client, + "[type(ext).__name__ for ext in ws.protocol.extensions]", + "['PerMessageDeflate']", + ) + + async def test_disable_compression(self): + """Server disables compression.""" + async with run_server(compression=None) as server: + async with connect(get_uri(server)) as client: + await self.assertEval(client, "ws.protocol.extensions", "[]") + + async def test_process_request_returns_none(self): + """Server runs process_request and continues the handshake.""" + + def process_request(ws, request): + self.assertIsInstance(request, Request) + ws.process_request_ran = True + + async with run_server(process_request=process_request) as server: + async with connect(get_uri(server)) as client: + await self.assertEval(client, "ws.process_request_ran", "True") + + async def test_async_process_request_returns_none(self): + """Server runs async process_request and continues the handshake.""" + + async def process_request(ws, request): + self.assertIsInstance(request, Request) + ws.process_request_ran = True + + async with run_server(process_request=process_request) as server: + async with connect(get_uri(server)) as client: + await self.assertEval(client, "ws.process_request_ran", "True") + + async def test_process_request_returns_response(self): + """Server aborts handshake if process_request returns a response.""" + + def process_request(ws, request): + return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") + + async def handler(ws): + self.fail("handler must not run") + + with self.assertNoLogs("websockets", logging.ERROR): + async with run_server(process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) + + async def test_async_process_request_returns_response(self): + """Server aborts handshake if async process_request returns a response.""" + + async def process_request(ws, request): + return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") + + async def handler(ws): + self.fail("handler must not run") + + with self.assertNoLogs("websockets", logging.ERROR): + async with run_server(process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) + + async def test_process_request_raises_exception(self): + """Server returns an error if process_request raises an exception.""" + + def process_request(ws, request): + raise RuntimeError("BOOM") + + with self.assertLogs("websockets", logging.ERROR) as logs: + async with run_server(process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + self.assertEqual( + [record.getMessage() for record in logs.records], + ["opening handshake failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) + + async def test_async_process_request_raises_exception(self): + """Server returns an error if async process_request raises an exception.""" + + async def process_request(ws, request): + raise RuntimeError("BOOM") + + with self.assertLogs("websockets", logging.ERROR) as logs: + async with run_server(process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + self.assertEqual( + [record.getMessage() for record in logs.records], + ["opening handshake failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) + + async def test_process_response_returns_none(self): + """Server runs process_response but keeps the handshake response.""" + + def process_response(ws, request, response): + self.assertIsInstance(request, Request) + self.assertIsInstance(response, Response) + ws.process_response_ran = True + + async with run_server(process_response=process_response) as server: + async with connect(get_uri(server)) as client: + await self.assertEval(client, "ws.process_response_ran", "True") + + async def test_async_process_response_returns_none(self): + """Server runs async process_response but keeps the handshake response.""" + + async def process_response(ws, request, response): + self.assertIsInstance(request, Request) + self.assertIsInstance(response, Response) + ws.process_response_ran = True + + async with run_server(process_response=process_response) as server: + async with connect(get_uri(server)) as client: + await self.assertEval(client, "ws.process_response_ran", "True") + + async def test_process_response_modifies_response(self): + """Server runs process_response and modifies the handshake response.""" + + def process_response(ws, request, response): + response.headers["X-ProcessResponse"] = "OK" + + async with run_server(process_response=process_response) as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") + + async def test_async_process_response_modifies_response(self): + """Server runs async process_response and modifies the handshake response.""" + + async def process_response(ws, request, response): + response.headers["X-ProcessResponse"] = "OK" + + async with run_server(process_response=process_response) as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") + + async def test_process_response_replaces_response(self): + """Server runs process_response and replaces the handshake response.""" + + def process_response(ws, request, response): + headers = response.headers.copy() + headers["X-ProcessResponse"] = "OK" + return dataclasses.replace(response, headers=headers) + + async with run_server(process_response=process_response) as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") + + async def test_async_process_response_replaces_response(self): + """Server runs async process_response and replaces the handshake response.""" + + async def process_response(ws, request, response): + headers = response.headers.copy() + headers["X-ProcessResponse"] = "OK" + return dataclasses.replace(response, headers=headers) + + async with run_server(process_response=process_response) as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") + + async def test_process_response_raises_exception(self): + """Server returns an error if process_response raises an exception.""" + + def process_response(ws, request, response): + raise RuntimeError("BOOM") + + with self.assertLogs("websockets", logging.ERROR) as logs: + async with run_server(process_response=process_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + self.assertEqual( + [record.getMessage() for record in logs.records], + ["opening handshake failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) + + async def test_async_process_response_raises_exception(self): + """Server returns an error if async process_response raises an exception.""" + + async def process_response(ws, request, response): + raise RuntimeError("BOOM") + + with self.assertLogs("websockets", logging.ERROR) as logs: + async with run_server(process_response=process_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + self.assertEqual( + [record.getMessage() for record in logs.records], + ["opening handshake failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) + + async def test_override_server(self): + """Server can override Server header with server_header.""" + async with run_server(server_header="Neo") as server: + async with connect(get_uri(server)) as client: + await self.assertEval(client, "ws.response.headers['Server']", "Neo") + + async def test_remove_server(self): + """Server can remove Server header with server_header.""" + async with run_server(server_header=None) as server: + async with connect(get_uri(server)) as client: + await self.assertEval( + client, "'Server' in ws.response.headers", "False" + ) + + async def test_keepalive_is_enabled(self): + """Server enables keepalive and measures latency.""" + async with run_server(ping_interval=MS) as server: + async with connect(get_uri(server)) as client: + await client.send("ws.latency") + latency = eval(await client.recv()) + self.assertEqual(latency, 0) + await trio.sleep(2 * MS) + await client.send("ws.latency") + latency = eval(await client.recv()) + self.assertGreater(latency, 0) + + async def test_disable_keepalive(self): + """Server disables keepalive.""" + async with run_server(ping_interval=None) as server: + async with connect(get_uri(server)) as client: + await trio.sleep(2 * MS) + await client.send("ws.latency") + latency = eval(await client.recv()) + self.assertEqual(latency, 0) + + async def test_logger(self): + """Server accepts a logger argument.""" + logger = logging.getLogger("test") + async with run_server(logger=logger) as server: + self.assertEqual(server.logger.name, logger.name) + + async def test_custom_connection_factory(self): + """Server runs ServerConnection factory provided in create_connection.""" + + def create_connection(*args, **kwargs): + server = ServerConnection(*args, **kwargs) + server.create_connection_ran = True + return server + + async with run_server(create_connection=create_connection) as server: + async with connect(get_uri(server)) as client: + await self.assertEval(client, "ws.create_connection_ran", "True") + + async def test_connections(self): + """Server provides a connections property.""" + async with run_server() as server: + self.assertEqual(server.connections, set()) + async with connect(get_uri(server)) as client: + self.assertEqual(len(server.connections), 1) + ws_id = str(next(iter(server.connections)).id) + await self.assertEval(client, "ws.id", ws_id) + self.assertEqual(server.connections, set()) + + async def test_handshake_fails(self): + """Server receives connection from client but the handshake fails.""" + + def remove_key_header(self, request): + del request.headers["Sec-WebSocket-Key"] + + async with run_server(process_request=remove_key_header) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 400", + ) + + async def test_timeout_during_handshake(self): + """Server times out before receiving handshake request from client.""" + async with run_server(open_timeout=MS) as server: + stream = await trio.open_tcp_stream(*get_host_port(server.listeners)) + try: + self.assertEqual(await stream.receive_some(4096), b"") + finally: + await stream.aclose() + + async def test_connection_closed_during_handshake(self): + """Server reads EOF before receiving handshake request from client.""" + async with run_server() as server: + stream = await trio.open_tcp_stream(*get_host_port(server.listeners)) + await stream.aclose() + + async def test_junk_handshake(self): + """Server closes the connection when receiving non-HTTP request from client.""" + with self.assertLogs("websockets", logging.ERROR) as logs: + async with run_server() as server: + stream = await trio.open_tcp_stream(*get_host_port(server.listeners)) + await stream.send_all(b"HELO relay.invalid\r\n") + try: + # Wait for the server to close the connection. + self.assertEqual(await stream.receive_some(4096), b"") + finally: + await stream.aclose() + + self.assertEqual( + [record.getMessage() for record in logs.records], + ["opening handshake failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["did not receive a valid HTTP request"], + ) + self.assertEqual( + [str(record.exc_info[1].__cause__) for record in logs.records], + ["invalid HTTP request line: HELO relay.invalid"], + ) + + async def test_close_server_rejects_connecting_connections(self): + """Server rejects connecting connections with HTTP 503 when closing.""" + + async def process_request(ws, _request): + while not ws.server.closing: + await trio.sleep(0) # pragma: no cover + + async with run_server(process_request=process_request) as server: + + async def close_server(server): + await trio.sleep(MS) + await server.aclose() + + self.nursery.start_soon(close_server, server) + + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 503", + ) + + async def test_close_server_closes_open_connections(self): + """Server closes open connections with close code 1001 when closing.""" + async with run_server() as server: + async with connect(get_uri(server)) as client: + await server.aclose() + with self.assertRaises(ConnectionClosedOK) as raised: + await client.recv() + self.assertEqual( + str(raised.exception), + "received 1001 (going away); then sent 1001 (going away)", + ) + + async def test_close_server_closes_open_connections_with_code_and_reason(self): + """Server closes open connections with custom code and reason when closing.""" + async with run_server() as server: + async with connect(get_uri(server)) as client: + await server.aclose(code=1012, reason="restarting") + with self.assertRaises(ConnectionClosedError) as raised: + await client.recv() + self.assertEqual( + str(raised.exception), + "received 1012 (service restart) restarting; " + "then sent 1012 (service restart) restarting", + ) + + async def test_close_server_keeps_connections_open(self): + """Server waits for client to close open connections when closing.""" + + async with run_server() as server: + server_closed = trio.Event() + + async def close_server(): + await server.aclose(close_connections=False) + server_closed.set() + + async with connect(get_uri(server)) as client: + self.nursery.start_soon(close_server) + + # Server cannot receive new connections. + with self.assertRaises(OSError): + async with connect(get_uri(server)): + self.fail("did not raise") + + # The server waits for the client to close the connection. + with self.assertRaises(trio.TooSlowError): + with trio.fail_after(MS): + await server_closed.wait() + + # Once the client closes the connection, the server terminates. + await client.aclose() + with trio.fail_after(MS): + await server_closed.wait() + + async def test_close_server_keeps_handlers_running(self): + """Server waits for connection handlers to terminate.""" + async with run_server() as server: + server_closed = trio.Event() + + async def close_server(): + await server.aclose(close_connections=False) + server_closed.set() + + async with connect(get_uri(server) + "/delay") as client: + # Delay termination of connection handler. + await client.send(str(3 * MS)) + + self.nursery.start_soon(close_server) + + # The server waits for the connection handler to terminate. + with self.assertRaises(trio.TooSlowError): + with trio.fail_after(2 * MS): + await server_closed.wait() + + # Set a large timeout here, else the test becomes flaky. + with trio.fail_after(5 * MS): + await server_closed.wait() + + +SSL_OBJECT = "ws.stream._ssl_object" + + +class SecureServerTests(EvalShellMixin, IsolatedTrioTestCase): + async def test_connection(self): + """Server receives secure connection from client.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + async with connect( + get_uri(server, secure=True), ssl=CLIENT_CONTEXT + ) as client: + await self.assertEval(client, "ws.protocol.state.name", "OPEN") + await self.assertEval(client, f"{SSL_OBJECT}.version()[:3]", "TLS") + + async def test_timeout_during_tls_handshake(self): + """Server times out before receiving TLS handshake request from client.""" + async with run_server(ssl=SERVER_CONTEXT, open_timeout=MS) as server: + stream = await trio.open_tcp_stream(*get_host_port(server.listeners)) + try: + self.assertEqual(await stream.receive_some(4096), b"") + finally: + await stream.aclose() + + async def test_connection_closed_during_tls_handshake(self): + """Server reads EOF before receiving TLS handshake request from client.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + stream = await trio.open_tcp_stream(*get_host_port(server.listeners)) + await stream.aclose() + + +class ServerUsageErrorsTests(IsolatedTrioTestCase): + async def test_missing_port(self): + """Server requires port.""" + with self.assertRaises(ValueError) as raised: + await serve(handler, None) + self.assertEqual( + str(raised.exception), + "port is required when listeners is not provided", + ) + + async def test_port_and_listeners(self): + """Server rejects port when listeners is provided.""" + listeners = await trio.open_tcp_listeners(0) + try: + with self.assertRaises(ValueError) as raised: + await serve(handler, port=0, listeners=listeners) + self.assertEqual( + str(raised.exception), + "port is incompatible with listeners", + ) + finally: + for listener in listeners: + await listener.aclose() + + async def test_host_and_listeners(self): + """Server rejects host when listeners is provided.""" + listeners = await trio.open_tcp_listeners(0) + try: + with self.assertRaises(ValueError) as raised: + await serve(handler, host="localhost", listeners=listeners) + self.assertEqual( + str(raised.exception), + "host is incompatible with listeners", + ) + finally: + for listener in listeners: + await listener.aclose() + + async def test_backlog_and_listeners(self): + """Server rejects backlog when listeners is provided.""" + listeners = await trio.open_tcp_listeners(0) + try: + with self.assertRaises(ValueError) as raised: + await serve(handler, backlog=65535, listeners=listeners) + self.assertEqual( + str(raised.exception), + "backlog is incompatible with listeners", + ) + finally: + for listener in listeners: + await listener.aclose() + + async def test_invalid_subprotocol(self): + """Server rejects single value of subprotocols.""" + with self.assertRaises(TypeError) as raised: + await serve(handler, subprotocols="chat") + self.assertEqual( + str(raised.exception), + "subprotocols must be a list, not a str", + ) + + async def test_unsupported_compression(self): + """Server rejects incorrect value of compression.""" + with self.assertRaises(ValueError) as raised: + await serve(handler, compression=False) + self.assertEqual( + str(raised.exception), + "unsupported compression: False", + ) + + +class BasicAuthTests(EvalShellMixin, IsolatedTrioTestCase): + async def test_valid_authorization(self): + """basic_auth authenticates client with HTTP Basic Authentication.""" + async with run_server( + process_request=basic_auth(credentials=("hello", "iloveyou")), + ) as server: + async with connect( + get_uri(server), + additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, + ) as client: + await self.assertEval(client, "ws.username", "hello") + + async def test_missing_authorization(self): + """basic_auth rejects client without credentials.""" + async with run_server( + process_request=basic_auth(credentials=("hello", "iloveyou")), + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 401", + ) + + async def test_unsupported_authorization(self): + """basic_auth rejects client with unsupported credentials.""" + async with run_server( + process_request=basic_auth(credentials=("hello", "iloveyou")), + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect( + get_uri(server), + additional_headers={"Authorization": "Negotiate ..."}, + ): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 401", + ) + + async def test_authorization_with_unknown_username(self): + """basic_auth rejects client with unknown username.""" + async with run_server( + process_request=basic_auth(credentials=("hello", "iloveyou")), + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect( + get_uri(server), + additional_headers={"Authorization": "Basic YnllOnlvdWxvdmVtZQ=="}, + ): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 401", + ) + + async def test_authorization_with_incorrect_password(self): + """basic_auth rejects client with incorrect password.""" + async with run_server( + process_request=basic_auth(credentials=("hello", "changeme")), + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect( + get_uri(server), + additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, + ): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 401", + ) + + async def test_list_of_credentials(self): + """basic_auth accepts a list of hard coded credentials.""" + async with run_server( + process_request=basic_auth( + credentials=[ + ("hello", "iloveyou"), + ("bye", "youloveme"), + ] + ), + ) as server: + async with connect( + get_uri(server), + additional_headers={"Authorization": "Basic YnllOnlvdWxvdmVtZQ=="}, + ) as client: + await self.assertEval(client, "ws.username", "bye") + + async def test_check_credentials_function(self): + """basic_auth accepts a check_credentials function.""" + + def check_credentials(username, password): + return hmac.compare_digest(password, "iloveyou") + + async with run_server( + process_request=basic_auth(check_credentials=check_credentials), + ) as server: + async with connect( + get_uri(server), + additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, + ) as client: + await self.assertEval(client, "ws.username", "hello") + + async def test_check_credentials_coroutine(self): + """basic_auth accepts a check_credentials coroutine.""" + + async def check_credentials(username, password): + return hmac.compare_digest(password, "iloveyou") + + async with run_server( + process_request=basic_auth(check_credentials=check_credentials), + ) as server: + async with connect( + get_uri(server), + additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, + ) as client: + await self.assertEval(client, "ws.username", "hello") + + async def test_without_credentials_or_check_credentials(self): + """basic_auth requires either credentials or check_credentials.""" + with self.assertRaises(ValueError) as raised: + basic_auth() + self.assertEqual( + str(raised.exception), + "provide either credentials or check_credentials", + ) + + async def test_with_credentials_and_check_credentials(self): + """basic_auth requires only one of credentials and check_credentials.""" + with self.assertRaises(ValueError) as raised: + basic_auth( + credentials=("hello", "iloveyou"), + check_credentials=lambda: False, # pragma: no cover + ) + self.assertEqual( + str(raised.exception), + "provide either credentials or check_credentials", + ) + + async def test_bad_credentials(self): + """basic_auth receives an unsupported credentials argument.""" + with self.assertRaises(TypeError) as raised: + basic_auth(credentials=42) + self.assertEqual( + str(raised.exception), + "invalid credentials argument: 42", + ) + + async def test_bad_list_of_credentials(self): + """basic_auth receives an unsupported credentials argument.""" + with self.assertRaises(TypeError) as raised: + basic_auth(credentials=[42]) + self.assertEqual( + str(raised.exception), + "invalid credentials argument: [42]", + ) From 9a52ba2c8d5f14eb4063c415d0561f9a04138ef2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 13 Aug 2025 15:52:14 +0200 Subject: [PATCH 2/3] Avoid a warning when building the docs. --- docs/conf.py | 2 ++ src/websockets/trio/client.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/docs/conf.py b/docs/conf.py index 0b1f64edc..3dcc04169 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -50,6 +50,8 @@ ("py:meth", "protocol.WebSocketCommonProtocol.connection_lost"), ("py:meth", "protocol.WebSocketCommonProtocol.read_message"), ("py:meth", "protocol.WebSocketCommonProtocol.write_frame"), + # Caused by https://github.com/sphinx-doc/sphinx/issues/13838 + ("py:class", "ssl_module.SSLContext"), ] # Add any Sphinx extension module names here, as strings. They can be diff --git a/src/websockets/trio/client.py b/src/websockets/trio/client.py index fa4b36b73..4af78af35 100644 --- a/src/websockets/trio/client.py +++ b/src/websockets/trio/client.py @@ -233,6 +233,9 @@ class connect: """ + # Arguments of type SSLContext don't render correctly in the documentation + # because of https://github.com/sphinx-doc/sphinx/issues/13838. + def __init__( self, uri: str, From f67346d9ae55c4331d0197e585bd570ba716c0e7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 21 Oct 2025 08:00:37 +0200 Subject: [PATCH 3/3] Add request for feedback on the trio API. --- docs/reference/trio/client.rst | 6 ++++++ docs/reference/trio/server.rst | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/docs/reference/trio/client.rst b/docs/reference/trio/client.rst index cf5643c55..0769b01d0 100644 --- a/docs/reference/trio/client.rst +++ b/docs/reference/trio/client.rst @@ -1,6 +1,12 @@ Client (:mod:`trio`) ======================= +.. admonition:: The :mod:`trio` API is experimental. + :class: caution + + Please provide feedback in GitHub issues about the API, especially if you + believe there's a more intuitive or convenient way to connect to a server. + .. automodule:: websockets.trio.client Opening a connection diff --git a/docs/reference/trio/server.rst b/docs/reference/trio/server.rst index e3d92ed45..17feb899d 100644 --- a/docs/reference/trio/server.rst +++ b/docs/reference/trio/server.rst @@ -1,6 +1,12 @@ Server (:mod:`trio`) ======================= +.. admonition:: The :mod:`trio` API is experimental. + :class: caution + + Please provide feedback in GitHub issues about the API, especially if you + believe there's a more intuitive or convenient way to run a server. + .. automodule:: websockets.trio.server Creating a server