Skip to content

Commit a8f1233

Browse files
committed
add n_seq_max and kv_unified options; fix batch embedding
1 parent c37132b commit a8f1233

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

llama_cpp/llama.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ def __init__(
9090
yarn_orig_ctx: int = 0,
9191
logits_all: bool = False,
9292
embedding: bool = False,
93+
n_seq_max: Optional[int] = None,
94+
kv_unified: Optional[bool] = None,
9395
offload_kqv: bool = True,
9496
flash_attn: bool = False,
9597
op_offload: Optional[bool] = None,
@@ -172,6 +174,8 @@ def __init__(
172174
yarn_orig_ctx: YaRN original context size
173175
logits_all: Return logits for all tokens, not just the last token. Must be True for completion to return logprobs.
174176
embedding: Embedding mode only.
177+
n_seq_max: Maximum number of sequences in KV cache
178+
kv_unified: Use unified KV cache across sequences
175179
offload_kqv: Offload K, Q, V to GPU.
176180
flash_attn: Use flash attention.
177181
op_offload: offload host tensor operations to device
@@ -343,6 +347,14 @@ def __init__(
343347
self.context_params.offload_kqv = offload_kqv
344348
self.context_params.flash_attn = flash_attn
345349

350+
# this allows for batch embedding many sequences
351+
if n_seq_max is not None:
352+
self.context_params.n_seq_max = n_seq_max
353+
if kv_unified is not None:
354+
self.context_params.kv_unified = kv_unified
355+
elif embedding and n_seq_max is None:
356+
self.context_params.kv_unified = True
357+
346358
if op_offload is not None:
347359
self.context_params.op_offload = op_offload
348360

0 commit comments

Comments
 (0)