File tree Expand file tree Collapse file tree 1 file changed +19
-9
lines changed Expand file tree Collapse file tree 1 file changed +19
-9
lines changed Original file line number Diff line number Diff line change @@ -390,18 +390,28 @@ def generate(
390390 """
391391 assert self .ctx is not None
392392
393- if (
394- reset
395- and len (self .eval_tokens ) > 0
396- and tuple (self .eval_tokens ) == tuple (tokens [: len (self .eval_tokens )])
397- ):
398- if self .verbose :
399- print ("Llama.generate: cache hit" , file = sys .stderr )
400- reset = False
401- tokens = tokens [len (self .eval_tokens ) :]
393+ if reset and len (self .eval_tokens ) > 0 :
394+ longest_prefix = 0
395+ for a , b in zip (self .eval_tokens , tokens [:- 1 ]):
396+ if a == b :
397+ longest_prefix += 1
398+ else :
399+ break
400+ if longest_prefix > 0 :
401+ if self .verbose :
402+ print ("Llama.generate: prefix-match hit" , file = sys .stderr )
403+ reset = False
404+ tokens = tokens [longest_prefix :]
405+ for _ in range (len (self .eval_tokens ) - longest_prefix ):
406+ self .eval_tokens .pop ()
407+ try :
408+ self .eval_logits .pop ()
409+ except IndexError :
410+ pass
402411
403412 if reset :
404413 self .reset ()
414+
405415 while True :
406416 self .eval (tokens )
407417 token = self .sample (
You can’t perform that action at this time.
0 commit comments