diff --git a/langgraph/checkpoint/redis/base.py b/langgraph/checkpoint/redis/base.py index 29fbd9c..022dcec 100644 --- a/langgraph/checkpoint/redis/base.py +++ b/langgraph/checkpoint/redis/base.py @@ -394,6 +394,25 @@ def _recursive_deserialize(self, obj: Any) -> Any: # Decode base64-encoded bytes return self._decode_blob(obj["__bytes__"]) + # Check if this is a Send object marker (issue #94) + if ( + obj.get("__send__") is True + and "node" in obj + and "arg" in obj + and len(obj) == 3 + ): + try: + from langgraph.types import Send + + return Send( + node=obj["node"], + arg=self._recursive_deserialize(obj["arg"]), + ) + except (ImportError, TypeError, ValueError) as e: + logger.debug( + "Failed to deserialize Send object: %s", e, exc_info=True + ) + # Check if this is a LangChain serialized object if obj.get("lc") in (1, 2) and obj.get("type") == "constructor": try: diff --git a/langgraph/checkpoint/redis/jsonplus_redis.py b/langgraph/checkpoint/redis/jsonplus_redis.py index a203b74..b61b50c 100644 --- a/langgraph/checkpoint/redis/jsonplus_redis.py +++ b/langgraph/checkpoint/redis/jsonplus_redis.py @@ -48,6 +48,16 @@ def _default_handler(self, obj: Any) -> Any: "id": obj.id, } + # Handle Send objects with a type marker (issue #94) + from langgraph.types import Send + + if isinstance(obj, Send): + return { + "__send__": True, + "node": obj.node, + "arg": obj.arg, + } + # Try to encode using parent's constructor args encoder # This creates the {"lc": 2, "type": "constructor", ...} format try: @@ -69,14 +79,14 @@ def _default_handler(self, obj: Any) -> Any: raise TypeError(f"Object of type {type(obj)} is not JSON serializable") def _preprocess_interrupts(self, obj: Any) -> Any: - """Recursively add type markers to Interrupt objects before serialization. + """Recursively add type markers to Interrupt and Send objects before serialization. This prevents false positives where user data with {value, id} fields could be incorrectly deserialized as Interrupt objects. Also handles dataclass instances to preserve type information during serialization. """ - from langgraph.types import Interrupt + from langgraph.types import Interrupt, Send if isinstance(obj, Interrupt): # Add type marker to distinguish from plain dicts @@ -85,6 +95,13 @@ def _preprocess_interrupts(self, obj: Any) -> Any: "value": self._preprocess_interrupts(obj.value), "id": obj.id, } + elif isinstance(obj, Send): + # Add type marker to distinguish from plain dicts (issue #94) + return { + "__send__": True, + "node": obj.node, + "arg": self._preprocess_interrupts(obj.arg), + } elif isinstance(obj, set): # Handle sets by converting to list for JSON serialization # Will be reconstructed back to set on deserialization @@ -277,6 +294,28 @@ def _revive_if_needed(self, obj: Any) -> Any: "Failed to deserialize Interrupt object: %s", e, exc_info=True ) + # Check if this is a serialized Send object with type marker (issue #94) + # Send objects serialize to {"__send__": True, "node": ..., "arg": ...} + if ( + obj.get("__send__") is True + and "node" in obj + and "arg" in obj + and len(obj) == 3 + ): + # Try to reconstruct as a Send object + try: + from langgraph.types import Send + + return Send( + node=obj["node"], + arg=self._revive_if_needed(obj["arg"]), + ) + except (ImportError, TypeError, ValueError) as e: + # If we can't import or construct Send, log and fall through + logger.debug( + "Failed to deserialize Send object: %s", e, exc_info=True + ) + # Recursively process nested dicts return {k: self._revive_if_needed(v) for k, v in obj.items()} elif isinstance(obj, list): diff --git a/pyproject.toml b/pyproject.toml index ee45de2..7e7308c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langgraph-checkpoint-redis" -version = "0.2.0" +version = "0.2.1" description = "Redis implementation of the LangGraph agent checkpoint saver and store." authors = ["Redis Inc. ", "Brian Sam-Bodden "] license = "MIT" diff --git a/tests/test_send_serialization.py b/tests/test_send_serialization.py new file mode 100644 index 0000000..38bcec7 --- /dev/null +++ b/tests/test_send_serialization.py @@ -0,0 +1,195 @@ +"""Test Send object serialization fix for issue #94. + +This test validates that langgraph.types.Send objects are properly serialized +and deserialized by the JsonPlusRedisSerializer. + +Before the fix, Send objects were not serialized correctly, which led to issues +with handling Interrupts - namely the user's response would not be treated as +the response to the Interrupt. + +The issue occurs in `prepare_single_task` in pregel._algo.py where: +``` +if not isinstance(packet, Send): + logger.warning( + f"Ignoring invalid packet type {type(packet)} in pending sends" + ) + return. <<<<< task not added +``` + +The fix adds custom serialization/deserialization for Send objects similar to +how Interrupt objects are handled. +""" + +import pytest +from langgraph.types import Send + +from langgraph.checkpoint.redis.jsonplus_redis import JsonPlusRedisSerializer + + +def test_send_object_serialization(): + """Test that Send objects are properly serialized and deserialized. + + Before the fix, Send objects would not serialize correctly, causing + isinstance(packet, Send) checks to fail after deserialization. + """ + serializer = JsonPlusRedisSerializer() + + # Create a Send object + send_obj = Send(node="my_node", arg={"key": "value"}) + + # Serialize + type_str, blob = serializer.dumps_typed(send_obj) + assert type_str == "json" + assert isinstance(blob, bytes) + + # Deserialize + deserialized = serializer.loads_typed((type_str, blob)) + + # Critical check: the deserialized object must be an instance of Send + assert isinstance(deserialized, Send), ( + f"Expected Send instance, got {type(deserialized)}. " + "This will cause isinstance(packet, Send) checks to fail!" + ) + assert deserialized.node == "my_node" + assert deserialized.arg == {"key": "value"} + assert deserialized == send_obj + + +def test_send_object_in_pending_sends_list(): + """Test that Send objects in pending_sends lists are properly handled. + + This simulates the scenario where Send objects are stored in checkpoint + pending_sends and must be correctly deserialized for interrupt handling. + """ + serializer = JsonPlusRedisSerializer() + + # Create multiple Send objects as they would appear in pending_sends + pending_sends = [ + Send(node="node1", arg={"data": "first"}), + Send(node="node2", arg={"data": "second"}), + Send(node="node3", arg={"data": "third"}), + ] + + # Serialize the list + type_str, blob = serializer.dumps_typed(pending_sends) + + # Deserialize + deserialized = serializer.loads_typed((type_str, blob)) + + # Verify all items are still Send instances + assert isinstance(deserialized, list) + assert len(deserialized) == 3 + + for i, send_obj in enumerate(deserialized): + assert isinstance( + send_obj, Send + ), f"Item {i} is not a Send instance: {type(send_obj)}" + + assert deserialized[0].node == "node1" + assert deserialized[1].node == "node2" + assert deserialized[2].node == "node3" + + +def test_send_object_with_complex_args(): + """Test Send objects with complex nested arguments.""" + serializer = JsonPlusRedisSerializer() + + # Create Send with complex nested arg + complex_arg = { + "messages": ["msg1", "msg2"], + "metadata": { + "step": 1, + "config": { + "model": "gpt-4", + "temperature": 0.7, + }, + }, + "nested_list": [ + {"a": 1, "b": 2}, + {"c": 3, "d": 4}, + ], + } + + send_obj = Send(node="processor", arg=complex_arg) + + # Serialize and deserialize + type_str, blob = serializer.dumps_typed(send_obj) + deserialized = serializer.loads_typed((type_str, blob)) + + # Verify type and structure + assert isinstance(deserialized, Send) + assert deserialized.node == "processor" + assert deserialized.arg == complex_arg + assert deserialized.arg["metadata"]["config"]["model"] == "gpt-4" + + +def test_send_object_in_checkpoint_structure(): + """Test Send objects embedded in checkpoint-like structures. + + This simulates how Send objects appear in actual checkpoint data. + """ + serializer = JsonPlusRedisSerializer() + + # Simulate checkpoint structure with pending_sends + checkpoint_data = { + "v": 1, + "id": "checkpoint_1", + "pending_sends": [ + Send(node="task1", arg={"task_data": "A"}), + Send(node="task2", arg={"task_data": "B"}), + ], + "channel_values": {"messages": ["msg1", "msg2"]}, + } + + # Serialize and deserialize + type_str, blob = serializer.dumps_typed(checkpoint_data) + deserialized = serializer.loads_typed((type_str, blob)) + + # Verify Send objects are preserved correctly + assert "pending_sends" in deserialized + assert len(deserialized["pending_sends"]) == 2 + + for send_obj in deserialized["pending_sends"]: + assert isinstance( + send_obj, Send + ), f"pending_sends contains non-Send object: {type(send_obj)}" + + +def test_send_object_equality_after_roundtrip(): + """Test that Send objects maintain equality after serialization roundtrip.""" + serializer = JsonPlusRedisSerializer() + + send1 = Send(node="test_node", arg={"value": 42}) + + # Serialize and deserialize + type_str, blob = serializer.dumps_typed(send1) + send2 = serializer.loads_typed((type_str, blob)) + + # Send objects should be equal + assert send1 == send2 + + # Test hash equality with hashable args + send_hashable1 = Send(node="test", arg="hashable_string") + type_str, blob = serializer.dumps_typed(send_hashable1) + send_hashable2 = serializer.loads_typed((type_str, blob)) + assert hash(send_hashable1) == hash(send_hashable2) + + +def test_send_object_with_none_arg(): + """Test Send object with None as argument.""" + serializer = JsonPlusRedisSerializer() + + send_obj = Send(node="null_handler", arg=None) + + # Serialize and deserialize + type_str, blob = serializer.dumps_typed(send_obj) + deserialized = serializer.loads_typed((type_str, blob)) + + assert isinstance(deserialized, Send) + assert deserialized.node == "null_handler" + assert deserialized.arg is None + + +if __name__ == "__main__": + # Run tests + pytest.main([__file__, "-v"])