22
33#include " llama.h"
44#include " llama-io.h"
5- #include " llama-graph .h"
5+ #include " llama-memory .h"
66
77#include " ggml-cpp.h"
88
@@ -13,6 +13,17 @@ struct llama_cparams;
1313struct llama_hparams ;
1414struct llama_ubatch ;
1515
16+ struct llama_kv_cache : public llama_memory_i {
17+ using llama_memory_i::llama_memory_i;
18+
19+ virtual int32_t get_n_tokens () const = 0;
20+ virtual uint32_t get_used_cells () const = 0; // TODO: remove, this is too-specific to the unified cache
21+
22+ virtual bool get_can_shift () const = 0;
23+
24+ bool get_can_edit () const override { return get_can_shift (); }
25+ };
26+
1627struct llama_kv_cell {
1728 llama_pos pos = -1 ;
1829 llama_pos delta = 0 ;
@@ -45,36 +56,10 @@ struct llama_kv_cache_slot_info {
4556 operator bool () const { return found; }
4657};
4758
48- struct llama_kv_cache {
49- public:
50- virtual int32_t n_tokens () const = 0;
51- virtual uint32_t used_cells () const = 0; // TODO: remove
52-
53- virtual void clear () = 0;
54- virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
55- virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
56- virtual void seq_keep (llama_seq_id seq_id) = 0;
57- virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0;
58- virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
59-
60- virtual llama_pos seq_pos_max (llama_seq_id seq_id) = 0;
61-
62- virtual void defrag () = 0;
63- virtual bool get_can_shift () const = 0;
64- };
65-
66-
67- // C++ alias
68- class llama_kv_cache_i : public llama_kv_cache {
69- public:
70- using llama_kv_cache::llama_kv_cache;
71- };
72-
73-
7459// ring-buffer of cached KV data
7560// TODO: pimpl
7661// TODO: add notion of max sequences
77- class llama_kv_cache_unified : public llama_kv_cache_i {
62+ class llama_kv_cache_unified : public llama_kv_cache {
7863public:
7964 llama_kv_cache_unified (const llama_hparams & hparams);
8065 virtual ~llama_kv_cache_unified () = default ;
@@ -88,15 +73,16 @@ class llama_kv_cache_unified : public llama_kv_cache_i {
8873 uint32_t kv_size,
8974 bool offload);
9075
91- int32_t n_tokens () const override ;
92- uint32_t used_cells () const override ;
76+ int32_t get_n_tokens () const override ;
77+ uint32_t get_used_cells () const override ;
9378
9479 size_t total_size () const ;
9580
9681 // TODO: better data structures to reduce the cost of this operation
9782 llama_pos pos_max () const ;
9883
9984 void clear () override ;
85+ void defrag () override ;
10086
10187 bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override ;
10288 void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override ;
@@ -106,7 +92,6 @@ class llama_kv_cache_unified : public llama_kv_cache_i {
10692
10793 llama_pos seq_pos_max (llama_seq_id seq_id) override ;
10894
109- void defrag () override ;
11095 bool get_can_shift () const override ;
11196
11297 // find an empty slot of size "n_tokens" in the cache
0 commit comments