2222from langchain .tools import BaseTool , StructuredTool
2323
2424from codeinterpreterapi .agents import OpenAIFunctionsAgent
25- from codeinterpreterapi .chains import get_file_modifications , remove_download_link
25+ from codeinterpreterapi .chains import (
26+ aget_file_modifications ,
27+ aremove_download_link ,
28+ get_file_modifications ,
29+ remove_download_link ,
30+ )
2631from codeinterpreterapi .config import settings
32+ from codeinterpreterapi .parser import CodeAgentOutputParser , CodeChatAgentOutputParser
2733from codeinterpreterapi .prompts import code_interpreter_system_message
2834from codeinterpreterapi .schema import (
2935 CodeInput ,
3036 CodeInterpreterResponse ,
3137 File ,
38+ SessionStatus ,
3239 UserRequest ,
3340)
34- from codeinterpreterapi .utils import (
35- CodeAgentOutputParser ,
36- CodeCallbackHandler ,
37- CodeChatAgentOutputParser ,
38- )
3941
4042
4143class CodeInterpreterSession :
@@ -45,19 +47,20 @@ def __init__(
4547 additional_tools : list [BaseTool ] = [],
4648 ** kwargs ,
4749 ) -> None :
48- self .codebox = CodeBox ()
50+ self .codebox = CodeBox (** kwargs )
4951 self .verbose = kwargs .get ("verbose" , settings .VERBOSE )
5052 self .tools : list [BaseTool ] = self ._tools (additional_tools )
5153 self .llm : BaseLanguageModel = llm or self ._choose_llm (** kwargs )
5254 self .agent_executor : AgentExecutor = self ._agent_executor ()
5355 self .input_files : list [File ] = []
5456 self .output_files : list [File ] = []
57+ self .code_log : list [tuple [str , str ]] = []
5558
56- def start (self ) -> None :
57- self .codebox .start ()
59+ def start (self ) -> SessionStatus :
60+ return SessionStatus . from_codebox_status ( self .codebox .start () )
5861
59- async def astart (self ) -> None :
60- await self .codebox .astart ()
62+ async def astart (self ) -> SessionStatus :
63+ return SessionStatus . from_codebox_status ( await self .codebox .astart () )
6164
6265 def _tools (self , additional_tools : list [BaseTool ]) -> list [BaseTool ]:
6366 return additional_tools + [
@@ -128,7 +131,6 @@ def _choose_agent(self) -> BaseSingleActionAgent:
128131 def _agent_executor (self ) -> AgentExecutor :
129132 return AgentExecutor .from_agent_and_tools (
130133 agent = self ._choose_agent (),
131- callbacks = [CodeCallbackHandler (self )],
132134 max_iterations = 9 ,
133135 tools = self .tools ,
134136 verbose = self .verbose ,
@@ -137,18 +139,67 @@ def _agent_executor(self) -> AgentExecutor:
137139 ),
138140 )
139141
140- async def show_code (self , code : str ) -> None :
142+ def show_code (self , code : str ) -> None :
143+ if self .verbose :
144+ print (code )
145+
146+ async def ashow_code (self , code : str ) -> None :
141147 """Callback function to show code to the user."""
142148 if self .verbose :
143149 print (code )
144150
145151 def _run_handler (self , code : str ):
146- raise NotImplementedError ("Use arun_handler for now." )
152+ """Run code in container and send the output to the user"""
153+ self .show_code (code )
154+ output : CodeBoxOutput = self .codebox .run (code )
155+ self .code_log .append ((code , output .content ))
156+
157+ if not isinstance (output .content , str ):
158+ raise TypeError ("Expected output.content to be a string." )
159+
160+ if output .type == "image/png" :
161+ filename = f"image-{ uuid .uuid4 ()} .png"
162+ file_buffer = BytesIO (base64 .b64decode (output .content ))
163+ file_buffer .name = filename
164+ self .output_files .append (File (name = filename , content = file_buffer .read ()))
165+ return f"Image { filename } got send to the user."
166+
167+ elif output .type == "error" :
168+ if "ModuleNotFoundError" in output .content :
169+ if package := re .search (
170+ r"ModuleNotFoundError: No module named '(.*)'" , output .content
171+ ):
172+ self .codebox .install (package .group (1 ))
173+ return (
174+ f"{ package .group (1 )} was missing but "
175+ "got installed now. Please try again."
176+ )
177+ else :
178+ # TODO: preanalyze error to optimize next code generation
179+ pass
180+ if self .verbose :
181+ print ("Error:" , output .content )
182+
183+ elif modifications := get_file_modifications (code , self .llm ):
184+ for filename in modifications :
185+ if filename in [file .name for file in self .input_files ]:
186+ continue
187+ fileb = self .codebox .download (filename )
188+ if not fileb .content :
189+ continue
190+ file_buffer = BytesIO (fileb .content )
191+ file_buffer .name = filename
192+ self .output_files .append (
193+ File (name = filename , content = file_buffer .read ())
194+ )
195+
196+ return output .content
147197
148198 async def _arun_handler (self , code : str ):
149199 """Run code in container and send the output to the user"""
150- print ( "Running code in container..." , code )
200+ await self . ashow_code ( code )
151201 output : CodeBoxOutput = await self .codebox .arun (code )
202+ self .code_log .append ((code , output .content ))
152203
153204 if not isinstance (output .content , str ):
154205 raise TypeError ("Expected output.content to be a string." )
@@ -176,7 +227,7 @@ async def _arun_handler(self, code: str):
176227 if self .verbose :
177228 print ("Error:" , output .content )
178229
179- elif modifications := await get_file_modifications (code , self .llm ):
230+ elif modifications := await aget_file_modifications (code , self .llm ):
180231 for filename in modifications :
181232 if filename in [file .name for file in self .input_files ]:
182233 continue
@@ -191,7 +242,22 @@ async def _arun_handler(self, code: str):
191242
192243 return output .content
193244
194- async def _input_handler (self , request : UserRequest ):
245+ def _input_handler (self , request : UserRequest ) -> None :
246+ """Callback function to handle user input."""
247+ if not request .files :
248+ return
249+ if not request .content :
250+ request .content = (
251+ "I uploaded, just text me back and confirm that you got the file(s)."
252+ )
253+ request .content += "\n **The user uploaded the following files: **\n "
254+ for file in request .files :
255+ self .input_files .append (file )
256+ request .content += f"[Attachment: { file .name } ]\n "
257+ self .codebox .upload (file .name , file .content )
258+ request .content += "**File(s) are now available in the cwd. **\n "
259+
260+ async def _ainput_handler (self , request : UserRequest ):
195261 # TODO: variables as context to the agent
196262 # TODO: current files as context to the agent
197263 if not request .files :
@@ -207,7 +273,7 @@ async def _input_handler(self, request: UserRequest):
207273 await self .codebox .aupload (file .name , file .content )
208274 request .content += "**File(s) are now available in the cwd. **\n "
209275
210- async def _output_handler (self , final_response : str ) -> CodeInterpreterResponse :
276+ def _output_handler (self , final_response : str ) -> CodeInterpreterResponse :
211277 """Embed images in the response"""
212278 for file in self .output_files :
213279 if str (file .name ) in final_response :
@@ -216,25 +282,98 @@ async def _output_handler(self, final_response: str) -> CodeInterpreterResponse:
216282
217283 if self .output_files and re .search (r"\n\[.*\]\(.*\)" , final_response ):
218284 try :
219- final_response = await remove_download_link (final_response , self .llm )
285+ final_response = remove_download_link (final_response , self .llm )
220286 except Exception as e :
221287 if self .verbose :
222288 print ("Error while removing download links:" , e )
223289
224- return CodeInterpreterResponse (content = final_response , files = self .output_files )
290+ output_files = self .output_files
291+ code_log = self .code_log
292+ self .output_files = []
293+ self .code_log = []
294+
295+ return CodeInterpreterResponse (
296+ content = final_response , files = output_files , code_log = code_log
297+ )
298+
299+ async def _aoutput_handler (self , final_response : str ) -> CodeInterpreterResponse :
300+ """Embed images in the response"""
301+ for file in self .output_files :
302+ if str (file .name ) in final_response :
303+ # rm  from the response
304+ final_response = re .sub (r"\n\n!\[.*\]\(.*\)" , "" , final_response )
305+
306+ if self .output_files and re .search (r"\n\[.*\]\(.*\)" , final_response ):
307+ try :
308+ final_response = await aremove_download_link (final_response , self .llm )
309+ except Exception as e :
310+ if self .verbose :
311+ print ("Error while removing download links:" , e )
312+
313+ output_files = self .output_files
314+ code_log = self .code_log
315+ self .output_files = []
316+ self .code_log = []
317+
318+ return CodeInterpreterResponse (
319+ content = final_response , files = output_files , code_log = code_log
320+ )
321+
322+ def generate_response_sync (
323+ self ,
324+ user_msg : str ,
325+ files : list [File ] = [],
326+ detailed_error : bool = False ,
327+ ) -> CodeInterpreterResponse :
328+ """Generate a Code Interpreter response based on the user's input."""
329+ user_request = UserRequest (content = user_msg , files = files )
330+ try :
331+ self ._input_handler (user_request )
332+ response = self .agent_executor .run (input = user_request .content )
333+ return self ._output_handler (response )
334+ except Exception as e :
335+ if self .verbose :
336+ traceback .print_exc ()
337+ if detailed_error :
338+ return CodeInterpreterResponse (
339+ content = "Error in CodeInterpreterSession: "
340+ f"{ e .__class__ .__name__ } - { e } "
341+ )
342+ else :
343+ return CodeInterpreterResponse (
344+ content = "Sorry, something went while generating your response."
345+ "Please try again or restart the session."
346+ )
225347
226348 async def generate_response (
227349 self ,
228350 user_msg : str ,
229351 files : list [File ] = [],
230352 detailed_error : bool = False ,
353+ ) -> CodeInterpreterResponse :
354+ print (
355+ "DEPRECATION WARNING: Use agenerate_response for async generation.\n "
356+ "This function will be converted to sync in the future.\n "
357+ "You can use generate_response_sync for now." ,
358+ )
359+ return await self .agenerate_response (
360+ user_msg = user_msg ,
361+ files = files ,
362+ detailed_error = detailed_error ,
363+ )
364+
365+ async def agenerate_response (
366+ self ,
367+ user_msg : str ,
368+ files : list [File ] = [],
369+ detailed_error : bool = False ,
231370 ) -> CodeInterpreterResponse :
232371 """Generate a Code Interpreter response based on the user's input."""
233372 user_request = UserRequest (content = user_msg , files = files )
234373 try :
235- await self ._input_handler (user_request )
374+ await self ._ainput_handler (user_request )
236375 response = await self .agent_executor .arun (input = user_request .content )
237- return await self ._output_handler (response )
376+ return await self ._aoutput_handler (response )
238377 except Exception as e :
239378 if self .verbose :
240379 traceback .print_exc ()
@@ -249,11 +388,24 @@ async def generate_response(
249388 "Please try again or restart the session."
250389 )
251390
252- async def is_running (self ) -> bool :
391+ def is_running (self ) -> bool :
392+ return self .codebox .status () == "running"
393+
394+ async def ais_running (self ) -> bool :
253395 return await self .codebox .astatus () == "running"
254396
255- async def astop (self ) -> None :
256- await self .codebox .astop ()
397+ def stop (self ) -> SessionStatus :
398+ return SessionStatus .from_codebox_status (self .codebox .stop ())
399+
400+ async def astop (self ) -> SessionStatus :
401+ return SessionStatus .from_codebox_status (await self .codebox .astop ())
402+
403+ def __enter__ (self ) -> "CodeInterpreterSession" :
404+ self .start ()
405+ return self
406+
407+ def __exit__ (self , exc_type , exc_value , traceback ) -> None :
408+ self .stop ()
257409
258410 async def __aenter__ (self ) -> "CodeInterpreterSession" :
259411 await self .astart ()
0 commit comments