Skip to content

Commit 350a176

Browse files
committed
Update sampling api
1 parent 7837c3f commit 350a176

File tree

2 files changed

+113
-28
lines changed

2 files changed

+113
-28
lines changed

llama_cpp/llama.py

Lines changed: 99 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,9 @@ def __init__(
127127
self.last_n_tokens_size = last_n_tokens_size
128128
self.n_batch = min(n_ctx, n_batch)
129129
self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
130-
self.eval_logits: Deque[List[float]] = deque(maxlen=n_ctx)
130+
self.eval_logits: Deque[List[llama_cpp.c_float]] = deque(
131+
maxlen=n_ctx if logits_all else 1
132+
)
131133

132134
self.cache: Optional[LlamaCache] = None
133135

@@ -236,17 +238,90 @@ def eval(self, tokens: Sequence[llama_cpp.llama_token]):
236238
)
237239
if int(return_code) != 0:
238240
raise RuntimeError(f"llama_eval returned {return_code}")
241+
# Save tokens
239242
self.eval_tokens.extend(batch)
240-
if self.params.logits_all:
241-
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
242-
cols = int(n_vocab)
243-
rows = n_tokens
244-
logits_view = llama_cpp.llama_get_logits(self.ctx)
245-
logits = [
246-
[logits_view[i * cols + j] for j in range(cols)]
247-
for i in range(rows)
248-
]
249-
self.eval_logits.extend(logits)
243+
# Save logits
244+
rows = n_tokens if self.params.logits_all else 1
245+
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
246+
cols = int(n_vocab)
247+
logits_view = llama_cpp.llama_get_logits(self.ctx)
248+
logits: List[List[llama_cpp.c_float]] = [
249+
[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)
250+
]
251+
self.eval_logits.extend(logits)
252+
253+
def _sample_top_p_top_k(
254+
self,
255+
last_n_tokens_data, # type: llama_cpp.Array[llama_cpp.llama_token]
256+
last_n_tokens_size: llama_cpp.c_int,
257+
top_k: llama_cpp.c_int,
258+
top_p: llama_cpp.c_float,
259+
temp: llama_cpp.c_float,
260+
repeat_penalty: llama_cpp.c_float,
261+
):
262+
assert self.ctx is not None
263+
assert len(self.eval_logits) > 0
264+
n_vocab = int(llama_cpp.llama_n_vocab(self.ctx))
265+
logits = self.eval_logits[-1]
266+
data = (llama_cpp.llama_token_data * n_vocab)(
267+
*[
268+
llama_cpp.llama_token_data(
269+
id=llama_cpp.llama_token(i),
270+
logit=logits[i],
271+
p=llama_cpp.c_float(0.0),
272+
)
273+
for i in range(n_vocab)
274+
]
275+
)
276+
size = llama_cpp.c_size_t(n_vocab)
277+
sorted = False
278+
candidates = llama_cpp.llama_token_data_array(
279+
data=data,
280+
size=size,
281+
sorted=sorted,
282+
)
283+
llama_cpp.llama_sample_repetition_penalty(
284+
ctx=self.ctx,
285+
last_tokens_data=last_n_tokens_data,
286+
last_tokens_size=last_n_tokens_size,
287+
candidates=llama_cpp.ctypes.pointer(candidates),
288+
penalty=repeat_penalty,
289+
)
290+
if temp == 0.0:
291+
return llama_cpp.llama_sample_token_greedy(
292+
ctx=self.ctx,
293+
candidates=llama_cpp.ctypes.pointer(candidates),
294+
)
295+
else:
296+
llama_cpp.llama_sample_top_k(
297+
ctx=self.ctx,
298+
candidates=llama_cpp.ctypes.pointer(candidates),
299+
k=top_k,
300+
)
301+
llama_cpp.llama_sample_tail_free(
302+
ctx=self.ctx,
303+
candidates=llama_cpp.ctypes.pointer(candidates),
304+
z=llama_cpp.c_float(1.0),
305+
)
306+
llama_cpp.llama_sample_typical(
307+
ctx=self.ctx,
308+
candidates=llama_cpp.ctypes.pointer(candidates),
309+
p=llama_cpp.c_float(1.0)
310+
)
311+
llama_cpp.llama_sample_top_p(
312+
ctx=self.ctx,
313+
candidates=llama_cpp.ctypes.pointer(candidates),
314+
p=top_p,
315+
)
316+
llama_cpp.llama_sample_temperature(
317+
ctx=self.ctx,
318+
candidates=llama_cpp.ctypes.pointer(candidates),
319+
temp=temp,
320+
)
321+
return llama_cpp.llama_sample_token(
322+
ctx=self.ctx,
323+
candidates=llama_cpp.ctypes.pointer(candidates),
324+
)
250325

251326
def sample(
252327
self,
@@ -270,8 +345,7 @@ def sample(
270345
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
271346
0, self.last_n_tokens_size - len(self.eval_tokens)
272347
) + list(self.eval_tokens)[-self.last_n_tokens_size :]
273-
return llama_cpp.llama_sample_top_p_top_k(
274-
ctx=self.ctx,
348+
return self._sample_top_p_top_k(
275349
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
276350
*last_n_tokens_data
277351
),
@@ -470,15 +544,15 @@ def _create_completion(
470544
all_text = self.detokenize(completion_tokens)
471545

472546
# Contains multi-byte UTF8
473-
for k,char in enumerate(all_text[-3:]):
547+
for k, char in enumerate(all_text[-3:]):
474548
k = 3 - k
475-
for num,pattern in [(2, 192), (3, 224), (4, 240)]:
549+
for num, pattern in [(2, 192), (3, 224), (4, 240)]:
476550
# Bitwise AND check
477-
if (num > k and pattern & char == pattern):
551+
if num > k and pattern & char == pattern:
478552
multibyte_fix = num - k
479553

480554
# Stop incomplete bytes from passing
481-
if (multibyte_fix > 0):
555+
if multibyte_fix > 0:
482556
multibyte_fix -= 1
483557
continue
484558

@@ -531,7 +605,9 @@ def _create_completion(
531605
"model": self.model_path,
532606
"choices": [
533607
{
534-
"text": text[returned_characters:].decode("utf-8", errors="ignore"),
608+
"text": text[returned_characters:].decode(
609+
"utf-8", errors="ignore"
610+
),
535611
"index": 0,
536612
"logprobs": None,
537613
"finish_reason": finish_reason,
@@ -558,7 +634,8 @@ def _create_completion(
558634

559635
all_tokens = prompt_tokens + completion_tokens
560636
all_token_strs = [
561-
self.detokenize([token]).decode("utf-8", errors="ignore") for token in all_tokens
637+
self.detokenize([token]).decode("utf-8", errors="ignore")
638+
for token in all_tokens
562639
]
563640
all_logprobs = [
564641
[Llama.logit_to_logprob(logit) for logit in row]
@@ -577,7 +654,9 @@ def _create_completion(
577654
)
578655
token_logprobs.append(sorted_logprobs[int(token)][0])
579656
top_logprob = {
580-
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8", errors="ignore"): logprob
657+
self.detokenize([llama_cpp.llama_token(i)]).decode(
658+
"utf-8", errors="ignore"
659+
): logprob
581660
for logprob, i in sorted_logprobs[:logprobs]
582661
}
583662
top_logprob.update({token_str: sorted_logprobs[int(token)][0]})

llama_cpp/llama_cpp.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -495,36 +495,40 @@ def llama_sample_softmax(ctx: llama_context_p, candidates):
495495

496496

497497
# @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
498-
def llama_sample_top_k(ctx: llama_context_p, candidates, k: c_int, min_keep: c_int):
498+
def llama_sample_top_k(
499+
ctx: llama_context_p, candidates, k: c_int, min_keep: c_size_t = c_size_t(1)
500+
):
499501
return _lib.llama_sample_top_k(ctx, candidates, k, min_keep)
500502

501503

502504
_lib.llama_sample_top_k.argtypes = [
503505
llama_context_p,
504506
llama_token_data_array_p,
505507
c_int,
506-
c_int,
508+
c_size_t,
507509
]
508510
_lib.llama_sample_top_k.restype = None
509511

510512

511513
# @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
512-
def llama_sample_top_p(ctx: llama_context_p, candidates, p: c_float, min_keep: c_int):
514+
def llama_sample_top_p(
515+
ctx: llama_context_p, candidates, p: c_float, min_keep: c_size_t = c_size_t(1)
516+
):
513517
return _lib.llama_sample_top_p(ctx, candidates, p, min_keep)
514518

515519

516520
_lib.llama_sample_top_p.argtypes = [
517521
llama_context_p,
518522
llama_token_data_array_p,
519523
c_float,
520-
c_int,
524+
c_size_t,
521525
]
522526
_lib.llama_sample_top_p.restype = None
523527

524528

525529
# @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
526530
def llama_sample_tail_free(
527-
ctx: llama_context_p, candidates, z: c_float, min_keep: c_int
531+
ctx: llama_context_p, candidates, z: c_float, min_keep: c_size_t = c_size_t(1)
528532
):
529533
return _lib.llama_sample_tail_free(ctx, candidates, z, min_keep)
530534

@@ -533,21 +537,23 @@ def llama_sample_tail_free(
533537
llama_context_p,
534538
llama_token_data_array_p,
535539
c_float,
536-
c_int,
540+
c_size_t,
537541
]
538542
_lib.llama_sample_tail_free.restype = None
539543

540544

541545
# @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
542-
def llama_sample_typical(ctx: llama_context_p, candidates, p: c_float, min_keep: c_int):
546+
def llama_sample_typical(
547+
ctx: llama_context_p, candidates, p: c_float, min_keep: c_size_t = c_size_t(1)
548+
):
543549
return _lib.llama_sample_typical(ctx, candidates, p, min_keep)
544550

545551

546552
_lib.llama_sample_typical.argtypes = [
547553
llama_context_p,
548554
llama_token_data_array_p,
549555
c_float,
550-
c_int,
556+
c_size_t,
551557
]
552558
_lib.llama_sample_typical.restype = None
553559

0 commit comments

Comments
 (0)