TL;DR
We wrote a CUDA flash attention kernel that reads raw TBQ4_0 blocks directly — no separate dequant pass, no intermediate F16 buffer. Combined with MTP speculative decoding, this achieves 80-87 tok/s with lossless 4.25 bpv KV cache at 262K context on RTX 4090 24GB.
| Config | Context | KV Cache | Speed | Draft Accept |
|---|---|---|---|---|
| MTP + Fused TBQ4 FA | 262K | TBQ4_0 (4.25 bpv) | 80–87 tok/s | 73–93% |
| MTP + Fused TBQ4 FA | 200K | TBQ4_0 (4.25 bpv) | 82–87 tok/s | 73% |
| MTP + Q4_0 KV | 200K | Q4_0 (4.5 bpv) | 92–97 tok/s | 93.6% |
| MTP + Q4_0 KV | 135K | Q4_0 (4.5 bpv) | 97–103 tok/s | 93.6% |
| MTP Draft 5 (TBQ4 FA) | 262K | TBQ4_0 (4.25 bpv) | 79.6 avg / 106 peak | 90.1% |
| Baseline (dflash + turbo4, no MTP) | 300K | Turbo4 | 41 tok/s | — |
- Key innovation: Fused quantized-KV dequant inside the flash attention inner loop — nobody else has done this
- How it works: Attention runs entirely in FWHT-rotated domain; inner loop needs only a 2-value centroid lookup per byte
- Code: github.com/Indras-Mirror/llama.cpp-mtp — fully buildable fork
1. The Problem: Quantized KV and Flash Attention Don't Mix
Flash attention kernels read K and V tiles from global memory, compute QKT and softmax in shared memory, then accumulate VKQ. The standard approach when using quantized KV cache (Q4_0, TBQ4_0, etc.) is:
- Dequantize the entire KV cache tile to F16 in a separate pass
- Write the F16 values to an intermediate buffer
- Read the F16 buffer in the FA kernel
This works, but it's bandwidth-wasteful: you read the quantized data once, write F16, then read F16 again. For TBQ4_0 (4.25 bits per value), that's reading 4.25 bits then writing+reading 16 bits — the dequant pass creates 3.76× more memory traffic than the raw data.
The dflash fork (spiritbuun) noted this limitation explicitly in their code: "Turbo forces nstages=0: cp.async can't do ALU dequant, so tiles load synchronously." Their solution was to accept nstages=0 and dequant K/V to F16 before the FA kernel sees them.
We wanted to eliminate the intermediate buffer entirely.
2. The Insight: Rotated-Domain Attention
TBQ4_0 uses a Fast Walsh-Hadamard Transform (FWHT) to decorrelate KV cache vectors before 4-bit PolarQuant encoding. The standard dequant path is:
// Standard TBQ4 dequant (per 128-element block): 1. Centroid lookup: 4-bit index → Lloyd-Max centroid value 2. Scale by norm: centroid × corrected_norm 3. Sign multiply (s2 array) 4. Inverse FWHT butterfly (7 stages, O(n log n)) 5. Sign multiply (s1 array) // Result: original F32 vector
Steps 3–5 exist solely to undo the rotation applied during quantization. But here's the key insight:
The Hadamard transform is orthonormal. Dot products are preserved: dot(Hx, Hy) = dot(x, y). So if we rotate Q into the same domain as K, the attention scores QKT are identical — and we never need to un-rotate K or V at all.
This means the per-element dequant in the FA inner loop reduces to just:
// Fused TBQ4 dequant in FA kernel (per byte = 2 elements): const uint8_t byte = __ldg(&blk->qs[b]); const half lo = __float2half(d_tbq4_centroids[byte & 0xF] * norm); const half hi = __float2half(d_tbq4_centroids[byte >> 4] * norm); tile[...] = __halves2half2(lo, hi); // That's it. Two multiplies, two lookups, one store.
The 16 Lloyd-Max centroids live in __constant__ memory (cached, broadcast to all threads). The norm is read once per 128-element block. Each thread processes one byte (2 elements) with 2 FP multiplies and 2 float-to-half conversions.
3. The Full Pipeline
// Fused TBQ4 Flash Attention Pipeline Phase 1: Pre-rotate Q (separate kernel, runs once) k_tbq4_rotate_input: sign_multiply(Q, s1) → FWHT_forward(Q) → sign_multiply(Q, s2) // Q is now in the same rotated domain as stored K/V Phase 2: Fused FA kernel (the hot loop) for each K/V tile: // K tile: read raw TBQ4 bytes from GMEM for each 128-element block: norm = __ldg(&blk->d) // 1 read per block for each byte in block: centroid_lo = d_tbq4_centroids[byte & 0xF] * norm centroid_hi = d_tbq4_centroids[byte >> 4] * norm tile_K[...] = half2(lo, hi) QK = mma(Q, KT) // tensor core matmul softmax(QK) // V tile: same centroid lookup VKQ += mma(softmax(QK), V) // accumulate Phase 3: Post-rotate output (separate kernel, runs once) k_tbq4_rotate_output: sign_multiply(VKQ, s2) → FWHT_inverse(VKQ) → sign_multiply(VKQ, s1) // Output is back in the original domain
The FWHT (7-stage butterfly, 128 elements) runs exactly twice per attention head — once for Q rotation and once for output un-rotation. It never runs inside the KV iteration loop. For a 262K context with 31K tokens in the KV cache, that's 2 FWHT calls vs the standard path's ~484 calls (one per K tile + one per V tile).
4. Optimization Journey: 43 → 82 tok/s
The fused kernel didn't start fast. Here's the progression across 5 sessions:
| Version | Change | tok/s | Speedup |
|---|---|---|---|
| v1 (naive) | Basic fused loader, one-row-per-thread | 43 | baseline |
| v2 (column-group) | Threads process one column across rows | 70 | +63% |
| v3 (direct lookup) | 2-value centroid lookup instead of 16-value precompute | 82–87 | +91% |
| Dead end: nstages=2 pipeline (raw byte staging + collaborative dequant) — garbled output, reverted | |||
4.1 Column-Group Access Pattern
The naive loader assigned one thread per row: thread t loads all 128 elements of K row t. This means each thread accesses bytes from a single TBQ4 block — poor L2 cache utilization because adjacent threads touch non-contiguous memory.
The column-group pattern assigns each thread to process one column across all rows in the tile. Adjacent threads now access adjacent bytes from the same TBQ4 block, achieving coalesced memory access. Nearly doubled throughput.
4.2 Direct Centroid Lookup
The v1 loader precomputed all 16 possible centroid×norm products per block into a local array:
// BEFORE: 16 muls + 16 float-to-half per thread, then index half cn_h[16]; for (int c = 0; c < 16; c++) cn_h[c] = __float2half(d_tbq4_centroids[c] * norm); const uint8_t byte = __ldg(&blk->qs[b]); tile[...] = __halves2half2(cn_h[byte & 0xF], cn_h[byte >> 4]); // AFTER: 2 muls + 2 float-to-half, no array const uint8_t byte = __ldg(&blk->qs[b]); const half lo = __float2half(d_tbq4_centroids[byte & 0xF] * norm); const half hi = __float2half(d_tbq4_centroids[byte >> 4] * norm); tile[...] = __halves2half2(lo, hi);
Each byte contains 2 elements. The precompute approach wastes 14 of 16 multiply results. Direct lookup does exactly 2 multiplies per byte — a 7× reduction in FP math per element. The centroids are in __constant__ memory with hardware broadcast, so the random-access lookup is free.
4.3 The nstages=2 Dead End
We attempted a pipelined staging approach where raw TBQ4 bytes would be loaded via int copies (overlapping with compute), then dequanted from shared memory staging buffers. After fixing 15 bugs including misaligned addresses, ordering violations, and nvcc codegen failures, it compiled and ran — but produced garbled output. The synchronous nstages=0 approach was already fast enough, so we reverted.
The dflash fork independently confirmed nobody has solved nstages>0 for quantized KV flash attention: "Turbo forces nstages=0: cp.async can't do ALU dequant."
5. Tensor Sharing: Saving 682 MiB
When llama.cpp loads an MTP model, it reads the same GGUF twice — once for the trunk model and once for the MTP head. The MTP head independently allocates token_embd.weight (682 MiB at Q4_K) even though the trunk already has an identical copy on the GPU.
We added a link_shared_tensors() virtual method to llama_model that lets sibling models wire up shared tensors from the trunk after loading:
// include/llama.h LLAMA_API void llama_model_link_shared_tensors( struct llama_model * model, const struct llama_model * trunk); // The MTP head's tok_embd now points to the trunk's GPU allocation // 682 MiB saved, zero quality impact
We tried sharing output.weight too (saving another 995 MiB), but this caused 0% draft acceptance — the Q4_K token embedding and Q6_K output projection produce meaningfully different logits despite being the "same" weight matrix. The quantization error across the 5120 × 248320 matrix accumulates enough to break draft prediction.
6. Multi-Token Prediction (MTP)
Qwen3.6-27B-Heretic-v2 (our base model) was trained with 3 MTP steps baked into the Qwen3.6 architecture. Each forward pass predicts 4 tokens (1 main + 3 draft). The MTP heads are extra transformer layers (blk.64.*) that predict subsequent tokens directly from intermediate hidden states.
MTP achieves 73–93.6% draft acceptance depending on KV quantization and task type:
| KV Cache | Task | Draft Accept | Effective Speedup |
|---|---|---|---|
| Q4_0 | Coding (short gen) | 93.6% | ~2.8× |
| TBQ4_0 (fused) | Coding (short gen) | 73% | ~2.2× |
| Q4_0 | Creative (2000 tok) | 69.2% | ~2.1× |
The lower acceptance with TBQ4_0 vs Q4_0 (73% vs 93.6%) is expected — the FWHT rotation introduces a small quantization domain mismatch between the main model's attention and the MTP head's predictions. The absolute speed (80-87 tok/s) is still excellent because the lossless compression saves enough VRAM to fit 262K context comfortably with ~4 GB of headroom.
7. TBQ4_0: How It Works
7.1 Quantization Pipeline (per 128-element block)
1. Compute L2 norm → normalize block to unit sphere 2. Sign multiply (s1 array, seed=42) 3. FWHT butterfly (7 stages, O(n log n), in-register) 4. Sign multiply (s2 array) 5. 4-bit PolarQuant via Lloyd-Max centroids for N(0, 1/√128) 6. Norm correction: corrected_norm = original_norm / reconstruction_norm 7. Pack 4-bit values → 64 bytes qs + 2 bytes norm = 66 bytes per 128 elements
7.2 Block Layout
struct block_tbq4_0 { // 66 bytes = 4.25 bits per value
ggml_half d; // corrected L2 norm (2 bytes)
uint8_t qs[64]; // packed 4-bit centroid indices (64 bytes)
}; // QK_TBQ4 = 128 elements per block
The 16 centroids are Lloyd-Max optimal for the standard normal distribution after FWHT rotation — N(0, 1/√128). They live in CUDA __constant__ memory, cached and broadcast to all threads in a warp.
7.3 Quality
TBQ4_0 at 4.25 bpv is near-lossless — better quality than Q4_0 at 4.5 bpv, and dramatically better than Q8_0 at 8 bpv in our benchmarks (see our Turbo4 KV Cache benchmark where it scored 100/100 on hardened agentic tests vs Q8_0's 91/100).
8. Build and Run
# Clone and build git clone https://github.com/Indras-Mirror/llama.cpp-mtp cd llama.cpp-mtp cmake -B build -DGGML_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES=89 cmake --build build -j$(nproc) --target llama-server # Run with fused TBQ4 FA + MTP (80-87 tok/s at 262K) ./build/bin/llama-server \ -m your-qwen3.6-mtp.gguf \ --spec-type mtp --spec-draft-n-max 3 \ -ctk tbq4_0 -ctv tbq4_0 \ -c 200000 -ngl 99 \ --flash-attn on --mlock \ -t 8 -ub 32 -np 1 --no-warmup # Or with Q4_0 KV for max speed (92-97 tok/s, more VRAM) ./build/bin/llama-server \ -m your-qwen3.6-mtp.gguf \ --spec-type mtp --spec-draft-n-max 3 \ -ctk q4_0 -ctv q4_0 \ -c 200000 -ngl 99 \ --flash-attn on --mlock \ -ub 32 -np 1
8.1 Getting an MTP-capable GGUF
Standard GGUF conversion strips MTP layers. You need to graft them:
# Download MTP head GGUF (457 MB, only blk.64.* tensors) wget https://huggingface.co/havenoammo/Qwen3.6-27B-MTP-UD-GGUF/resolve/main/MTP-Q8_0.gguf # Graft onto your base GGUF uv venv .venv --seed && source .venv/bin/activate uv pip install gguf python convert.py base-model.gguf MTP-Q8_0.gguf output-mtp.gguf
Or use the pre-built GGUF we use: llmfan46/Qwen3.6-27B-Heretic-v2-Native-MTP-Preserved Q4_K_M — llmfan46's Heretic-v2 uncensored fine-tune with all 15 native MTP heads preserved from Qwen3.6 training.
9. The Bug Journey: 15 Bugs Fixed
Getting fused quantized-KV flash attention working required fixing 15 bugs across 5 sessions. Some highlights:
| # | Bug | Fix |
|---|---|---|
| 3 | nvcc generates bad code from if constexpr dead branches in TBQ4 templates |
Replaced with #if !defined(TBQ4_KV_FUSED) preprocessor guards |
| 4 | Q rotation caused register spill (FWHT butterfly needs 128 registers) | Moved to separate kernel — eliminates pressure on the main FA kernel |
| 7 | V dequant happened AFTER K preload (reads garbage from shared staging buffer) | Reordered: V dequant before K[next] preload (shared buffer discipline) |
| 8 | cp_async requires 16-byte alignment; TBQ4 rows are 132 bytes |
Replaced cp_async with int loads (raw byte copy) |
| 14 | Last iteration of nstages=2 loop missing tbq4_staging parameter |
Added staging + type params to both ncols2 paths |
Full bug tracker in HANDOFF_TBQ4.md.
10. What's Next
- nstages=2 pipeline — the staging approach should work with correct synchronization. If we can overlap raw byte loading with compute, the fused kernel could reach Q4_0-equivalent speeds (90+ tok/s) while maintaining lossless quality.
- Upstream contribution — the
link_shared_tensors()API is backwards-compatible and solves a real 682 MiB duplication problem for any model with tied embeddings + sibling models. Candidate for PR. - TBQ3_0 fused path — same technique at 3.0625 bpv would push context even further (250K+ at full ngl).
- InnerQ (quantized Q) — quantizing Q in the FA kernel could reduce register pressure and improve occupancy.
For llama.cpp Maintainers
If you're reviewing this for potential upstream inclusion:
- Tensor sharing (
link_shared_tensors) — minimal, backwards-compatible API. Solves a real problem for any architecture with tied embeddings + MTP/speculative heads. - Fused TBQ4 FA — demonstrates that quantized-KV can be dequanted inside the attention kernel via rotated-domain computation. The technique generalizes to any FWHT-based quant type.
- 15 bugs documented — full technical handoff in
HANDOFF_TBQ4.mdcovers every pitfall (nvcc codegen, cp_async alignment, staging ordering).
11. FAQ & Known Limitations
Which model are you using exactly?
Qwen3.6-27B-uncensored-heretic-v2-Native-MTP-Preserved-Q4_K_M.gguf from llmfan46 on HuggingFace. This is llmfan46's Heretic-v2 uncensored fine-tune (built with Heretic v1.3.0 using MPOA — Magnitude-Preserving Orthogonal Ablation) on top of Qwen3.6-27B. All 15 MTP heads are natively preserved from the original Qwen3.6 training (no grafting needed). Achieves 94% fewer refusals (6/100 vs 92/100 original) with minimal quality impact (0.0021 KL divergence, 85.67% MMLU vs 86.65% original).
Why not Q4_0 KV instead of TBQ4_0?
Q4_0 KV is 4.5 bpv, slightly faster (92-97 tok/s vs 80-87), but uses more VRAM — at 200K we're at 23.96 GB, leaving almost no headroom. TBQ4_0 at 4.25 bpv saves ~1.5 GB at 262K, giving us the headroom to push context higher. The quality difference is marginal in our testing. If you're at shorter contexts (135K or less), Q4_0 KV is the better choice for raw speed.
What's the max context? Does quality degrade?
We've loaded 262K tokens successfully. At this size, TBQ4_0 uses ~20 GB VRAM with ~4 GB headroom. Benchmarked May 2026 at 52K context fill: 498 tok/s prefill, 84-95 tok/s short decode, 66 tok/s sustained (400 tokens). Quality remains coherent — factual recall, code generation, math reasoning, and creative writing all produce correct, non-degraded output at high context. No collapse or repetition issues observed. The random Hadamard rotations (merged upstream as --attn-rot) help maintain quality. For users reporting degradation with Q4 cache: asymmetric compression (Q8 keys + TQ4 values) preserves more quality at a minimal VRAM cost. See the TurboQuant paper for details.
Does this work with ROCm / AMD GPUs?
The fused TBQ4 flash attention kernel is CUDA-specific (uses warp shuffle, ldmatrix, cp_async). The TBQ4 quantize/dequant kernels and MTP support are also CUDA. Porting to ROCm/HIP would require adapting the warp intrinsics and memory instructions. The link_shared_tensors() API is backend-agnostic. If someone wants to port the kernels, I'm open to PRs.
What about vision/multimodal?
Vision + MTP currently crashes — this is an upstream MTP PR bug (reported 2026-05-06), not specific to our fork. The multimodal projection tensor layout conflicts with MTP's token prediction pipeline. As a workaround, disable MTP for vision tasks: remove --spec-type mtp and use standard decoding. Vision-only (no MTP) works fine.
Why draft-n-max 3 instead of 5?
| Metric | Draft 3 | Draft 5 |
|---|---|---|
| Avg decode | 80.6 tok/s | 79.6 tok/s |
| Min decode | 62.7 tok/s | 58.1 tok/s |
| Max decode | 98.5 tok/s | 106.2 tok/s |
| Draft acceptance | 92.6% | 90.1% |
Draft 5 occasionally hits higher peaks but the overhead from verifying 5 candidate tokens + lower per-token acceptance eats the average gain. Draft 3 is the sweet spot for this model.
Does tool calling work?
Yes. Tool calling and agentic use work correctly with MTP. We use this setup daily for coding (QuetzaCodetl), agentic benchmarks, and general tool use. No degradation observed vs non-MTP decoding for structured outputs.
What about smaller/larger models?
7B models: Crash with TBQ4 due to alignment issues (nb1=264, needs 16-byte alignment; 27B has nb1=528 which works). Fix would be padding in the KV cache allocator — deferred, not a priority for this fork.
MoE models (35B-A3B): Should work but less tested. The qwen35moe_mtp architecture is implemented. If you get vector::_M_range_check on load, verify your GGUF has nextn_predict_layers in the metadata and that the MTP head tensors were properly grafted. Use --verbose to check.
Will this work with exl3/exllamav2?
No — this is a llama.cpp CUDA fork. The fused flash attention kernel is tightly coupled to ggml's tensor layout and the llama.cpp graph builder. Porting to exllamav2 would require a complete kernel rewrite for its different memory model.
12. Credits
- havenoammo — MTP graft tooling, first Qwen3.6-27B-MTP GGUF release
- spiritbuun — dflash fork with CUDA TurboQuant kernels (our FWHT kernels adapted from this)
- ggml-org/llama.cpp — PR #22673 (MTP), PR #21089 (CPU TBQ)
- llmfan46 — Qwen3.6-27B-Heretic-v2 Native-MTP-Preserved GGUF (the model we use — 15 native MTP heads, MPOA uncensoring)
- HauhauCS — Original Qwen3.6-Heretic-v2 uncensored base model
- Radamanthys11 — MTP-Q8_0 GGUF extraction
- froggeric — Fixed chat templates for Qwen3.6 + MTP
Built on RTX 4090 24GB, Ubuntu 24.04, CUDA 12.x. All benchmarks with llmfan46/Qwen3.6-27B-uncensored-heretic-v2-Native-MTP-Preserved Q4_K_M (Heretic v1.3 MPOA uncensored, 15 native MTP heads). Max context: 262K at TBQ4_0, 135K at Q4_0 KV. Draft acceptance measured by llama.cpp's internal MTP statistics. Prefill: 498-614 t/s. Benchmark results from May 2026: 84-95 tok/s decode, 66 tok/s sustained at 400 tokens. Source code: github.com/Indras-Mirror/llama.cpp-mtp
