Skip to content

Commit 8f6e04c

Browse files
Merge branch 'main' into fix/JsonPlusRedisSerializer
2 parents 3c902a9 + 393d5f5 commit 8f6e04c

File tree

9 files changed

+472
-211
lines changed

9 files changed

+472
-211
lines changed

langgraph/checkpoint/redis/aio.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1724,7 +1724,7 @@ async def _abatch_load_pending_sends(
17241724
return_fields=[
17251725
"checkpoint_id",
17261726
"type",
1727-
"blob",
1727+
"$.blob",
17281728
"task_path",
17291729
"task_id",
17301730
"idx",
@@ -1745,20 +1745,27 @@ async def _abatch_load_pending_sends(
17451745
# Sort and format results for each parent checkpoint
17461746
for parent_checkpoint_id in parent_checkpoint_ids:
17471747
batch_key = (thread_id, checkpoint_ns, parent_checkpoint_id)
1748-
docs = writes_by_checkpoint.get(parent_checkpoint_id, [])
1749-
1750-
# Sort for deterministic order
1751-
sorted_docs = sorted(
1752-
docs,
1753-
key=lambda d: (
1754-
getattr(d, "task_path", ""),
1755-
getattr(d, "task_id", ""),
1756-
getattr(d, "idx", 0),
1748+
writes = writes_by_checkpoint.get(parent_checkpoint_id, [])
1749+
1750+
# Sort results by task_path, task_id, idx
1751+
sorted_writes = sorted(
1752+
writes,
1753+
key=lambda x: (
1754+
getattr(x, "task_path", ""),
1755+
getattr(x, "task_id", ""),
1756+
getattr(x, "idx", 0),
17571757
),
17581758
)
17591759

1760-
# Convert to expected format
1761-
results_map[batch_key] = [(d.type, d.blob) for d in sorted_docs]
1760+
# Extract type and blob pairs
1761+
# Handle both direct attribute access and JSON path access
1762+
results_map[batch_key] = [
1763+
(
1764+
getattr(doc, "type", ""),
1765+
getattr(doc, "$.blob", getattr(doc, "blob", b"")),
1766+
)
1767+
for doc in sorted_writes
1768+
]
17621769

17631770
return results_map
17641771

langgraph/checkpoint/redis/jsonplus_redis.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -62,23 +62,22 @@ def loads(self, data: bytes) -> Any:
6262
return super().loads_typed(("json", data))
6363

6464
def _revive_if_needed(self, obj: Any) -> Any:
65-
"""Recursively apply reviver to handle LangChain serialized objects.
65+
"""Recursively apply reviver to handle LangChain and LangGraph serialized objects.
6666
6767
This method is crucial for preventing MESSAGE_COERCION_FAILURE by ensuring
6868
that LangChain message objects stored in their serialized format are properly
6969
reconstructed. Without this, messages would remain as dictionaries with
7070
'lc', 'type', and 'constructor' fields, causing errors when the application
7171
expects actual message objects with 'role' and 'content' attributes.
7272
73-
This also handles Interrupt objects that may be stored as plain dictionaries
74-
with 'value' and 'id' keys, reconstructing them as proper Interrupt instances
75-
to prevent AttributeError when accessing the 'id' attribute.
73+
It also handles LangGraph Interrupt objects which serialize to {"value": ..., "resumable": ..., "ns": ..., "when": ...}
74+
and must be reconstructed to prevent AttributeError when accessing Interrupt attributes.
7675
7776
Args:
7877
obj: The object to potentially revive, which may be a dict, list, or primitive.
7978
8079
Returns:
81-
The revived object with LangChain objects properly reconstructed.
80+
The revived object with LangChain/LangGraph objects properly reconstructed.
8281
"""
8382
if isinstance(obj, dict):
8483
# Check if this is a LangChain serialized object
@@ -88,23 +87,31 @@ def _revive_if_needed(self, obj: Any) -> Any:
8887
# the actual LangChain object (e.g., HumanMessage, AIMessage)
8988
return self._reviver(obj)
9089

91-
# Check if this looks like an Interrupt object stored as a plain dict
92-
# Interrupt objects have 'value' and 'id' keys, and possibly nothing else
93-
# We need to be careful not to accidentally convert other dicts
90+
# Check if this is a serialized Interrupt object
91+
# Interrupt objects serialize to {"value": ..., "resumable": ..., "ns": ..., "when": ...}
92+
# This must be done before recursively processing to avoid losing the structure
9493
if (
9594
"value" in obj
96-
and "id" in obj
97-
and len(obj) == 2
98-
and isinstance(obj.get("id"), str)
95+
and "resumable" in obj
96+
and "when" in obj
97+
and len(obj) == 4
98+
and isinstance(obj.get("resumable"), bool)
9999
):
100100
# Try to reconstruct as an Interrupt object
101101
try:
102102
from langgraph.types import Interrupt
103103

104-
return Interrupt(value=obj["value"], id=obj["id"]) # type: ignore[call-arg]
105-
except (ImportError, TypeError, ValueError):
106-
# If we can't import or construct Interrupt, fall through
107-
pass
104+
return Interrupt(
105+
value=self._revive_if_needed(obj["value"]),
106+
resumable=obj["resumable"],
107+
ns=obj["ns"],
108+
when=obj["when"],
109+
)
110+
except (ImportError, TypeError, ValueError) as e:
111+
# If we can't import or construct Interrupt, log and fall through
112+
logger.debug(
113+
"Failed to deserialize Interrupt object: %s", e, exc_info=True
114+
)
108115

109116
# Recursively process nested dicts
110117
return {k: self._revive_if_needed(v) for k, v in obj.items()}

langgraph/store/redis/aio.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ async def __aexit__(
365365
) -> None:
366366
"""Async context manager exit."""
367367
# Cancel the background task created by AsyncBatchedBaseStore
368-
if hasattr(self, "_task") and not self._task.done():
368+
if hasattr(self, "_task") and self._task is not None and not self._task.done():
369369
self._task.cancel()
370370
try:
371371
await self._task

0 commit comments

Comments
 (0)