11L Int4 MLP QAT + BigramHash(10240) + SWA#314
Open
aravhawk wants to merge 1 commit intoopenai:mainfrom
Open
Conversation
Adds 11th transformer layer funded by int4 MLP quantization savings, with STE quantization aware training. Built on thwu1 SOTA.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
11 layer transformer using int4 quantization aware training (QAT) for MLP weights, building on @thwu1's SOTA (10L int5 MLP, 1.14276 bpb). Switching MLP weights from int5 to int4 with STE fake quantization saves ~2MB of compressed artifact space, funding an 11th layer within the 16MB budget.
Core changes from SOTA (7 targeted modifications):
num_layers10 to 11CastedLinear.forward()for MLP layers onlyblocks.9instead ofblocks.8)warmdown_iters3000 to 2500 (fewer total steps with deeper model)fake_quantize_per_rowfunction for STE QATAll SOTA innovations preserved: SmearGate, BigramHash(10240), orthogonal init, U-Net skip connections, SWA (start_frac=0.4), sliding window eval (stride=64), zstd 22 compression.
Architecture
Expected outcome
Conservative: 1.135 to 1.139 bpb (vs SOTA 1.1428). The extra layer adds ~0.004 to 0.008 bpb, while QAT minimizes int4 degradation. Pending 3 seed eval on 8xH100.
Notes
QAT pattern validated against existing submission in PR #162 (MLP3x QAT Int6 SlidingWindow), confirming
torch.compile(fullgraph=True)compatibility. Script is 1252 lines (under 1500 limit), syntax verified.