11"""A pass that rewrites KV cache creation functions in IRModule."""
22
33import json
4- from typing import Any , Dict
4+ from typing import Any , Dict , Literal , Tuple
55
66import tvm
77from tvm import IRModule , relax
88from tvm .relax .frontend .nn .llm import kv_cache
99from tvm .relax .frontend .nn .llm .kv_cache import RopeMode
1010
1111
12- def extract_creation_args (func : relax .Function ) -> Dict [str , Any ]:
12+ def extract_creation_args (func : relax .Function ) -> Tuple [ Literal [ "mha" , "mla" ], Dict [str , Any ] ]:
1313 """Extract the KV cache creation args from the given generic creation func."""
1414 assert isinstance (func .body , relax .SeqExpr )
1515 assert len (func .body .blocks ) == 1
1616 assert isinstance (func .body .blocks [0 ], relax .DataflowBlock )
1717 assert isinstance (func .body .blocks [0 ].bindings [0 ], relax .VarBinding )
1818 assert isinstance (func .body .blocks [0 ].bindings [0 ].value , relax .Call )
1919 assert func .body .blocks [0 ].bindings [0 ].value .op == tvm .ir .Op .get ("relax.call_pure_packed" )
20- args = func .body .blocks [0 ].bindings [0 ].value .args
21- assert isinstance (args [0 ], relax .ExternFunc )
22- assert args [0 ].global_symbol == "mlc.create_paged_kv_cache_generic"
23-
24- assert len (args ) == 15
25- assert isinstance (args [1 ], relax .ShapeExpr )
26- assert len (args [1 ].values ) == 5
27- assert isinstance (args [2 ], relax .ShapeExpr )
28- for i in range (3 , 14 ):
29- if i in [10 , 11 ]:
30- continue
31- assert isinstance (args [i ], relax .PrimValue )
32- assert isinstance (args [i ].value , (tvm .tir .IntImm , tvm .tir .FloatImm ))
33- assert isinstance (args [10 ], relax .StringImm )
34- assert isinstance (args [11 ], (relax .Constant , relax .PrimValue ))
35- assert isinstance (args [14 ], relax .DataTypeImm )
36-
37- return {
38- "max_batch_size" : args [1 ].values [0 ],
39- "max_total_seq_len" : args [1 ].values [1 ],
40- "prefill_chunk_size" : args [1 ].values [2 ],
41- "page_size" : args [1 ].values [3 ],
42- "support_sliding_window" : args [1 ].values [4 ],
43- "layer_partition" : args [2 ],
44- "num_hidden_layers" : args [3 ].value .value ,
45- "num_attention_heads" : args [4 ].value .value ,
46- "num_key_value_heads" : args [5 ].value .value ,
47- "head_dim" : args [6 ].value .value ,
48- "rope_mode" : args [7 ].value .value ,
49- "rope_scale" : args [8 ].value .value ,
50- "rope_theta" : args [9 ].value .value ,
51- "rope_scaling" : json .loads (args [10 ].value ),
52- "rope_ext_factors" : args [11 ],
53- "rotary_dim" : args [12 ].value .value ,
54- "enable_disaggregation" : bool (args [13 ].value .value ),
55- "dtype" : args [14 ].value ,
56- }
20+ call_args = func .body .blocks [0 ].bindings [0 ].value .args
21+ assert isinstance (call_args [0 ], relax .ExternFunc )
22+ assert call_args [0 ].global_symbol == "mlc.create_paged_kv_cache_generic"
23+ assert isinstance (call_args [1 ], relax .StringImm )
24+
25+ args = call_args [1 :]
26+ if args [0 ].value == "mha" :
27+ assert len (args ) == 15
28+ assert isinstance (args [1 ], relax .ShapeExpr )
29+ assert len (args [1 ].values ) == 5
30+ assert isinstance (args [2 ], relax .ShapeExpr )
31+ for i in range (3 , 14 ):
32+ if i in [10 , 11 ]:
33+ continue
34+ assert isinstance (args [i ], relax .PrimValue )
35+ assert isinstance (args [i ].value , (tvm .tir .IntImm , tvm .tir .FloatImm ))
36+ assert isinstance (args [10 ], relax .StringImm )
37+ assert isinstance (args [11 ], (relax .Constant , relax .PrimValue ))
38+ assert isinstance (args [14 ], relax .DataTypeImm )
39+
40+ return "mha" , {
41+ "max_batch_size" : args [1 ].values [0 ],
42+ "max_total_seq_len" : args [1 ].values [1 ],
43+ "prefill_chunk_size" : args [1 ].values [2 ],
44+ "page_size" : args [1 ].values [3 ],
45+ "support_sliding_window" : args [1 ].values [4 ],
46+ "layer_partition" : args [2 ],
47+ "num_hidden_layers" : args [3 ].value .value ,
48+ "num_attention_heads" : args [4 ].value .value ,
49+ "num_key_value_heads" : args [5 ].value .value ,
50+ "head_dim" : args [6 ].value .value ,
51+ "rope_mode" : args [7 ].value .value ,
52+ "rope_scale" : args [8 ].value .value ,
53+ "rope_theta" : args [9 ].value .value ,
54+ "rope_scaling" : json .loads (args [10 ].value ),
55+ "rope_ext_factors" : args [11 ],
56+ "rotary_dim" : args [12 ].value .value ,
57+ "enable_disaggregation" : bool (args [13 ].value .value ),
58+ "dtype" : args [14 ].value ,
59+ }
60+ if call_args [1 ].value == "mla" :
61+ assert len (args ) == 12
62+ assert isinstance (args [1 ], relax .ShapeExpr )
63+ assert len (args [1 ].values ) == 5
64+ assert isinstance (args [2 ], relax .ShapeExpr )
65+ for i in range (3 , 11 ):
66+ assert isinstance (args [i ], relax .PrimValue )
67+ assert isinstance (args [i ].value , tvm .tir .IntImm )
68+ assert isinstance (args [11 ], relax .DataTypeImm )
69+
70+ return "mla" , {
71+ "max_batch_size" : args [1 ].values [0 ],
72+ "max_total_seq_len" : args [1 ].values [1 ],
73+ "prefill_chunk_size" : args [1 ].values [2 ],
74+ "page_size" : args [1 ].values [3 ],
75+ "support_sliding_window" : args [1 ].values [4 ],
76+ "layer_partition" : args [2 ],
77+ "num_hidden_layers" : args [3 ].value .value ,
78+ "num_attention_heads" : args [4 ].value .value ,
79+ "num_key_value_heads" : args [5 ].value .value ,
80+ "qk_nope_head_dim" : args [6 ].value .value ,
81+ "qk_rope_head_dim" : args [7 ].value .value ,
82+ "v_head_dim" : args [8 ].value .value ,
83+ "kv_lora_rank" : args [9 ].value .value ,
84+ "enable_disaggregation" : bool (args [10 ].value .value ),
85+ "dtype" : args [11 ].value ,
86+ }
87+
88+ raise ValueError ("Cannot reach here" )
5789
5890
5991@tvm .transform .module_pass (opt_level = 0 , name = "DispatchKVCacheCreation" )
@@ -100,24 +132,38 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR
100132 if mod .attrs is not None :
101133 new_mod = new_mod .with_attrs (mod .attrs )
102134
103- kwargs = extract_creation_args (creation_func )
104- self .attach_kv_cache_metadata (kwargs )
135+ kv_cache_kind , kwargs = extract_creation_args (creation_func )
136+ self .attach_kv_cache_metadata (kv_cache_kind , kwargs )
105137
106138 bb = relax .BlockBuilder (new_mod )
107- self .create_tir_paged_kv_cache (bb , kwargs )
108- self .create_flashinfer_paged_kv_cache (bb , kwargs )
139+ self .create_tir_paged_kv_cache (bb , kv_cache_kind , kwargs )
140+ self .create_flashinfer_paged_kv_cache (bb , kv_cache_kind , kwargs )
109141 return bb .finalize ()
110142
111- def attach_kv_cache_metadata (self , kwargs : Dict [str , Any ]):
143+ def attach_kv_cache_metadata (
144+ self , kv_cache_kind : Literal ["mha" , "mla" ], kwargs : Dict [str , Any ]
145+ ):
112146 """Attach the KV cache metadata to model metadata."""
113- self .metadata ["kv_cache" ] = {
114- "num_hidden_layers" : kwargs ["num_hidden_layers" ],
115- "num_attention_heads" : kwargs ["num_attention_heads" ],
116- "num_key_value_heads" : kwargs ["num_key_value_heads" ],
117- "head_dim" : kwargs ["head_dim" ],
118- }
119-
120- def create_tir_paged_kv_cache (self , bb : relax .BlockBuilder , kwargs : Dict [str , Any ]) -> None :
147+ if kv_cache_kind == "mha" :
148+ self .metadata ["kv_cache" ] = {
149+ "num_hidden_layers" : kwargs ["num_hidden_layers" ],
150+ "num_attention_heads" : kwargs ["num_attention_heads" ],
151+ "num_key_value_heads" : kwargs ["num_key_value_heads" ],
152+ "head_dim" : kwargs ["head_dim" ],
153+ }
154+ elif kv_cache_kind == "mla" :
155+ self .metadata ["kv_cache" ] = {
156+ "num_hidden_layers" : kwargs ["num_hidden_layers" ],
157+ "num_attention_heads" : kwargs ["num_attention_heads" ],
158+ "num_key_value_heads" : 1 ,
159+ "head_dim" : kwargs ["kv_lora_rank" ] + kwargs ["qk_rope_head_dim" ],
160+ }
161+ else :
162+ raise ValueError ("Cannot reach here." )
163+
164+ def create_tir_paged_kv_cache (
165+ self , bb : relax .BlockBuilder , kv_cache_kind : Literal ["mha" , "mla" ], kwargs : Dict [str , Any ]
166+ ) -> None :
121167 """Create the TIR-based PagedKVCache"""
122168 max_batch_size = relax .Var (
123169 "max_batch_size_" , relax .ShapeStructInfo ([kwargs ["max_batch_size" ]])
@@ -143,16 +189,22 @@ def create_tir_paged_kv_cache(self, bb: relax.BlockBuilder, kwargs: Dict[str, An
143189 support_sliding_window ,
144190 ],
145191 ):
146- cache = kv_cache .TIRPagedKVCache (target = self .target , ** kwargs )
192+ if kv_cache_kind == "mha" :
193+ cache = kv_cache .TIRPagedKVCache (target = self .target , ** kwargs )
194+ elif kv_cache_kind == "mla" :
195+ cache = kv_cache .TIRPagedKVCache .create_mla_kv_cache (target = self .target , ** kwargs )
196+ else :
197+ raise ValueError ("Cannot reach here" )
147198 bb .emit_func_output (cache ._expr ) # pylint: disable=protected-access
148199
149200 def create_flashinfer_paged_kv_cache (
150- self , bb : relax .BlockBuilder , kwargs : Dict [str , Any ]
201+ self , bb : relax .BlockBuilder , kv_cache_kind : Literal [ "mha" , "mla" ], kwargs : Dict [str , Any ]
151202 ) -> None :
152203 """Create the FlashInfer-based PagedKVCache"""
153204 # Filter the cases which FlashInfer does not support.
154205 if ( # pylint: disable=too-many-boolean-expressions
155206 not self .flashinfer
207+ or kv_cache_kind != "mha"
156208 or str (kwargs ["dtype" ]) != "float16"
157209 or kwargs ["head_dim" ] != 128
158210 or (
0 commit comments