Record: Int6 + MLP 3x + STE QAT + NorMuon + sliding window (val_bpb 1.1666)#137
Open
abhishekgahlot2 wants to merge 2 commits intoopenai:mainfrom
Open
Record: Int6 + MLP 3x + STE QAT + NorMuon + sliding window (val_bpb 1.1666)#137abhishekgahlot2 wants to merge 2 commits intoopenai:mainfrom
abhishekgahlot2 wants to merge 2 commits intoopenai:mainfrom
Conversation
4 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Int6 mixed quantization with STE fake-int6 QAT, 3x MLP expansion, NorMuon optimizer, SWA checkpoint averaging, and sliding window eval.
what changed
MLP 3x expansion (hidden=1536): 21.8M params. Extra capacity paid for by int6 quantization.
STE fake-int6 QAT: weights fake-quantized to int6 via straight-through estimator throughout training. Reduces quantization penalty from ~0.008 to ~0.001 BPB.
NorMuon optimizer: per-neuron row-wise RMS normalization after Newton-Schulz orthogonalization.
SWA checkpoint averaging: collects checkpoints every 200 steps during warmdown and averages them.
Mixed quantization: int6 per-row on MLP and attention weights, fp16 passthrough for tied embedding, zstd-22 compression.
Sliding window eval (stride=64): each token scored with nearly full context.
seq_len=2048, batch=786K, grad_clip=0.3, matrix_lr=0.02, Muon momentum=0.99, Muon WD=0.01, warmdown=3000 iters, logit softcap=15.
results
8xH100 80GB HBM3 (Modal, 10 min wallclock, seed 1337):
6,065 steps at 98.9ms/step. Quant loss: 0.001 BPB. Sliding window eval: 156s.
test plan
final_mixed_roundtrip_exact val_bpb:1.18774689final_sliding_window_exact val_bpb:1.16658140