diff --git a/docs/source/features/auto_deploy/support_matrix.md b/docs/source/features/auto_deploy/support_matrix.md index a41090932f0..26c07b308b8 100644 --- a/docs/source/features/auto_deploy/support_matrix.md +++ b/docs/source/features/auto_deploy/support_matrix.md @@ -83,6 +83,7 @@ In addition, the following models have been officially validated using the defau - nvidia/Llama-3_1-Nemotron-Ultra-253B-v1-FP8 - nvidia/Llama-3_3-Nemotron-Super-49B-v1 - nvidia/Mistral-NeMo-Minitron-8B-Base +- nvidia/Nemotron-Flash-3B-Instruct - perplexity-ai/r1-1776-distill-llama-70b diff --git a/examples/auto_deploy/.gitignore b/examples/auto_deploy/.gitignore index 36cf5a6dd49..f15c233aee9 100644 --- a/examples/auto_deploy/.gitignore +++ b/examples/auto_deploy/.gitignore @@ -5,3 +5,4 @@ benchmark_results.json # ignore config files that users might put here for debugging *.yaml !nano_v3.yaml +!nemotron_flash.yaml diff --git a/examples/auto_deploy/nemotron_flash.yaml b/examples/auto_deploy/nemotron_flash.yaml new file mode 100644 index 00000000000..cbb692a70ed --- /dev/null +++ b/examples/auto_deploy/nemotron_flash.yaml @@ -0,0 +1,11 @@ +compile_backend: torch-cudagraph +max_batch_size: 384 +max_seq_len: 2097152 +max_num_tokens: 8192 +enable_chunked_prefill: true +model_factory: NemotronFlashForCausalLM +free_mem_ratio: 0.9 +cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 24, 32, 64,96, 128, 256, 320, 384] +kv_cache_config: + # disable kv_cache reuse since not supported for hybrid/ssm models + enable_block_reuse: false diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index 5aab6310db2..91e6dbf75d3 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -152,6 +152,9 @@ transforms: insert_cached_causal_conv: stage: cache_init backend: cuda_causal_conv + insert_cached_delta_rule: + stage: cache_init + backend: fla_delta initialize_cache: stage: cache_init run_per_gm: false diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py index b34e08178c2..818e560578a 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py @@ -34,8 +34,11 @@ class CacheConfig(BaseModel): dtype: Optional[torch.dtype] = Field(default=None, description="KV cache dtype.") mamba_dtype: Optional[torch.dtype] = Field(default=None, description="Mamba cache dtype.") + delta_dtype: Optional[torch.dtype] = Field( + default=torch.float32, description="Delta cache dtype. Defaults to float32." + ) - @field_validator("dtype", "mamba_dtype", mode="before") + @field_validator("dtype", "mamba_dtype", "delta_dtype", mode="before") @classmethod def _coerce_dtype(cls, value): if value is None or isinstance(value, torch.dtype): diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/__init__.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/delta_rule/__init__.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/delta_rule/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/delta_rule/chunk.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/delta_rule/chunk.py new file mode 100644 index 00000000000..a2a898cf9f1 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/delta_rule/chunk.py @@ -0,0 +1,42 @@ +# Adapted from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/delta_rule/chunk.py + +from typing import Optional + +import torch + +from tensorrt_llm._torch.modules.fla.chunk_delta_h import chunk_gated_delta_rule_fwd_h +from tensorrt_llm._torch.modules.fla.chunk_o import chunk_fwd_o + +from .wy_fast import prepare_wy_repr_fwd + + +def chunk_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, +): + # obtain WY representation. u is actually the new v. + w, u, A = prepare_wy_repr_fwd( + k=k, + v=v, + beta=beta, + cu_seqlens=cu_seqlens, + ) + h, v_new, final_state = chunk_gated_delta_rule_fwd_h( + k=k, + w=w, + u=u, + g=None, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + ) + + o = chunk_fwd_o(q=q, k=k, v=v_new, h=h, g=None, scale=scale, cu_seqlens=cu_seqlens) + + return o, A, final_state diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/delta_rule/fused_recurrent.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/delta_rule/fused_recurrent.py new file mode 100644 index 00000000000..c1f19993f9e --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/delta_rule/fused_recurrent.py @@ -0,0 +1,160 @@ +# Adapted from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/delta_rule/fused_recurrent.py +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + + +@triton.heuristics( + { + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "STORE_FINAL_STATE": lambda args: args["ht"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.jit(do_not_specialize=["T"]) +def fused_recurrent_delta_rule_fwd_kernel( + q, + k, + v, + u, + beta, + o, + h0, + ht, + cu_seqlens, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + IS_BETA_HEADWISE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_k, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int64), + tl.load(cu_seqlens + i_n + 1).to(tl.int64), + ) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + + p_q = q + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK) + p_k = k + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK) + p_v = v + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) + p_u = u + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) + if IS_BETA_HEADWISE: + p_beta = beta + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) + else: + p_beta = beta + bos * H + i_h + p_o = o + ((i_k * all + bos) * H + i_h) * V + i_v * BV + tl.arange(0, BV) + + mask_k = (i_k * BK + tl.arange(0, BK)) < K + mask_v = (i_v * BV + tl.arange(0, BV)) < V + mask_h = mask_k[None, :] & mask_v[:, None] + + b_h = tl.zeros([BV, BK], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = ( + h0 + + i_nh * K * V + + (i_k * BK + tl.arange(0, BK)[None, :]) * V + + (i_v * BV + tl.arange(0, BV)[:, None]) + ) + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for _ in range(0, T): + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale + b_v_minus = tl.sum(b_h * b_k[None, :], axis=1) + b_v -= b_v_minus + if IS_BETA_HEADWISE: + b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + tl.store(p_u, b_v.to(p_v.dtype.element_ty), mask=mask_v) + b_v *= b_beta + b_h += b_k[None, :] * b_v[:, None] + b_o = b_h * b_q[None, :] + b_o = tl.sum(b_o, axis=1) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + + p_q += H * K + p_k += H * K + p_o += H * V + p_v += H * V + p_u += H * V + p_beta += H * (V if IS_BETA_HEADWISE else 1) + + if STORE_FINAL_STATE: + p_ht = ( + ht + + i_nh * K * V + + (i_k * BK + tl.arange(0, BK)[None, :]) * V + + (i_v * BV + tl.arange(0, BV)[:, None]) + ) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + +def fused_recurrent_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, "NK > 1 is not supported yet" + num_stages = 1 + num_warps = 1 + + o = q.new_empty(NK, *v.shape) + if output_final_state: + final_state = q.new_empty(N, H, K, V, dtype=torch.float32) + else: + final_state = None + + grid = (NV, NK, N * H) + u = torch.empty_like(v) + fused_recurrent_delta_rule_fwd_kernel[grid]( + q, + k, + v, + u, + beta, + o, + initial_state, + final_state, + cu_seqlens, + scale, + T=T, + B=B, + H=H, + K=K, + V=V, + BK=BK, + BV=BV, + IS_BETA_HEADWISE=beta.ndim == v.ndim, + num_warps=num_warps, + num_stages=num_stages, + ) + o = o.squeeze(0) + return o, u, final_state diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/delta_rule/utils.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/delta_rule/utils.py new file mode 100644 index 00000000000..c02e5e546aa --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/delta_rule/utils.py @@ -0,0 +1,11 @@ +# Adapted from https://github.com/fla-org/flash-linear-attention/blob/main/fla/utils.py +import inspect +import os + +import triton + +FLA_CACHE_RESULTS = os.getenv("FLA_CACHE_RESULTS", "1") == "1" + + +supports_autotune_cache = "cache_results" in inspect.signature(triton.autotune).parameters +autotune_cache_kwargs = {"cache_results": FLA_CACHE_RESULTS} if supports_autotune_cache else {} diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/delta_rule/wy_fast.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/delta_rule/wy_fast.py new file mode 100644 index 00000000000..d207d3fead3 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/delta_rule/wy_fast.py @@ -0,0 +1,152 @@ +# Adapted from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/delta_rule/wy_fast.py + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from tensorrt_llm._torch.modules.fla.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd +from tensorrt_llm._torch.modules.fla.index import prepare_chunk_indices +from tensorrt_llm._torch.modules.fla.solve_tril import solve_tril +from tensorrt_llm._torch.modules.fla.utils import check_shared_mem + +from .utils import autotune_cache_kwargs + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def recompute_w_u_fwd_kernel( + k, + v, + beta, + w, + u, + A, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_A = tl.make_block_ptr( + A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0) + ) + b_beta = tl.load(p_beta, boundary_check=(0,)) + b_A = tl.load(p_A, boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr( + v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0) + ) + p_u = tl.make_block_ptr( + u + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0) + ) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype) + b_u = tl.dot(b_A.to(b_vb.dtype), b_vb, allow_tf32=False) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr( + k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0) + ) + p_w = tl.make_block_ptr( + w + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0) + ) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + b_w = tl.dot(b_A.to(b_kb.dtype), b_kb, allow_tf32=False) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + +def prepare_wy_repr_fwd( + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor], +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + A = chunk_scaled_dot_kkt_fwd( + k=k, + beta=beta, + cu_seqlens=cu_seqlens, + chunk_size=64, + output_dtype=torch.float32, + ) + A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) + w, u = recompute_w_u_fwd( + k=k, + v=v, + beta=beta, + A=A, + cu_seqlens=cu_seqlens, + ) + return w, u, A + + +def recompute_w_u_fwd( + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + A: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor], +) -> Tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = 64 + CONST_TILING = 64 if check_shared_mem() else 32 + BK = min(max(triton.next_power_of_2(K), 16), CONST_TILING) + BV = min(max(triton.next_power_of_2(V), 16), CONST_TILING) + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + u = torch.empty_like(v) + w = torch.empty_like(k) + recompute_w_u_fwd_kernel[(NT, B * H)]( + k, + v, + beta, + w, + u, + A, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + return w, u diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_delta.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_delta.py new file mode 100644 index 00000000000..7cff09435f6 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_delta.py @@ -0,0 +1,256 @@ +"""Cached attention op for chunked delta rule using the fla kernel library.""" + +from typing import List, Tuple + +import torch +from torch._ops import OpOverloadPacket +from torch.fx import Node + +from ...utils.node_utils import extract_op_args +from ..attention_interface import ( + AttentionDescriptor, + AttentionLayout, + AttentionRegistry, + BufferInitializerDict, + CacheConfig, + CacheInitializerDict, + Constant, + MHACallable, + PrepareMetadataCallable, + SequenceInfo, +) +from .delta_rule.chunk import chunk_delta_rule_fwd +from .delta_rule.fused_recurrent import fused_recurrent_delta_rule_fwd + + +@torch.library.custom_op("auto_deploy::fla_delta_prepare_metadata", mutates_args=()) +def fla_delta_prepare_metadata( + position_ids: torch.Tensor, + seq_len: torch.Tensor, + input_pos: torch.Tensor, + cache_loc: torch.Tensor, + pages_per_seq: torch.Tensor, + slot_idx: torch.Tensor, + page_size: int, + chunk_size: int, +) -> List[torch.Tensor]: + """Prepare metadata for cached chunked delta rule. + + Returns a tuple of (cu_seq_lens, slot_idx_sanitized, use_initial_states, batch_info_tensor). + """ + seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len) + num_seq = len(seq_len_sanitized) + cu_seqlens = torch.zeros(num_seq + 2, dtype=torch.int32, device=seq_len_sanitized.device) + + slot_idx_sanitized = slot_idx[:num_seq].clone().to(torch.long) + use_initial_states = input_pos[:num_seq] > 0 + + _, s = position_ids.shape[:2] + if s > 1: + prefill_mask = seq_len_sanitized > 1 + num_prefill = int(prefill_mask.sum().item()) + num_prefill_tokens = int(seq_len_sanitized[:num_prefill].sum().item()) + num_decode = num_seq - num_prefill + + # compute cu_seq_lens for the prefill sequences first + cu_seqlens[1 : num_prefill + 1] = torch.cumsum(seq_len_sanitized[:num_prefill], 0) + else: + num_prefill = 0 + num_prefill_tokens = 0 + num_decode = num_seq + + # decode is just arange... + cu_seqlens[num_prefill + 1 :] = torch.arange( + num_decode + 1, device=cu_seqlens.device, dtype=cu_seqlens.dtype + ) + batch_info_tensor = torch.tensor( + [num_prefill, num_prefill_tokens, num_decode], dtype=torch.int32 + ) + + return cu_seqlens, slot_idx_sanitized, use_initial_states, batch_info_tensor + + +@fla_delta_prepare_metadata.register_fake +def fla_delta_prepare_metadata_fake( + position_ids, + seq_len, + input_pos, + cache_loc, + pages_per_seq, + slot_idx, + page_size, + chunk_size, +): + seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len) + num_seq = len(seq_len_sanitized) + cu_seq_lens = torch.empty(num_seq + 2, dtype=torch.int32, device=seq_len_sanitized.device) + return ( + cu_seq_lens, + torch.empty(num_seq, dtype=torch.long, device=slot_idx.device), + torch.empty(num_seq, dtype=torch.bool, device=slot_idx.device), + torch.empty(3, dtype=torch.int32), # host tensor + ) + + +@torch.library.custom_op("auto_deploy::fla_cached_delta_rule", mutates_args=()) +def fla_cached_delta_rule( + # INPUTS (dense but may be flattened across sequences) + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + # METADATA + cu_seqlens: torch.Tensor, # [num_seq + 1] + slot_idx: torch.Tensor, # [num_seq] + use_initial_states: torch.Tensor, # [num_seq] + batch_info_tensor: torch.Tensor, # [3] + # CACHES + delta_cache: torch.Tensor, # [max_batch_size, H, K, V] + # CONSTANTS + scale: float, +) -> torch.Tensor: + b, s, num_heads, _ = q.shape + + # flatten it + q_flat = q.view(1, b * s, num_heads, -1) + k_flat = k.view(1, b * s, num_heads, -1) + v_flat = v.view(1, b * s, num_heads, -1) + beta_flat = beta.view(1, b * s, num_heads) + + # pre-allocate output + y = torch.empty_like(v, memory_format=torch.contiguous_format) + y_flat = y.view(1, b * s, num_heads, -1) + + num_prefill, num_prefill_tokens, num_decode = batch_info_tensor.tolist() + + if num_prefill > 0: + initial_states = None + if torch.any(use_initial_states[:num_prefill]): + initial_states = torch.where( + use_initial_states[:num_prefill, None, None, None], + delta_cache[slot_idx[:num_prefill]], + 0, + ) + + y_prefill, _, final_state = chunk_delta_rule_fwd( + q=q_flat[:, :num_prefill_tokens], + k=k_flat[:, :num_prefill_tokens], + v=v_flat[:, :num_prefill_tokens], + beta=beta_flat[:, :num_prefill_tokens], + scale=scale, + initial_state=initial_states, + output_final_state=True, + cu_seqlens=cu_seqlens[: num_prefill + 1], + ) + + y_flat[:, :num_prefill_tokens] = y_prefill.to(y_flat.dtype) + delta_cache.index_copy_(0, slot_idx[:num_prefill], final_state.to(delta_cache.dtype)) + + del y_prefill, initial_states, final_state + + if num_decode > 0: + # NOTE: avoiding state clone here and adopting the kernel to handle + # indexed initial states would give a boost + y_decode, _, final_state = fused_recurrent_delta_rule_fwd( + q=q_flat[:, num_prefill_tokens:], + k=k_flat[:, num_prefill_tokens:], + v=v_flat[:, num_prefill_tokens:], + beta=beta_flat[:, num_prefill_tokens:], + scale=scale, + initial_state=delta_cache[slot_idx[num_prefill:]].clone(), + output_final_state=True, + cu_seqlens=cu_seqlens[num_prefill + 1 :], + ) + + y_flat[:, num_prefill_tokens:] = y_decode.to(y_flat.dtype) + delta_cache.index_copy_(0, slot_idx[num_prefill:], final_state.to(delta_cache.dtype)) + + del y_decode, final_state + + return y + + +@fla_cached_delta_rule.register_fake +def fla_cached_delta_rule_fake( + # INPUTS (dense but may be flattened across sequences) + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + # METADATA + cu_seqlens: torch.Tensor, # [num_seq + 1] + slot_idx: torch.Tensor, # [num_seq] + use_initial_states: torch.Tensor, # [num_seq] + batch_info_tensor: torch.Tensor, # [3] + # CACHES + delta_cache: torch.Tensor, # [max_batch_size, H, K, V] + # CONSTANTS + scale: float, +) -> torch.Tensor: + return torch.empty_like(v) + + +@AttentionRegistry.register("fla_delta") +class FlaDeltaBackend(AttentionDescriptor): + @classmethod + def is_paged(cls) -> bool: + # TODO: we should refine our notion of "is_paged" --> seems counterintuitive for ssm nows + return True + + @classmethod + def get_attention_layout(cls) -> AttentionLayout: + return "bsnd" + + @classmethod + def get_num_qkv_args(cls) -> int: + # q, k, v, beta + return 4 + + @classmethod + def get_source_attention_op(cls) -> OpOverloadPacket: + return torch.ops.auto_deploy.fla_delta_rule + + @classmethod + def get_cached_attention_op(cls) -> MHACallable: + return torch.ops.auto_deploy.fla_cached_delta_rule + + @classmethod + def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]: + # Returns (cu_seq_lens, slot_idx, use_initial_states, batch_info_tensor) + return torch.ops.auto_deploy.fla_delta_prepare_metadata, 4 + + @classmethod + def get_cache_initializers( + cls, source_attn_node: Node, cache_config: CacheConfig + ) -> CacheInitializerDict: + key_node = source_attn_node.args[1] + value_node = source_attn_node.args[2] + num_heads = key_node.meta["val"].shape[-2] + key_dim = key_node.meta["val"].shape[-1] + value_dim = value_node.meta["val"].shape[-1] + key_dtype = key_node.meta["val"].dtype + + def _get_delta_cache(si: SequenceInfo): + return torch.empty( + si.max_batch_size, + num_heads, + key_dim, + value_dim, + device=si.device, + dtype=cache_config.delta_dtype or key_dtype, + ) + + return {"delta_cache": _get_delta_cache} + + @classmethod + def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitializerDict: + return {} + + @classmethod + def get_constants(cls, source_attn_node: Node) -> List[Constant]: + scale = extract_op_args(source_attn_node, "scale")[0] + if scale is None: + key_node = source_attn_node.args[1] + key_dim = key_node.meta["val"].shape[-1] + scale = key_dim**-0.5 + return [scale] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_delta.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_delta.py new file mode 100644 index 00000000000..d7b38f73f36 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_delta.py @@ -0,0 +1,32 @@ +"""Custom ops corresponding to fla's chunked delta rule.""" + +from typing import Optional + +import torch + +from .delta_rule.chunk import chunk_delta_rule_fwd + + +@torch.library.custom_op("auto_deploy::fla_delta_rule", mutates_args=()) +def fla_chunked_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + scale: Optional[float] = None, +) -> torch.Tensor: + o, A, final_state = chunk_delta_rule_fwd( + q, k, v, beta, scale, initial_state=None, output_final_state=False, cu_seqlens=None + ) + return o + + +@fla_chunked_delta_rule.register_fake +def fla_chunked_delta_rule_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + scale: Optional[float] = None, +) -> torch.Tensor: + return torch.empty_like(v) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/l2norm.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/l2norm.py new file mode 100644 index 00000000000..2bd7ab6d19f --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/l2norm.py @@ -0,0 +1,40 @@ +"""Custom ops corresponding to l2norm.""" + +import torch + +from tensorrt_llm._torch.modules.fla.l2norm import l2norm_fwd + +# TODO: add a pattern matcher for this such that +# 1. pattern match to torch_l2norm +# 2. fuse transform to map to desired backend like fla + + +@torch.library.custom_op("auto_deploy::torch_l2norm", mutates_args=()) +def _torch_l2norm(x: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: + """ + x: (..., D) + returns: + y: (..., D) # normalized + rstd: (...,) # 1/sqrt(sum(x^2)+eps) along `dim` + """ + x_f32 = x.float() + s = (x_f32 * x_f32).sum(dim=-1, keepdim=True) # (..., 1) + rstd = torch.rsqrt(s + eps) # (..., 1) + y = (x_f32 * rstd).to(x.dtype) # cast back + return y + + +@_torch_l2norm.register_fake +def _torch_l2norm_fake(x: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: + return torch.empty_like(x) + + +@torch.library.custom_op("auto_deploy::fla_l2norm", mutates_args=()) +def fla_l2norm(x: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: + y = l2norm_fwd(x, eps) + return y + + +@fla_l2norm.register_fake +def fla_l2norm_fake(x: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: + return torch.empty_like(x) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_causal_conv.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_causal_conv.py index 86ba58833ef..a53f9840128 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_causal_conv.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_causal_conv.py @@ -21,15 +21,19 @@ def _torch_causal_conv1d( batch_size, seq_len, _ = input.shape - return F.conv1d( - input.transpose(1, 2), - weight, - bias, - stride, - padding, - dilation, - groups, - )[..., :seq_len].transpose(1, 2) + return ( + F.conv1d( + input.transpose(1, 2), + weight, + bias, + stride, + padding, + dilation, + groups, + )[..., :seq_len] + .transpose(1, 2) + .contiguous() + ) @_torch_causal_conv1d.register_fake diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py index 708449ea732..b4701781570 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py @@ -74,7 +74,7 @@ def torch_rmsnorm(input: torch.Tensor, weight: torch.Tensor, eps: float) -> torc input = input.to(torch.float32) variance = input.pow(2).mean(-1, keepdim=True) input = input * torch.rsqrt(variance + eps) - return weight * input.to(input_dtype) + return (weight * input.to(input_dtype)).contiguous() @torch_rmsnorm.register_fake diff --git a/tensorrt_llm/_torch/auto_deploy/models/__init__.py b/tensorrt_llm/_torch/auto_deploy/models/__init__.py index 49a1c8b9377..6eae19f23c5 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/__init__.py +++ b/tensorrt_llm/_torch/auto_deploy/models/__init__.py @@ -1,4 +1,4 @@ # TODO: When getting rid of the nemotron H patches, import `modeling_nemotron_h` here to ensure the # custom model implementation is registered. -from . import hf, patches +from . import custom, hf, nemotron_flash, patches from .factory import * diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py b/tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py new file mode 100644 index 00000000000..fef9fdb1660 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py @@ -0,0 +1 @@ +from .modeling_nemotron_flash import NemotronFlashForCausalLM, NemotronFlashPreTrainedTokenizerFast diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_flash.py b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_flash.py new file mode 100644 index 00000000000..e4f73cb465a --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_flash.py @@ -0,0 +1,1102 @@ +# Adapted from: https://huggingface.co/nvidia/Nemotron-Flash-3B-Instruct/tree/main +import copy +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import PreTrainedTokenizerFast +from transformers.activations import ACT2FN +from transformers.generation.utils import GenerationMixin +from transformers.modeling_outputs import CausalLMOutput, MoeModelOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.tokenization_utils_base import BatchEncoding + +from tensorrt_llm.inputs.utils import HF_CHAT_TEMPLATE_EXCEPTIONS + +from ..nemotron_flash import NemotronFlashForCausalLMFactory + + +class NemotronFlashPreTrainedTokenizerFast(PreTrainedTokenizerFast): + def __init__(self, *args, num_memory_tokens: int = 0, vocab_size_model: int = 0, **kwargs): + super().__init__(*args, **kwargs) + self.num_memory_tokens = num_memory_tokens + self.num_dummy_tokens = max(0, vocab_size_model - len(self)) + self.add_tokens([f"" for i in range(self.num_dummy_tokens)], special_tokens=True) + + self.mem_tokens = [f"" for i in range(self.num_memory_tokens)] + self.mem_token_ids = list(range(len(self), len(self) + self.num_memory_tokens)) + self.add_tokens(self.mem_tokens, special_tokens=True) + + if getattr(self, "chat_template", None) is None: + self.chat_template = ( + "{% for m in messages %}" + "{{ m['content'] }}" + "{% if not loop.last %}\n{% endif %}" + "{% endfor %}" + ) + self.model_input_names = ["input_ids"] + + def _add_memory_tokens(self, input_ids) -> Union[List[List[int]], torch.Tensor]: + is_unbatched = True + if isinstance(input_ids, list) and isinstance(input_ids[0], int): + input_ids = [input_ids] + elif isinstance(input_ids, torch.Tensor) and input_ids.ndim == 1: + input_ids = input_ids[None] + else: + is_unbatched = False + + if isinstance(input_ids, list): + input_ids = [self.mem_token_ids + _ids for _ids in input_ids] + elif isinstance(input_ids, torch.Tensor): + mem_token_tnsr = torch.tensor( + self.mem_token_ids, device=input_ids.device, dtype=input_ids.dtype + ) + input_ids = torch.cat((mem_token_tnsr[None], input_ids), dim=1) + else: + raise ValueError(f"Unsupported input type {type(input_ids)}") + + if is_unbatched: + return input_ids[0] + else: + return input_ids + + def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding: + batch_encoding = super()._batch_encode_plus(*args, **kwargs) + batch_encoding.data["input_ids"] = self._add_memory_tokens(batch_encoding.data["input_ids"]) + return batch_encoding + + def _decode(self, token_ids: Union[int, list[int]], *args, **kwargs) -> str: + if isinstance(token_ids, list): + token_ids = [_id for _id in token_ids if _id not in self.mem_token_ids] + return super()._decode(token_ids, *args, **kwargs) + + +def elu_p1(x): + return (F.elu(x, 1.0, False) + 1.0).to(x) + + +def sum_norm(x): + return (x / x.sum(-1, keepdim=True)).to(x) + + +class NemotronFlashRMSNorm(nn.Module): + def __init__(self, hidden_size: int, eps: float = 1e-5) -> "NemotronFlashRMSNorm": + super().__init__() + + self.hidden_size = hidden_size + self.eps = eps + self.weight = nn.Parameter(torch.empty(hidden_size)) + + self.reset_parameters() + + def reset_parameters(self): + nn.init.ones_(self.weight) + + def forward(self, x): + return torch.ops.auto_deploy.triton_rms_norm(x, self.weight, self.eps).to(x.dtype) + + +class RMSNormGated(nn.Module): + def __init__( + self, + hidden_size: int, + eps: float = 1e-5, + group_size: Optional[int] = None, + norm_before_gate: bool = False, + ) -> "RMSNormGated": + super().__init__() + + self.hidden_size = hidden_size + self.eps = eps + self.weight = nn.Parameter(torch.empty(hidden_size)) + self.group_size = group_size or hidden_size + self.norm_before_gate = norm_before_gate + self.reset_parameters() + + def reset_parameters(self): + nn.init.ones_(self.weight) + + def forward( + self, + x: torch.Tensor, + g: torch.Tensor, + ) -> torch.Tensor: + return torch.ops.auto_deploy.triton_rmsnorm_gated( + x, + weight=self.weight, + gate=g, + eps=self.eps, + group_size=self.group_size, + norm_before_gate=self.norm_before_gate, + ) + + +class CausalConv1d(nn.Conv1d): + def __init__(self, hidden_size: int, kernel_size: int, bias: bool = False): + super().__init__( + in_channels=hidden_size, + out_channels=hidden_size, + kernel_size=kernel_size, + padding=kernel_size - 1, + groups=hidden_size, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.auto_deploy.torch_causal_conv1d( + x, + self.weight, + self.bias, + self.stride[0], + self.padding[0], + self.dilation[0], + self.groups, + self.padding_mode, + ) + + +class DeltaNet(nn.Module): + def __init__( + self, + hidden_size: int = 1024, + expand_k: float = 1.0, + expand_v: float = 1.0, + num_heads: int = 4, + use_beta: bool = True, + use_gate: bool = False, + use_short_conv: bool = True, + conv_size: int = 4, + conv_bias: bool = False, + allow_neg_eigval: bool = False, + layer_idx: int = None, + qk_activation: str = "silu", + qk_norm: str = "l2", + norm_eps: float = 1e-5, + config=None, + **kwargs, + ) -> "DeltaNet": + super().__init__() + + self.qk_activation = qk_activation + self.qk_norm = qk_norm + + assert self.qk_activation in ["silu", "relu", "elu", "identity"] + assert self.qk_norm in ["l2", "sum"] + + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.num_heads = num_heads + self.use_beta = use_beta + self.use_gate = use_gate + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.conv_bias = conv_bias + self.allow_neg_eigval = allow_neg_eigval + + self.key_dim = int(hidden_size * expand_k) + self.value_dim = int(hidden_size * expand_v) + self.head_k_dim = self.key_dim // num_heads + self.head_v_dim = self.value_dim // num_heads + self.layer_idx = layer_idx + + self.silu = nn.SiLU() + + assert self.key_dim % num_heads == 0, ( + f"key dim must be divisible by num_heads of {num_heads}" + ) + assert self.value_dim % num_heads == 0, ( + f"value dim must be divisible by num_heads of {num_heads}" + ) + + self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + + if self.use_beta: + self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False) + if self.use_short_conv: + self.q_conv1d = CausalConv1d(self.key_dim, kernel_size=conv_size, bias=conv_bias) + self.k_conv1d = CausalConv1d(self.key_dim, kernel_size=conv_size, bias=conv_bias) + self.v_conv1d = CausalConv1d(self.value_dim, kernel_size=conv_size, bias=conv_bias) + + if self.use_gate: + self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + self.o_norm = RMSNormGated(self.head_v_dim, eps=norm_eps) + else: + self.o_norm = NemotronFlashRMSNorm(self.head_v_dim, eps=norm_eps) + + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + + self.apply(self._initialize_weights) + + def _initialize_weights(self, module: nn.Module): + if getattr(module, "_is_hf_initialized", False): + return + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight, gain=2**-2.5) + if module.bias is not None: + nn.init.zeros_(module.bias) + module._is_hf_initialized = True + + def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor: + batch_size, seq_len, _ = hidden_states.shape + dtype = hidden_states.dtype + + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + if self.use_short_conv: + q = self.q_conv1d(q) + k = self.k_conv1d(k) + v = self.v_conv1d(v) + + q = q.view(batch_size, seq_len, -1, self.head_k_dim) + k = k.view(batch_size, seq_len, -1, self.head_k_dim) + v = v.view(batch_size, seq_len, -1, self.head_v_dim) + + if self.qk_activation == "relu": + q, k = q.relu(), k.relu() + elif self.qk_activation == "elu": + q, k = elu_p1(q), elu_p1(k) + elif self.qk_activation == "identity": + pass + elif self.qk_activation == "silu": + q, k = self.silu(q), self.silu(k) + else: + raise NotImplementedError + + v = self.silu(v) + + if self.use_beta: + beta = self.b_proj(hidden_states) + beta = beta.sigmoid() + else: + beta = q.new_ones() + + if self.allow_neg_eigval: + beta = beta * 2.0 + + if self.qk_norm == "l2": + q = torch.ops.auto_deploy.fla_l2norm(q) + k = torch.ops.auto_deploy.fla_l2norm(k) + elif self.qk_norm == "sum": + q = sum_norm(q).to(q) + k = sum_norm(k).to(k) + else: + raise NotImplementedError(f"Not supported qk_norm `{self.qk_norm}`.") + + o = torch.ops.auto_deploy.fla_delta_rule(q, k, v, beta) + + if self.use_gate: + g = self.g_proj(hidden_states).view(batch_size, seq_len, -1, self.head_v_dim) + o = self.o_norm(o, g) + else: + o = self.o_norm(o) + o = o.view(batch_size, seq_len, -1) + o = self.o_proj(o) + + return o.to(dtype) + + +class NemotronFlashMamba2(nn.Module): + def __init__( + self, + config, + layer_idx: int, + num_groups=1, + rmsnorm=True, + dt_min=0.001, + dt_max=0.1, + dt_init_floor=1e-4, + dt_limit=(0.0, float("inf")), + bias=False, + conv_bias=True, + chunk_size=128, # NOTE: original 256 gives us IMA in _chunk_scan_fwd_kernel for mamba2 + ): + super().__init__() + + self.config = config + self.d_model = config.hidden_size + self.d_state = config.mamba_d_state + self.d_conv = config.mamba_d_conv + + self.expand = config.mamba_expand + self.d_inner = self.expand * self.d_model + self.headdim = config.mamba2_headdim + self.num_groups = num_groups + self.num_heads = self.d_inner // self.headdim + self.rmsnorm = rmsnorm + self.dt_limit = dt_limit + self.activation = "silu" + self.chunk_size = chunk_size + self.layer_idx = layer_idx + + assert self.d_inner % self.headdim == 0 + + # Order: [z, x, B, C, dt] + d_in_proj = 2 * self.d_inner + 2 * self.num_groups * self.d_state + self.num_heads + self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias) + + conv_dim = self.d_inner + 2 * self.num_groups * self.d_state + self.conv1d = CausalConv1d(conv_dim, kernel_size=self.d_conv, bias=conv_bias) + self.act = nn.SiLU() + + # Initialize log dt bias + dt = torch.exp( + torch.rand(self.num_heads) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) + ) + dt = torch.clamp(dt, min=dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + self.dt_bias = nn.Parameter(inv_dt) + self.dt_bias._no_weight_decay = True + + A = torch.arange(1, self.num_heads + 1) + self.A_log = nn.Parameter(torch.log(A)) + self.A_log._no_weight_decay = True + + # D "skip" parameter + self.D = nn.Parameter(torch.ones(self.num_heads)) + self.D._no_weight_decay = True + + if self.rmsnorm: + assert RMSNormGated is not None + self.norm = RMSNormGated( + self.d_inner, + eps=1e-5, + group_size=self.d_inner // num_groups, + ) + + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias) + + def forward(self, hidden_states, **kwargs): + batch_size, seq_len, _ = hidden_states.shape + dtype = hidden_states.dtype + + zxbcdt = self.in_proj(hidden_states) # (B, L, d_in_proj) or (B * L, d_in_proj) + + z, xBC, dt = torch.split( + zxbcdt, + [self.d_inner, self.d_inner + 2 * self.num_groups * self.d_state, self.num_heads], + dim=-1, + ) + + xBC = self.act(self.conv1d(xBC)) + + x, B, C = torch.split( + xBC, + [self.d_inner, self.num_groups * self.d_state, self.num_groups * self.d_state], + dim=-1, + ) + + y = torch.ops.auto_deploy.torch_ssm( + hidden_states=x.view(batch_size, seq_len, -1, self.headdim), + A=-torch.exp(self.A_log.float()), + B=B.view(batch_size, seq_len, -1, self.d_state), + C=C.view(batch_size, seq_len, -1, self.d_state), + D=self.D, + dt=dt, + dt_bias=self.dt_bias, + time_step_limit=self.dt_limit, + chunk_size=self.chunk_size, + ) + + y = y.view(batch_size, seq_len, -1) + + if self.rmsnorm: + y = self.norm(y, z) + + out = self.out_proj(y.to(dtype)) + + return out + + +class LlamaRotaryEmbedding(nn.Module): + def __init__(self, config, dim, base=10000, device=None, scaling_factor=1.0): + super().__init__() + self.scaling_factor = scaling_factor + self.dim = dim + self.base = base + self.config = config + + self.rope_type = config.rope_type + + self.factor = 2 + + max_position_embeddings = self.config.max_position_embeddings + + if config.rope_type is None or config.rope_type == "default": + inv_freq = 1.0 / ( + self.base + ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim) + ) + self.max_seq_len_cached = max_position_embeddings + + elif config.rope_type == "ntk": + assert self.config.orig_max_position_embeddings is not None + orig_max_position_embeddings = self.config.orig_max_position_embeddings + + base = base * ( + (self.factor * max_position_embeddings / orig_max_position_embeddings) + - (self.factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base + ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim) + ) + + self.max_seq_len_cached = orig_max_position_embeddings + else: + raise ValueError(f"Not support rope_type: {config.rope_type}") + + self.register_buffer("inv_freq", inv_freq, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = ( + self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + ) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = ( + device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + if q is not None: + q_embed = (q * cos) + (rotate_half(q) * sin) + + else: + q_embed = None + + if k is not None: + k_embed = (k * cos) + (rotate_half(k) * sin) + else: + k_embed = None + return q_embed, k_embed + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class NemotronFlashAttention(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + + self.hidden_size = ( + config.attn_hidden_size if config.attn_hidden_size > 0 else config.hidden_size + ) + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + + self.kq_head_dim = config.kq_head_dim if config.kq_head_dim > 0 else self.head_dim + self.v_head_dim = config.v_head_dim if config.v_head_dim > 0 else self.head_dim + + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.is_causal = True + self.attention_dropout = config.attention_dropout + + if ( + self.head_dim * self.num_heads + ) != self.hidden_size and self.kq_head_dim == self.head_dim: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear( + self.hidden_size, + self.num_heads * self.kq_head_dim, + bias=False, + ) + self.k_proj = nn.Linear( + self.hidden_size, + self.num_key_value_heads * self.kq_head_dim, + bias=False, + ) + self.v_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.v_head_dim, bias=False + ) + + self.o_proj = nn.Linear(self.num_heads * self.v_head_dim, self.hidden_size, bias=False) + + if self.config.kq_norm == "rms": + self.k_norm = NemotronFlashRMSNorm(self.kq_head_dim) + self.q_norm = NemotronFlashRMSNorm(self.kq_head_dim) + elif self.config.kq_norm == "none": + self.k_norm = None + self.q_norm = None + else: + raise NotImplementedError(f"Unknown kq_norm: {self.config.kq_norm}") + + if self.config.rope: + self._init_rope() + + def _init_rope(self): + self.rotary_emb = LlamaRotaryEmbedding( + config=self.config, + dim=self.kq_head_dim, + base=self.rope_theta, + device=torch.device("cuda"), + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + sliding_window: Optional[int] = None, + **kwargs, + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.kq_head_dim) + key_states = key_states.view(bsz, q_len, -1, self.kq_head_dim) + value_states = value_states.view(bsz, q_len, -1, self.v_head_dim) + + if self.q_norm is not None: + query_states = self.q_norm(query_states) + if self.k_norm is not None: + key_states = self.k_norm(key_states) + + if self.config.rope: + cos, sin = self.rotary_emb(hidden_states, position_ids) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, unsqueeze_dim=2 + ) + + attn_output = torch.ops.auto_deploy.torch_attention( + query_states, + key_states, + value_states, + is_causal=True, + dropout_p=self.attention_dropout, + sliding_window=sliding_window, + layout="bsnd", + ) + + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + return attn_output + + +class NemotronFlashMLP(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + self.config = config + self.act_fn_name = config.mlp_hidden_act + self.act_fn = ACT2FN[self.act_fn_name] + + if config.ffn_expand_ratio is not None: + self.ffn_dim = int(config.ffn_expand_ratio * config.hidden_size) // 128 * 128 + else: + self.ffn_dim = config.intermediate_size + + self.hidden_dim = config.hidden_size + + self.layer_idx = layer_idx + + if self.act_fn_name == "silu": + self.gate_proj = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.down_proj = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) + self.up_proj = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + + def forward(self, x): + if self.act_fn_name == "silu": + output = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + elif self.act_fn_name == "relu2": + output = self.down_proj(self.act_fn(self.up_proj(x))) + else: + raise NotImplementedError(f"No such hidden_act: {self.act_fn_name}") + + return output + + +class NemotronFlashAttentionDecoderLayer(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + + self.config = config + + self.layer_idx = layer_idx + + self.self_attn = NemotronFlashAttention(config, layer_idx) + + if self.config.intermediate_size > 0: + self.ffn = NemotronFlashMLP(config, layer_idx=layer_idx) + self.pre_ffn_layernorm = NemotronFlashRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + else: + self.ffn = None + self.pre_ffn_layernorm = None + + self.input_layernorm = NemotronFlashRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, hidden_states: torch.Tensor, position_ids: torch.LongTensor, **kwargs + ) -> torch.FloatTensor: + residual = hidden_states + + if self.input_layernorm is not None: + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.self_attn( + hidden_states=hidden_states, position_ids=position_ids, **kwargs + ) + + hidden_states = residual + hidden_states + + if self.ffn is not None: + residual = hidden_states + if self.pre_ffn_layernorm is not None: + hidden_states = self.pre_ffn_layernorm(hidden_states) + hidden_states = self.ffn(hidden_states) + + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + return outputs + + +class FFNDecoderLayer(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + + self.config = config + + self.layer_idx = layer_idx + + self.ffn = NemotronFlashMLP(config, layer_idx=layer_idx) + + self.pre_ffn_layernorm = NemotronFlashRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward(self, hidden_states: torch.Tensor, **kwargs) -> Tuple[torch.Tensor]: + residual = hidden_states + if self.pre_ffn_layernorm is not None: + hidden_states = self.pre_ffn_layernorm(hidden_states) + hidden_states = self.ffn(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + return outputs + + +class NemotronFlashMambaDecoderLayer(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + + self.config = config + self.layer_idx = layer_idx + + self.mamba = NemotronFlashMamba2(config=config, layer_idx=layer_idx) + + self.intermediate_size = config.intermediate_size + if self.intermediate_size > 0: + self.ffn = NemotronFlashMLP(config, layer_idx=layer_idx) + self.pre_ffn_layernorm = NemotronFlashRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + else: + self.ffn = None + self.pre_ffn_layernorm = None + + self.input_layernorm = NemotronFlashRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, hidden_states: torch.Tensor, position_ids: torch.LongTensor, **kwargs + ) -> torch.FloatTensor: + residual = hidden_states + + if self.input_layernorm is not None: + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.mamba(hidden_states=hidden_states) + hidden_states = residual + hidden_states + + if self.intermediate_size > 0: + residual = hidden_states + + if self.pre_ffn_layernorm is not None: + hidden_states = self.pre_ffn_layernorm(hidden_states) + + hidden_states = self.ffn(hidden_states) + + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + return outputs + + +class NemotronFlashHybridDecoderLayer(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + + self.config = config + + self.layer_idx = layer_idx + + if config.hybrid_decoder_layer == "mamba": + self.mamba = NemotronFlashMamba2(config=config, layer_idx=layer_idx) + if config.hybrid_decoder_layer == "deltanet": + ## this is to properly handle cache index + if config.layer_types is not None: + deltanet_idx = sum( + 1 for i in range(layer_idx) if config.layer_types[i] == "deltanet" + ) + else: + deltanet_idx = layer_idx + + self.gla = DeltaNet( + hidden_size=config.hidden_size, + num_heads=config.num_attention_heads, + layer_idx=deltanet_idx, + config=self.config, + ) + else: + raise ValueError(f"Not supported: {config.hybrid_decoder_layer}") + + self.config = config + + if self.config.intermediate_size > 0: + self.ffn = NemotronFlashMLP(config, layer_idx=layer_idx) + self.pre_ffn_layernorm = NemotronFlashRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + else: + self.ffn = None + self.pre_ffn_layernorm = None + + self.input_layernorm = NemotronFlashRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, hidden_states: torch.Tensor, position_ids: torch.LongTensor, **kwargs + ) -> torch.FloatTensor: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + if self.config.hybrid_decoder_layer == "mamba": + hybrid_op_hidden_states = self.mamba(hidden_states=hidden_states) + + else: + hybrid_op_hidden_states = self.gla(hidden_states=hidden_states) + + hidden_states = residual + hybrid_op_hidden_states + + if self.ffn is not None: + residual = hidden_states + hidden_states = self.pre_ffn_layernorm(hidden_states) + hidden_states = self.ffn(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + return outputs + + +class NemotronFlashPreTrainedModel(PreTrainedModel): + # config_class = NemotronFlashConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["NemotronFlashAttentionDecoderLayer", "NemotronFlashMambaDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class NemotronFlashModel(NemotronFlashPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`NemotronFlashDecoderLayer`] + + Args: + config: NemotronFlashConfig + """ + + def __init__(self, config): + super().__init__(config) + + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding( + config.vocab_size + config.num_memory_tokens, + config.hidden_size, + self.padding_idx, + ) + + if self.config.num_memory_tokens > 0: + # register an appropriate pre load hook to merge memory tokens into the embedding layer + self.register_load_state_dict_pre_hook(self._merge_memory_tokens) + + decoder_layers = [] + + layer_type = [] + for i in range(config.num_hidden_layers): + if config.layer_types[i] in ["deltanet"]: + layer_type.append("m") + config_new = copy.deepcopy(config) + config_new.hybrid_decoder_layer = "deltanet" + decoder_layer = NemotronFlashHybridDecoderLayer(config_new, layer_idx=i) + elif config.layer_types[i] in ["m", "m2"]: + layer_type.append("m") + decoder_layer = NemotronFlashMambaDecoderLayer(config, layer_idx=i) + elif config.layer_types[i] == "a": + layer_type.append("a") + decoder_layer = NemotronFlashAttentionDecoderLayer(config, layer_idx=i) + elif config.layer_types[i] == "f": + layer_type.append("a") + decoder_layer = FFNDecoderLayer(config, layer_idx=i) + else: + raise ValueError(f"Unsupported layer type {config.layer_types[i]}") + + decoder_layers.append(decoder_layer) + + config.layer_type = layer_type + + if config.sliding_window is not None: + self.sliding_window = config.sliding_window + self.global_attn_idx = config.global_attn_idx + else: + self.sliding_window = None + self.global_attn_idx = None + + self.layers = nn.ModuleList(decoder_layers) + + self.final_layernorm = NemotronFlashRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @staticmethod + def _merge_memory_tokens( + module: "NemotronFlashModel", + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + embed_suffix = "embed_tokens.weight" + embed_key = prefix + embed_suffix + memory_key = prefix + "memory_tokens" + + if embed_key not in state_dict or memory_key not in state_dict: + return + + embed_weight_from_state = state_dict[embed_key] + + # Already merged. Remove stale memory token weights if present. + if embed_weight_from_state.shape[0] == module.get_parameter(embed_suffix).shape[0]: + state_dict.pop(memory_key, None) + return + + memory_weights = state_dict.pop(memory_key).to(embed_weight_from_state.dtype) + state_dict[embed_key] = torch.cat([embed_weight_from_state, memory_weights], dim=0) + + def forward( + self, + input_ids: torch.LongTensor = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_hidden_states: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, MoeModelOutputWithPast]: + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if position_ids is None: + position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).view(1, -1) + + hidden_states = inputs_embeds + + all_hidden_states = [] + + for i, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states.append(hidden_states) + layer_outputs = decoder_layer( + hidden_states, + position_ids=position_ids, + sliding_window=self.sliding_window is not None and i not in self.global_attn_idx, + ) + + hidden_states = layer_outputs[0] + + if self.final_layernorm is not None: + hidden_states = self.final_layernorm(hidden_states) + + if output_hidden_states: + all_hidden_states.append(hidden_states) + + return (hidden_states, all_hidden_states) + + +class TruncatedLinear(nn.Linear): + def __init__(self, *args, out_truncated: int, **kwargs): + super().__init__(*args, **kwargs) + self.out_truncated = out_truncated + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.linear(input, self.weight[: self.out_truncated], self.bias) + + +class NemotronFlashForCausalLM(NemotronFlashPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config, **kwargs): + super().__init__(config) + self.config = config + self.model = NemotronFlashModel(config) + self.vocab_size = config.vocab_size + self.lm_head = TruncatedLinear( + config.hidden_size, + config.vocab_size + config.num_memory_tokens, + bias=False, + out_truncated=config.vocab_size, + ) + + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + return_dict: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + **kwargs, + ) -> Tuple[torch.Tensor]: + outputs = self.model( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_hidden_states=output_hidden_states, + **kwargs, + ) + + hidden_states = outputs[0] + + logits = self.lm_head(hidden_states) + + if return_dict: + return CausalLMOutput( + logits=logits, + hidden_states=outputs[1] if output_hidden_states else None, + ) + + return (logits, *outputs[1:]) + + +NemotronFlashForCausalLMFactory.register_custom_model_cls( + "NemotronFlashConfig", NemotronFlashForCausalLM +) +HF_CHAT_TEMPLATE_EXCEPTIONS.append("nemotron_flash") diff --git a/tensorrt_llm/_torch/auto_deploy/models/nemotron_flash.py b/tensorrt_llm/_torch/auto_deploy/models/nemotron_flash.py new file mode 100644 index 00000000000..a0a05c77d71 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/models/nemotron_flash.py @@ -0,0 +1,20 @@ +from .factory import ModelFactoryRegistry +from .hf import AutoModelForCausalLMFactory + + +@ModelFactoryRegistry.register("NemotronFlashForCausalLM") +class NemotronFlashForCausalLMFactory(AutoModelForCausalLMFactory): + # TODO: custom tokenizer initialization system + def init_tokenizer(self): + if self.tokenizer is None: + return None + + from .custom import NemotronFlashPreTrainedTokenizerFast + + model_config, _ = self._get_model_config() + return NemotronFlashPreTrainedTokenizerFast.from_pretrained( + self.tokenizer, + **self.tokenizer_kwargs, + num_memory_tokens=model_config.num_memory_tokens, + vocab_size_model=model_config.vocab_size, + ) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/ssm_cache.py b/tensorrt_llm/_torch/auto_deploy/transform/library/ssm_cache.py index bb30722e0af..a8e7907f725 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/ssm_cache.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/ssm_cache.py @@ -13,3 +13,8 @@ class SSMCacheTransform(InsertCachedAttention): @TransformRegistry.register("insert_cached_causal_conv") class InitializeCausalConvCache(InsertCachedAttention): """A transform to handle causal conv cache operations.""" + + +@TransformRegistry.register("insert_cached_delta_rule") +class InsertCachedDeltaRule(InsertCachedAttention): + """A transform to handle delta rule cache operations."""