TL;DR
We wrote a CUDA flash attention kernel that reads raw TBQ4_0 blocks directly — no separate dequant pass, no intermediate F16 buffer. Combined with MTP speculative decoding, this achieves 82+ tok/s with lossless 4.25 bpv KV cache at 200K context on RTX 4090 24GB.
| Config | Context | KV Cache | Speed | Draft Accept |
|---|---|---|---|---|
| MTP + Fused TBQ4 FA | 200K | TBQ4_0 (4.25 bpv) | 82–87 tok/s | 73% |
| MTP + Q4_0 KV | 200K | Q4_0 (4.5 bpv) | 92–97 tok/s | 93.6% |
| MTP + Q4_0 KV | 135K | Q4_0 (4.5 bpv) | 97–103 tok/s | 93.6% |
| Baseline (dflash + turbo4, no MTP) | 300K | Turbo4 | 41 tok/s | — |
- Key innovation: Fused quantized-KV dequant inside the flash attention inner loop — nobody else has done this
- How it works: Attention runs entirely in FWHT-rotated domain; inner loop needs only a 2-value centroid lookup per byte
- Code: github.com/Indras-Mirror/llama.cpp-mtp — fully buildable fork
1. The Problem: Quantized KV and Flash Attention Don't Mix
Flash attention kernels read K and V tiles from global memory, compute QKT and softmax in shared memory, then accumulate VKQ. The standard approach when using quantized KV cache (Q4_0, TBQ4_0, etc.) is:
- Dequantize the entire KV cache tile to F16 in a separate pass
- Write the F16 values to an intermediate buffer
- Read the F16 buffer in the FA kernel
This works, but it's bandwidth-wasteful: you read the quantized data once, write F16, then read F16 again. For TBQ4_0 (4.25 bits per value), that's reading 4.25 bits then writing+reading 16 bits — the dequant pass creates 3.76× more memory traffic than the raw data.
The dflash fork (spiritbuun) noted this limitation explicitly in their code: "Turbo forces nstages=0: cp.async can't do ALU dequant, so tiles load synchronously." Their solution was to accept nstages=0 and dequant K/V to F16 before the FA kernel sees them.
We wanted to eliminate the intermediate buffer entirely.
2. The Insight: Rotated-Domain Attention
TBQ4_0 uses a Fast Walsh-Hadamard Transform (FWHT) to decorrelate KV cache vectors before 4-bit PolarQuant encoding. The standard dequant path is:
// Standard TBQ4 dequant (per 128-element block): 1. Centroid lookup: 4-bit index → Lloyd-Max centroid value 2. Scale by norm: centroid × corrected_norm 3. Sign multiply (s2 array) 4. Inverse FWHT butterfly (7 stages, O(n log n)) 5. Sign multiply (s1 array) // Result: original F32 vector
Steps 3–5 exist solely to undo the rotation applied during quantization. But here's the key insight:
The Hadamard transform is orthonormal. Dot products are preserved: dot(Hx, Hy) = dot(x, y). So if we rotate Q into the same domain as K, the attention scores QKT are identical — and we never need to un-rotate K or V at all.
This means the per-element dequant in the FA inner loop reduces to just:
// Fused TBQ4 dequant in FA kernel (per byte = 2 elements): const uint8_t byte = __ldg(&blk->qs[b]); const half lo = __float2half(d_tbq4_centroids[byte & 0xF] * norm); const half hi = __float2half(d_tbq4_centroids[byte >> 4] * norm); tile[...] = __halves2half2(lo, hi); // That's it. Two multiplies, two lookups, one store.
The 16 Lloyd-Max centroids live in __constant__ memory (cached, broadcast to all threads). The norm is read once per 128-element block. Each thread processes one byte (2 elements) with 2 FP multiplies and 2 float-to-half conversions.
3. The Full Pipeline
// Fused TBQ4 Flash Attention Pipeline Phase 1: Pre-rotate Q (separate kernel, runs once) k_tbq4_rotate_input: sign_multiply(Q, s1) → FWHT_forward(Q) → sign_multiply(Q, s2) // Q is now in the same rotated domain as stored K/V Phase 2: Fused FA kernel (the hot loop) for each K/V tile: // K tile: read raw TBQ4 bytes from GMEM for each 128-element block: norm = __ldg(&blk->d) // 1 read per block for each byte in block: centroid_lo = d_tbq4_centroids[byte & 0xF] * norm centroid_hi = d_tbq4_centroids[byte >> 4] * norm tile_K[...] = half2(lo, hi) QK = mma(Q, KT) // tensor core matmul softmax(QK) // V tile: same centroid lookup VKQ += mma(softmax(QK), V) // accumulate Phase 3: Post-rotate output (separate kernel, runs once) k_tbq4_rotate_output: sign_multiply(VKQ, s2) → FWHT_inverse(VKQ) → sign_multiply(VKQ, s1) // Output is back in the original domain
The FWHT (7-stage butterfly, 128 elements) runs exactly twice per attention head — once for Q rotation and once for output un-rotation. It never runs inside the KV iteration loop. For a 200K context with 24K tokens in the KV cache, that's 2 FWHT calls vs the standard path's ~375 calls (one per K tile + one per V tile).
4. Optimization Journey: 43 → 82 tok/s
The fused kernel didn't start fast. Here's the progression across 5 sessions:
| Version | Change | tok/s | Speedup |
|---|---|---|---|
| v1 (naive) | Basic fused loader, one-row-per-thread | 43 | baseline |
| v2 (column-group) | Threads process one column across rows | 70 | +63% |
| v3 (direct lookup) | 2-value centroid lookup instead of 16-value precompute | 82–87 | +91% |
| Dead end: nstages=2 pipeline (raw byte staging + collaborative dequant) — garbled output, reverted | |||
4.1 Column-Group Access Pattern
The naive loader assigned one thread per row: thread t loads all 128 elements of K row t. This means each thread accesses bytes from a single TBQ4 block — poor L2 cache utilization because adjacent threads touch non-contiguous memory.
The column-group pattern assigns each thread to process one column across all rows in the tile. Adjacent threads now access adjacent bytes from the same TBQ4 block, achieving coalesced memory access. Nearly doubled throughput.
4.2 Direct Centroid Lookup
The v1 loader precomputed all 16 possible centroid×norm products per block into a local array:
// BEFORE: 16 muls + 16 float-to-half per thread, then index half cn_h[16]; for (int c = 0; c < 16; c++) cn_h[c] = __float2half(d_tbq4_centroids[c] * norm); const uint8_t byte = __ldg(&blk->qs[b]); tile[...] = __halves2half2(cn_h[byte & 0xF], cn_h[byte >> 4]); // AFTER: 2 muls + 2 float-to-half, no array const uint8_t byte = __ldg(&blk->qs[b]); const half lo = __float2half(d_tbq4_centroids[byte & 0xF] * norm); const half hi = __float2half(d_tbq4_centroids[byte >> 4] * norm); tile[...] = __halves2half2(lo, hi);
Each byte contains 2 elements. The precompute approach wastes 14 of 16 multiply results. Direct lookup does exactly 2 multiplies per byte — a 7× reduction in FP math per element. The centroids are in __constant__ memory with hardware broadcast, so the random-access lookup is free.
4.3 The nstages=2 Dead End
We attempted a pipelined staging approach where raw TBQ4 bytes would be loaded via int copies (overlapping with compute), then dequanted from shared memory staging buffers. After fixing 15 bugs including misaligned addresses, ordering violations, and nvcc codegen failures, it compiled and ran — but produced garbled output. The synchronous nstages=0 approach was already fast enough, so we reverted.
The dflash fork independently confirmed nobody has solved nstages>0 for quantized KV flash attention: "Turbo forces nstages=0: cp.async can't do ALU dequant."
5. Tensor Sharing: Saving 682 MiB
When llama.cpp loads an MTP model, it reads the same GGUF twice — once for the trunk model and once for the MTP head. The MTP head independently allocates token_embd.weight (682 MiB at Q4_K) even though the trunk already has an identical copy on the GPU.
We added a link_shared_tensors() virtual method to llama_model that lets sibling models wire up shared tensors from the trunk after loading:
// include/llama.h LLAMA_API void llama_model_link_shared_tensors( struct llama_model * model, const struct llama_model * trunk); // The MTP head's tok_embd now points to the trunk's GPU allocation // 682 MiB saved, zero quality impact
We tried sharing output.weight too (saving another 995 MiB), but this caused 0% draft acceptance — the Q4_K token embedding and Q6_K output projection produce meaningfully different logits despite being the "same" weight matrix. The quantization error across the 5120 × 248320 matrix accumulates enough to break draft prediction.
6. Multi-Token Prediction (MTP)
Qwen3.6-27B was trained with 3 MTP steps baked into the architecture. Each forward pass predicts 4 tokens (1 main + 3 draft). The MTP heads are extra transformer layers (blk.64.*) that predict subsequent tokens directly from intermediate hidden states.
MTP achieves 73–93.6% draft acceptance depending on KV quantization and task type:
| KV Cache | Task | Draft Accept | Effective Speedup |
|---|---|---|---|
| Q4_0 | Coding (short gen) | 93.6% | ~2.8× |
| TBQ4_0 (fused) | Coding (short gen) | 73% | ~2.2× |
| Q4_0 | Creative (2000 tok) | 69.2% | ~2.1× |
The lower acceptance with TBQ4_0 vs Q4_0 (73% vs 93.6%) is expected — the FWHT rotation introduces a small quantization domain mismatch between the main model's attention and the MTP head's predictions. The absolute speed (82+ tok/s) is still excellent because the lossless compression saves enough VRAM to fit 200K context comfortably with ~4 GB of headroom.
7. TBQ4_0: How It Works
7.1 Quantization Pipeline (per 128-element block)
1. Compute L2 norm → normalize block to unit sphere 2. Sign multiply (s1 array, seed=42) 3. FWHT butterfly (7 stages, O(n log n), in-register) 4. Sign multiply (s2 array) 5. 4-bit PolarQuant via Lloyd-Max centroids for N(0, 1/√128) 6. Norm correction: corrected_norm = original_norm / reconstruction_norm 7. Pack 4-bit values → 64 bytes qs + 2 bytes norm = 66 bytes per 128 elements
7.2 Block Layout
struct block_tbq4_0 { // 66 bytes = 4.25 bits per value
ggml_half d; // corrected L2 norm (2 bytes)
uint8_t qs[64]; // packed 4-bit centroid indices (64 bytes)
}; // QK_TBQ4 = 128 elements per block
The 16 centroids are Lloyd-Max optimal for the standard normal distribution after FWHT rotation — N(0, 1/√128). They live in CUDA __constant__ memory, cached and broadcast to all threads in a warp.
7.3 Quality
TBQ4_0 at 4.25 bpv is near-lossless — better quality than Q4_0 at 4.5 bpv, and dramatically better than Q8_0 at 8 bpv in our benchmarks (see our Turbo4 KV Cache benchmark where it scored 100/100 on hardened agentic tests vs Q8_0's 91/100).
8. Build and Run
# Clone and build git clone https://github.com/Indras-Mirror/llama.cpp-mtp cd llama.cpp-mtp cmake -B build -DGGML_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES=89 cmake --build build -j$(nproc) --target llama-server # Run with fused TBQ4 FA + MTP (82+ tok/s at 200K) ./build/bin/llama-server \ -m your-qwen3.6-mtp.gguf \ --spec-type mtp --spec-draft-n-max 3 \ -ctk tbq4_0 -ctv tbq4_0 \ -c 200000 -ngl 99 \ --flash-attn on --mlock \ -t 8 -ub 32 -np 1 --no-warmup # Or with Q4_0 KV for max speed (92-97 tok/s, more VRAM) ./build/bin/llama-server \ -m your-qwen3.6-mtp.gguf \ --spec-type mtp --spec-draft-n-max 3 \ -ctk q4_0 -ctv q4_0 \ -c 200000 -ngl 99 \ --flash-attn on --mlock \ -ub 32 -np 1
8.1 Getting an MTP-capable GGUF
Standard GGUF conversion strips MTP layers. You need to graft them:
# Download MTP head GGUF (457 MB, only blk.64.* tensors) wget https://huggingface.co/havenoammo/Qwen3.6-27B-MTP-UD-GGUF/resolve/main/MTP-Q8_0.gguf # Graft onto your base GGUF uv venv .venv --seed && source .venv/bin/activate uv pip install gguf python convert.py base-model.gguf MTP-Q8_0.gguf output-mtp.gguf
Or use a pre-grafted GGUF like Qwen3.6-27B-Heretic-v2-MTP-Q4_K_M.
9. The Bug Journey: 15 Bugs Fixed
Getting fused quantized-KV flash attention working required fixing 15 bugs across 5 sessions. Some highlights:
| # | Bug | Fix |
|---|---|---|
| 3 | nvcc generates bad code from if constexpr dead branches in TBQ4 templates |
Replaced with #if !defined(TBQ4_KV_FUSED) preprocessor guards |
| 4 | Q rotation caused register spill (FWHT butterfly needs 128 registers) | Moved to separate kernel — eliminates pressure on the main FA kernel |
| 7 | V dequant happened AFTER K preload (reads garbage from shared staging buffer) | Reordered: V dequant before K[next] preload (shared buffer discipline) |
| 8 | cp_async requires 16-byte alignment; TBQ4 rows are 132 bytes |
Replaced cp_async with int loads (raw byte copy) |
| 14 | Last iteration of nstages=2 loop missing tbq4_staging parameter |
Added staging + type params to both ncols2 paths |
Full bug tracker in HANDOFF_TBQ4.md.
10. What's Next
- nstages=2 pipeline — the staging approach should work with correct synchronization. If we can overlap raw byte loading with compute, the fused kernel could reach Q4_0-equivalent speeds (90+ tok/s) while maintaining lossless quality.
- Upstream contribution — the
link_shared_tensors()API is backwards-compatible and solves a real 682 MiB duplication problem for any model with tied embeddings + sibling models. Candidate for PR. - TBQ3_0 fused path — same technique at 3.0625 bpv would push context even further (250K+ at full ngl).
- InnerQ (quantized Q) — quantizing Q in the FA kernel could reduce register pressure and improve occupancy.
For llama.cpp Maintainers
If you're reviewing this for potential upstream inclusion:
- Tensor sharing (
link_shared_tensors) — minimal, backwards-compatible API. Solves a real problem for any architecture with tied embeddings + MTP/speculative heads. - Fused TBQ4 FA — demonstrates that quantized-KV can be dequanted inside the attention kernel via rotated-domain computation. The technique generalizes to any FWHT-based quant type.
- 15 bugs documented — full technical handoff in
HANDOFF_TBQ4.mdcovers every pitfall (nvcc codegen, cp_async alignment, staging ordering).
11. Credits
- havenoammo — MTP graft tooling, first Qwen3.6-27B-MTP GGUF release
- spiritbuun — dflash fork with CUDA TurboQuant kernels (our FWHT kernels adapted from this)
- ggml-org/llama.cpp — PR #22673 (MTP), PR #21089 (CPU TBQ)
- HauhauCS — Uncensored Qwen3.6 with K_P quantization
- Radamanthys11 — MTP-Q8_0 GGUF extraction
- froggeric — Fixed chat templates for Qwen3.6 + MTP
Built on RTX 4090 24GB, Ubuntu 24.04, CUDA 12.x. All benchmarks with Qwen3.6-27B-Heretic-v2-MTP-Q4_K_M. Draft acceptance measured by llama.cpp's internal MTP statistics. Source code: github.com/Indras-Mirror/llama.cpp-mtp
