@@ -20,25 +20,31 @@ def rotary_embedding(
2020):
2121 r"""
2222 Applies RotaryEmbedding (see https://huggingface.co/papers/2104.09864)
23- on the `query ` or `key` before their multi-head attention computation.
23+ on the `query ` or `key` before their multi-head attention computation.
24+
2425 Args:
25- - query, key (torch.Tensor) : inputs to be applied with position embeddings, taking shape of
26- [batch size, sequence length, num_head/num_kv_head, head_dim]
27- or [num_tokens, num_head/num_kv_head, head_dim] (as well as the output shape).
28- - sin/cos (torch.Tensor): [num_tokens, rotary_dim] the sin/cos value tensor generated to be applied on query/key.
29- - rotary_ndims (int): the rotary dimension. e.g., 64 for GPTJ. head size for LLama.
30- - head_dim (int) : head dim from the input shape.
31- - rotary_half (bool) : if False. e.g., GPT-J 6B/ChatGLM, cos/sin is applied to the neighboring 2 elements,
32- so the offset is 1.
33- if True, e.g., for llama, cos/sin is applied to the neighboring rotary_dim elements,
34- so the offset is rotary_dim/2.
35- - position_ids (torch.Tensor): Default is None and optional if sin/cos is provided. the according position_ids
36- for the input. The shape should be [batch size, sequence length].
26+ query, key (torch.Tensor) : inputs to be applied with position embeddings,
27+ taking shape of [batch size, sequence length, num_head/num_kv_head, head_dim]
28+ or [num_tokens, num_head/num_kv_head, head_dim] (as well as the output shape).
29+ sin/cos (torch.Tensor): [num_tokens, rotary_dim] the sin/cos value tensor
30+ generated to be applied on query/key.
31+ rotary_ndims (int): the rotary dimension. e.g., 64 for GPTJ. head size for LLama.
32+ head_dim (int) : head dim from the input shape.
33+ rotary_half (bool) : if False. e.g., GPT-J 6B/ChatGLM, cos/sin is applied to the neighboring 2 elements,
34+ so the offset is 1.
35+
36+ if True, e.g., for llama, cos/sin is applied to the neighboring rotary_dim elements,
37+ so the offset is rotary_dim/2.
38+
39+ position_ids (torch.Tensor): Default is None and optional if sin/cos is provided.
40+ The according position_ids for the input. The shape should be [batch size, sequence length].
41+
3742 Return
38- - query, key (torch.Tensor): [batch size, sequence length, num_head/num_kv_head, head_dim]
39- or [num_tokens, num_head/num_kv_head, head_dim].
43+ query, key (torch.Tensor): [batch size, sequence length, num_head/num_kv_head, head_dim]
44+ or [num_tokens, num_head/num_kv_head, head_dim].
4045
4146 """
47+
4248 return RotaryEmbedding .apply_function (
4349 query , key , sin , cos , rotary_dim , rotary_half , position_ids
4450 )
@@ -48,12 +54,14 @@ def rms_norm(hidden_states: torch.Tensor, weight: torch.Tensor, eps: float):
4854 r"""
4955 Applies RMSnorm on the input (hidden states).
5056 (see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L76)
57+
5158 Args:
52- - hidden_states(torch.Tensor) : the input tensor to apply RMSNorm.
53- - weight (torch.Tensor): the weight to apply RMSnorm.
54- - eps (float) : the variance_epsilon to apply RMSnorm.
59+ hidden_states(torch.Tensor) : the input tensor to apply RMSNorm.
60+ weight (torch.Tensor): the weight to apply RMSnorm.
61+ eps (float) : the variance_epsilon to apply RMSnorm.
5562
5663 """
64+
5765 return RMSNorm .apply_function (hidden_states , weight , eps )
5866
5967
@@ -67,12 +75,14 @@ def fast_layer_norm(
6775 r"""
6876 Applies PyTorch Layernorm (see https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html)
6977 on the input (hidden states).
78+
7079 Args:
71- - hidden_states(torch.Tensor) : the input tensor to apply normalization.
72- - normalized_shape (int or list) or torch.Size) input shape from an expected input of size.
73- - weight (torch.Tensor): the weight to apply normalization.
74- - bias (torch.Tensor): an additive bias for normalization.
75- - eps (float): a value added to the denominator for numerical stability.
80+ hidden_states(torch.Tensor) : the input tensor to apply normalization.
81+ normalized_shape (int or list) or torch.Size) input shape from an
82+ expected input of size.
83+ weight (torch.Tensor): the weight to apply normalization.
84+ bias (torch.Tensor): an additive bias for normalization.
85+ eps (float): a value added to the denominator for numerical stability.
7686
7787 """
7888
@@ -103,33 +113,49 @@ def indirect_access_kv_cache_attention(
103113 buffers(key and value use different buffers) to store all key/value hidden states and beam index information.
104114 It can use beam index history to decide which beam should be used by a timestamp and this information will
105115 generate an offset to access the kv_cache buffer.
116+
106117 Data Format:
107- - The shape of the pre-allocated key(value) buffer is [max_seq, beam*batch, head_num, head_size],
108- the hidden state of key/value which is the shape of [beam*batch, head_num, head_size] is stored token by token.
109- All beam idx information of every timestamp is also stored in a Tensor with the shape of [max_seq, beam*batch].
110-
111- forward
112- - query (torch.Tensor): Query tensor; shape: (beam*batch, seq_len, head_num, head_dim).
113- - key (torch.Tensor): Key tensor; shape: (beam*batch, seq_len, head_num, head_dim).
114- - value (torch.Tensor): Value tensor; shape: (beam*batch, seq_len, head_num, head_dim).
115- - scale_attn (float):scale used by the attention layer. should be the sqrt(head_size).
116- - layer_past (tuple(torch.Tensor)): tuple(seq_info, key_cache, value_cache, beam-idx).
117- key_cache: key cache tensor, shape: (max_seq, beam*batch, head_num, head_dim);
118- value_cache: value cache tensor, shape: (max_seq, beam*batch, head_num, head_dim);
119- beam-idx: history beam idx, shape:(max_seq, beam*batch);
120- seq_info: Sequence info tensor, shape:(1, 1, max_seq, max_seq).
121- - head_mask (torch.Tensor): Head mask tensor which is not supported by kernel yet.
122- - attention_mask(torch.Tensor): Attention mask information.
123- - text_max_length (int) : the max length of kv cache to be used for generation (allocate the pre-cache buffer).
118+
119+ The shape of the pre-allocated key(value) buffer is [max_seq, beam*batch, head_num, head_size],
120+ the hidden state of key/value which is the shape of [beam*batch, head_num, head_size] is stored token by token.
121+ All beam idx information of every timestamp is also stored in a Tensor with the shape of [max_seq, beam*batch].
122+
123+ Args:
124+ query (torch.Tensor): Query tensor; shape: (beam*batch, seq_len, head_num, head_dim).
125+ key (torch.Tensor): Key tensor; shape: (beam*batch, seq_len, head_num, head_dim).
126+ value (torch.Tensor): Value tensor; shape: (beam*batch, seq_len, head_num, head_dim).
127+ scale_attn (float):scale used by the attention layer. should be the sqrt(head_size).
128+ layer_past (tuple(torch.Tensor)): tuple(seq_info, key_cache, value_cache, beam-idx).
129+
130+ - key_cache: key cache tensor, shape: (max_seq, beam*batch, head_num, head_dim);
131+
132+ - value_cache: value cache tensor, shape: (max_seq, beam*batch, head_num, head_dim);
133+
134+ - beam-idx: history beam idx, shape:(max_seq, beam*batch);
135+
136+ - seq_info: Sequence info tensor, shape:(1, 1, max_seq, max_seq).
137+
138+ head_mask (torch.Tensor): Head mask tensor which is not supported by kernel yet.
139+ attention_mask(torch.Tensor): Attention mask information.
140+ text_max_length (int) : the max length of kv cache to be used for generation
141+ (allocate the pre-cache buffer).
124142
125143 Return:
126- - attn_output: weighted value which is the output of scale dot product. shape (beam*batch, seq_len, head_num, head_size).
127- - attn_weights: The output tensor of the first matmul in scale dot product which is not supported by kernel now.
128- - new_layer_past: updated layer_past (seq_info, key_cache, value_cache, beam-idx).
144+ attn_output: weighted value which is the output of scale dot product.
145+ shape (beam*batch, seq_len, head_num, head_size).
146+
147+ attn_weights: the output tensor of the first matmul in scale dot product
148+ which is not supported by kernel now.
149+
150+ new_layer_past: updated layer_past (seq_info, key_cache, value_cache, beam-idx).
129151
130152 Notes:
131- - How to reorder KV cache when using the format of IndirectAccessKVCacheAttention (e.g., on llama model
132- see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1318)
153+ How to reorder KV cache when using the format of IndirectAccessKVCacheAttention (e.g., on llama model
154+ see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1318)
155+
156+ .. highlight:: python
157+ .. code-block:: python
158+
133159 def _reorder_cache(
134160 self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
135161 ) -> Tuple[Tuple[torch.Tensor]]:
@@ -141,6 +167,7 @@ def _reorder_cache(
141167 return past_key_values
142168
143169 """
170+
144171 return IndirectAccessKVCacheAttention .apply_function (
145172 query ,
146173 key ,
@@ -174,23 +201,30 @@ def varlen_attention(
174201):
175202 r"""
176203 Applies PyTorch scaled_dot_product_attention on the inputs of query, key and value
177- (see https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html),
178- and accept the variant (different) sequence length among the query, key and value.
204+ (see https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html),
205+ and accept the variant (different) sequence length among the query, key and value.
206+
207+ This module does not have args for `module init`.
208+
209+ `forward()`
179210
180211 Args:
181- module init: this module does not have args for module init
182- forward:
183- - query (torch.Tensor): shape [query_tokens, num_head, head_size], where tokens is total sequence length among batch size.
184- - key (torch.Tensor): shape [key_tokens, num_head, head_size], where tokens is total sequence length among batch size.
185- - value (torch.Tensor): shape [value_tokens, num_head, head_size], where tokens is total sequence length among batch size.
186- - out (torch.Tensor): buffer to get the results, the shape is the same as query.
187- - seqlen_q (torch.Tensor): shape [batch_size + 1], points the current query_tokens among total sequence length.
188- - seqlen_k (torch.Tensor): shape [batch_size + 1], points the current key_tokens among total sequence length.
189- - max_seqlen_q (int): max/total sequence length of query.
190- - max_seqlen_k (int): max/total sequence length of key.
191- - pdropout (float): dropout probability; if greater than 0.0, dropout is applied, default is 0.0.
192- - softmax_scale (float): scaling factor applied is prior to softmax.
193- - is_causal (bool): whether to apply causal attention masking, default is True.
212+ query (torch.Tensor): shape [query_tokens, num_head, head_size],
213+ where tokens is total sequence length among batch size.
214+ key (torch.Tensor): shape [key_tokens, num_head, head_size],
215+ where tokens is total sequence length among batch size.
216+ value (torch.Tensor): shape [value_tokens, num_head, head_size],
217+ where tokens is total sequence length among batch size.
218+ out (torch.Tensor): buffer to get the results, the shape is the same as query.
219+ seqlen_q (torch.Tensor): shape [batch_size + 1],
220+ points the current query_tokens among total sequence length.
221+ seqlen_k (torch.Tensor): shape [batch_size + 1],
222+ points the current key_tokens among total sequence length.
223+ max_seqlen_q (int): max/total sequence length of query.
224+ max_seqlen_k (int): max/total sequence length of key.
225+ pdropout (float): dropout probability; if greater than 0.0, dropout is applied, default is 0.0.
226+ softmax_scale (float): scaling factor applied is prior to softmax.
227+ is_causal (bool): whether to apply causal attention masking, default is True.
194228
195229 """
196230 return VarlenAttention .apply_function (
0 commit comments