@@ -90,6 +90,8 @@ def __init__(
9090 yarn_orig_ctx : int = 0 ,
9191 logits_all : bool = False ,
9292 embedding : bool = False ,
93+ n_seq_max : Optional [int ] = None ,
94+ kv_unified : Optional [bool ] = None ,
9395 offload_kqv : bool = True ,
9496 flash_attn : bool = False ,
9597 op_offload : Optional [bool ] = None ,
@@ -172,6 +174,8 @@ def __init__(
172174 yarn_orig_ctx: YaRN original context size
173175 logits_all: Return logits for all tokens, not just the last token. Must be True for completion to return logprobs.
174176 embedding: Embedding mode only.
177+ n_seq_max: Maximum number of sequences in KV cache
178+ kv_unified: Use unified KV cache across sequences
175179 offload_kqv: Offload K, Q, V to GPU.
176180 flash_attn: Use flash attention.
177181 op_offload: offload host tensor operations to device
@@ -343,6 +347,14 @@ def __init__(
343347 self .context_params .offload_kqv = offload_kqv
344348 self .context_params .flash_attn = flash_attn
345349
350+ # this allows for batch embedding many sequences
351+ if n_seq_max is not None :
352+ self .context_params .n_seq_max = n_seq_max
353+ if kv_unified is not None :
354+ self .context_params .kv_unified = kv_unified
355+ elif embedding and n_seq_max is None :
356+ self .context_params .kv_unified = True
357+
346358 if op_offload is not None :
347359 self .context_params .op_offload = op_offload
348360
0 commit comments