@@ -229,8 +229,8 @@ def __init__(
229229 n_batch : int = 512 ,
230230 n_threads : Optional [int ] = None ,
231231 n_threads_batch : Optional [int ] = None ,
232- rope_freq_base : float = 10000 .0 ,
233- rope_freq_scale : float = 1 .0 ,
232+ rope_freq_base : float = 0 .0 ,
233+ rope_freq_scale : float = 0 .0 ,
234234 mul_mat_q : bool = True ,
235235 f16_kv : bool = True ,
236236 logits_all : bool = False ,
@@ -282,7 +282,6 @@ def __init__(
282282 Returns:
283283 A Llama instance.
284284 """
285-
286285 self .verbose = verbose
287286
288287 self .numa = numa
@@ -320,16 +319,19 @@ def __init__(
320319 self .n_threads_batch = n_threads_batch or max (
321320 multiprocessing .cpu_count () // 2 , 1
322321 )
323-
324322 # Context Params
325323 self .context_params = llama_cpp .llama_context_default_params ()
326324 self .context_params .seed = seed
327325 self .context_params .n_ctx = n_ctx
328326 self .context_params .n_batch = self .n_batch
329327 self .context_params .n_threads = self .n_threads
330328 self .context_params .n_threads_batch = self .n_threads_batch
331- self .context_params .rope_freq_base = rope_freq_base
332- self .context_params .rope_freq_scale = rope_freq_scale
329+ self .context_params .rope_freq_base = (
330+ rope_freq_base if rope_freq_base != 0.0 else 0
331+ )
332+ self .context_params .rope_freq_scale = (
333+ rope_freq_scale if rope_freq_scale != 0.0 else 0
334+ )
333335 self .context_params .mul_mat_q = mul_mat_q
334336 self .context_params .f16_kv = f16_kv
335337 self .context_params .logits_all = logits_all
@@ -338,7 +340,6 @@ def __init__(
338340 # Sampling Params
339341 self .last_n_tokens_size = last_n_tokens_size
340342
341-
342343 self .cache : Optional [BaseLlamaCache ] = None
343344
344345 self .lora_base = lora_base
0 commit comments