Skip to content

Commit a50dd31

Browse files
authored
Improve tools (#21)
* Trimmed tools * Removed redundant notebool tools
1 parent 8191d30 commit a50dd31

File tree

3 files changed

+179
-270
lines changed

3 files changed

+179
-270
lines changed

jupyter_ai_jupyternaut/jupyternaut/jupyternaut.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
)
2323
from .toolkits.notebook import toolkit as nb_toolkit
2424
from .toolkits.jupyterlab import toolkit as jlab_toolkit
25+
from .toolkits.code_execution import toolkit as exec_toolkit
2526

2627
MEMORY_STORE_PATH = os.path.join(jupyter_data_dir(), "jupyter_ai", "memory.sqlite")
2728

@@ -123,29 +124,18 @@ async def get_memory_store(self):
123124
def get_tools(self):
124125
tools = nb_toolkit
125126
tools += jlab_toolkit
126-
return nb_toolkit
127+
tools += exec_toolkit
128+
return tools
127129

128130
async def get_agent(self, model_id: str, model_args, system_prompt: str):
129131
model = ChatLiteLLM(**model_args, model=model_id, streaming=True)
130132
memory_store = await self.get_memory_store()
131-
132-
if not hasattr(self, "search_tool"):
133-
self.search_tool = FilesystemFileSearchMiddleware(
134-
root_path=self.parent.root_dir
135-
)
136-
if not hasattr(self, "shell_tool"):
137-
self.shell_tool = ShellToolMiddleware(workspace_root=self.parent.root_dir)
138-
if not hasattr(self, "tool_call_handler"):
139-
self.tool_call_handler = ToolMonitoringMiddleware(
140-
persona=self
141-
)
142-
133+
143134
return create_agent(
144135
model,
145136
system_prompt=system_prompt,
146137
checkpointer=memory_store,
147-
tools=self.get_tools(), # notebook and jlab tools
148-
middleware=[self.shell_tool, self.tool_call_handler],
138+
tools=self.get_tools(),
149139
)
150140

151141
async def process_message(self, message: Message) -> None:

jupyter_ai_jupyternaut/jupyternaut/toolkits/jupyterlab.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,21 @@ async def open_file(file_path: str):
55
"""
66
Opens a file in JupyterLab main area
77
"""
8-
await execute_command("docmanager:open", {"path": file_path})
8+
return await execute_command("docmanager:open", {"path": file_path})
9+
910

1011
async def run_all_cells():
1112
"""
1213
Runs all cells in the currently active Jupyter notebook
1314
"""
1415
return await execute_command("notebook:run-all-cells")
1516

16-
toolkit = [open_file, run_all_cells]
17+
18+
async def restart_kernel():
19+
"""
20+
Restarts the notebook kernel, useful when new packages are installed
21+
"""
22+
return await execute_command("notebook:restart-kernel")
23+
24+
25+
toolkit = [open_file, run_all_cells, restart_kernel]

0 commit comments

Comments
 (0)