Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/features/auto_deploy/support_matrix.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ In addition, the following models have been officially validated using the defau
- nvidia/Llama-3_1-Nemotron-Ultra-253B-v1-FP8
- nvidia/Llama-3_3-Nemotron-Super-49B-v1
- nvidia/Mistral-NeMo-Minitron-8B-Base
- nvidia/Nemotron-Flash-3B-Instruct
- perplexity-ai/r1-1776-distill-llama-70b

</details>
Expand Down
1 change: 1 addition & 0 deletions examples/auto_deploy/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ benchmark_results.json
# ignore config files that users might put here for debugging
*.yaml
!nano_v3.yaml
!nemotron_flash.yaml
11 changes: 11 additions & 0 deletions examples/auto_deploy/nemotron_flash.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
compile_backend: torch-cudagraph
max_batch_size: 384
max_seq_len: 2097152
max_num_tokens: 8192
enable_chunked_prefill: true
model_factory: NemotronFlashForCausalLM
free_mem_ratio: 0.9
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 24, 32, 64,96, 128, 256, 320, 384]
kv_cache_config:
# disable kv_cache reuse since not supported for hybrid/ssm models
enable_block_reuse: false
3 changes: 3 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ transforms:
insert_cached_causal_conv:
stage: cache_init
backend: cuda_causal_conv
insert_cached_delta_rule:
stage: cache_init
backend: fla_delta
initialize_cache:
stage: cache_init
run_per_gm: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@ class CacheConfig(BaseModel):

dtype: Optional[torch.dtype] = Field(default=None, description="KV cache dtype.")
mamba_dtype: Optional[torch.dtype] = Field(default=None, description="Mamba cache dtype.")
delta_dtype: Optional[torch.dtype] = Field(
default=torch.float32, description="Delta cache dtype. Defaults to float32."
)

@field_validator("dtype", "mamba_dtype", mode="before")
@field_validator("dtype", "mamba_dtype", "delta_dtype", mode="before")
@classmethod
def _coerce_dtype(cls, value):
if value is None or isinstance(value, torch.dtype):
Expand Down
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Adapted from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/delta_rule/chunk.py

from typing import Optional

import torch

from tensorrt_llm._torch.modules.fla.chunk_delta_h import chunk_gated_delta_rule_fwd_h
from tensorrt_llm._torch.modules.fla.chunk_o import chunk_fwd_o

from .wy_fast import prepare_wy_repr_fwd


def chunk_delta_rule_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
):
# obtain WY representation. u is actually the new v.
w, u, A = prepare_wy_repr_fwd(
k=k,
v=v,
beta=beta,
cu_seqlens=cu_seqlens,
)
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
k=k,
w=w,
u=u,
g=None,
initial_state=initial_state,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
)

o = chunk_fwd_o(q=q, k=k, v=v_new, h=h, g=None, scale=scale, cu_seqlens=cu_seqlens)

return o, A, final_state
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Adapted from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/delta_rule/fused_recurrent.py
from typing import Optional, Tuple

import torch
import triton
import triton.language as tl


@triton.heuristics(
{
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
"STORE_FINAL_STATE": lambda args: args["ht"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.jit(do_not_specialize=["T"])
def fused_recurrent_delta_rule_fwd_kernel(
q,
k,
v,
u,
beta,
o,
h0,
ht,
cu_seqlens,
scale,
T,
B: tl.constexpr,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr,
IS_BETA_HEADWISE: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_v, i_k, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_n, i_h = i_nh // H, i_nh % H
if IS_VARLEN:
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int64),
tl.load(cu_seqlens + i_n + 1).to(tl.int64),
)
all = T
T = eos - bos
else:
bos, eos = i_n * T, i_n * T + T
all = B * T

p_q = q + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK)
p_k = k + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK)
p_v = v + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
p_u = u + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
if IS_BETA_HEADWISE:
p_beta = beta + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
else:
p_beta = beta + bos * H + i_h
p_o = o + ((i_k * all + bos) * H + i_h) * V + i_v * BV + tl.arange(0, BV)

mask_k = (i_k * BK + tl.arange(0, BK)) < K
mask_v = (i_v * BV + tl.arange(0, BV)) < V
mask_h = mask_k[None, :] & mask_v[:, None]

b_h = tl.zeros([BV, BK], dtype=tl.float32)
if USE_INITIAL_STATE:
p_h0 = (
h0
+ i_nh * K * V
+ (i_k * BK + tl.arange(0, BK)[None, :]) * V
+ (i_v * BV + tl.arange(0, BV)[:, None])
)
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)

for _ in range(0, T):
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
b_v_minus = tl.sum(b_h * b_k[None, :], axis=1)
b_v -= b_v_minus
if IS_BETA_HEADWISE:
b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
else:
b_beta = tl.load(p_beta).to(tl.float32)
tl.store(p_u, b_v.to(p_v.dtype.element_ty), mask=mask_v)
b_v *= b_beta
b_h += b_k[None, :] * b_v[:, None]
b_o = b_h * b_q[None, :]
b_o = tl.sum(b_o, axis=1)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)

p_q += H * K
p_k += H * K
p_o += H * V
p_v += H * V
p_u += H * V
p_beta += H * (V if IS_BETA_HEADWISE else 1)

if STORE_FINAL_STATE:
p_ht = (
ht
+ i_nh * K * V
+ (i_k * BK + tl.arange(0, BK)[None, :]) * V
+ (i_v * BV + tl.arange(0, BV)[:, None])
)
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)


def fused_recurrent_delta_rule_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
N = B if cu_seqlens is None else len(cu_seqlens) - 1
BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8)
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
assert NK == 1, "NK > 1 is not supported yet"
num_stages = 1
num_warps = 1

o = q.new_empty(NK, *v.shape)
if output_final_state:
final_state = q.new_empty(N, H, K, V, dtype=torch.float32)
else:
final_state = None

grid = (NV, NK, N * H)
u = torch.empty_like(v)
fused_recurrent_delta_rule_fwd_kernel[grid](
q,
k,
v,
u,
beta,
o,
initial_state,
final_state,
cu_seqlens,
scale,
T=T,
B=B,
H=H,
K=K,
V=V,
BK=BK,
BV=BV,
IS_BETA_HEADWISE=beta.ndim == v.ndim,
num_warps=num_warps,
num_stages=num_stages,
)
o = o.squeeze(0)
return o, u, final_state
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Adapted from https://github.com/fla-org/flash-linear-attention/blob/main/fla/utils.py
import inspect
import os

import triton

FLA_CACHE_RESULTS = os.getenv("FLA_CACHE_RESULTS", "1") == "1"


supports_autotune_cache = "cache_results" in inspect.signature(triton.autotune).parameters
autotune_cache_kwargs = {"cache_results": FLA_CACHE_RESULTS} if supports_autotune_cache else {}
Loading
Loading