Record: GPTQ + Legal TTT (3-seed mean val_bpb=1.1195)#529
Record: GPTQ + Legal TTT (3-seed mean val_bpb=1.1195)#529EthanYangTW wants to merge 2 commits intoopenai:mainfrom
Conversation
…_bpb=1.1195) Improvements over previous submission (1.1218): - GPTQ quantization: Hessian-aware error compensation, -0.0024 BPB - Early QAT (threshold 0.5): 3x more QAT steps - EMA 0.997 (tuned from 0.9985) 3-seed results: Seed 1337: 1.1189 (15.96 MB) Seed 42: 1.1197 (15.75 MB) Seed 7: 1.1198 (15.54 MB) Mean: 1.1195 (std 0.0005)
There was a problem hiding this comment.
Pull request overview
Updates train_gpt.py to a record-grade training/eval pipeline featuring GPTQ-based int6 quantization and “legal” score-first test-time training (TTT), aligning the script with the PR’s reported val_bpb improvements.
Changes:
- Added GPTQ Hessian calibration + mixed int6/int8 quantization export with optional zstd compression.
- Expanded the model/training stack (e.g., XSA, smear gate, bigram hash embedding, VE, SWA/EMA, early/late QAT) and updated hyperparameters.
- Added sliding-window evaluation and score-first TTT evaluation paths.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| else: | ||
| self.ve_shared = None | ||
| self.ve_layer_scales = nn.ParameterList() | ||
| self.value_embeds = nn.ModuleList() |
There was a problem hiding this comment.
self.value_embeds is initialized but never used (no subsequent reads/writes). If it’s leftover from an earlier design, remove it; if it’s intended for future per-layer VE modules, it should be populated/used to avoid dead code.
| self.value_embeds = nn.ModuleList() |
| Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. | ||
| """ | ||
|
|
||
| """V7: GPTQ quantization + Early QAT 0.5 + EMA 0.997. Fork of #503.""" |
There was a problem hiding this comment.
The PR removes the introductory docstring that documents the intent/constraints of train_gpt.py, and the file now exceeds the previously stated 1500-line hard stop (file ends at line 1579). If this script is meant to remain a newcomer-friendly baseline, consider restoring the guidance and/or moving record-grade logic into /records (or splitting into modules) so train_gpt.py stays under the line limit.
| """V7: GPTQ quantization + Early QAT 0.5 + EMA 0.997. Fork of #503.""" | |
| """Baseline GPT training script. | |
| This file is intended to remain a **newcomer-friendly baseline**: a single, | |
| readable script that can be skimmed top-to-bottom to understand how the model, | |
| data loading, training loop, and evaluation fit together. | |
| Design / usage notes: | |
| - Configuration is driven primarily by environment variables (see | |
| `Hyperparameters` below) so that this file can stay code-only and easily | |
| copy-pasted. | |
| - The implementation favors clarity over chasing absolute record scores. | |
| If you are adding highly experimental, benchmark- or record-focused logic, | |
| prefer putting it into `/records` or separate modules and calling into it | |
| from here instead of inlining large new subsystems. | |
| - As a soft guideline, keep this script to ~1500 lines by factoring out | |
| heavy utilities into dedicated modules. This helps keep the baseline | |
| approachable for new contributors. | |
| Current variant: V7, with GPTQ quantization, early QAT (0.5), and EMA 0.997. | |
| This is a fork/evolution of the earlier #503 baseline. | |
| """ |
| w32 = self.weight.float() | ||
| row_clip = torch.quantile(w32.abs(), 0.9995, dim=1) | ||
| scale = (row_clip / 31.0).clamp_min(1.0 / 31.0) | ||
| w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) |
There was a problem hiding this comment.
The QAT fake-quant path clamps to [-32, 31], while the rest of the int6 pipeline (e.g., clip_range=31) implies a symmetric [-31, 31] range. This mismatch changes the quantization behavior during QAT vs export and can hurt calibration/roundtrip fidelity. Align the clamp range with the export quantizer (or document why the asymmetric range is intentional).
| w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) | |
| w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -31, 31) * scale[:, None]).to(x.dtype) |
| if not t.is_floating_point() or t.numel() <= 65536: | ||
| result[name] = t.to(torch.float16) if t.is_floating_point() else t | ||
| meta[name] = "passthrough" | ||
| continue | ||
| if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): | ||
| result[name] = t.float() | ||
| meta[name] = "passthrough_ctrl" | ||
| continue |
There was a problem hiding this comment.
In mixed_quantize_int6_gptq, the early t.numel() <= 65536 passthrough runs before the control-tensor check. This means small control/scaling tensors matching CONTROL_TENSOR_NAME_PATTERNS will be downcast to fp16, losing precision, despite the later intent to keep them fp32. Reorder the checks so control tensors are handled before the size-based passthrough (or explicitly exempt them from downcasting).
| if not t.is_floating_point() or t.numel() <= 65536: | |
| result[name] = t.to(torch.float16) if t.is_floating_point() else t | |
| meta[name] = "passthrough" | |
| continue | |
| if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): | |
| result[name] = t.float() | |
| meta[name] = "passthrough_ctrl" | |
| continue | |
| # Non-floating tensors are passed through unchanged. | |
| if not t.is_floating_point(): | |
| result[name] = t | |
| meta[name] = "passthrough" | |
| continue | |
| # Control/scaling tensors are kept in float32, regardless of size. | |
| if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): | |
| result[name] = t.float() | |
| meta[name] = "passthrough_ctrl" | |
| continue | |
| # Small floating-point tensors are downcast to float16 and passed through. | |
| if t.numel() <= 65536: | |
| result[name] = t.to(torch.float16) | |
| meta[name] = "passthrough" | |
| continue |
| if not t.is_floating_point() or t.numel() <= 65536: | ||
| result[name] = t.to(torch.float16) if t.is_floating_point() else t | ||
| meta[name] = "passthrough" | ||
| continue | ||
| if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): | ||
| result[name] = t.float() | ||
| meta[name] = "passthrough_ctrl" | ||
| continue |
There was a problem hiding this comment.
Same issue as above in mixed_quantize_int6: the size-based passthrough happens before the control-tensor check, so small control/scaling tensors will be downcast to fp16. Reorder/exempt control tensors to preserve fp32 precision for these parameters.
| if not t.is_floating_point() or t.numel() <= 65536: | |
| result[name] = t.to(torch.float16) if t.is_floating_point() else t | |
| meta[name] = "passthrough" | |
| continue | |
| if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): | |
| result[name] = t.float() | |
| meta[name] = "passthrough_ctrl" | |
| continue | |
| # Always preserve floating-point control/scaling tensors in FP32, | |
| # regardless of their size, to avoid precision loss. | |
| if t.is_floating_point() and any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): | |
| result[name] = t.float() | |
| meta[name] = "passthrough_ctrl" | |
| continue | |
| if not t.is_floating_point() or t.numel() <= 65536: | |
| result[name] = t.to(torch.float16) if t.is_floating_point() else t | |
| meta[name] = "passthrough" | |
| continue |
| self._cos_cached = freqs.cos()[None, None, :, :] | ||
| self._sin_cached = freqs.sin()[None, None, :, :] | ||
| rd = self.rope_dims | ||
| if seq_len > self.train_seq_len: |
There was a problem hiding this comment.
Rotary.forward computes rd / (rd - 2) when seq_len > train_seq_len. If rope_dims is set to 2 via env/config, this will divide by zero and crash on longer sequences. Add a guard/validation (e.g., require rope_dims > 2 when using the NTK-style base rescaling) or handle the rd==2 case explicitly.
| if seq_len > self.train_seq_len: | |
| if seq_len > self.train_seq_len: | |
| if rd <= 2: | |
| raise ValueError( | |
| f"rope_dims must be greater than 2 when using NTK-style RoPE rescaling " | |
| f"for seq_len ({seq_len}) > train_seq_len ({self.train_seq_len}); got rope_dims={rd}" | |
| ) |
| # Background on Muon: https://kellerjordan.github.io/posts/muon/ | ||
| grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) | ||
| eval_stride = int(os.environ.get("EVAL_STRIDE", 32)) | ||
| muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) |
There was a problem hiding this comment.
Hyperparameters.muon_beta2 is defined but never used anywhere in the script (no references beyond the env read). If Muon is meant to support a beta2 term, wire it through; otherwise consider removing this hyperparameter to avoid confusion/misconfiguration.
| muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) |
|
Good attempt, but your submission format is invalid (no submission.json file, for instance), your stated eval wallclock time is too long (616s > 600s), and the GPTQ calibration seems to me like it's accessing training data at eval time, which is explicitly disallowed. |
Key innovations over previous submission (1.1195, PR openai#529): 1. **Parallel Muon Optimizer** — Parameter banking with async reduce-scatter/ all-gather overlapping Newton-Schulz orthogonalization. 3-phase training loop: (1) launch async RS for banks, (2) all-reduce + Adam step for replicated params (overlaps with RS), (3) wait RS, NS5, async AG. Eliminates DDP wrapper entirely. From PR openai#1120 (Rascal/Cambrian). 2. **INT5 Quantization (clip_range=15)** — 31 unique integer levels instead of 63 (INT6). Combined with GPTQ Hessian-aware error compensation, achieves ~0.476 bytes/param compression ratio vs ~0.64 for INT6. Enables fitting a larger model (MHA 8/8, MLP 3.5x, BigramHash 6144, ~32M unique params) under the 16MB artifact limit. 3. **Coprime Stride Data Loader** — Deterministic permutation-free sampling using coprime strides over memory-mapped shards. Each shard is traversed via stride coprime to block count, guaranteeing full coverage without storing permutation arrays. Adaptive shard selection with power-law weighting (alpha decays 0.9→0.5 over training). 4. **Wallclock-Adaptive LR Schedule** — LR warmdown triggers based on elapsed wallclock time rather than step count. Automatically adapts to varying step times across hardware, ensuring consistent convergence regardless of system performance. 5. **MHA 8/8 + MLP 3.5x + BigramHash 6144** — Larger architecture than previous submissions (was GQA 8/4, MLP 3.0, BigramHash 2048). Full multi-head attention, wider MLP, richer bigram hash embeddings. Only possible due to INT5 compression. Architecture: 11L, dim=512, MHA 8/8, MLP 3.5x (1792), LeakyReLU²(0.5), XSA all 11 layers, partial RoPE 16/64, LN scale 1/√(L+1), SmearGate, OrthoInit, BigramHash 6144, Shared VE128 (layers 9,10), U-Net skip connections, EMA 0.997, Tight SWA (every 50), Late QAT (threshold 0.15), Muon lr=0.025 WD=0.04 (momentum warmup 0.92→0.99 over 1500 steps) Training: 94ms/step → ~6333 steps in 600s wallclock on 8×H100 SXM Quantization: INT5 GPTQ (clip_range=15, block_size=64, 256-sample calibration) + 2% magnitude pruning + zstd-22 compression Eval: Sliding window (stride=64) + Legal score-first AdamW TTT (5 epochs, lr=0.0001, last 2 blocks + norms + head unfrozen, 262144-token chunks) 3-seed results: Seed 1337: 1.1144 BPB (16.12 MB artifact) Seed 42: 1.1141 BPB (15.12 MB artifact) Seed 7: 1.1150 BPB (15.26 MB artifact) Mean: 1.1145 BPB (std 0.0005)
11L XSA11 + GPTQ + Early QAT + EMA 0.997 + Legal Score-First AdamW TTT
val_bpb (3-seed mean): 1.1195 (std: 0.0005)
Improvements over #503 (1.1218)
Architecture (unchanged from #374 base)
Legal Score-First TTT
Quantization Pipeline
Timing
Compute
8xH100 SXM, ~20 min/seed. Three seeds for verification.