Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/features/auto_deploy/support_matrix.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ In addition, the following models have been officially validated using the defau
- nvidia/Llama-3_1-Nemotron-Ultra-253B-v1-FP8
- nvidia/Llama-3_3-Nemotron-Super-49B-v1
- nvidia/Mistral-NeMo-Minitron-8B-Base
- nvidia/Nemotron-Flash-3B-Instruct
- perplexity-ai/r1-1776-distill-llama-70b

</details>
Expand Down
1 change: 1 addition & 0 deletions examples/auto_deploy/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ benchmark_results.json
# ignore config files that users might put here for debugging
*.yaml
!nano_v3.yaml
!nemotron_flash.yaml
11 changes: 11 additions & 0 deletions examples/auto_deploy/nemotron_flash.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
compile_backend: torch-cudagraph
max_batch_size: 384
max_seq_len: 2097152
max_num_tokens: 8192
enable_chunked_prefill: true
model_factory: NemotronFlashForCausalLM
free_mem_ratio: 0.9
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 24, 32, 64,96, 128, 256, 320, 384]
kv_cache_config:
# disable kv_cache reuse since not supported for hybrid/ssm models
enable_block_reuse: false
3 changes: 3 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ transforms:
insert_cached_causal_conv:
stage: cache_init
backend: cuda_causal_conv
insert_cached_delta_rule:
stage: cache_init
backend: fla_delta
initialize_cache:
stage: cache_init
run_per_gm: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@ class CacheConfig(BaseModel):

dtype: Optional[torch.dtype] = Field(default=None, description="KV cache dtype.")
mamba_dtype: Optional[torch.dtype] = Field(default=None, description="Mamba cache dtype.")
delta_dtype: Optional[torch.dtype] = Field(
default=torch.float32, description="Delta cache dtype. Defaults to float32."
)

@field_validator("dtype", "mamba_dtype", mode="before")
@field_validator("dtype", "mamba_dtype", "delta_dtype", mode="before")
@classmethod
def _coerce_dtype(cls, value):
if value is None or isinstance(value, torch.dtype):
Expand Down
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Adapted from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/delta_rule/chunk.py

from typing import Optional

import torch

from tensorrt_llm._torch.modules.fla.chunk_delta_h import chunk_gated_delta_rule_fwd_h
from tensorrt_llm._torch.modules.fla.chunk_o import chunk_fwd_o

from .wy_fast import prepare_wy_repr_fwd


def chunk_delta_rule_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
):
# obtain WY representation. u is actually the new v.
w, u, A = prepare_wy_repr_fwd(
k=k,
v=v,
beta=beta,
cu_seqlens=cu_seqlens,
)
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
k=k,
w=w,
u=u,
g=None,
initial_state=initial_state,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
)

o = chunk_fwd_o(q=q, k=k, v=v_new, h=h, g=None, scale=scale, cu_seqlens=cu_seqlens)

return o, A, final_state
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Adapted from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/delta_rule/fused_recurrent.py
from typing import Optional, Tuple

import torch
import triton
import triton.language as tl


@triton.heuristics(
{
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
"STORE_FINAL_STATE": lambda args: args["ht"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.jit(do_not_specialize=["T"])
def fused_recurrent_delta_rule_fwd_kernel(
q,
k,
v,
u,
beta,
o,
h0,
ht,
cu_seqlens,
scale,
T,
B: tl.constexpr,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr,
IS_BETA_HEADWISE: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_v, i_k, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_n, i_h = i_nh // H, i_nh % H
if IS_VARLEN:
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int64),
tl.load(cu_seqlens + i_n + 1).to(tl.int64),
)
all = T
T = eos - bos
else:
bos, eos = i_n * T, i_n * T + T
all = B * T

p_q = q + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK)
p_k = k + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK)
p_v = v + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
p_u = u + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
if IS_BETA_HEADWISE:
p_beta = beta + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
else:
p_beta = beta + bos * H + i_h
p_o = o + ((i_k * all + bos) * H + i_h) * V + i_v * BV + tl.arange(0, BV)

mask_k = (i_k * BK + tl.arange(0, BK)) < K
mask_v = (i_v * BV + tl.arange(0, BV)) < V
mask_h = mask_k[None, :] & mask_v[:, None]

b_h = tl.zeros([BV, BK], dtype=tl.float32)
if USE_INITIAL_STATE:
p_h0 = (
h0
+ i_nh * K * V
+ (i_k * BK + tl.arange(0, BK)[None, :]) * V
+ (i_v * BV + tl.arange(0, BV)[:, None])
)
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)

for _ in range(0, T):
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
b_v_minus = tl.sum(b_h * b_k[None, :], axis=1)
b_v -= b_v_minus
if IS_BETA_HEADWISE:
b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
else:
b_beta = tl.load(p_beta).to(tl.float32)
tl.store(p_u, b_v.to(p_v.dtype.element_ty), mask=mask_v)
b_v *= b_beta
b_h += b_k[None, :] * b_v[:, None]
b_o = b_h * b_q[None, :]
b_o = tl.sum(b_o, axis=1)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)

p_q += H * K
p_k += H * K
p_o += H * V
p_v += H * V
p_u += H * V
p_beta += H * (V if IS_BETA_HEADWISE else 1)

if STORE_FINAL_STATE:
p_ht = (
ht
+ i_nh * K * V
+ (i_k * BK + tl.arange(0, BK)[None, :]) * V
+ (i_v * BV + tl.arange(0, BV)[:, None])
)
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)


def fused_recurrent_delta_rule_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
N = B if cu_seqlens is None else len(cu_seqlens) - 1
BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8)
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
assert NK == 1, "NK > 1 is not supported yet"
num_stages = 1
num_warps = 1

o = q.new_empty(NK, *v.shape)
if output_final_state:
final_state = q.new_empty(N, H, K, V, dtype=torch.float32)
else:
final_state = None

grid = (NV, NK, N * H)
u = torch.empty_like(v)
fused_recurrent_delta_rule_fwd_kernel[grid](
q,
k,
v,
u,
beta,
o,
initial_state,
final_state,
cu_seqlens,
scale,
T=T,
B=B,
H=H,
K=K,
V=V,
BK=BK,
BV=BV,
IS_BETA_HEADWISE=beta.ndim == v.ndim,
num_warps=num_warps,
num_stages=num_stages,
)
o = o.squeeze(0)
return o, u, final_state
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Adapted from https://github.com/fla-org/flash-linear-attention/blob/main/fla/utils.py
import inspect
import os

import triton

FLA_CACHE_RESULTS = os.getenv("FLA_CACHE_RESULTS", "1") == "1"


supports_autotune_cache = "cache_results" in inspect.signature(triton.autotune).parameters
autotune_cache_kwargs = {"cache_results": FLA_CACHE_RESULTS} if supports_autotune_cache else {}
Loading