@@ -24,6 +24,10 @@ class LLaMAInteract:
2424 def __init__ (self , params : GptParams ) -> None :
2525 # input args
2626 self .params = params
27+ if self .params .path_session is None :
28+ self .params .path_session = ""
29+ if self .params .antiprompt is None :
30+ self .params .antiprompt = ""
2731
2832 if (self .params .perplexity ):
2933 raise NotImplementedError ("""************
@@ -66,7 +70,9 @@ def __init__(self, params: GptParams) -> None:
6670 self .lparams .use_mlock = self .params .use_mlock
6771 self .lparams .use_mmap = self .params .use_mmap
6872
69- self .ctx = llama_cpp .llama_init_from_file (self .params .model .encode ("utf8" ), self .lparams )
73+ self .model = llama_cpp .llama_load_model_from_file (
74+ self .params .model .encode ("utf8" ), self .lparams )
75+ self .ctx = llama_cpp .llama_new_context_with_model (self .model , self .lparams )
7076 if (not self .ctx ):
7177 raise RuntimeError (f"error: failed to load model '{ self .params .model } '" )
7278
@@ -181,12 +187,12 @@ def __init__(self, params: GptParams) -> None:
181187number of tokens in prompt = { len (self .embd_inp )} """ , file = sys .stderr )
182188
183189 for i in range (len (self .embd_inp )):
184- print (f"{ self .embd_inp [i ]} -> '{ llama_cpp . llama_token_to_str ( self .ctx , self .embd_inp [i ])} '" , file = sys .stderr )
190+ print (f"{ self .embd_inp [i ]} -> '{ self .token_to_str ( self .embd_inp [i ])} '" , file = sys .stderr )
185191
186192 if (self .params .n_keep > 0 ):
187193 print ("static prompt based on n_keep: '" )
188194 for i in range (self .params .n_keep ):
189- print (llama_cpp . llama_token_to_str ( self .ctx , self .embd_inp [i ]), file = sys .stderr )
195+ print (self .token_to_str ( self .embd_inp [i ]), file = sys .stderr )
190196 print ("'" , file = sys .stderr )
191197 print (file = sys .stderr )
192198
@@ -339,7 +345,7 @@ def generate(self):
339345 candidates_p = llama_cpp .ctypes .pointer (llama_cpp .llama_token_data_array (_arr , len (_arr ), False ))
340346
341347 # Apply penalties
342- nl_logit = logits [llama_cpp .llama_token_nl ()]
348+ nl_logit = logits [llama_cpp .llama_token_nl (self . ctx )]
343349 last_n_repeat = min (len (self .last_n_tokens ), repeat_last_n , self .n_ctx )
344350
345351 _arr = (llama_cpp .llama_token * last_n_repeat )(* self .last_n_tokens [len (self .last_n_tokens ) - last_n_repeat :])
@@ -380,7 +386,7 @@ def generate(self):
380386 self .last_n_tokens .append (id )
381387
382388 # replace end of text token with newline token when in interactive mode
383- if (id == llama_cpp .llama_token_eos () and self .params .interactive and not self .params .instruct ):
389+ if (id == llama_cpp .llama_token_eos (self . ctx ) and self .params .interactive and not self .params .instruct ):
384390 id = self .llama_token_newline [0 ]
385391 self .embd .append (id )
386392 if (self .use_antiprompt ()):
@@ -437,7 +443,7 @@ def generate(self):
437443 break
438444
439445 # end of text token
440- if len (self .embd ) > 0 and self .embd [- 1 ] == llama_cpp .llama_token_eos ():
446+ if len (self .embd ) > 0 and self .embd [- 1 ] == llama_cpp .llama_token_eos (self . ctx ):
441447 if (not self .params .instruct ):
442448 for i in self .llama_token_eot :
443449 yield i
@@ -464,10 +470,18 @@ def exit(self):
464470 llama_cpp .llama_free (self .ctx )
465471 self .set_color (util .CONSOLE_COLOR_DEFAULT )
466472
473+ def token_to_str (self , token_id : int ) -> bytes :
474+ size = 32
475+ buffer = (ctypes .c_char * size )()
476+ n = llama_cpp .llama_token_to_piece_with_model (
477+ self .model , llama_cpp .llama_token (token_id ), buffer , size )
478+ assert n <= size
479+ return bytes (buffer [:n ])
480+
467481 # return past text
468482 def past (self ):
469483 for id in self .last_n_tokens [- self .n_past :]:
470- yield llama_cpp . llama_token_to_str ( self .ctx , id ).decode ("utf8" , errors = "ignore" )
484+ yield self .token_to_str ( id ).decode ("utf8" , errors = "ignore" )
471485
472486 # write input
473487 def input (self , prompt : str ):
@@ -481,7 +495,7 @@ def input(self, prompt: str):
481495 def output (self ):
482496 self .remaining_tokens = self .params .n_predict
483497 for id in self .generate ():
484- cur_char = llama_cpp . llama_token_to_str ( self .ctx , id )
498+ cur_char = self .token_to_str ( id )
485499
486500 # Add remainder of missing bytes
487501 if None in self .multibyte_fix :
0 commit comments