@@ -141,7 +141,9 @@ def __getitem__(self, key: Sequence[int]) -> "LlamaState":
141141 if _key is None :
142142 raise KeyError ("Key not found" )
143143 value : "LlamaState" = self .cache .pop (_key ) # type: ignore
144- self .cache .push (_key , side = "front" ) # type: ignore
144+ # NOTE: This puts an integer as key in cache, which breaks,
145+ # Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens
146+ # self.cache.push(_key, side="front") # type: ignore
145147 return value
146148
147149 def __contains__ (self , key : Sequence [int ]) -> bool :
@@ -168,7 +170,7 @@ def __init__(
168170 eval_logits : Deque [List [float ]],
169171 input_ids : npt .NDArray [np .intc ],
170172 scores : npt .NDArray [np .single ],
171- llama_state , # type: llama_cpp.Array[llama_cpp.c_uint8]
173+ llama_state : bytes ,
172174 llama_state_size : int ,
173175 ):
174176 self .eval_tokens = eval_tokens
@@ -1509,7 +1511,7 @@ def save_state(self) -> LlamaState:
15091511 eval_logits = self .eval_logits .copy (),
15101512 scores = self ._scores .copy (),
15111513 input_ids = self ._input_ids .copy (),
1512- llama_state = llama_state_compact ,
1514+ llama_state = bytes ( llama_state_compact ) ,
15131515 llama_state_size = n_bytes ,
15141516 )
15151517
@@ -1520,7 +1522,10 @@ def load_state(self, state: LlamaState) -> None:
15201522 self ._scores = state .scores .copy ()
15211523 self ._input_ids = state .input_ids .copy ()
15221524 state_size = state .llama_state_size
1523- if llama_cpp .llama_set_state_data (self .ctx , state .llama_state ) != state_size :
1525+ LLamaStateArrayType = (llama_cpp .c_uint8 * state_size )
1526+ llama_state = LLamaStateArrayType .from_buffer_copy (state .llama_state )
1527+
1528+ if llama_cpp .llama_set_state_data (self .ctx , llama_state ) != state_size :
15241529 raise RuntimeError ("Failed to set llama state data" )
15251530
15261531 def n_ctx (self ) -> int :
0 commit comments