11import os
2- from typing import Any , Callable
2+ from typing import Any
33
44import aiosqlite
55from jupyter_ai_persona_manager import BasePersona , PersonaDefaults
66from jupyter_core .paths import jupyter_data_dir
77from jupyterlab_chat .models import Message
88from langchain .agents import create_agent
9- from langchain .agents .middleware import AgentMiddleware
10- from langchain .agents .middleware .file_search import FilesystemFileSearchMiddleware
11- from langchain .agents .middleware .shell_tool import ShellToolMiddleware
12- from langchain .messages import ToolMessage
13- from langchain .tools .tool_node import ToolCallRequest
14- from langchain_core .messages import ToolMessage
159from langgraph .checkpoint .sqlite .aio import AsyncSqliteSaver
16- from langgraph .types import Command
1710
1811from .chat_models import ChatLiteLLM
1912from .prompt_template import (
2619
2720MEMORY_STORE_PATH = os .path .join (jupyter_data_dir (), "jupyter_ai" , "memory.sqlite" )
2821
29- JUPYTERNAUT_AVATAR_PATH = str (os .path .abspath (
30- os .path .join (os .path .dirname (__file__ ), "../static" , "jupyternaut.svg" )
31- ))
32-
33-
34- def format_tool_args_compact (args_dict , threshold = 25 ):
35- """
36- Create a more compact string representation of tool call args.
37- Each key-value pair is on its own line for better readability.
38-
39- Args:
40- args_dict (dict): Dictionary of tool arguments
41- threshold (int): Maximum number of lines before truncation (default: 25)
42-
43- Returns:
44- str: Formatted string representation of arguments
45- """
46- if not args_dict :
47- return "{}"
48-
49- formatted_pairs = []
50-
51- for key , value in args_dict .items ():
52- value_str = str (value )
53- lines = value_str .split ('\n ' )
54-
55- if len (lines ) <= threshold :
56- if len (lines ) == 1 and len (value_str ) > 80 :
57- # Single long line - truncate
58- truncated = value_str [:77 ] + "..."
59- formatted_pairs .append (f" { key } : { truncated } " )
60- else :
61- # Add indentation for multi-line values
62- if len (lines ) > 1 :
63- indented_value = '\n ' .join (['' ] + lines )
64- formatted_pairs .append (f" { key } :{ indented_value } " )
65- else :
66- formatted_pairs .append (f" { key } : { value_str } " )
67- else :
68- # Truncate and add summary
69- truncated_lines = lines [:threshold ]
70- remaining_lines = len (lines ) - threshold
71- indented_value = '\n ' .join (['' ] + truncated_lines )
72- formatted_pairs .append (f" { key } :{ indented_value } \n [+{ remaining_lines } more lines]" )
73-
74- return "{\n " + ",\n " .join (formatted_pairs ) + "\n }"
75-
76-
77- class ToolMonitoringMiddleware (AgentMiddleware ):
78- def __init__ (self , * , persona : BasePersona ):
79- self .stream_message = persona .stream_message
80- self .log = persona .log
81-
82- async def awrap_tool_call (
83- self ,
84- request : ToolCallRequest ,
85- handler : Callable [[ToolCallRequest ], ToolMessage | Command ],
86- ) -> ToolMessage | Command :
87- args = format_tool_args_compact (request .tool_call ['args' ])
88- self .log .info (f"{ request .tool_call ['name' ]} ({ args } )" )
89-
90- try :
91- result = await handler (request )
92- self .log .info (f"{ request .tool_call ['name' ]} Done!" )
93- return result
94- except Exception as e :
95- self .log .info (f"{ request .tool_call ['name' ]} failed: { e } " )
96- return ToolMessage (
97- tool_call_id = request .tool_call ["id" ], status = "error" , content = f"{ e } "
98- )
22+ JUPYTERNAUT_AVATAR_PATH = str (
23+ os .path .abspath (
24+ os .path .join (os .path .dirname (__file__ ), "../static" , "jupyternaut.svg" )
25+ )
26+ )
9927
10028
10129class JupyternautPersona (BasePersona ):
@@ -115,6 +43,10 @@ def defaults(self):
11543 system_prompt = "..." ,
11644 )
11745
46+ @property
47+ def yroom_manager (self ):
48+ return self .parent .serverapp .web_app .settings ["yroom_manager" ]
49+
11850 async def get_memory_store (self ):
11951 if not hasattr (self , "_memory_store" ):
12052 conn = await aiosqlite .connect (MEMORY_STORE_PATH , check_same_thread = False )
@@ -130,7 +62,7 @@ def get_tools(self):
13062 async def get_agent (self , model_id : str , model_args , system_prompt : str ):
13163 model = ChatLiteLLM (** model_args , model = model_id , streaming = True )
13264 memory_store = await self .get_memory_store ()
133-
65+
13466 return create_agent (
13567 model ,
13668 system_prompt = system_prompt ,
@@ -158,18 +90,20 @@ async def process_message(self, message: Message) -> None:
15890 model_id = model_id , model_args = model_args , system_prompt = system_prompt
15991 )
16092
93+ context = {
94+ "thread_id" : self .ychat .get_id (),
95+ "username" : message .sender
96+ }
97+
16198 async def create_aiter ():
16299 async for token , metadata in agent .astream (
163100 {"messages" : [{"role" : "user" , "content" : message .body }]},
164- {"configurable" : { "thread_id" : self . ychat . get_id ()} },
101+ {"configurable" : context },
165102 stream_mode = "messages" ,
166103 ):
167104 node = metadata ["langgraph_node" ]
168105 content_blocks = token .content_blocks
169- if (
170- node == "model"
171- and content_blocks
172- ):
106+ if node == "model" and content_blocks :
173107 if token .text :
174108 yield token .text
175109
@@ -182,15 +116,19 @@ def get_system_prompt(
182116 """
183117 Returns the system prompt, including attachments as a string.
184118 """
119+
120+ context = self .process_attachments (message ) or ""
121+ context = f"User's username is '{ message .sender } '\n \n " + context
122+
185123 system_msg_args = JupyternautSystemPromptArgs (
186124 model_id = model_id ,
187125 persona_name = self .name ,
188- context = self . process_attachments ( message ) ,
126+ context = context ,
189127 ).model_dump ()
190128
191129 return JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE .render (** system_msg_args )
192130
193131 def shutdown (self ):
194- if hasattr (self ,"_memory_store" ):
132+ if hasattr (self , "_memory_store" ):
195133 self .parent .event_loop .create_task (self ._memory_store .conn .close ())
196134 super ().shutdown ()
0 commit comments