diff --git a/src/socketio/msgpack_packet.py b/src/socketio/msgpack_packet.py index 27462634..9622dd26 100644 --- a/src/socketio/msgpack_packet.py +++ b/src/socketio/msgpack_packet.py @@ -4,14 +4,28 @@ class MsgPackPacket(packet.Packet): uses_binary_events = False + dumps_default = None + ext_hook = msgpack.ExtType + + @classmethod + def configure(cls, dumps_default=None, ext_hook=msgpack.ExtType): + class CustomMsgPackPacket(MsgPackPacket): + dumps_default = None + ext_hook = None + + CustomMsgPackPacket.dumps_default = dumps_default + CustomMsgPackPacket.ext_hook = ext_hook + return CustomMsgPackPacket def encode(self): """Encode the packet for transmission.""" - return msgpack.dumps(self._to_dict()) + return msgpack.dumps(self._to_dict(), + default=self.__class__.dumps_default) def decode(self, encoded_packet): """Decode a transmitted package.""" - decoded = msgpack.loads(encoded_packet) + decoded = msgpack.loads(encoded_packet, + ext_hook=self.__class__.ext_hook) self.packet_type = decoded['type'] self.data = decoded.get('data') self.id = decoded.get('id') diff --git a/tests/async/test_client.py b/tests/async/test_client.py index b4b0c6c5..7a7bfa7c 100644 --- a/tests/async/test_client.py +++ b/tests/async/test_client.py @@ -1,5 +1,6 @@ import asyncio from unittest import mock +from datetime import datetime, timezone, timedelta import pytest @@ -8,6 +9,7 @@ from engineio import exceptions as engineio_exceptions from socketio import exceptions from socketio import packet +from socketio.msgpack_packet import MsgPackPacket class TestAsyncClient: @@ -1242,3 +1244,21 @@ async def test_eio_disconnect_no_reconnect(self): assert c.sid is None assert not c.connected c.start_background_task.assert_not_called() + + def test_serializer_args_with_msgpack(self): + def default(o): + if isinstance(o, datetime): + return o.isoformat() + raise TypeError("Unknown type") + + data = {"current": datetime.now(timezone(timedelta(0)))} + c = async_client.AsyncClient( + serializer=MsgPackPacket.configure(dumps_default=default)) + p = c.packet_class(data=data) + p2 = c.packet_class(encoded_packet=p.encode()) + + assert p.data != p2.data + assert isinstance(p2.data, dict) + assert "current" in p2.data + assert isinstance(p2.data["current"], str) + assert default(data["current"]) == p2.data["current"] diff --git a/tests/async/test_server.py b/tests/async/test_server.py index 575f2097..10d7ba14 100644 --- a/tests/async/test_server.py +++ b/tests/async/test_server.py @@ -1,6 +1,7 @@ import asyncio import logging from unittest import mock +from datetime import datetime, timezone, timedelta from engineio import json from engineio import packet as eio_packet @@ -11,6 +12,7 @@ from socketio import exceptions from socketio import namespace from socketio import packet +from socketio.msgpack_packet import MsgPackPacket @mock.patch('socketio.server.engineio.AsyncServer', **{ @@ -1089,3 +1091,21 @@ async def test_sleep(self, eio): s = async_server.AsyncServer() await s.sleep(1.23) s.eio.sleep.assert_awaited_once_with(1.23) + + def test_serializer_args_with_msgpack(self, eio): + def default(o): + if isinstance(o, datetime): + return o.isoformat() + raise TypeError("Unknown type") + + data = {"current": datetime.now(timezone(timedelta(0)))} + s = async_server.AsyncServer( + serializer=MsgPackPacket.configure(dumps_default=default)) + p = s.packet_class(data=data) + p2 = s.packet_class(encoded_packet=p.encode()) + + assert p.data != p2.data + assert isinstance(p2.data, dict) + assert "current" in p2.data + assert isinstance(p2.data["current"], str) + assert default(data["current"]) == p2.data["current"] diff --git a/tests/common/test_client.py b/tests/common/test_client.py index cbda3f1f..d386a9c3 100644 --- a/tests/common/test_client.py +++ b/tests/common/test_client.py @@ -1,6 +1,7 @@ import logging import time from unittest import mock +from datetime import datetime, timezone, timedelta from engineio import exceptions as engineio_exceptions from engineio import json @@ -13,6 +14,7 @@ from socketio import msgpack_packet from socketio import namespace from socketio import packet +from socketio.msgpack_packet import MsgPackPacket class TestClient: @@ -1386,3 +1388,21 @@ def test_eio_disconnect_no_reconnect(self): assert c.sid is None assert not c.connected c.start_background_task.assert_not_called() + + def test_serializer_args_with_msgpack(self): + def default(o): + if isinstance(o, datetime): + return o.isoformat() + raise TypeError("Unknown type") + + data = {"current": datetime.now(timezone(timedelta(0)))} + c = client.Client( + serializer=MsgPackPacket.configure(dumps_default=default)) + p = c.packet_class(data=data) + p2 = c.packet_class(encoded_packet=p.encode()) + + assert p.data != p2.data + assert isinstance(p2.data, dict) + assert "current" in p2.data + assert isinstance(p2.data["current"], str) + assert default(data["current"]) == p2.data["current"] diff --git a/tests/common/test_msgpack_packet.py b/tests/common/test_msgpack_packet.py index e0197a27..0fad0292 100644 --- a/tests/common/test_msgpack_packet.py +++ b/tests/common/test_msgpack_packet.py @@ -1,3 +1,8 @@ +from datetime import datetime, timedelta, timezone + +import pytest +import msgpack + from socketio import msgpack_packet from socketio import packet @@ -32,3 +37,102 @@ def test_encode_binary_ack_packet(self): assert p.packet_type == packet.ACK p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode()) assert p2.data == {'foo': b'bar'} + + def test_encode_with_dumps_default(self): + def default(obj): + if isinstance(obj, datetime): + return obj.isoformat() + raise TypeError('Unknown type') + + data = { + 'current': datetime.now(tz=timezone(timedelta(0))), + 'key': 'value', + } + p = msgpack_packet.MsgPackPacket.configure(dumps_default=default)( + data=data) + p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode()) + assert p.packet_type == p2.packet_type + assert p.id == p2.id + assert p.namespace == p2.namespace + assert p.data != p2.data + + assert isinstance(p2.data, dict) + assert 'current' in p2.data + assert isinstance(p2.data['current'], str) + assert default(data['current']) == p2.data['current'] + + data.pop('current') + p2_data_without_current = p2.data.copy() + p2_data_without_current.pop('current') + assert data == p2_data_without_current + + def test_encode_without_dumps_default(self): + data = { + 'current': datetime.now(tz=timezone(timedelta(0))), + 'key': 'value', + } + p_without_default = msgpack_packet.MsgPackPacket(data=data) + with pytest.raises(TypeError): + p_without_default.encode() + + def test_encode_decode_with_ext_hook(self): + class Custom: + def __init__(self, value): + self.value = value + + def __eq__(self, value: object) -> bool: + return isinstance(value, Custom) and self.value == value.value + + def default(obj): + if isinstance(obj, Custom): + return msgpack.ExtType(1, obj.value) + raise TypeError('Unknown type') + + def ext_hook(code, data): + if code == 1: + return Custom(data) + raise TypeError('Unknown ext type') + + data = {'custom': Custom(b'custom_data'), 'key': 'value'} + p = msgpack_packet.MsgPackPacket.configure(dumps_default=default)( + data=data) + p2 = msgpack_packet.MsgPackPacket.configure(ext_hook=ext_hook)( + encoded_packet=p.encode() + ) + assert p.packet_type == p2.packet_type + assert p.id == p2.id + assert p.data == p2.data + assert p.namespace == p2.namespace + + def test_encode_decode_without_ext_hook(self): + class Custom: + def __init__(self, value): + self.value = value + + def __eq__(self, value: object) -> bool: + return isinstance(value, Custom) and self.value == value.value + + def default(obj): + if isinstance(obj, Custom): + return msgpack.ExtType(1, obj.value) + raise TypeError('Unknown type') + + data = {'custom': Custom(b'custom_data'), 'key': 'value'} + p = msgpack_packet.MsgPackPacket.configure(dumps_default=default)( + data=data) + p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode()) + assert p.packet_type == p2.packet_type + assert p.id == p2.id + assert p.namespace == p2.namespace + assert p.data != p2.data + + assert isinstance(p2.data, dict) + assert 'custom' in p2.data + assert isinstance(p2.data['custom'], msgpack.ExtType) + assert p2.data['custom'].code == 1 + assert p2.data['custom'].data == b'custom_data' + + data.pop('custom') + p2_data_without_custom = p2.data.copy() + p2_data_without_custom.pop('custom') + assert data == p2_data_without_custom diff --git a/tests/common/test_server.py b/tests/common/test_server.py index bdbbfe07..4c2c8071 100644 --- a/tests/common/test_server.py +++ b/tests/common/test_server.py @@ -1,5 +1,6 @@ import logging from unittest import mock +from datetime import datetime, timezone, timedelta from engineio import json from engineio import packet as eio_packet @@ -10,6 +11,7 @@ from socketio import namespace from socketio import packet from socketio import server +from socketio.msgpack_packet import MsgPackPacket @mock.patch('socketio.server.engineio.Server', **{ @@ -1032,3 +1034,21 @@ def test_sleep(self, eio): s = server.Server() s.sleep(1.23) s.eio.sleep.assert_called_once_with(1.23) + + def test_serializer_args_with_msgpack(self, eio): + def default(o): + if isinstance(o, datetime): + return o.isoformat() + raise TypeError("Unknown type") + + data = {"current": datetime.now(timezone(timedelta(0)))} + s = server.Server( + serializer=MsgPackPacket.configure(dumps_default=default)) + p = s.packet_class(data=data) + p2 = s.packet_class(encoded_packet=p.encode()) + + assert p.data != p2.data + assert isinstance(p2.data, dict) + assert "current" in p2.data + assert isinstance(p2.data["current"], str) + assert default(data["current"]) == p2.data["current"]