@@ -216,14 +216,15 @@ def __init__(
216216 embedding : bool = False ,
217217 n_threads : Optional [int ] = None ,
218218 n_batch : int = 512 ,
219- n_gqa : Optional [int ] = None , # must be 8 for llama2 70b
220219 last_n_tokens_size : int = 64 ,
221220 lora_base : Optional [str ] = None ,
222221 lora_path : Optional [str ] = None ,
223222 low_vram : bool = False ,
224223 tensor_split : Optional [List [float ]] = None ,
225224 rope_freq_base : float = 10000.0 ,
226225 rope_freq_scale : float = 1.0 ,
226+ n_gqa : Optional [int ] = None , # (TEMPORARY) must be 8 for llama2 70b
227+ rms_eps_norm : Optional [float ] = None , # (TEMPORARY)
227228 verbose : bool = True ,
228229 ):
229230 """Load a llama.cpp model from `model_path`.
@@ -261,8 +262,6 @@ def __init__(
261262
262263 self .params = llama_cpp .llama_context_default_params ()
263264 self .params .n_ctx = n_ctx
264- if n_gqa is not None :
265- self .params .n_gqa = n_gqa
266265 self .params .n_gpu_layers = n_gpu_layers
267266 self .params .seed = seed
268267 self .params .f16_kv = f16_kv
@@ -285,6 +284,12 @@ def __init__(
285284 self .params .rope_freq_base = rope_freq_base
286285 self .params .rope_freq_scale = rope_freq_scale
287286
287+ if n_gqa is not None :
288+ self .params .n_gqa = n_gqa
289+
290+ if rms_eps_norm is not None :
291+ self .params .rms_eps_norm = rms_eps_norm
292+
288293 self .last_n_tokens_size = last_n_tokens_size
289294 self .n_batch = min (n_ctx , n_batch )
290295
@@ -1526,6 +1531,10 @@ def __getstate__(self):
15261531 lora_base = self .lora_base ,
15271532 lora_path = self .lora_path ,
15281533 tensor_split = self .tensor_split ,
1534+ ### TEMPORARY ###
1535+ n_gqa = self .params .n_gqa ,
1536+ rms_eps_norm = self .params .rms_eps_norm ,
1537+ ### TEMPORARY ###
15291538 ### DEPRECATED ###
15301539 n_parts = self .n_parts ,
15311540 ### DEPRECATED ###
@@ -1535,7 +1544,6 @@ def __setstate__(self, state):
15351544 self .__init__ (
15361545 model_path = state ["model_path" ],
15371546 n_ctx = state ["n_ctx" ],
1538- n_parts = state ["n_parts" ],
15391547 n_gpu_layers = state ["n_gpu_layers" ],
15401548 seed = state ["seed" ],
15411549 f16_kv = state ["f16_kv" ],
@@ -1551,7 +1559,14 @@ def __setstate__(self, state):
15511559 lora_base = state ["lora_base" ],
15521560 lora_path = state ["lora_path" ],
15531561 tensor_split = state ["tensor_split" ],
1562+ n_gqa = state ["n_gqa" ],
1563+ ### TEMPORARY ###
1564+ rms_eps_norm = state ["rms_eps_norm" ],
15541565 verbose = state ["verbose" ],
1566+ ### TEMPORARY ###
1567+ ### DEPRECATED ###
1568+ n_parts = state ["n_parts" ],
1569+ ### DEPRECATED ###
15551570 )
15561571
15571572 def save_state (self ) -> LlamaState :
0 commit comments