feat: Qwen3.5 / Qwen3.5-MoE MTP speculative decoding#1330
Conversation
There was a problem hiding this comment.
Code Review
This pull request adds support for Qwen3.5 and Qwen3.5-Moe MTP models, integrating speculative decoding with CUDA graph support and linear attention. It introduces a spec-decode capable causal_conv1d update kernel, refactors CUDA graph capture and warmup logic to support both normal and MTP verify decode layouts, and adjusts memory management to handle dedicated draft KV slots and widened GPU conv buffers. Feedback on the changes suggests correcting a typo in the function name _nomarl_prefill_att to _normal_prefill_att in fp.py to improve code readability and maintainability.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| def _nomarl_prefill_att( | ||
| self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, att_control: AttControl, alloc_func=torch.empty | ||
| self, | ||
| q: torch.Tensor, | ||
| k: torch.Tensor, | ||
| v: torch.Tensor, | ||
| att_control: AttControl, | ||
| alloc_func=torch.empty, | ||
| ) -> torch.Tensor: |
There was a problem hiding this comment.
There appears to be a typo in the function name. _nomarl_prefill_att should likely be _normal_prefill_att. This should be corrected for better code readability and maintainability. You'll also need to update the call site for this function.
| def _nomarl_prefill_att( | |
| self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, att_control: AttControl, alloc_func=torch.empty | |
| self, | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| att_control: AttControl, | |
| alloc_func=torch.empty, | |
| ) -> torch.Tensor: | |
| def _normal_prefill_att( | |
| self, | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| att_control: AttControl, | |
| alloc_func=torch.empty, | |
| ) -> torch.Tensor: |
Make the linear-attention (GDN) cache able to serve a speculative verify pass over multiple draft tokens without corrupting the canonical per-request state: - conv-state shape splits into a widened GPU slot (holds the in-flight verify window) vs the narrow slot that is persisted/restored, while the SSM state keeps an (S+1) block so each draft position has a slot. - snapshot/restore helpers read the committed conv window + SSM block slot and reset the carried accept_len, so the next step reads from the canonical offset-0 / block-0 pointer. - relax the ReqManagerForMamba / CPU-cache MTP gates for hybrid models (draft KV is not persisted) and enforce the S<=7 bound. Covered by conv-state shape-split, snapshot-split, and mamba req-manager gate unit tests.
Add the Gated DeltaNet (qwen3next) verify forward used by MTP: - vendor a spec-decode causal_conv1d_update kernel (causal_conv1d_spec) so multiple draft positions can advance the conv state in one launch. - add the _gdn_verify kernel + MTP-verify dispatch branch, building the verify cu_seqlens, SSM index rows, conv indices and is_mtp_verify flag in infer_struct, and allocate non-colliding GPU draft full-attn slots. - run the hybrid MTP decode eagerly so the GDN verify path is honored. Unit tests assert the GDN verify state equals sequential T=1 decode, cover prefill conv indices, the spec conv kernel, and draft-slot layout.
Wire MTP into the base model decode path: - capture/replay decode CUDA graphs for the MTP verify step and thread b_num_accepted_tokens through ModelInput / InferStateInfo. - add the MTP-verify dispatch in basemodel and pass the per-position draft index into the FA3 attention backends (fp / fp8 / mla). Covered by the MTP decode CUDA-graph unit test.
Drive the draft/verify loop from the scheduler: - carry a canonical InferReq.mtp_accept_len pointer and persist the per-request accept_len across steps; build per-req b_num_accepted_tokens in decode_mtp and commit it in phase 2 so the next step reads a fresh count. - extend the chunked_prefill backend / base_backend with the MTP verify dispatch and the partial-accept read offset.
Add the MTP draft model packages and register them: - qwen3_5_mtp: a forced single full-attn-layer draft model, with the MTP pre-layer infer (embed/hidden norm + fc fusion) and pre/post + transformer-layer weight loaders reading the mtp.* namespace. - qwen3_5_moe_mtp: the MoE variant draft weight loaders + model. - register qwen3_5 / qwen3_5_moe MTP draft models with per-block draft_idx, plus the qwen3_5 verify infer_struct. Unit tests scaffold the MTP draft layer and the hybrid verify forward.
The write-only layer_infer._draft_kv_slot was never read anywhere; the KV-slot mapping is fully expressed via layer_num_ = draft_kv_slot * interval.
…commit - Revert local reformatting to match upstream/main exactly, minimizing PR diff - Inline _commit_mtp_accept_len into decode_mtp (phase-2 ordering preserved) - Drop redundant inline comments
- Dispatch to MTP bench whenever mtp_mode is set (was dead-coded to 'deepseekv3') - init_mtp_model: dispatch by config model_type (deepseek_v3/qwen3_moe/mistral/ glm4_moe_lite/qwen3_5/qwen3_5_moe), handle eagle (1 instance) vs vanilla (mtp_step instances); fix mem_faction typo; pass full att/kv/quant kvargs - run_forward_once: adapt to new ModelInput API (mem_indexes_cpu + CPU tensors, max_q/kv_seq_len, b_mtp_index, b_prefill_start_loc); reuse draft instances via _step % num_instances; pad/truncate draft_ids to mtp_step+1 - Cap max_req_num at 512 to avoid GDN req-state cache OOM under MTP
Summary
Adds Multi-Token Prediction (MTP) speculative decoding for Qwen3.5 and Qwen3.5-MoE (hybrid full-attention + Gated Delta Net /
qwen3nextlinear-attention models). The draft head proposesmtp_steptokens per step and the base model verifies them in a single fused forward, giving exact greedy output with single-stream/low-concurrency latency speedups.Implemented in 5 logical layers:
_gdn_verify_kernel+ hybrid dispatch and a vendored spec-decodecausal_conv1d_specTriton kernel.lcm(unit, mtp_step+1); capture-safe (no D2H syncs).accept_lenpointer plumbing; accept-count carry committed in phase 2 (pre-forward-release) to avoid a one-step-stale read under the overlap scheduler.qwen3_5_mtp/qwen3_5_moe_mtppackages (full-attn draft layer, mrope, inlinemtp.*weights), non-colliding draft KV slots, backend registration.Test Plan
pytest unit_tests/common unit_tests/models/qwen3_5 unit_tests/models/qwen3nextblackcleanNotes