|
| 1 | +from __future__ import annotations |
| 2 | + |
1 | 3 | import os |
2 | 4 | import sys |
3 | 5 | import uuid |
|
40 | 42 | ) |
41 | 43 |
|
42 | 44 |
|
43 | | -class LlamaState: |
44 | | - def __init__( |
45 | | - self, |
46 | | - input_ids: npt.NDArray[np.intc], |
47 | | - scores: npt.NDArray[np.single], |
48 | | - n_tokens: int, |
49 | | - llama_state: bytes, |
50 | | - llama_state_size: int, |
51 | | - ): |
52 | | - self.input_ids = input_ids |
53 | | - self.scores = scores |
54 | | - self.n_tokens = n_tokens |
55 | | - self.llama_state = llama_state |
56 | | - self.llama_state_size = llama_state_size |
57 | | - |
58 | | - |
59 | | -LogitsProcessor = Callable[ |
60 | | - [npt.NDArray[np.intc], npt.NDArray[np.single]], npt.NDArray[np.single] |
61 | | -] |
62 | | - |
63 | | - |
64 | | -class LogitsProcessorList(List[LogitsProcessor]): |
65 | | - def __call__( |
66 | | - self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single] |
67 | | - ) -> npt.NDArray[np.single]: |
68 | | - for processor in self: |
69 | | - scores = processor(input_ids, scores) |
70 | | - return scores |
71 | | - |
72 | | - |
73 | | -StoppingCriteria = Callable[[npt.NDArray[np.intc], npt.NDArray[np.single]], bool] |
74 | | - |
75 | | - |
76 | | -class StoppingCriteriaList(List[StoppingCriteria]): |
77 | | - def __call__( |
78 | | - self, input_ids: npt.NDArray[np.intc], logits: npt.NDArray[np.single] |
79 | | - ) -> bool: |
80 | | - return any([stopping_criteria(input_ids, logits) for stopping_criteria in self]) |
81 | | - |
82 | | - |
83 | 45 | class Llama: |
84 | 46 | """High-level Python wrapper for a llama.cpp model.""" |
85 | 47 |
|
@@ -1733,3 +1695,43 @@ def decode(self, tokens: List[int]) -> str: |
1733 | 1695 | @classmethod |
1734 | 1696 | def from_ggml_file(cls, path: str) -> "LlamaTokenizer": |
1735 | 1697 | return cls(Llama(model_path=path, vocab_only=True)) |
| 1698 | + |
| 1699 | + |
| 1700 | +class LlamaState: |
| 1701 | + def __init__( |
| 1702 | + self, |
| 1703 | + input_ids: npt.NDArray[np.intc], |
| 1704 | + scores: npt.NDArray[np.single], |
| 1705 | + n_tokens: int, |
| 1706 | + llama_state: bytes, |
| 1707 | + llama_state_size: int, |
| 1708 | + ): |
| 1709 | + self.input_ids = input_ids |
| 1710 | + self.scores = scores |
| 1711 | + self.n_tokens = n_tokens |
| 1712 | + self.llama_state = llama_state |
| 1713 | + self.llama_state_size = llama_state_size |
| 1714 | + |
| 1715 | + |
| 1716 | +LogitsProcessor = Callable[ |
| 1717 | + [npt.NDArray[np.intc], npt.NDArray[np.single]], npt.NDArray[np.single] |
| 1718 | +] |
| 1719 | + |
| 1720 | + |
| 1721 | +class LogitsProcessorList(List[LogitsProcessor]): |
| 1722 | + def __call__( |
| 1723 | + self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single] |
| 1724 | + ) -> npt.NDArray[np.single]: |
| 1725 | + for processor in self: |
| 1726 | + scores = processor(input_ids, scores) |
| 1727 | + return scores |
| 1728 | + |
| 1729 | + |
| 1730 | +StoppingCriteria = Callable[[npt.NDArray[np.intc], npt.NDArray[np.single]], bool] |
| 1731 | + |
| 1732 | + |
| 1733 | +class StoppingCriteriaList(List[StoppingCriteria]): |
| 1734 | + def __call__( |
| 1735 | + self, input_ids: npt.NDArray[np.intc], logits: npt.NDArray[np.single] |
| 1736 | + ) -> bool: |
| 1737 | + return any([stopping_criteria(input_ids, logits) for stopping_criteria in self]) |
0 commit comments