Skip to content

Commit 11d3208

Browse files
3coinsdlqqq
andauthored
Fix agent errors (#24)
* Upgraded jupyterlab_chat * Fixed open_file tool * Added run, select cell tools * Updated tools, prompt to nudge agent to work with chat and active notebook * Removed middleware * add jupyterlab_notebook_awareness to deps --------- Co-authored-by: David L. Qiu <david@qiu.dev>
1 parent a50dd31 commit 11d3208

File tree

5 files changed

+580
-160
lines changed

5 files changed

+580
-160
lines changed
Lines changed: 24 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,12 @@
11
import os
2-
from typing import Any, Callable
2+
from typing import Any
33

44
import aiosqlite
55
from jupyter_ai_persona_manager import BasePersona, PersonaDefaults
66
from jupyter_core.paths import jupyter_data_dir
77
from jupyterlab_chat.models import Message
88
from langchain.agents import create_agent
9-
from langchain.agents.middleware import AgentMiddleware
10-
from langchain.agents.middleware.file_search import FilesystemFileSearchMiddleware
11-
from langchain.agents.middleware.shell_tool import ShellToolMiddleware
12-
from langchain.messages import ToolMessage
13-
from langchain.tools.tool_node import ToolCallRequest
14-
from langchain_core.messages import ToolMessage
159
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
16-
from langgraph.types import Command
1710

1811
from .chat_models import ChatLiteLLM
1912
from .prompt_template import (
@@ -26,76 +19,11 @@
2619

2720
MEMORY_STORE_PATH = os.path.join(jupyter_data_dir(), "jupyter_ai", "memory.sqlite")
2821

29-
JUPYTERNAUT_AVATAR_PATH = str(os.path.abspath(
30-
os.path.join(os.path.dirname(__file__), "../static", "jupyternaut.svg")
31-
))
32-
33-
34-
def format_tool_args_compact(args_dict, threshold=25):
35-
"""
36-
Create a more compact string representation of tool call args.
37-
Each key-value pair is on its own line for better readability.
38-
39-
Args:
40-
args_dict (dict): Dictionary of tool arguments
41-
threshold (int): Maximum number of lines before truncation (default: 25)
42-
43-
Returns:
44-
str: Formatted string representation of arguments
45-
"""
46-
if not args_dict:
47-
return "{}"
48-
49-
formatted_pairs = []
50-
51-
for key, value in args_dict.items():
52-
value_str = str(value)
53-
lines = value_str.split('\n')
54-
55-
if len(lines) <= threshold:
56-
if len(lines) == 1 and len(value_str) > 80:
57-
# Single long line - truncate
58-
truncated = value_str[:77] + "..."
59-
formatted_pairs.append(f" {key}: {truncated}")
60-
else:
61-
# Add indentation for multi-line values
62-
if len(lines) > 1:
63-
indented_value = '\n '.join([''] + lines)
64-
formatted_pairs.append(f" {key}:{indented_value}")
65-
else:
66-
formatted_pairs.append(f" {key}: {value_str}")
67-
else:
68-
# Truncate and add summary
69-
truncated_lines = lines[:threshold]
70-
remaining_lines = len(lines) - threshold
71-
indented_value = '\n '.join([''] + truncated_lines)
72-
formatted_pairs.append(f" {key}:{indented_value}\n [+{remaining_lines} more lines]")
73-
74-
return "{\n" + ",\n".join(formatted_pairs) + "\n}"
75-
76-
77-
class ToolMonitoringMiddleware(AgentMiddleware):
78-
def __init__(self, *, persona: BasePersona):
79-
self.stream_message = persona.stream_message
80-
self.log = persona.log
81-
82-
async def awrap_tool_call(
83-
self,
84-
request: ToolCallRequest,
85-
handler: Callable[[ToolCallRequest], ToolMessage | Command],
86-
) -> ToolMessage | Command:
87-
args = format_tool_args_compact(request.tool_call['args'])
88-
self.log.info(f"{request.tool_call['name']}({args})")
89-
90-
try:
91-
result = await handler(request)
92-
self.log.info(f"{request.tool_call['name']} Done!")
93-
return result
94-
except Exception as e:
95-
self.log.info(f"{request.tool_call['name']} failed: {e}")
96-
return ToolMessage(
97-
tool_call_id=request.tool_call["id"], status="error", content=f"{e}"
98-
)
22+
JUPYTERNAUT_AVATAR_PATH = str(
23+
os.path.abspath(
24+
os.path.join(os.path.dirname(__file__), "../static", "jupyternaut.svg")
25+
)
26+
)
9927

10028

10129
class JupyternautPersona(BasePersona):
@@ -115,6 +43,10 @@ def defaults(self):
11543
system_prompt="...",
11644
)
11745

46+
@property
47+
def yroom_manager(self):
48+
return self.parent.serverapp.web_app.settings["yroom_manager"]
49+
11850
async def get_memory_store(self):
11951
if not hasattr(self, "_memory_store"):
12052
conn = await aiosqlite.connect(MEMORY_STORE_PATH, check_same_thread=False)
@@ -130,7 +62,7 @@ def get_tools(self):
13062
async def get_agent(self, model_id: str, model_args, system_prompt: str):
13163
model = ChatLiteLLM(**model_args, model=model_id, streaming=True)
13264
memory_store = await self.get_memory_store()
133-
65+
13466
return create_agent(
13567
model,
13668
system_prompt=system_prompt,
@@ -158,18 +90,20 @@ async def process_message(self, message: Message) -> None:
15890
model_id=model_id, model_args=model_args, system_prompt=system_prompt
15991
)
16092

93+
context = {
94+
"thread_id": self.ychat.get_id(),
95+
"username": message.sender
96+
}
97+
16198
async def create_aiter():
16299
async for token, metadata in agent.astream(
163100
{"messages": [{"role": "user", "content": message.body}]},
164-
{"configurable": {"thread_id": self.ychat.get_id()}},
101+
{"configurable": context},
165102
stream_mode="messages",
166103
):
167104
node = metadata["langgraph_node"]
168105
content_blocks = token.content_blocks
169-
if (
170-
node == "model"
171-
and content_blocks
172-
):
106+
if node == "model" and content_blocks:
173107
if token.text:
174108
yield token.text
175109

@@ -182,15 +116,19 @@ def get_system_prompt(
182116
"""
183117
Returns the system prompt, including attachments as a string.
184118
"""
119+
120+
context = self.process_attachments(message) or ""
121+
context = f"User's username is '{message.sender}'\n\n" + context
122+
185123
system_msg_args = JupyternautSystemPromptArgs(
186124
model_id=model_id,
187125
persona_name=self.name,
188-
context=self.process_attachments(message),
126+
context=context,
189127
).model_dump()
190128

191129
return JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE.render(**system_msg_args)
192130

193131
def shutdown(self):
194-
if hasattr(self,"_memory_store"):
132+
if hasattr(self, "_memory_store"):
195133
self.parent.event_loop.create_task(self._memory_store.conn.close())
196134
super().shutdown()

0 commit comments

Comments
 (0)