Record: Fused MLP (Triton+CUTLASS EVT) + Fast Causal N-Gram Tilt & Subword Certainty (3-seed mean)#1105
Conversation
ba665dd to
64ce201
Compare
9b27cf4 to
c27131c
Compare
c27131c to
0df40cc
Compare
|
@abaybektursun - this is a fantastic write-up! Congrats on the SLOT improvement. If you need to free up even more room, you should check out the shrink.py script I used in PR 1089. I was able to shrink the train_gpt.py file by ~100KB. That might let you reduce pruning and/or promote one more group to int6. |
|
Ohhh I think with newer Pytroch performance and speed will be even better! I will try it when I can get my hands around 8xH100s |
| @@ -0,0 +1,110 @@ | |||
| # Record: Fused MLP (Triton+CUTLASS EVT) + MLP 3.5× + Mixed int5/int6 + Brotli | |||
…seed mean) Based on PR openai#1105 (abaybektursun) with improvements: - Window attention (size=512) on layers 2,4,6,8,10 via FA3 - Mixed seq_len training: 5 GPUs at 2048x36 + 3 GPUs at 6144x10 - Train-data GPTQ calibration (14s vs 220s AR self-gen) - Auto eval_seq_len detection from max train seq_len - Causal n-gram fix (within_hint/word_hint prefix-only) - Sliding window eval at seq_len=6144, stride=128 3-seed results (sliding window bpb): seed 1337: 1.1077 seed 42: 1.1083 seed 7: 1.1091 mean: 1.1084 (vs leader 1.1147) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Based on PR openai#1105 (abaybektursun) with improvements: - Window attention (size=512) on layers 2,4,6,8,10 via FA3 - Mixed seq_len training: 5 GPUs at 2048x36 + 3 GPUs at 6144x10 - Train-data GPTQ calibration (14s vs 220s AR self-gen) - Auto eval_seq_len detection from max train seq_len - Causal n-gram fix (within_hint/word_hint prefix-only) - Sliding window eval at seq_len=6144, stride=128 3-seed results (sliding window bpb): seed 1337: 1.1077 seed 42: 1.1083 seed 7: 1.1091 mean: 1.1084 (vs leader 1.1147) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
3b4bcf1 to
30c83c7
Compare
|
@abaybektursun - This is a great submission! If you haven't experimented with it yet, I'd recommend trying different negative slopes for the squared leaky ReLU activation (leaky_relu(x, slope).square()). I found ~0.3 was optimal on average with a similar architecture, but there's a depth-dependent pattern worth exploiting. I added a learned per-channel alpha parameter (initialized at 0.3) to each MLP layer. After three iterative runs — each warm-started from the prior run's endpoints — the values converged: layer 0 settled near 0 (essentially ReLU²), middle layers stayed around 0.3, and the deepest layers preferred 0.4–0.47. One gotcha: since the activation squares the output, alpha and -alpha produce identical results — only the magnitude matters. Once converged, I hardcoded the per-layer slopes so I wasn't wasting parameters on values that had already stabilized. Also, I wanted to share these results/learnings in case you are thinking about tokenizer experiments: |
|
@abaybektursun - since you have almost 500KB of headroom on the artifact you might be able to increase MLP to 3.625x which should help given your finding that the model was parameter starved in MLP. You can also get another ~100KB of headroom using this script to shrink your train_gpt.py before submission: |
…b 1.0970 (seed 42) SP4608 tokenizer, all-int6 (66 layers), QK_GAIN=5.0, fast causal n-gram tilt. Eval: 87.6s wall on 8×H100. 2 more seeds pending. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
30c83c7 to
17d5028
Compare
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
I am sleep deprived so things are bit messy rn, but data is accurate will clean up in the morning. |
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Results: val_bpb 1.0962 (3-seed mean) | 8×H100 SXM | 600s
sp4608 tokenizer, N_INT6=66 (all int6), QK_GAIN=5.0
Post-quant BPB and Submission BPB come from two separate evaluation runs of the same model, not two passes. The sliding-window eval measures the quantized neural model alone; the n-gram eval adds the causal tilt on top. Both are single-pass, single-run evaluations — we report both for ablation (isolating the n-gram contribution).
N-gram eval: 87.6s wall (50.1s loop). Tilted: 25.6% | Hits: 53.5%.
What does the improvement look like? Side-by-side generation (temp=0.8)
Prompt (50 tokens): "Insurance Company Declares Living Man Dead George Johannesen is very much alive. Which is why it was so surpr"
The old model drifts into incoherence ("Rachel Drobles... techniques of the car industry... Lyon Man is dead"). The new model stays on topic — insurance, health measurement, living man — and maintains grammatical coherence throughout. Both are wrong (the real text is about a cancelled driver's license), but the new model's errors are at least topically plausible.
Changes vs our PR 1019
1. Fused MLP Kernels: Triton TMA Forward + CUTLASS EVT Backward
Forward (Triton TMA): Fuses
F.linear(x, up_w) → LeakyReLU(0.5) → squareinto a single kernel. The 302MB intermediate never touches HBM.Backward (CUTLASS EVT): Fuses
(go @ down_w.T) * act_gradinto a single CUTLASS 3.x kernel via Epilogue Visitor Tree. The elementwise multiply runs in the GEMM epilogue while tiles are still in registers — eliminating one 302MB write + read per layer.Key design insight — pre-computed activation gradient: We store the activation gradient in the forward pass instead of the pre-activation:
The identity
post = 0.5 · act_grad · preholds for both signs because:This eliminates all branching from the backward, reducing the CUTLASS EVT epilogue to a trivial 3-node tree:
Sm90EVT<multiplies, AccFetch, AuxLoad>. No conditionals in the kernel.CUTLASS EVT is a hard dependency — no silent fallback. See Appendix A.3 and A.4 for detailed benchmarks.
2. Fast Causal N-Gram Tilt & Subword Certainty (~0.002 BPB, 4–8× faster than competition n-gram approaches)
Architecture Shift: Sparse Auxiliary Memory
This PR replaces the old eval-time n-gram mixing path with a fast, legal, single-pass causal n-gram tilt system. The core change is that the n-gram is no longer treated as a second language model. Instead, it acts as a sparse auxiliary memory that proposes a hinted token from the strict prefix, while the neural model remains the full normalized distribution. We then apply a one-token exponential tilt directly on the GPU.
Motivation & Interpretability
This work was guided by the interpretability results in our PR 1019 model autopsy and PR 1105 model autopsy. Those analyses showed that the model is not broadly weak at language modeling; it is specifically weak at exact copy/repetition. In particular, it has very limited induction capability, while much of the remaining loss is in categories like numbers, punctuation, and whitespace where generic short-order n-grams do not help much.
That changed the design target. Instead of building “better PPM everywhere,” we focused on the narrow places where n-grams are actually complementary:
The Key Insight: Mechanical Subword Certainty
Initially, within-word BPE completions seemed redundant since the neural baseline already assigns high probability to these tokens. However, the most significant BPB drop (~0.002, characterized on the MLP 3.0× model) was unlocked by aggressively lowering the
within_thresholdfrom 0.80 to 0.25, allowing the expert to fire on 35.7% of positions.Why it works: While the neural model knows subword patterns, it inherently hedges its bets by distributing probability mass across alternatives. The n-gram expert acts as a mechanical override, capturing the absolute certainty of BPE completions that the neural model refuses to assign a 1.0 probability to.
Measured eval time (8×H100): ~120s (setup 54s + loop 65.5s, mean across 3 seeds), 4.0× faster than PR 1145 (independently developed ctypes n-gram approach, 481s — discovered after our implementation was complete). See Appendix A.5 and A.6 for full engineering details and benchmarks.
3. Brotli-11 Compression (replaces LZMA-9)
−581 KB (−5.9%) vs LZMA-9. Independently discovered; PR 1089 (mikeapedia) also uses Brotli.
4. Memmap Multi-Shard Data Pipeline + GPU Prefetch
Coprime-stride sampling, daemon thread, CUDA stream prefetch. Credit: DeepReinforce (PR 726).
5. MLP 3.5× (1536 → 1792 hidden dim)
Motivated by mechanistic analysis: SVD analysis of our PR 1019 model showed MLP at 94.4% rank utilization (fully packed) while attention Q sat at 72.6% (spare capacity). The model was parameter-starved in MLP, not attention — so we made MLP wider.
Increases hidden dim from 3.0 × 512 = 1536 to 3.5 × 512 = 1792. Model goes from 27.07M to 29.95M params (+2.88M). At uniform int6, the 29.95M model compresses to 17.36 MB — 1.36 MB over the 16 MB limit. This is what makes mixed quantization (change 6) necessary.
Impact: −0.003 BPB from capacity, +13ms/step on 2×H100 (bigger GEMMs). Credit: PR 185 (dttdrv), PR 344 (aryanbhosale).
6. Mixed int5/int6 Quantization (Hessian-based)
Motivated by mechanistic analysis: Per-matrix quantization sensitivity showed MLP accounts for 80% of int6 quantization damage (MLP_down: +0.0039 BPB total, all Q matrices: +0.0003 BPB total — a 13× gap). Giving more bits to MLP is the optimal allocation.
Instead of uniform int6 for all layers, use int5 as default and promote the top 10 most sensitive layers to int6 based on Hessian trace ranking. Sensitivity = trace(H) where H = X^TX collected during GPTQ calibration. MLP projection layers in early blocks are most sensitive — they get int6; the remaining 56 layers get int5.
Uniform int5 loses ~0.019 BPB (catastrophic). Targeted Hessian-based allocation keeps quality loss under ~0.003 BPB while saving ~1.5 MB — exactly the headroom MLP 3.5× needs to fit under 16 MB. The wider MLP also made the model 3.6× less sensitive to quantization overall — information distributed across more dimensions means no single weight is load-bearing.
Credit: mixed quant concept PR 76 (Will DePue), gradient-guided PR 332 (saml212), Hessian-based PR 1089 (mikeapedia).
7. LR Floor (0.05)
During warmdown, learning rate normally decays to 0. With
lr_floor=0.05, it stops at 5% of peak instead. Prevents the optimizer from stalling, which helps with quantization-sensitive weight distributions still being refined at end of training.Impact: ~0.001 BPB. Credit: PR 130 (mohosy).
8. Vocab 4608
Inspired by PR 1218 (Clark), which established 4096 as effective. We measured β(V) — bytes per token — across intermediate vocab sizes.
+2.2% bytes per token, +0.01 nats per-token loss, net −0.020 BPB. Artifact cost +0.12 MB (+0.8%), absorbed by removing BigramHash and SmearGate (redundant at this vocab size).
Negative Results
Note: This model still has known inefficiencies — the sp4608 architecture has not been fully tuned (hyperparameters, layer count, MLP ratio, and quantization bit allocation were carried over from the sp1024 stack). We believe further BPB reductions are achievable.
Appendix
A.1 Prior Results
Prior results: sp1024, val_bpb 1.1052 (3-seed mean)
Mixed quantization: 10 layers int6, 56 layers int5, no pruning needed.
Prior results (val_bpb 1.1125, 3-seed)
SLOT study (removed from submission — causality violation)
SLOT (Selective Logit Offset Tuning) optimizes a 512-dim delta vector at the last hidden layer using AdamW (lr=0.003, 5 steps) per sliding-window batch. It gave −0.0037 BPB (1.1125 → 1.1088), but violates causality: the delta has shape
[1,1,512]and is optimized using targets at all positions, then applied to all positions — so position t's prediction is influenced by future tokens through the shared delta. Removed from submission code; results below are for reference only.Credit: PR 609 (saml212).
Prior results: fused kernels + Brotli only (val_bpb 1.1138, 3-seed)
Delta vs PR 549: −0.00943 nats. Welch's t = −10.26, df ≈ 3.78, p < 0.01.
A.2 Throughput Recovery
Our PR 1019 (now merged as SOTA) traded throughput for quality — full Hessian GPTQ and BigramHash 3072×112 added 3.3ms/step. Fused MLP kernels recover that regression. Mechanistic analysis of that model identified MLP as the capacity bottleneck, leading to MLP 3.5× (enabled by mixed quantization + Brotli headroom).
A.3 Kernel Benchmarks
Kernel benchmarks + incremental deltas (2×H100)
Per-layer kernel timing:
CUTLASS vs Triton: +0.032 ms/layer, +0.347 ms/step kernel-level.
End-to-end training (35 steps, seed=42):
Kernel-level 0.347ms translates to 0.43ms end-to-end (cache/scheduling interactions).
8×H100: 86.7ms (our PR 1019, unfused) → 83.5ms (this PR) = −3.2ms/step (−3.7%).
A.4 Step-Time Profile
Step-time profile — where all 313ms goes (2×H100, Nsight)
Why surgical fusion, not full-MLP autograd.Function: The 21.6% from torch.compile's cross-layer fusions (RMSNorm backward, residual adds, RoPE backward) only exists because these ops are visible to the compiler. Wrapping the full MLP backward in
autograd.Functionmakes it opaque to Inductor — all backward GEMMs plus cross-layer fusion run in eager mode, 2.7× slower net (identified in our PR 670). We fuse only forward and one backward GEMM+pointwise, preserving the compiler's scope.Top individual kernels:
Wall-clock breakdown: forward+backward compute ~94%, NCCL ~1.6%, CPU overhead ~4.1%.
A.5 N-Gram Engineering Details
Engineering Overhaul
Previous attempts at n-gram blending using flat tables and Python/NumPy logic were bottlenecked by severe hash collisions and massive FFI overhead. Initial runs with a logistic mixer yielded a catastrophic +0.210 BPB degradation because collision noise was inflating token probabilities.
By migrating to an open-addressing scheme (64M entries, 26-bit) to store exact keys, we eliminated false positives, pushing token PPM accuracy to 82.3%. To solve the execution bottleneck, we deployed a highly optimized pipeline:
fused_expert_blend.cpp,ngram_blend.cpp).A.6 N-Gram Benchmarks
The ~295× speedup vs naive Python was the enabling constraint: a brute-force per-token PPM over 62M tokens would take hours; our C++ open-addressing hash with batched nanobind calls runs in ~29s (n-gram lookup only), well within the 600s eval budget.
By trading brute-force multi-expert agreements for single-expert confidence scaling and targeted subword overrides, this architecture runs well within the 600s eval budget while matching the quality of approaches that take 4–8× longer.
A.7 Architecture
Calibration legality: AR self-generated (64 seqs × 2048 tokens, temp=0.8). No val data, no train data accessed during quantization. Same method as our PR 1019.
A.8 Setup & Reproduction