From db7fcdaaf7adaf76585f5350f3a52db6960b8397 Mon Sep 17 00:00:00 2001 From: phi Date: Sun, 2 Nov 2025 15:57:37 +0900 Subject: [PATCH 01/15] feat: pass dumps_default, ext_hook --- src/socketio/msgpack_packet.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/src/socketio/msgpack_packet.py b/src/socketio/msgpack_packet.py index 27462634..6e86a72a 100644 --- a/src/socketio/msgpack_packet.py +++ b/src/socketio/msgpack_packet.py @@ -5,13 +5,33 @@ class MsgPackPacket(packet.Packet): uses_binary_events = False + def __init__( + self, + packet_type=packet.EVENT, + data=None, + namespace=None, + id=None, + binary=None, + encoded_packet=None, + dumps_default=None, + ext_hook=None, + ): + super().__init__( + packet_type, data, namespace, id, binary, encoded_packet + ) + self.dumps_default = dumps_default + self.ext_hook = ext_hook + def encode(self): """Encode the packet for transmission.""" - return msgpack.dumps(self._to_dict()) + return msgpack.dumps(self._to_dict(), default=self.dumps_default) def decode(self, encoded_packet): """Decode a transmitted package.""" - decoded = msgpack.loads(encoded_packet) + if self.ext_hook is None: + decoded = msgpack.loads(encoded_packet) + else: + decoded = msgpack.loads(encoded_packet, ext_hook=self.ext_hook) self.packet_type = decoded['type'] self.data = decoded.get('data') self.id = decoded.get('id') From abb2c8c4bf267d0f0aea438ed9390f6475096392 Mon Sep 17 00:00:00 2001 From: phi Date: Sun, 2 Nov 2025 16:20:49 +0900 Subject: [PATCH 02/15] test: msgpack packet tests --- src/socketio/msgpack_packet.py | 4 +- tests/common/test_msgpack_packet.py | 109 +++++++++++++++++++++++++++- 2 files changed, 109 insertions(+), 4 deletions(-) diff --git a/src/socketio/msgpack_packet.py b/src/socketio/msgpack_packet.py index 6e86a72a..846db2f3 100644 --- a/src/socketio/msgpack_packet.py +++ b/src/socketio/msgpack_packet.py @@ -16,11 +16,11 @@ def __init__( dumps_default=None, ext_hook=None, ): + self.dumps_default = dumps_default + self.ext_hook = ext_hook super().__init__( packet_type, data, namespace, id, binary, encoded_packet ) - self.dumps_default = dumps_default - self.ext_hook = ext_hook def encode(self): """Encode the packet for transmission.""" diff --git a/tests/common/test_msgpack_packet.py b/tests/common/test_msgpack_packet.py index e0197a27..1079e018 100644 --- a/tests/common/test_msgpack_packet.py +++ b/tests/common/test_msgpack_packet.py @@ -1,3 +1,8 @@ +from datetime import datetime, timedelta, timezone + +import pytest +import msgpack + from socketio import msgpack_packet from socketio import packet @@ -5,7 +10,8 @@ class TestMsgPackPacket: def test_encode_decode(self): p = msgpack_packet.MsgPackPacket( - packet.CONNECT, data={'auth': {'token': '123'}}, namespace='/foo') + packet.CONNECT, data={'auth': {'token': '123'}}, namespace='/foo' + ) p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode()) assert p.packet_type == p2.packet_type assert p.data == p2.data @@ -14,7 +20,8 @@ def test_encode_decode(self): def test_encode_decode_with_id(self): p = msgpack_packet.MsgPackPacket( - packet.EVENT, data=['ev', 42], id=123, namespace='/foo') + packet.EVENT, data=['ev', 42], id=123, namespace='/foo' + ) p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode()) assert p.packet_type == p2.packet_type assert p.data == p2.data @@ -32,3 +39,101 @@ def test_encode_binary_ack_packet(self): assert p.packet_type == packet.ACK p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode()) assert p2.data == {'foo': b'bar'} + + def test_encode_with_dumps_default(self): + def default(obj): + if isinstance(obj, datetime): + return obj.isoformat() + raise TypeError('Unknown type') + + data = { + 'current': datetime.now(tz=timezone(timedelta(0))), + 'key': 'value', + } + p = msgpack_packet.MsgPackPacket(data=data, dumps_default=default) + p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode()) + assert p.packet_type == p2.packet_type + assert p.id == p2.id + assert p.namespace == p2.namespace + assert p.data != p2.data + + assert isinstance(p2.data, dict) + assert 'current' in p2.data + assert isinstance(p2.data['current'], str) + assert default(data['current']) == p2.data['current'] + + data.pop('current') + p2_data_without_current = p2.data.copy() + p2_data_without_current.pop('current') + assert data == p2_data_without_current + + def test_encode_without_dumps_default(self): + data = { + 'current': datetime.now(tz=timezone(timedelta(0))), + 'key': 'value', + } + p_without_default = msgpack_packet.MsgPackPacket(data=data) + with pytest.raises( + TypeError, match="can not serialize 'datetime.datetime' object" + ): + p_without_default.encode() + + def test_encode_decode_with_ext_hook(self): + class Custom: + def __init__(self, value): + self.value = value + + def __eq__(self, value: object) -> bool: + return isinstance(value, Custom) and self.value == value.value + + def default(obj): + if isinstance(obj, Custom): + return msgpack.ExtType(1, obj.value) + raise TypeError('Unknown type') + + def ext_hook(code, data): + if code == 1: + return Custom(data) + raise TypeError('Unknown ext type') + + data = {'custom': Custom(b'custom_data'), 'key': 'value'} + p = msgpack_packet.MsgPackPacket(data=data, dumps_default=default) + p2 = msgpack_packet.MsgPackPacket( + encoded_packet=p.encode(), ext_hook=ext_hook + ) + assert p.packet_type == p2.packet_type + assert p.id == p2.id + assert p.data == p2.data + assert p.namespace == p2.namespace + + def test_encode_decode_without_ext_hook(self): + class Custom: + def __init__(self, value): + self.value = value + + def __eq__(self, value: object) -> bool: + return isinstance(value, Custom) and self.value == value.value + + def default(obj): + if isinstance(obj, Custom): + return msgpack.ExtType(1, obj.value) + raise TypeError('Unknown type') + + data = {'custom': Custom(b'custom_data'), 'key': 'value'} + p = msgpack_packet.MsgPackPacket(data=data, dumps_default=default) + p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode()) + assert p.packet_type == p2.packet_type + assert p.id == p2.id + assert p.namespace == p2.namespace + assert p.data != p2.data + + assert isinstance(p2.data, dict) + assert 'custom' in p2.data + assert isinstance(p2.data['custom'], msgpack.ExtType) + assert p2.data['custom'].code == 1 + assert p2.data['custom'].data == b'custom_data' + + data.pop('custom') + p2_data_without_custom = p2.data.copy() + p2_data_without_custom.pop('custom') + assert data == p2_data_without_custom From 5d9e3a7e6eda266496783b27684045a8f7a98c12 Mon Sep 17 00:00:00 2001 From: phi Date: Sun, 2 Nov 2025 16:32:26 +0900 Subject: [PATCH 03/15] fix: pypy tests --- tests/common/test_msgpack_packet.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/common/test_msgpack_packet.py b/tests/common/test_msgpack_packet.py index 1079e018..8a3befd5 100644 --- a/tests/common/test_msgpack_packet.py +++ b/tests/common/test_msgpack_packet.py @@ -73,9 +73,7 @@ def test_encode_without_dumps_default(self): 'key': 'value', } p_without_default = msgpack_packet.MsgPackPacket(data=data) - with pytest.raises( - TypeError, match="can not serialize 'datetime.datetime' object" - ): + with pytest.raises(TypeError): p_without_default.encode() def test_encode_decode_with_ext_hook(self): From c86a3fdc1ec23ffd05b693f3a7613dd961af562d Mon Sep 17 00:00:00 2001 From: phi Date: Sun, 2 Nov 2025 22:04:03 +0900 Subject: [PATCH 04/15] feat: serializer_args, _create_packet --- src/socketio/base_client.py | 7 ++++++- src/socketio/base_server.py | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/socketio/base_client.py b/src/socketio/base_client.py index 0232dca7..2bcaafcf 100644 --- a/src/socketio/base_client.py +++ b/src/socketio/base_client.py @@ -38,7 +38,8 @@ class BaseClient: def __init__(self, reconnection=True, reconnection_attempts=0, reconnection_delay=1, reconnection_delay_max=5, randomization_factor=0.5, logger=False, serializer='default', - json=None, handle_sigint=True, **kwargs): + json=None, handle_sigint=True, serializer_args=None, + **kwargs): global original_signal_handler if handle_sigint and original_signal_handler is None and \ threading.current_thread() == threading.main_thread(): @@ -63,6 +64,7 @@ def __init__(self, reconnection=True, reconnection_attempts=0, self.packet_class = msgpack_packet.MsgPackPacket else: self.packet_class = serializer + self.packet_class_args = serializer_args or {} if json is not None: self.packet_class.json = json engineio_options['json'] = json @@ -283,6 +285,9 @@ def _generate_ack_id(self, namespace, callback): self.callbacks[namespace][id] = callback return id + def _create_packet(self, *args, **kwargs): + return self.packet_class(*args, **kwargs, **self.packet_class_args) + def _handle_eio_connect(self): # pragma: no cover raise NotImplementedError() diff --git a/src/socketio/base_server.py b/src/socketio/base_server.py index d134eba1..873e969f 100644 --- a/src/socketio/base_server.py +++ b/src/socketio/base_server.py @@ -15,7 +15,7 @@ class BaseServer: def __init__(self, client_manager=None, logger=False, serializer='default', json=None, async_handlers=True, always_connect=False, - namespaces=None, **kwargs): + namespaces=None, serializer_args=None, **kwargs): engineio_options = kwargs engineio_logger = engineio_options.pop('engineio_logger', None) if engineio_logger is not None: @@ -27,6 +27,7 @@ def __init__(self, client_manager=None, logger=False, serializer='default', self.packet_class = msgpack_packet.MsgPackPacket else: self.packet_class = serializer + self.packet_class_args = serializer_args or {} if json is not None: self.packet_class.json = json engineio_options['json'] = json @@ -252,6 +253,10 @@ def _get_namespace_handler(self, namespace, args): handler = self.namespace_handlers['*'] args = (namespace, *args) return handler, args + + def _create_packet(self, *args, **kwargs): + return self.packet_class(*args, **kwargs, + **self.packet_class_args) def _handle_eio_connect(self): # pragma: no cover raise NotImplementedError() From 005953da1d11ff4f81e59d54466fa389f0e8924d Mon Sep 17 00:00:00 2001 From: phi Date: Sun, 2 Nov 2025 22:07:49 +0900 Subject: [PATCH 05/15] fix: apply _create_packet --- src/socketio/async_client.py | 10 +++++----- src/socketio/async_server.py | 16 ++++++++-------- src/socketio/client.py | 10 +++++----- src/socketio/server.py | 16 ++++++++-------- 4 files changed, 26 insertions(+), 26 deletions(-) diff --git a/src/socketio/async_client.py b/src/socketio/async_client.py index 678743a2..c1ec14f0 100644 --- a/src/socketio/async_client.py +++ b/src/socketio/async_client.py @@ -243,7 +243,7 @@ async def emit(self, event, data=None, namespace=None, callback=None): data = [data] else: data = [] - await self._send_packet(self.packet_class( + await self._send_packet(self._create_packet( packet.EVENT, namespace=namespace, data=[event] + data, id=id)) async def send(self, data, namespace=None, callback=None): @@ -325,7 +325,7 @@ async def disconnect(self): # here we just request the disconnection # later in _handle_eio_disconnect we invoke the disconnect handler for n in self.namespaces: - await self._send_packet(self.packet_class(packet.DISCONNECT, + await self._send_packet(self._create_packet(packet.DISCONNECT, namespace=n)) await self.eio.disconnect() @@ -422,7 +422,7 @@ async def _handle_event(self, namespace, id, data): data = list(r) else: data = [r] - await self._send_packet(self.packet_class( + await self._send_packet(self._create_packet( packet.ACK, namespace=namespace, id=id, data=data)) async def _handle_ack(self, namespace, id, data): @@ -555,7 +555,7 @@ async def _handle_eio_connect(self): self.sid = self.eio.sid real_auth = await self._get_real_value(self.connection_auth) or {} for n in self.connection_namespaces: - await self._send_packet(self.packet_class( + await self._send_packet(self._create_packet( packet.CONNECT, data=real_auth, namespace=n)) async def _handle_eio_message(self, data): @@ -569,7 +569,7 @@ async def _handle_eio_message(self, data): else: await self._handle_ack(pkt.namespace, pkt.id, pkt.data) else: - pkt = self.packet_class(encoded_packet=data) + pkt = self._create_packet(encoded_packet=data) if pkt.packet_type == packet.CONNECT: await self._handle_connect(pkt.namespace, pkt.data) elif pkt.packet_type == packet.DISCONNECT: diff --git a/src/socketio/async_server.py b/src/socketio/async_server.py index 6c9e3ca3..5b896bf5 100644 --- a/src/socketio/async_server.py +++ b/src/socketio/async_server.py @@ -425,7 +425,7 @@ async def disconnect(self, sid, namespace=None, ignore_queue=False): if delete_it: self.logger.info('Disconnecting %s [%s]', sid, namespace) eio_sid = self.manager.pre_disconnect(sid, namespace=namespace) - await self._send_packet(eio_sid, self.packet_class( + await self._send_packet(eio_sid, self._create_packet( packet.DISCONNECT, namespace=namespace)) await self._trigger_event('disconnect', namespace, sid, self.reason.SERVER_DISCONNECT) @@ -538,13 +538,13 @@ async def _handle_connect(self, eio_sid, namespace, data): or self.namespaces == '*' or namespace in self.namespaces: sid = await self.manager.connect(eio_sid, namespace) if sid is None: - await self._send_packet(eio_sid, self.packet_class( + await self._send_packet(eio_sid, self._create_packet( packet.CONNECT_ERROR, data='Unable to connect', namespace=namespace)) return if self.always_connect: - await self._send_packet(eio_sid, self.packet_class( + await self._send_packet(eio_sid, self._create_packet( packet.CONNECT, {'sid': sid}, namespace=namespace)) fail_reason = exceptions.ConnectionRefusedError().error_args try: @@ -568,15 +568,15 @@ async def _handle_connect(self, eio_sid, namespace, data): if success is False: if self.always_connect: self.manager.pre_disconnect(sid, namespace) - await self._send_packet(eio_sid, self.packet_class( + await self._send_packet(eio_sid, self._create_packet( packet.DISCONNECT, data=fail_reason, namespace=namespace)) else: - await self._send_packet(eio_sid, self.packet_class( + await self._send_packet(eio_sid, self._create_packet( packet.CONNECT_ERROR, data=fail_reason, namespace=namespace)) await self.manager.disconnect(sid, namespace, ignore_queue=True) elif not self.always_connect: - await self._send_packet(eio_sid, self.packet_class( + await self._send_packet(eio_sid, self._create_packet( packet.CONNECT, {'sid': sid}, namespace=namespace)) async def _handle_disconnect(self, eio_sid, namespace, reason=None): @@ -622,7 +622,7 @@ async def _handle_event_internal(self, server, sid, eio_sid, data, data = list(r) else: data = [r] - await server._send_packet(eio_sid, self.packet_class( + await server._send_packet(eio_sid, self._create_packet( packet.ACK, namespace=namespace, id=id, data=data)) async def _handle_ack(self, eio_sid, namespace, id, data): @@ -686,7 +686,7 @@ async def _handle_eio_message(self, eio_sid, data): await self._handle_ack(eio_sid, pkt.namespace, pkt.id, pkt.data) else: - pkt = self.packet_class(encoded_packet=data) + pkt = self._create_packet(encoded_packet=data) if pkt.packet_type == packet.CONNECT: await self._handle_connect(eio_sid, pkt.namespace, pkt.data) elif pkt.packet_type == packet.DISCONNECT: diff --git a/src/socketio/client.py b/src/socketio/client.py index 5282e0a1..296e4dc4 100644 --- a/src/socketio/client.py +++ b/src/socketio/client.py @@ -234,7 +234,7 @@ def emit(self, event, data=None, namespace=None, callback=None): data = [data] else: data = [] - self._send_packet(self.packet_class(packet.EVENT, namespace=namespace, + self._send_packet(self._create_packet(packet.EVENT, namespace=namespace, data=[event] + data, id=id)) def send(self, data, namespace=None, callback=None): @@ -307,7 +307,7 @@ def disconnect(self): # here we just request the disconnection # later in _handle_eio_disconnect we invoke the disconnect handler for n in self.namespaces: - self._send_packet(self.packet_class( + self._send_packet(self._create_packet( packet.DISCONNECT, namespace=n)) self.eio.disconnect() @@ -402,7 +402,7 @@ def _handle_event(self, namespace, id, data): data = list(r) else: data = [r] - self._send_packet(self.packet_class( + self._send_packet(self._create_packet( packet.ACK, namespace=namespace, id=id, data=data)) def _handle_ack(self, namespace, id, data): @@ -506,7 +506,7 @@ def _handle_eio_connect(self): self.sid = self.eio.sid real_auth = self._get_real_value(self.connection_auth) or {} for n in self.connection_namespaces: - self._send_packet(self.packet_class( + self._send_packet(self._create_packet( packet.CONNECT, data=real_auth, namespace=n)) def _handle_eio_message(self, data): @@ -520,7 +520,7 @@ def _handle_eio_message(self, data): else: self._handle_ack(pkt.namespace, pkt.id, pkt.data) else: - pkt = self.packet_class(encoded_packet=data) + pkt = self._create_packet(encoded_packet=data) if pkt.packet_type == packet.CONNECT: self._handle_connect(pkt.namespace, pkt.data) elif pkt.packet_type == packet.DISCONNECT: diff --git a/src/socketio/server.py b/src/socketio/server.py index f3257081..21d6afeb 100644 --- a/src/socketio/server.py +++ b/src/socketio/server.py @@ -401,7 +401,7 @@ def disconnect(self, sid, namespace=None, ignore_queue=False): if delete_it: self.logger.info('Disconnecting %s [%s]', sid, namespace) eio_sid = self.manager.pre_disconnect(sid, namespace=namespace) - self._send_packet(eio_sid, self.packet_class( + self._send_packet(eio_sid, self._create_packet( packet.DISCONNECT, namespace=namespace)) self._trigger_event('disconnect', namespace, sid, self.reason.SERVER_DISCONNECT) @@ -520,13 +520,13 @@ def _handle_connect(self, eio_sid, namespace, data): or self.namespaces == '*' or namespace in self.namespaces: sid = self.manager.connect(eio_sid, namespace) if sid is None: - self._send_packet(eio_sid, self.packet_class( + self._send_packet(eio_sid, self._create_packet( packet.CONNECT_ERROR, data='Unable to connect', namespace=namespace)) return if self.always_connect: - self._send_packet(eio_sid, self.packet_class( + self._send_packet(eio_sid, self._create_packet( packet.CONNECT, {'sid': sid}, namespace=namespace)) fail_reason = exceptions.ConnectionRefusedError().error_args try: @@ -550,15 +550,15 @@ def _handle_connect(self, eio_sid, namespace, data): if success is False: if self.always_connect: self.manager.pre_disconnect(sid, namespace) - self._send_packet(eio_sid, self.packet_class( + self._send_packet(eio_sid, self._create_packet( packet.DISCONNECT, data=fail_reason, namespace=namespace)) else: - self._send_packet(eio_sid, self.packet_class( + self._send_packet(eio_sid, self._create_packet( packet.CONNECT_ERROR, data=fail_reason, namespace=namespace)) self.manager.disconnect(sid, namespace, ignore_queue=True) elif not self.always_connect: - self._send_packet(eio_sid, self.packet_class( + self._send_packet(eio_sid, self._create_packet( packet.CONNECT, {'sid': sid}, namespace=namespace)) def _handle_disconnect(self, eio_sid, namespace, reason=None): @@ -601,7 +601,7 @@ def _handle_event_internal(self, server, sid, eio_sid, data, namespace, data = list(r) else: data = [r] - server._send_packet(eio_sid, self.packet_class( + server._send_packet(eio_sid, self._create_packet( packet.ACK, namespace=namespace, id=id, data=data)) def _handle_ack(self, eio_sid, namespace, id, data): @@ -650,7 +650,7 @@ def _handle_eio_message(self, eio_sid, data): else: self._handle_ack(eio_sid, pkt.namespace, pkt.id, pkt.data) else: - pkt = self.packet_class(encoded_packet=data) + pkt = self._create_packet(encoded_packet=data) if pkt.packet_type == packet.CONNECT: self._handle_connect(eio_sid, pkt.namespace, pkt.data) elif pkt.packet_type == packet.DISCONNECT: From 81e96142f9d92f1203985a07d3a5a12b63609972 Mon Sep 17 00:00:00 2001 From: phi Date: Sun, 2 Nov 2025 22:10:48 +0900 Subject: [PATCH 06/15] docs: add serializer_args --- src/socketio/async_client.py | 3 +++ src/socketio/async_server.py | 3 +++ src/socketio/client.py | 3 +++ src/socketio/server.py | 3 +++ 4 files changed, 12 insertions(+) diff --git a/src/socketio/async_client.py b/src/socketio/async_client.py index c1ec14f0..fc7ce3fa 100644 --- a/src/socketio/async_client.py +++ b/src/socketio/async_client.py @@ -45,6 +45,9 @@ class AsyncClient(base_client.BaseClient): leave interrupt handling to the calling application. Interrupt handling can only be enabled when the client instance is created in the main thread. + :param serializer_args: A mapping of additional parameters to pass to + the serializer. The content of this dictionary + depends on the selected serialization method. The Engine.IO configuration supports the following settings: diff --git a/src/socketio/async_server.py b/src/socketio/async_server.py index 5b896bf5..9bdfca4f 100644 --- a/src/socketio/async_server.py +++ b/src/socketio/async_server.py @@ -50,6 +50,9 @@ class AsyncServer(base_server.BaseServer): default is `['/']`, which always accepts connections to the default namespace. Set to `'*'` to accept all namespaces. + :param serializer_args: A mapping of additional parameters to pass to + the serializer. The content of this dictionary + depends on the selected serialization method. :param kwargs: Connection parameters for the underlying Engine.IO server. The Engine.IO configuration supports the following settings: diff --git a/src/socketio/client.py b/src/socketio/client.py index 296e4dc4..7be92ccb 100644 --- a/src/socketio/client.py +++ b/src/socketio/client.py @@ -48,6 +48,9 @@ class Client(base_client.BaseClient): leave interrupt handling to the calling application. Interrupt handling can only be enabled when the client instance is created in the main thread. + :param serializer_args: A mapping of additional parameters to pass to + the serializer. The content of this dictionary + depends on the selected serialization method. The Engine.IO configuration supports the following settings: diff --git a/src/socketio/server.py b/src/socketio/server.py index 21d6afeb..1658fa5a 100644 --- a/src/socketio/server.py +++ b/src/socketio/server.py @@ -53,6 +53,9 @@ class Server(base_server.BaseServer): default is `['/']`, which always accepts connections to the default namespace. Set to `'*'` to accept all namespaces. + :param serializer_args: A mapping of additional parameters to pass to + the serializer. The content of this dictionary + depends on the selected serialization method. :param kwargs: Connection parameters for the underlying Engine.IO server. The Engine.IO configuration supports the following settings: From 5bb11b6964e99b4874fa41dff4c0032435987f8b Mon Sep 17 00:00:00 2001 From: phi Date: Sun, 2 Nov 2025 22:12:04 +0900 Subject: [PATCH 07/15] fix: lint error --- src/socketio/async_client.py | 2 +- src/socketio/async_server.py | 2 +- src/socketio/base_server.py | 2 +- src/socketio/client.py | 7 ++++--- src/socketio/server.py | 2 +- 5 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/socketio/async_client.py b/src/socketio/async_client.py index fc7ce3fa..c19c8459 100644 --- a/src/socketio/async_client.py +++ b/src/socketio/async_client.py @@ -45,7 +45,7 @@ class AsyncClient(base_client.BaseClient): leave interrupt handling to the calling application. Interrupt handling can only be enabled when the client instance is created in the main thread. - :param serializer_args: A mapping of additional parameters to pass to + :param serializer_args: A mapping of additional parameters to pass to the serializer. The content of this dictionary depends on the selected serialization method. diff --git a/src/socketio/async_server.py b/src/socketio/async_server.py index 9bdfca4f..fa22393e 100644 --- a/src/socketio/async_server.py +++ b/src/socketio/async_server.py @@ -50,7 +50,7 @@ class AsyncServer(base_server.BaseServer): default is `['/']`, which always accepts connections to the default namespace. Set to `'*'` to accept all namespaces. - :param serializer_args: A mapping of additional parameters to pass to + :param serializer_args: A mapping of additional parameters to pass to the serializer. The content of this dictionary depends on the selected serialization method. :param kwargs: Connection parameters for the underlying Engine.IO server. diff --git a/src/socketio/base_server.py b/src/socketio/base_server.py index 873e969f..488ffe1d 100644 --- a/src/socketio/base_server.py +++ b/src/socketio/base_server.py @@ -253,7 +253,7 @@ def _get_namespace_handler(self, namespace, args): handler = self.namespace_handlers['*'] args = (namespace, *args) return handler, args - + def _create_packet(self, *args, **kwargs): return self.packet_class(*args, **kwargs, **self.packet_class_args) diff --git a/src/socketio/client.py b/src/socketio/client.py index 7be92ccb..29c1f25c 100644 --- a/src/socketio/client.py +++ b/src/socketio/client.py @@ -48,7 +48,7 @@ class Client(base_client.BaseClient): leave interrupt handling to the calling application. Interrupt handling can only be enabled when the client instance is created in the main thread. - :param serializer_args: A mapping of additional parameters to pass to + :param serializer_args: A mapping of additional parameters to pass to the serializer. The content of this dictionary depends on the selected serialization method. @@ -237,8 +237,9 @@ def emit(self, event, data=None, namespace=None, callback=None): data = [data] else: data = [] - self._send_packet(self._create_packet(packet.EVENT, namespace=namespace, - data=[event] + data, id=id)) + self._send_packet( + self._create_packet(packet.EVENT, namespace=namespace, + data=[event] + data, id=id)) def send(self, data, namespace=None, callback=None): """Send a message to the server. diff --git a/src/socketio/server.py b/src/socketio/server.py index 1658fa5a..7312506c 100644 --- a/src/socketio/server.py +++ b/src/socketio/server.py @@ -53,7 +53,7 @@ class Server(base_server.BaseServer): default is `['/']`, which always accepts connections to the default namespace. Set to `'*'` to accept all namespaces. - :param serializer_args: A mapping of additional parameters to pass to + :param serializer_args: A mapping of additional parameters to pass to the serializer. The content of this dictionary depends on the selected serialization method. :param kwargs: Connection parameters for the underlying Engine.IO server. From 6089dd0b9e74e9d7fd00002a0934c7a9216ed369 Mon Sep 17 00:00:00 2001 From: phi Date: Sun, 2 Nov 2025 22:25:31 +0900 Subject: [PATCH 08/15] test: serializer_args tests --- tests/async/test_client.py | 29 +++++++++++++++++++++++++++++ tests/async/test_server.py | 29 +++++++++++++++++++++++++++++ tests/common/test_client.py | 29 +++++++++++++++++++++++++++++ tests/common/test_server.py | 29 +++++++++++++++++++++++++++++ 4 files changed, 116 insertions(+) diff --git a/tests/async/test_client.py b/tests/async/test_client.py index b4b0c6c5..6b25c7ba 100644 --- a/tests/async/test_client.py +++ b/tests/async/test_client.py @@ -1,5 +1,6 @@ import asyncio from unittest import mock +from datetime import datetime, timezone, timedelta import pytest @@ -1242,3 +1243,31 @@ async def test_eio_disconnect_no_reconnect(self): assert c.sid is None assert not c.connected c.start_background_task.assert_not_called() + + def test_serializer_args(self): + args = {"foo": "bar"} + c = async_client.AsyncClient(serializer_args=args) + assert c.packet_class_args == args + + def test_serializer_args_with_msgpack(self): + def default(o): + if isinstance(o, datetime): + return o.isoformat() + raise TypeError("Unknown type") + args = {"dumps_default": default} + data = {"current": datetime.now(timezone(timedelta(0)))} + c = async_client.AsyncClient(serializer='msgpack', serializer_args=args) + p = c._create_packet(data=data) + p2 = c._create_packet(encoded_packet=p.encode()) + + assert p.data != p2.data + assert isinstance(p2.data, dict) + assert "current" in p2.data + assert isinstance(p2.data["current"], str) + assert default(data["current"]) == p2.data["current"] + + def test_invalid_serializer_args(self): + args = {"invalid_arg": 123} + c = async_client.AsyncClient(serializer='msgpack', serializer_args=args) + with pytest.raises(TypeError): + c._create_packet(data={"foo": "bar"}).encode() \ No newline at end of file diff --git a/tests/async/test_server.py b/tests/async/test_server.py index 575f2097..6bc75b9c 100644 --- a/tests/async/test_server.py +++ b/tests/async/test_server.py @@ -1,6 +1,7 @@ import asyncio import logging from unittest import mock +from datetime import datetime, timezone, timedelta from engineio import json from engineio import packet as eio_packet @@ -1089,3 +1090,31 @@ async def test_sleep(self, eio): s = async_server.AsyncServer() await s.sleep(1.23) s.eio.sleep.assert_awaited_once_with(1.23) + + def test_serializer_args(self, eio): + args = {"foo": "bar"} + s = async_server.AsyncServer(serializer_args=args) + assert s.packet_class_args == args + + def test_serializer_args_with_msgpack(self, eio): + def default(o): + if isinstance(o, datetime): + return o.isoformat() + raise TypeError("Unknown type") + args = {"dumps_default": default} + data = {"current": datetime.now(timezone(timedelta(0)))} + s = async_server.AsyncServer(serializer='msgpack', serializer_args=args) + p = s._create_packet(data=data) + p2 = s._create_packet(encoded_packet=p.encode()) + + assert p.data != p2.data + assert isinstance(p2.data, dict) + assert "current" in p2.data + assert isinstance(p2.data["current"], str) + assert default(data["current"]) == p2.data["current"] + + def test_invalid_serializer_args(self, eio): + args = {"invalid_arg": 123} + s = async_server.AsyncServer(serializer='msgpack', serializer_args=args) + with pytest.raises(TypeError): + s._create_packet(data={"foo": "bar"}).encode() \ No newline at end of file diff --git a/tests/common/test_client.py b/tests/common/test_client.py index cbda3f1f..fd90512c 100644 --- a/tests/common/test_client.py +++ b/tests/common/test_client.py @@ -1,6 +1,7 @@ import logging import time from unittest import mock +from datetime import datetime, timezone, timedelta from engineio import exceptions as engineio_exceptions from engineio import json @@ -1386,3 +1387,31 @@ def test_eio_disconnect_no_reconnect(self): assert c.sid is None assert not c.connected c.start_background_task.assert_not_called() + + def test_serializer_args(self): + args = {"foo": "bar"} + c = client.Client(serializer_args=args) + assert c.packet_class_args == args + + def test_serializer_args_with_msgpack(self): + def default(o): + if isinstance(o, datetime): + return o.isoformat() + raise TypeError("Unknown type") + args = {"dumps_default": default} + data = {"current": datetime.now(timezone(timedelta(0)))} + c = client.Client(serializer='msgpack', serializer_args=args) + p = c._create_packet(data=data) + p2 = c._create_packet(encoded_packet=p.encode()) + + assert p.data != p2.data + assert isinstance(p2.data, dict) + assert "current" in p2.data + assert isinstance(p2.data["current"], str) + assert default(data["current"]) == p2.data["current"] + + def test_invalid_serializer_args(self): + args = {"invalid_arg": 123} + c = client.Client(serializer='msgpack', serializer_args=args) + with pytest.raises(TypeError): + c._create_packet(data={"foo": "bar"}).encode() \ No newline at end of file diff --git a/tests/common/test_server.py b/tests/common/test_server.py index bdbbfe07..a8df8c8a 100644 --- a/tests/common/test_server.py +++ b/tests/common/test_server.py @@ -1,5 +1,6 @@ import logging from unittest import mock +from datetime import datetime, timezone, timedelta from engineio import json from engineio import packet as eio_packet @@ -1032,3 +1033,31 @@ def test_sleep(self, eio): s = server.Server() s.sleep(1.23) s.eio.sleep.assert_called_once_with(1.23) + + def test_serializer_args(self, eio): + args = {"foo": "bar"} + s = server.Server(serializer_args=args) + assert s.packet_class_args == args + + def test_serializer_args_with_msgpack(self, eio): + def default(o): + if isinstance(o, datetime): + return o.isoformat() + raise TypeError("Unknown type") + args = {"dumps_default": default} + data = {"current": datetime.now(timezone(timedelta(0)))} + s = server.Server(serializer='msgpack', serializer_args=args) + p = s._create_packet(data=data) + p2 = s._create_packet(encoded_packet=p.encode()) + + assert p.data != p2.data + assert isinstance(p2.data, dict) + assert "current" in p2.data + assert isinstance(p2.data["current"], str) + assert default(data["current"]) == p2.data["current"] + + def test_invalid_serializer_args(self, eio): + args = {"invalid_arg": 123} + s = server.Server(serializer='msgpack', serializer_args=args) + with pytest.raises(TypeError): + s._create_packet(data={"foo": "bar"}).encode() \ No newline at end of file From 6c899d74259776c0cad724c0eedc220c817e8dc1 Mon Sep 17 00:00:00 2001 From: phi Date: Sun, 2 Nov 2025 22:26:53 +0900 Subject: [PATCH 09/15] fix: lint errors --- tests/async/test_client.py | 12 +++++++----- tests/async/test_server.py | 12 +++++++----- tests/common/test_client.py | 6 +++--- tests/common/test_server.py | 6 +++--- 4 files changed, 20 insertions(+), 16 deletions(-) diff --git a/tests/async/test_client.py b/tests/async/test_client.py index 6b25c7ba..58e2ac75 100644 --- a/tests/async/test_client.py +++ b/tests/async/test_client.py @@ -1248,7 +1248,7 @@ def test_serializer_args(self): args = {"foo": "bar"} c = async_client.AsyncClient(serializer_args=args) assert c.packet_class_args == args - + def test_serializer_args_with_msgpack(self): def default(o): if isinstance(o, datetime): @@ -1256,7 +1256,8 @@ def default(o): raise TypeError("Unknown type") args = {"dumps_default": default} data = {"current": datetime.now(timezone(timedelta(0)))} - c = async_client.AsyncClient(serializer='msgpack', serializer_args=args) + c = async_client.AsyncClient(serializer='msgpack', + serializer_args=args) p = c._create_packet(data=data) p2 = c._create_packet(encoded_packet=p.encode()) @@ -1265,9 +1266,10 @@ def default(o): assert "current" in p2.data assert isinstance(p2.data["current"], str) assert default(data["current"]) == p2.data["current"] - + def test_invalid_serializer_args(self): args = {"invalid_arg": 123} - c = async_client.AsyncClient(serializer='msgpack', serializer_args=args) + c = async_client.AsyncClient(serializer='msgpack', + serializer_args=args) with pytest.raises(TypeError): - c._create_packet(data={"foo": "bar"}).encode() \ No newline at end of file + c._create_packet(data={"foo": "bar"}).encode() diff --git a/tests/async/test_server.py b/tests/async/test_server.py index 6bc75b9c..793192f2 100644 --- a/tests/async/test_server.py +++ b/tests/async/test_server.py @@ -1095,7 +1095,7 @@ def test_serializer_args(self, eio): args = {"foo": "bar"} s = async_server.AsyncServer(serializer_args=args) assert s.packet_class_args == args - + def test_serializer_args_with_msgpack(self, eio): def default(o): if isinstance(o, datetime): @@ -1103,7 +1103,8 @@ def default(o): raise TypeError("Unknown type") args = {"dumps_default": default} data = {"current": datetime.now(timezone(timedelta(0)))} - s = async_server.AsyncServer(serializer='msgpack', serializer_args=args) + s = async_server.AsyncServer(serializer='msgpack', + serializer_args=args) p = s._create_packet(data=data) p2 = s._create_packet(encoded_packet=p.encode()) @@ -1112,9 +1113,10 @@ def default(o): assert "current" in p2.data assert isinstance(p2.data["current"], str) assert default(data["current"]) == p2.data["current"] - + def test_invalid_serializer_args(self, eio): args = {"invalid_arg": 123} - s = async_server.AsyncServer(serializer='msgpack', serializer_args=args) + s = async_server.AsyncServer(serializer='msgpack', + serializer_args=args) with pytest.raises(TypeError): - s._create_packet(data={"foo": "bar"}).encode() \ No newline at end of file + s._create_packet(data={"foo": "bar"}).encode() diff --git a/tests/common/test_client.py b/tests/common/test_client.py index fd90512c..90ab5dfd 100644 --- a/tests/common/test_client.py +++ b/tests/common/test_client.py @@ -1392,7 +1392,7 @@ def test_serializer_args(self): args = {"foo": "bar"} c = client.Client(serializer_args=args) assert c.packet_class_args == args - + def test_serializer_args_with_msgpack(self): def default(o): if isinstance(o, datetime): @@ -1409,9 +1409,9 @@ def default(o): assert "current" in p2.data assert isinstance(p2.data["current"], str) assert default(data["current"]) == p2.data["current"] - + def test_invalid_serializer_args(self): args = {"invalid_arg": 123} c = client.Client(serializer='msgpack', serializer_args=args) with pytest.raises(TypeError): - c._create_packet(data={"foo": "bar"}).encode() \ No newline at end of file + c._create_packet(data={"foo": "bar"}).encode() diff --git a/tests/common/test_server.py b/tests/common/test_server.py index a8df8c8a..6bbe7c4c 100644 --- a/tests/common/test_server.py +++ b/tests/common/test_server.py @@ -1038,7 +1038,7 @@ def test_serializer_args(self, eio): args = {"foo": "bar"} s = server.Server(serializer_args=args) assert s.packet_class_args == args - + def test_serializer_args_with_msgpack(self, eio): def default(o): if isinstance(o, datetime): @@ -1055,9 +1055,9 @@ def default(o): assert "current" in p2.data assert isinstance(p2.data["current"], str) assert default(data["current"]) == p2.data["current"] - + def test_invalid_serializer_args(self, eio): args = {"invalid_arg": 123} s = server.Server(serializer='msgpack', serializer_args=args) with pytest.raises(TypeError): - s._create_packet(data={"foo": "bar"}).encode() \ No newline at end of file + s._create_packet(data={"foo": "bar"}).encode() From aac7612685b16407ee4839cd3e561b8c3c5091d5 Mon Sep 17 00:00:00 2001 From: phi Date: Wed, 5 Nov 2025 19:38:31 +0900 Subject: [PATCH 10/15] Revert all This reverts commit db7fcdaaf7adaf76585f5350f3a52db6960b8397. --- src/socketio/async_client.py | 13 ++-- src/socketio/async_server.py | 19 +++-- src/socketio/base_client.py | 7 +- src/socketio/base_server.py | 7 +- src/socketio/client.py | 16 ++--- src/socketio/msgpack_packet.py | 24 +------ src/socketio/server.py | 19 +++-- tests/async/test_client.py | 31 -------- tests/async/test_server.py | 31 -------- tests/common/test_client.py | 29 -------- tests/common/test_msgpack_packet.py | 107 +--------------------------- tests/common/test_server.py | 29 -------- 12 files changed, 33 insertions(+), 299 deletions(-) diff --git a/src/socketio/async_client.py b/src/socketio/async_client.py index c19c8459..678743a2 100644 --- a/src/socketio/async_client.py +++ b/src/socketio/async_client.py @@ -45,9 +45,6 @@ class AsyncClient(base_client.BaseClient): leave interrupt handling to the calling application. Interrupt handling can only be enabled when the client instance is created in the main thread. - :param serializer_args: A mapping of additional parameters to pass to - the serializer. The content of this dictionary - depends on the selected serialization method. The Engine.IO configuration supports the following settings: @@ -246,7 +243,7 @@ async def emit(self, event, data=None, namespace=None, callback=None): data = [data] else: data = [] - await self._send_packet(self._create_packet( + await self._send_packet(self.packet_class( packet.EVENT, namespace=namespace, data=[event] + data, id=id)) async def send(self, data, namespace=None, callback=None): @@ -328,7 +325,7 @@ async def disconnect(self): # here we just request the disconnection # later in _handle_eio_disconnect we invoke the disconnect handler for n in self.namespaces: - await self._send_packet(self._create_packet(packet.DISCONNECT, + await self._send_packet(self.packet_class(packet.DISCONNECT, namespace=n)) await self.eio.disconnect() @@ -425,7 +422,7 @@ async def _handle_event(self, namespace, id, data): data = list(r) else: data = [r] - await self._send_packet(self._create_packet( + await self._send_packet(self.packet_class( packet.ACK, namespace=namespace, id=id, data=data)) async def _handle_ack(self, namespace, id, data): @@ -558,7 +555,7 @@ async def _handle_eio_connect(self): self.sid = self.eio.sid real_auth = await self._get_real_value(self.connection_auth) or {} for n in self.connection_namespaces: - await self._send_packet(self._create_packet( + await self._send_packet(self.packet_class( packet.CONNECT, data=real_auth, namespace=n)) async def _handle_eio_message(self, data): @@ -572,7 +569,7 @@ async def _handle_eio_message(self, data): else: await self._handle_ack(pkt.namespace, pkt.id, pkt.data) else: - pkt = self._create_packet(encoded_packet=data) + pkt = self.packet_class(encoded_packet=data) if pkt.packet_type == packet.CONNECT: await self._handle_connect(pkt.namespace, pkt.data) elif pkt.packet_type == packet.DISCONNECT: diff --git a/src/socketio/async_server.py b/src/socketio/async_server.py index fa22393e..6c9e3ca3 100644 --- a/src/socketio/async_server.py +++ b/src/socketio/async_server.py @@ -50,9 +50,6 @@ class AsyncServer(base_server.BaseServer): default is `['/']`, which always accepts connections to the default namespace. Set to `'*'` to accept all namespaces. - :param serializer_args: A mapping of additional parameters to pass to - the serializer. The content of this dictionary - depends on the selected serialization method. :param kwargs: Connection parameters for the underlying Engine.IO server. The Engine.IO configuration supports the following settings: @@ -428,7 +425,7 @@ async def disconnect(self, sid, namespace=None, ignore_queue=False): if delete_it: self.logger.info('Disconnecting %s [%s]', sid, namespace) eio_sid = self.manager.pre_disconnect(sid, namespace=namespace) - await self._send_packet(eio_sid, self._create_packet( + await self._send_packet(eio_sid, self.packet_class( packet.DISCONNECT, namespace=namespace)) await self._trigger_event('disconnect', namespace, sid, self.reason.SERVER_DISCONNECT) @@ -541,13 +538,13 @@ async def _handle_connect(self, eio_sid, namespace, data): or self.namespaces == '*' or namespace in self.namespaces: sid = await self.manager.connect(eio_sid, namespace) if sid is None: - await self._send_packet(eio_sid, self._create_packet( + await self._send_packet(eio_sid, self.packet_class( packet.CONNECT_ERROR, data='Unable to connect', namespace=namespace)) return if self.always_connect: - await self._send_packet(eio_sid, self._create_packet( + await self._send_packet(eio_sid, self.packet_class( packet.CONNECT, {'sid': sid}, namespace=namespace)) fail_reason = exceptions.ConnectionRefusedError().error_args try: @@ -571,15 +568,15 @@ async def _handle_connect(self, eio_sid, namespace, data): if success is False: if self.always_connect: self.manager.pre_disconnect(sid, namespace) - await self._send_packet(eio_sid, self._create_packet( + await self._send_packet(eio_sid, self.packet_class( packet.DISCONNECT, data=fail_reason, namespace=namespace)) else: - await self._send_packet(eio_sid, self._create_packet( + await self._send_packet(eio_sid, self.packet_class( packet.CONNECT_ERROR, data=fail_reason, namespace=namespace)) await self.manager.disconnect(sid, namespace, ignore_queue=True) elif not self.always_connect: - await self._send_packet(eio_sid, self._create_packet( + await self._send_packet(eio_sid, self.packet_class( packet.CONNECT, {'sid': sid}, namespace=namespace)) async def _handle_disconnect(self, eio_sid, namespace, reason=None): @@ -625,7 +622,7 @@ async def _handle_event_internal(self, server, sid, eio_sid, data, data = list(r) else: data = [r] - await server._send_packet(eio_sid, self._create_packet( + await server._send_packet(eio_sid, self.packet_class( packet.ACK, namespace=namespace, id=id, data=data)) async def _handle_ack(self, eio_sid, namespace, id, data): @@ -689,7 +686,7 @@ async def _handle_eio_message(self, eio_sid, data): await self._handle_ack(eio_sid, pkt.namespace, pkt.id, pkt.data) else: - pkt = self._create_packet(encoded_packet=data) + pkt = self.packet_class(encoded_packet=data) if pkt.packet_type == packet.CONNECT: await self._handle_connect(eio_sid, pkt.namespace, pkt.data) elif pkt.packet_type == packet.DISCONNECT: diff --git a/src/socketio/base_client.py b/src/socketio/base_client.py index 2bcaafcf..0232dca7 100644 --- a/src/socketio/base_client.py +++ b/src/socketio/base_client.py @@ -38,8 +38,7 @@ class BaseClient: def __init__(self, reconnection=True, reconnection_attempts=0, reconnection_delay=1, reconnection_delay_max=5, randomization_factor=0.5, logger=False, serializer='default', - json=None, handle_sigint=True, serializer_args=None, - **kwargs): + json=None, handle_sigint=True, **kwargs): global original_signal_handler if handle_sigint and original_signal_handler is None and \ threading.current_thread() == threading.main_thread(): @@ -64,7 +63,6 @@ def __init__(self, reconnection=True, reconnection_attempts=0, self.packet_class = msgpack_packet.MsgPackPacket else: self.packet_class = serializer - self.packet_class_args = serializer_args or {} if json is not None: self.packet_class.json = json engineio_options['json'] = json @@ -285,9 +283,6 @@ def _generate_ack_id(self, namespace, callback): self.callbacks[namespace][id] = callback return id - def _create_packet(self, *args, **kwargs): - return self.packet_class(*args, **kwargs, **self.packet_class_args) - def _handle_eio_connect(self): # pragma: no cover raise NotImplementedError() diff --git a/src/socketio/base_server.py b/src/socketio/base_server.py index 488ffe1d..d134eba1 100644 --- a/src/socketio/base_server.py +++ b/src/socketio/base_server.py @@ -15,7 +15,7 @@ class BaseServer: def __init__(self, client_manager=None, logger=False, serializer='default', json=None, async_handlers=True, always_connect=False, - namespaces=None, serializer_args=None, **kwargs): + namespaces=None, **kwargs): engineio_options = kwargs engineio_logger = engineio_options.pop('engineio_logger', None) if engineio_logger is not None: @@ -27,7 +27,6 @@ def __init__(self, client_manager=None, logger=False, serializer='default', self.packet_class = msgpack_packet.MsgPackPacket else: self.packet_class = serializer - self.packet_class_args = serializer_args or {} if json is not None: self.packet_class.json = json engineio_options['json'] = json @@ -254,10 +253,6 @@ def _get_namespace_handler(self, namespace, args): args = (namespace, *args) return handler, args - def _create_packet(self, *args, **kwargs): - return self.packet_class(*args, **kwargs, - **self.packet_class_args) - def _handle_eio_connect(self): # pragma: no cover raise NotImplementedError() diff --git a/src/socketio/client.py b/src/socketio/client.py index 29c1f25c..5282e0a1 100644 --- a/src/socketio/client.py +++ b/src/socketio/client.py @@ -48,9 +48,6 @@ class Client(base_client.BaseClient): leave interrupt handling to the calling application. Interrupt handling can only be enabled when the client instance is created in the main thread. - :param serializer_args: A mapping of additional parameters to pass to - the serializer. The content of this dictionary - depends on the selected serialization method. The Engine.IO configuration supports the following settings: @@ -237,9 +234,8 @@ def emit(self, event, data=None, namespace=None, callback=None): data = [data] else: data = [] - self._send_packet( - self._create_packet(packet.EVENT, namespace=namespace, - data=[event] + data, id=id)) + self._send_packet(self.packet_class(packet.EVENT, namespace=namespace, + data=[event] + data, id=id)) def send(self, data, namespace=None, callback=None): """Send a message to the server. @@ -311,7 +307,7 @@ def disconnect(self): # here we just request the disconnection # later in _handle_eio_disconnect we invoke the disconnect handler for n in self.namespaces: - self._send_packet(self._create_packet( + self._send_packet(self.packet_class( packet.DISCONNECT, namespace=n)) self.eio.disconnect() @@ -406,7 +402,7 @@ def _handle_event(self, namespace, id, data): data = list(r) else: data = [r] - self._send_packet(self._create_packet( + self._send_packet(self.packet_class( packet.ACK, namespace=namespace, id=id, data=data)) def _handle_ack(self, namespace, id, data): @@ -510,7 +506,7 @@ def _handle_eio_connect(self): self.sid = self.eio.sid real_auth = self._get_real_value(self.connection_auth) or {} for n in self.connection_namespaces: - self._send_packet(self._create_packet( + self._send_packet(self.packet_class( packet.CONNECT, data=real_auth, namespace=n)) def _handle_eio_message(self, data): @@ -524,7 +520,7 @@ def _handle_eio_message(self, data): else: self._handle_ack(pkt.namespace, pkt.id, pkt.data) else: - pkt = self._create_packet(encoded_packet=data) + pkt = self.packet_class(encoded_packet=data) if pkt.packet_type == packet.CONNECT: self._handle_connect(pkt.namespace, pkt.data) elif pkt.packet_type == packet.DISCONNECT: diff --git a/src/socketio/msgpack_packet.py b/src/socketio/msgpack_packet.py index 846db2f3..27462634 100644 --- a/src/socketio/msgpack_packet.py +++ b/src/socketio/msgpack_packet.py @@ -5,33 +5,13 @@ class MsgPackPacket(packet.Packet): uses_binary_events = False - def __init__( - self, - packet_type=packet.EVENT, - data=None, - namespace=None, - id=None, - binary=None, - encoded_packet=None, - dumps_default=None, - ext_hook=None, - ): - self.dumps_default = dumps_default - self.ext_hook = ext_hook - super().__init__( - packet_type, data, namespace, id, binary, encoded_packet - ) - def encode(self): """Encode the packet for transmission.""" - return msgpack.dumps(self._to_dict(), default=self.dumps_default) + return msgpack.dumps(self._to_dict()) def decode(self, encoded_packet): """Decode a transmitted package.""" - if self.ext_hook is None: - decoded = msgpack.loads(encoded_packet) - else: - decoded = msgpack.loads(encoded_packet, ext_hook=self.ext_hook) + decoded = msgpack.loads(encoded_packet) self.packet_type = decoded['type'] self.data = decoded.get('data') self.id = decoded.get('id') diff --git a/src/socketio/server.py b/src/socketio/server.py index 7312506c..f3257081 100644 --- a/src/socketio/server.py +++ b/src/socketio/server.py @@ -53,9 +53,6 @@ class Server(base_server.BaseServer): default is `['/']`, which always accepts connections to the default namespace. Set to `'*'` to accept all namespaces. - :param serializer_args: A mapping of additional parameters to pass to - the serializer. The content of this dictionary - depends on the selected serialization method. :param kwargs: Connection parameters for the underlying Engine.IO server. The Engine.IO configuration supports the following settings: @@ -404,7 +401,7 @@ def disconnect(self, sid, namespace=None, ignore_queue=False): if delete_it: self.logger.info('Disconnecting %s [%s]', sid, namespace) eio_sid = self.manager.pre_disconnect(sid, namespace=namespace) - self._send_packet(eio_sid, self._create_packet( + self._send_packet(eio_sid, self.packet_class( packet.DISCONNECT, namespace=namespace)) self._trigger_event('disconnect', namespace, sid, self.reason.SERVER_DISCONNECT) @@ -523,13 +520,13 @@ def _handle_connect(self, eio_sid, namespace, data): or self.namespaces == '*' or namespace in self.namespaces: sid = self.manager.connect(eio_sid, namespace) if sid is None: - self._send_packet(eio_sid, self._create_packet( + self._send_packet(eio_sid, self.packet_class( packet.CONNECT_ERROR, data='Unable to connect', namespace=namespace)) return if self.always_connect: - self._send_packet(eio_sid, self._create_packet( + self._send_packet(eio_sid, self.packet_class( packet.CONNECT, {'sid': sid}, namespace=namespace)) fail_reason = exceptions.ConnectionRefusedError().error_args try: @@ -553,15 +550,15 @@ def _handle_connect(self, eio_sid, namespace, data): if success is False: if self.always_connect: self.manager.pre_disconnect(sid, namespace) - self._send_packet(eio_sid, self._create_packet( + self._send_packet(eio_sid, self.packet_class( packet.DISCONNECT, data=fail_reason, namespace=namespace)) else: - self._send_packet(eio_sid, self._create_packet( + self._send_packet(eio_sid, self.packet_class( packet.CONNECT_ERROR, data=fail_reason, namespace=namespace)) self.manager.disconnect(sid, namespace, ignore_queue=True) elif not self.always_connect: - self._send_packet(eio_sid, self._create_packet( + self._send_packet(eio_sid, self.packet_class( packet.CONNECT, {'sid': sid}, namespace=namespace)) def _handle_disconnect(self, eio_sid, namespace, reason=None): @@ -604,7 +601,7 @@ def _handle_event_internal(self, server, sid, eio_sid, data, namespace, data = list(r) else: data = [r] - server._send_packet(eio_sid, self._create_packet( + server._send_packet(eio_sid, self.packet_class( packet.ACK, namespace=namespace, id=id, data=data)) def _handle_ack(self, eio_sid, namespace, id, data): @@ -653,7 +650,7 @@ def _handle_eio_message(self, eio_sid, data): else: self._handle_ack(eio_sid, pkt.namespace, pkt.id, pkt.data) else: - pkt = self._create_packet(encoded_packet=data) + pkt = self.packet_class(encoded_packet=data) if pkt.packet_type == packet.CONNECT: self._handle_connect(eio_sid, pkt.namespace, pkt.data) elif pkt.packet_type == packet.DISCONNECT: diff --git a/tests/async/test_client.py b/tests/async/test_client.py index 58e2ac75..b4b0c6c5 100644 --- a/tests/async/test_client.py +++ b/tests/async/test_client.py @@ -1,6 +1,5 @@ import asyncio from unittest import mock -from datetime import datetime, timezone, timedelta import pytest @@ -1243,33 +1242,3 @@ async def test_eio_disconnect_no_reconnect(self): assert c.sid is None assert not c.connected c.start_background_task.assert_not_called() - - def test_serializer_args(self): - args = {"foo": "bar"} - c = async_client.AsyncClient(serializer_args=args) - assert c.packet_class_args == args - - def test_serializer_args_with_msgpack(self): - def default(o): - if isinstance(o, datetime): - return o.isoformat() - raise TypeError("Unknown type") - args = {"dumps_default": default} - data = {"current": datetime.now(timezone(timedelta(0)))} - c = async_client.AsyncClient(serializer='msgpack', - serializer_args=args) - p = c._create_packet(data=data) - p2 = c._create_packet(encoded_packet=p.encode()) - - assert p.data != p2.data - assert isinstance(p2.data, dict) - assert "current" in p2.data - assert isinstance(p2.data["current"], str) - assert default(data["current"]) == p2.data["current"] - - def test_invalid_serializer_args(self): - args = {"invalid_arg": 123} - c = async_client.AsyncClient(serializer='msgpack', - serializer_args=args) - with pytest.raises(TypeError): - c._create_packet(data={"foo": "bar"}).encode() diff --git a/tests/async/test_server.py b/tests/async/test_server.py index 793192f2..575f2097 100644 --- a/tests/async/test_server.py +++ b/tests/async/test_server.py @@ -1,7 +1,6 @@ import asyncio import logging from unittest import mock -from datetime import datetime, timezone, timedelta from engineio import json from engineio import packet as eio_packet @@ -1090,33 +1089,3 @@ async def test_sleep(self, eio): s = async_server.AsyncServer() await s.sleep(1.23) s.eio.sleep.assert_awaited_once_with(1.23) - - def test_serializer_args(self, eio): - args = {"foo": "bar"} - s = async_server.AsyncServer(serializer_args=args) - assert s.packet_class_args == args - - def test_serializer_args_with_msgpack(self, eio): - def default(o): - if isinstance(o, datetime): - return o.isoformat() - raise TypeError("Unknown type") - args = {"dumps_default": default} - data = {"current": datetime.now(timezone(timedelta(0)))} - s = async_server.AsyncServer(serializer='msgpack', - serializer_args=args) - p = s._create_packet(data=data) - p2 = s._create_packet(encoded_packet=p.encode()) - - assert p.data != p2.data - assert isinstance(p2.data, dict) - assert "current" in p2.data - assert isinstance(p2.data["current"], str) - assert default(data["current"]) == p2.data["current"] - - def test_invalid_serializer_args(self, eio): - args = {"invalid_arg": 123} - s = async_server.AsyncServer(serializer='msgpack', - serializer_args=args) - with pytest.raises(TypeError): - s._create_packet(data={"foo": "bar"}).encode() diff --git a/tests/common/test_client.py b/tests/common/test_client.py index 90ab5dfd..cbda3f1f 100644 --- a/tests/common/test_client.py +++ b/tests/common/test_client.py @@ -1,7 +1,6 @@ import logging import time from unittest import mock -from datetime import datetime, timezone, timedelta from engineio import exceptions as engineio_exceptions from engineio import json @@ -1387,31 +1386,3 @@ def test_eio_disconnect_no_reconnect(self): assert c.sid is None assert not c.connected c.start_background_task.assert_not_called() - - def test_serializer_args(self): - args = {"foo": "bar"} - c = client.Client(serializer_args=args) - assert c.packet_class_args == args - - def test_serializer_args_with_msgpack(self): - def default(o): - if isinstance(o, datetime): - return o.isoformat() - raise TypeError("Unknown type") - args = {"dumps_default": default} - data = {"current": datetime.now(timezone(timedelta(0)))} - c = client.Client(serializer='msgpack', serializer_args=args) - p = c._create_packet(data=data) - p2 = c._create_packet(encoded_packet=p.encode()) - - assert p.data != p2.data - assert isinstance(p2.data, dict) - assert "current" in p2.data - assert isinstance(p2.data["current"], str) - assert default(data["current"]) == p2.data["current"] - - def test_invalid_serializer_args(self): - args = {"invalid_arg": 123} - c = client.Client(serializer='msgpack', serializer_args=args) - with pytest.raises(TypeError): - c._create_packet(data={"foo": "bar"}).encode() diff --git a/tests/common/test_msgpack_packet.py b/tests/common/test_msgpack_packet.py index 8a3befd5..e0197a27 100644 --- a/tests/common/test_msgpack_packet.py +++ b/tests/common/test_msgpack_packet.py @@ -1,8 +1,3 @@ -from datetime import datetime, timedelta, timezone - -import pytest -import msgpack - from socketio import msgpack_packet from socketio import packet @@ -10,8 +5,7 @@ class TestMsgPackPacket: def test_encode_decode(self): p = msgpack_packet.MsgPackPacket( - packet.CONNECT, data={'auth': {'token': '123'}}, namespace='/foo' - ) + packet.CONNECT, data={'auth': {'token': '123'}}, namespace='/foo') p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode()) assert p.packet_type == p2.packet_type assert p.data == p2.data @@ -20,8 +14,7 @@ def test_encode_decode(self): def test_encode_decode_with_id(self): p = msgpack_packet.MsgPackPacket( - packet.EVENT, data=['ev', 42], id=123, namespace='/foo' - ) + packet.EVENT, data=['ev', 42], id=123, namespace='/foo') p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode()) assert p.packet_type == p2.packet_type assert p.data == p2.data @@ -39,99 +32,3 @@ def test_encode_binary_ack_packet(self): assert p.packet_type == packet.ACK p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode()) assert p2.data == {'foo': b'bar'} - - def test_encode_with_dumps_default(self): - def default(obj): - if isinstance(obj, datetime): - return obj.isoformat() - raise TypeError('Unknown type') - - data = { - 'current': datetime.now(tz=timezone(timedelta(0))), - 'key': 'value', - } - p = msgpack_packet.MsgPackPacket(data=data, dumps_default=default) - p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode()) - assert p.packet_type == p2.packet_type - assert p.id == p2.id - assert p.namespace == p2.namespace - assert p.data != p2.data - - assert isinstance(p2.data, dict) - assert 'current' in p2.data - assert isinstance(p2.data['current'], str) - assert default(data['current']) == p2.data['current'] - - data.pop('current') - p2_data_without_current = p2.data.copy() - p2_data_without_current.pop('current') - assert data == p2_data_without_current - - def test_encode_without_dumps_default(self): - data = { - 'current': datetime.now(tz=timezone(timedelta(0))), - 'key': 'value', - } - p_without_default = msgpack_packet.MsgPackPacket(data=data) - with pytest.raises(TypeError): - p_without_default.encode() - - def test_encode_decode_with_ext_hook(self): - class Custom: - def __init__(self, value): - self.value = value - - def __eq__(self, value: object) -> bool: - return isinstance(value, Custom) and self.value == value.value - - def default(obj): - if isinstance(obj, Custom): - return msgpack.ExtType(1, obj.value) - raise TypeError('Unknown type') - - def ext_hook(code, data): - if code == 1: - return Custom(data) - raise TypeError('Unknown ext type') - - data = {'custom': Custom(b'custom_data'), 'key': 'value'} - p = msgpack_packet.MsgPackPacket(data=data, dumps_default=default) - p2 = msgpack_packet.MsgPackPacket( - encoded_packet=p.encode(), ext_hook=ext_hook - ) - assert p.packet_type == p2.packet_type - assert p.id == p2.id - assert p.data == p2.data - assert p.namespace == p2.namespace - - def test_encode_decode_without_ext_hook(self): - class Custom: - def __init__(self, value): - self.value = value - - def __eq__(self, value: object) -> bool: - return isinstance(value, Custom) and self.value == value.value - - def default(obj): - if isinstance(obj, Custom): - return msgpack.ExtType(1, obj.value) - raise TypeError('Unknown type') - - data = {'custom': Custom(b'custom_data'), 'key': 'value'} - p = msgpack_packet.MsgPackPacket(data=data, dumps_default=default) - p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode()) - assert p.packet_type == p2.packet_type - assert p.id == p2.id - assert p.namespace == p2.namespace - assert p.data != p2.data - - assert isinstance(p2.data, dict) - assert 'custom' in p2.data - assert isinstance(p2.data['custom'], msgpack.ExtType) - assert p2.data['custom'].code == 1 - assert p2.data['custom'].data == b'custom_data' - - data.pop('custom') - p2_data_without_custom = p2.data.copy() - p2_data_without_custom.pop('custom') - assert data == p2_data_without_custom diff --git a/tests/common/test_server.py b/tests/common/test_server.py index 6bbe7c4c..bdbbfe07 100644 --- a/tests/common/test_server.py +++ b/tests/common/test_server.py @@ -1,6 +1,5 @@ import logging from unittest import mock -from datetime import datetime, timezone, timedelta from engineio import json from engineio import packet as eio_packet @@ -1033,31 +1032,3 @@ def test_sleep(self, eio): s = server.Server() s.sleep(1.23) s.eio.sleep.assert_called_once_with(1.23) - - def test_serializer_args(self, eio): - args = {"foo": "bar"} - s = server.Server(serializer_args=args) - assert s.packet_class_args == args - - def test_serializer_args_with_msgpack(self, eio): - def default(o): - if isinstance(o, datetime): - return o.isoformat() - raise TypeError("Unknown type") - args = {"dumps_default": default} - data = {"current": datetime.now(timezone(timedelta(0)))} - s = server.Server(serializer='msgpack', serializer_args=args) - p = s._create_packet(data=data) - p2 = s._create_packet(encoded_packet=p.encode()) - - assert p.data != p2.data - assert isinstance(p2.data, dict) - assert "current" in p2.data - assert isinstance(p2.data["current"], str) - assert default(data["current"]) == p2.data["current"] - - def test_invalid_serializer_args(self, eio): - args = {"invalid_arg": 123} - s = server.Server(serializer='msgpack', serializer_args=args) - with pytest.raises(TypeError): - s._create_packet(data={"foo": "bar"}).encode() From 37f94fed127bb89e0510b0ff664262eadc91ee40 Mon Sep 17 00:00:00 2001 From: phi Date: Wed, 5 Nov 2025 19:48:17 +0900 Subject: [PATCH 11/15] feat: configure --- src/socketio/packet.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/src/socketio/packet.py b/src/socketio/packet.py index 3deba7fb..dea5a8e6 100644 --- a/src/socketio/packet.py +++ b/src/socketio/packet.py @@ -1,4 +1,5 @@ import functools +import warnings from engineio import json as _json (CONNECT, DISCONNECT, EVENT, ACK, CONNECT_ERROR, BINARY_EVENT, BINARY_ACK) = \ @@ -21,6 +22,8 @@ class Packet: uses_binary_events = True json = _json + _configure_args = ((),()) + _subclass_registry = {} def __init__(self, packet_type=EVENT, data=None, namespace=None, id=None, binary=None, encoded_packet=None): @@ -192,3 +195,23 @@ def _to_dict(self): if self.id is not None: d['id'] = self.id return d + + @classmethod + def configure(cls, *args, **kwargs): + configure_args = (args, tuple(sorted(kwargs.items()))) + try: + args_hash = hash(configure_args) + except TypeError: + warnings.warn('Packet.configure() called with unhashable ' + 'arguments; subclass caching will not work.', + RuntimeWarning) + args_hash = None + + if args_hash in cls._subclass_registry: + return cls._subclass_registry[args_hash] + return cls._configure(*args, **kwargs) + + @classmethod + def _configure(cls, *args, **kwargs): + raise NotImplementedError('Packet._configure() must be implemented ' + 'by subclasses.') \ No newline at end of file From f4b635ee8f3b4a8d7a176424aca4a648d13fa4aa Mon Sep 17 00:00:00 2001 From: phi Date: Wed, 5 Nov 2025 19:57:29 +0900 Subject: [PATCH 12/15] fix: warning -> logging --- src/socketio/packet.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/socketio/packet.py b/src/socketio/packet.py index dea5a8e6..61b38363 100644 --- a/src/socketio/packet.py +++ b/src/socketio/packet.py @@ -1,5 +1,5 @@ import functools -import warnings +import logging from engineio import json as _json (CONNECT, DISCONNECT, EVENT, ACK, CONNECT_ERROR, BINARY_EVENT, BINARY_ACK) = \ @@ -7,6 +7,7 @@ packet_names = ['CONNECT', 'DISCONNECT', 'EVENT', 'ACK', 'CONNECT_ERROR', 'BINARY_EVENT', 'BINARY_ACK'] +logger = logging.getLogger('socketio.packet') class Packet: """Socket.IO packet.""" @@ -202,12 +203,13 @@ def configure(cls, *args, **kwargs): try: args_hash = hash(configure_args) except TypeError: - warnings.warn('Packet.configure() called with unhashable ' - 'arguments; subclass caching will not work.', - RuntimeWarning) + logger.warning("Packet.configure() called with unhashable " + "arguments; subclass caching will not work.") args_hash = None if args_hash in cls._subclass_registry: + logger.debug("Using cached Packet subclass for args %s, %s", + args, kwargs) return cls._subclass_registry[args_hash] return cls._configure(*args, **kwargs) From 3e83ae66fa9f8eba02cb8fad9747c34990d88de6 Mon Sep 17 00:00:00 2001 From: phi Date: Wed, 5 Nov 2025 19:57:50 +0900 Subject: [PATCH 13/15] feat: MsgPackPacket._configure --- src/socketio/msgpack_packet.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/src/socketio/msgpack_packet.py b/src/socketio/msgpack_packet.py index 27462634..85df97c4 100644 --- a/src/socketio/msgpack_packet.py +++ b/src/socketio/msgpack_packet.py @@ -7,12 +7,33 @@ class MsgPackPacket(packet.Packet): def encode(self): """Encode the packet for transmission.""" - return msgpack.dumps(self._to_dict()) + return self._encode() + + def _encode(self, **kwargs): + return _msgpack.dumps(self._to_dict(), **kwargs) def decode(self, encoded_packet): """Decode a transmitted package.""" - decoded = msgpack.loads(encoded_packet) + return self._decode(encoded_packet) + + def _decode(self, encoded_packet, **kwargs): + decoded = msgpack.loads(encoded_packet, **kwargs) self.packet_type = decoded['type'] self.data = decoded.get('data') self.id = decoded.get('id') self.namespace = decoded['nsp'] + + @classmethod + def _configure(cls, *args, **kwargs): + dumps_default = kwargs.pop('dumps_default', None) + ext_hook = kwargs.pop('ext_hook', msgpack.ExtType) + + class ConfiguredMsgPackPacket(cls): + def _encode(self, **kwargs): + kwargs.setdefault('default', dumps_default) + return super()._encode(**kwargs) + def _decode(self, encoded_packet, **kwargs): + kwargs.setdefault('ext_hook', ext_hook) + return super()._decode(encoded_packet, **kwargs) + + return ConfiguredMsgPackPacket \ No newline at end of file From 30b67dd53fb95c9c050e978fc3e9cc931db0035f Mon Sep 17 00:00:00 2001 From: phi Date: Wed, 5 Nov 2025 19:59:58 +0900 Subject: [PATCH 14/15] fix: set cache --- src/socketio/packet.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/socketio/packet.py b/src/socketio/packet.py index 61b38363..421c800b 100644 --- a/src/socketio/packet.py +++ b/src/socketio/packet.py @@ -211,7 +211,12 @@ def configure(cls, *args, **kwargs): logger.debug("Using cached Packet subclass for args %s, %s", args, kwargs) return cls._subclass_registry[args_hash] - return cls._configure(*args, **kwargs) + new = cls._configure(*args, **kwargs) + if args_hash is not None: + cls._subclass_registry[args_hash] = new + logger.debug("Caching Packet subclass for args %s, %s", + args, kwargs) + return new @classmethod def _configure(cls, *args, **kwargs): From 608a82fa2e022392aea1a05f270ec7f25378355f Mon Sep 17 00:00:00 2001 From: phi Date: Wed, 5 Nov 2025 20:03:21 +0900 Subject: [PATCH 15/15] fix: add warnings --- src/socketio/msgpack_packet.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/src/socketio/msgpack_packet.py b/src/socketio/msgpack_packet.py index 85df97c4..32c4debd 100644 --- a/src/socketio/msgpack_packet.py +++ b/src/socketio/msgpack_packet.py @@ -1,6 +1,9 @@ +import logging import msgpack from . import packet +logger = logging.getLogger('socketio') + class MsgPackPacket(packet.Packet): uses_binary_events = False @@ -8,14 +11,14 @@ class MsgPackPacket(packet.Packet): def encode(self): """Encode the packet for transmission.""" return self._encode() - + def _encode(self, **kwargs): return _msgpack.dumps(self._to_dict(), **kwargs) def decode(self, encoded_packet): """Decode a transmitted package.""" return self._decode(encoded_packet) - + def _decode(self, encoded_packet, **kwargs): decoded = msgpack.loads(encoded_packet, **kwargs) self.packet_type = decoded['type'] @@ -28,12 +31,26 @@ def _configure(cls, *args, **kwargs): dumps_default = kwargs.pop('dumps_default', None) ext_hook = kwargs.pop('ext_hook', msgpack.ExtType) + if args: + logger.warning( + 'Some positional arguments to MsgPackPacket.configure() are ' + 'not used: %s', + args, + ) + if kwargs: + logger.warning( + 'Some keyword arguments to MsgPackPacket.configure() are ' + 'not used: %s', + kwargs, + ) + class ConfiguredMsgPackPacket(cls): def _encode(self, **kwargs): kwargs.setdefault('default', dumps_default) return super()._encode(**kwargs) + def _decode(self, encoded_packet, **kwargs): kwargs.setdefault('ext_hook', ext_hook) return super()._decode(encoded_packet, **kwargs) - - return ConfiguredMsgPackPacket \ No newline at end of file + + return ConfiguredMsgPackPacket