@@ -141,7 +141,9 @@ def __getitem__(self, key: Sequence[int]) -> "LlamaState":
141141 if _key is None :
142142 raise KeyError ("Key not found" )
143143 value : "LlamaState" = self .cache .pop (_key ) # type: ignore
144- self .cache .push (_key , side = "front" ) # type: ignore
144+ # NOTE: This puts an integer as key in cache, which breaks,
145+ # Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens
146+ # self.cache.push(_key, side="front") # type: ignore
145147 return value
146148
147149 def __contains__ (self , key : Sequence [int ]) -> bool :
@@ -168,7 +170,7 @@ def __init__(
168170 eval_logits : Deque [List [float ]],
169171 input_ids : npt .NDArray [np .intc ],
170172 scores : npt .NDArray [np .single ],
171- llama_state , # type: llama_cpp.Array[llama_cpp.c_uint8]
173+ llama_state : bytes ,
172174 llama_state_size : int ,
173175 ):
174176 self .eval_tokens = eval_tokens
@@ -1512,7 +1514,7 @@ def save_state(self) -> LlamaState:
15121514 eval_logits = self .eval_logits .copy (),
15131515 scores = self ._scores .copy (),
15141516 input_ids = self ._input_ids .copy (),
1515- llama_state = llama_state_compact ,
1517+ llama_state = bytes ( llama_state_compact ) ,
15161518 llama_state_size = n_bytes ,
15171519 )
15181520
@@ -1523,7 +1525,10 @@ def load_state(self, state: LlamaState) -> None:
15231525 self ._scores = state .scores .copy ()
15241526 self ._input_ids = state .input_ids .copy ()
15251527 state_size = state .llama_state_size
1526- if llama_cpp .llama_set_state_data (self .ctx , state .llama_state ) != state_size :
1528+ LLamaStateArrayType = (llama_cpp .c_uint8 * state_size )
1529+ llama_state = LLamaStateArrayType .from_buffer_copy (state .llama_state )
1530+
1531+ if llama_cpp .llama_set_state_data (self .ctx , llama_state ) != state_size :
15271532 raise RuntimeError ("Failed to set llama state data" )
15281533
15291534 def n_ctx (self ) -> int :
0 commit comments