Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions lightllm/models/qwen3next/triton_kernel/fla/ops/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
# ruff: noqa: E501
import torch
from einops import rearrange
import functools
import os
from lightllm.utils.log_utils import init_logger

from .chunk_delta_h import chunk_gated_delta_rule_fwd_h
from .chunk_o import chunk_fwd_o
Expand All @@ -19,6 +22,36 @@
from .utils import SUPPRESS_LEVEL, input_guard
from .wy_fast import recompute_w_u_fwd

logger = init_logger(__name__)


@functools.lru_cache(maxsize=1)
def _flashqla_chunk_gated_delta_rule():
if os.environ.get("LIGHTLLM_DISABLE_FLASHQLA", "0").lower() in ["1", "true", "yes"]:
return None
try:
import flash_qla
except ImportError:
return None
if not torch.cuda.is_available():
return None
if torch.cuda.get_device_capability() < (9, 0):
return None
tv = torch.__version__.split("+")[0].split(".")
if (int(tv[0]), int(tv[1])) < (2, 8):
return None
cv = torch.version.cuda
if cv is None:
return None
cv_parts = cv.split(".")
if (int(cv_parts[0]), int(cv_parts[1])) < (12, 8):
return None
Comment on lines +40 to +48
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The version requirements for PyTorch (2.8) and CUDA (12.8) appear to be typos, as these versions are either not yet released or do not exist (CUDA 12.8). This will cause the FlashQLA backend to be disabled on all current environments. Additionally, the parsing logic is fragile and may raise IndexError or ValueError depending on the version string format (e.g., if it contains non-numeric suffixes like rc1).

    try:
        tv = torch.__version__.split("+")[0].split(".")
        if len(tv) < 2 or (int(tv[0]), int(tv[1])) < (2, 4):
            return None
        cv = torch.version.cuda
        if cv is None:
            return None
        cv_parts = cv.split(".")
        if len(cv_parts) < 2 or (int(cv_parts[0]), int(cv_parts[1])) < (12, 1):
            return None
    except (ValueError, IndexError):
        return None

logger.info(
"qwen3next chunk_gated_delta_rule: using FlashQLA backend (flash_qla.chunk_gated_delta_rule); "
"set LIGHTLLM_DISABLE_FLASHQLA=1 to fall back to the FLA Triton kernels."
)
return flash_qla.chunk_gated_delta_rule


def chunk_gated_delta_rule_fwd(
q: torch.Tensor,
Expand Down Expand Up @@ -183,6 +216,22 @@ def chunk_gated_delta_rule(
cu_seqlens=cu_seqlens
)
"""
flashqla_fn = _flashqla_chunk_gated_delta_rule()
if flashqla_fn is not None and not head_first:
return flashqla_fn(
q=q.contiguous(),
k=k.contiguous(),
v=v.contiguous(),
g=g.contiguous(),
beta=beta.contiguous(),
scale=scale,
initial_state=initial_state.contiguous() if initial_state is not None else None,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
head_first=head_first,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
)
Comment on lines +221 to +233
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If scale is None, it is passed directly to flashqla_fn. The fallback Triton path explicitly calculates scale as k.shape[-1] ** -0.5. To ensure consistency and avoid potential issues if the flash_qla library does not handle None defaults, the scale should be explicitly provided.

Suggested change
return flashqla_fn(
q=q.contiguous(),
k=k.contiguous(),
v=v.contiguous(),
g=g.contiguous(),
beta=beta.contiguous(),
scale=scale,
initial_state=initial_state.contiguous() if initial_state is not None else None,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
head_first=head_first,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
)
return flashqla_fn(
q=q.contiguous(),
k=k.contiguous(),
v=v.contiguous(),
g=g.contiguous(),
beta=beta.contiguous(),
scale=scale if scale is not None else k.shape[-1] ** -0.5,
initial_state=initial_state.contiguous() if initial_state is not None else None,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
head_first=head_first,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
)


assert q.dtype == k.dtype == v.dtype
assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
assert len(beta.shape) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
Expand Down
Loading
Loading