int5 GPTQ + 33.6M model: 1.1179 BPB (3-seed mean)#544
int5 GPTQ + 33.6M model: 1.1179 BPB (3-seed mean)#544EthanYangTW wants to merge 2 commits intoopenai:mainfrom
Conversation
33.6M params (MHA 8/8, BigramHash 8192, MLP 3.5x) quantized to int5 with GPTQ error compensation. Artifact fits under 16MB (15.3-15.5MB). Seeds: 1337 (1.1170), 42 (1.1182), 7 (1.1184)
Seed 1337: de843ef6 (TTT 1.1170) Seed 42: b6560b60 (TTT 1.1182) Seed 7: c1c18644 (TTT 1.1184)
There was a problem hiding this comment.
Pull request overview
Updates train_gpt.py to implement and export a new 33.6M-parameter GPT variant targeting int5-style (clip_range=15) GPTQ-assisted quantization under the 16MB artifact limit, with added sliding-window evaluation and score-first TTT evaluation.
Changes:
- Expands the model architecture (11 layers, BigramHash embedding, optional XSA, optional value embeddings, RoPE tweaks, smear gating, optional DTG/ln scaling).
- Adds new evaluation modes (separate eval sequence length, sliding-window BPB eval, and “legal score-first” TTT eval).
- Replaces the prior int8 export path with a mixed int8/int5-like (named “int6” in code) quantization pipeline including GPTQ calibration, optional pruning, and zstd/zlib compression.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. | ||
| """ | ||
|
|
||
| """V23: int5 GPTQ + 33.6M model (MHA 8/8, BigramHash 8192, MLP 3.5x).""" |
There was a problem hiding this comment.
train_gpt_mlx.py’s module docstring states a hard stop that both train_gpt.py and train_gpt_mlx.py should not exceed 1500 lines for newcomer readability, but train_gpt.py is now 1580 lines. Please move record-specific / experimental code (e.g., GPTQ/TTT helpers) into /records or otherwise reduce the core script length to stay within that stated limit.
| rd = self.rope_dims | ||
| if seq_len > self.train_seq_len: | ||
| scale = seq_len / self.train_seq_len | ||
| new_base = self.base * (scale ** (rd / (rd - 2))) | ||
| inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) | ||
| else: | ||
| inv_freq = self.inv_freq.to(device) | ||
| t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) |
There was a problem hiding this comment.
Rotary.forward() computes new_base = self.base * (scale ** (rd / (rd - 2))), which will divide by zero when rope_dims is 2 (and behaves poorly when rope_dims <= 2). Add validation that rope_dims is either 0 (meaning full head_dim) or an even value >= 4 to avoid runtime errors.
| if rope_dims > 0: | ||
| head_dim = model_dim // num_heads | ||
| for block in self.blocks: | ||
| block.attn.rope_dims = rope_dims | ||
| block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) | ||
| self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] |
There was a problem hiding this comment.
When rope_dims > 0, it’s applied without checking it’s even and <= head_dim. If rope_dims is odd or larger than head_dim, Rotary will build cos/sin tables with incompatible last-dimension sizes and apply_rotary_emb() can error at runtime. Consider validating 0 < rope_dims <= head_dim and rope_dims % 2 == 0 up front (e.g., in GPT.__init__).
| self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) | ||
| self.rotary = Rotary(self.head_dim, base=rope_base) | ||
|
|
||
| def forward(self, x: Tensor) -> Tensor: | ||
| self.rope_dims = 0 | ||
| self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) | ||
| self.use_xsa = False |
There was a problem hiding this comment.
CausalSelfAttention hard-codes train_seq_len=1024 when constructing Rotary, but the default Hyperparameters.train_seq_len is now 2048. This means RoPE will always take the “seq_len > train_seq_len” scaling branch during training/eval at 2048, which is easy to do unintentionally. If the scaling should be tied to the actual training context length, thread args.train_seq_len (or a dedicated hyperparameter) into Rotary construction.
| def bigram_hash(self, tokens: Tensor) -> Tensor: | ||
| t = tokens.to(torch.int32) | ||
| mod = self.bigram_vocab_size - 1 | ||
| out = torch.empty_like(t) | ||
| out[..., 0] = mod | ||
| out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod | ||
| return out.long() |
There was a problem hiding this comment.
BigramHashEmbedding.bigram_hash() sets mod = bigram_vocab_size - 1 and then does % mod, which will raise a division-by-zero error if BIGRAM_VOCAB_SIZE is 1 (and produces negative indices if it’s 0 but still instantiated). Add input validation in BigramHashEmbedding.__init__ (or where it’s constructed) to require bigram_vocab_size >= 2 when enabled.
| if CastedLinear._qat_enabled and self.training and w.ndim == 2: | ||
| with torch.no_grad(): | ||
| w32 = self.weight.float() | ||
| row_clip = torch.quantile(w32.abs(), 0.9995, dim=1) |
There was a problem hiding this comment.
The QAT path in CastedLinear.forward() computes torch.quantile(..., dim=1) for every 2D weight on every forward pass when enabled. This is extremely expensive (quantile involves sorting / heavy reductions) and is likely to blow the 10-minute training budget once late-QAT turns on. Consider precomputing per-row clip values periodically, using a cheaper statistic (e.g., amax / EMA), or restricting QAT to a small subset of layers.
| row_clip = torch.quantile(w32.abs(), 0.9995, dim=1) | |
| # Use a cheaper per-row amax-based EMA for clipping instead of per-forward quantile. | |
| if not hasattr(self, "_qat_row_clip_ema"): | |
| # Initialize EMA with current per-row maximum absolute values. | |
| self._qat_row_clip_ema = w32.abs().amax(dim=1) | |
| else: | |
| current = w32.abs().amax(dim=1) | |
| ema_decay = 0.9 | |
| self._qat_row_clip_ema.mul_(ema_decay).add_(current, alpha=1.0 - ema_decay) | |
| row_clip = self._qat_row_clip_ema |
|
|
||
| with torch.no_grad(): | ||
| for name, t in base_model.state_dict().items(): | ||
| ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) |
There was a problem hiding this comment.
ema_state is initialized and updated as GPU float32 copies of the full state_dict() each step. This can significantly increase GPU memory usage (potentially causing OOM) and adds per-step overhead. Consider keeping EMA on CPU (and updating from .detach().cpu()), or limiting EMA to a subset of parameters.
| ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) | |
| # Keep EMA state on CPU to avoid maintaining a full float32 copy on GPU. | |
| ema_t = ema_state[name] | |
| if ema_t.device != torch.device("cpu"): | |
| ema_t = ema_t.detach().cpu() | |
| ema_state[name] = ema_t | |
| t_cpu = t.detach().float().cpu() | |
| ema_t.mul_(ema_decay).add_(t_cpu, alpha=1.0 - ema_decay) |
| if not t.is_floating_point() or t.numel() <= 65536: | ||
| result[name] = t.to(torch.float16) if t.is_floating_point() else t | ||
| meta[name] = "passthrough" | ||
| continue | ||
| if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): | ||
| result[name] = t.float() | ||
| meta[name] = "passthrough_ctrl" | ||
| continue |
There was a problem hiding this comment.
In mixed_quantize_int6_gptq, the early branch if not t.is_floating_point() or t.numel() <= 65536 runs before the control-tensor check, so small control tensors (matching CONTROL_TENSOR_NAME_PATTERNS) will be downcast to fp16 and never get the intended passthrough_ctrl handling. If control tensors are meant to stay fp32 for stability/accuracy, check CONTROL_TENSOR_NAME_PATTERNS before the numel() <= 65536 passthrough/downcast logic.
| if not t.is_floating_point() or t.numel() <= 65536: | |
| result[name] = t.to(torch.float16) if t.is_floating_point() else t | |
| meta[name] = "passthrough" | |
| continue | |
| if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): | |
| result[name] = t.float() | |
| meta[name] = "passthrough_ctrl" | |
| continue | |
| # Non-floating tensors are always passed through as-is. | |
| if not t.is_floating_point(): | |
| result[name] = t | |
| meta[name] = "passthrough" | |
| continue | |
| # Control tensors are kept in float32 for stability/accuracy, regardless of size. | |
| if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): | |
| result[name] = t.float() | |
| meta[name] = "passthrough_ctrl" | |
| continue | |
| # Small floating-point tensors are passed through in float16. | |
| if t.numel() <= 65536: | |
| result[name] = t.to(torch.float16) | |
| meta[name] = "passthrough" | |
| continue |
| def quantize_int6_per_row(t: Tensor, clip_range: int = 15) -> tuple[Tensor, Tensor]: | ||
| t32 = t.float() | ||
| if t32.ndim == 2: | ||
| best_q, best_s, best_err = None, None, float('inf') | ||
| for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: | ||
| if pct < 1.0: | ||
| row_clip = torch.quantile(t32.abs(), pct, dim=1) | ||
| else: | ||
| row_clip = t32.abs().amax(dim=1) | ||
| s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) | ||
| q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) | ||
| recon = q.float() * s.float()[:, None] | ||
| err = (t32 - recon).pow(2).mean().item() | ||
| if err < best_err: | ||
| best_q, best_s, best_err = q, s, err | ||
| return best_q, best_s | ||
| amax = t32.abs().max().item() | ||
| scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) | ||
| q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) | ||
| return q, scale |
There was a problem hiding this comment.
Several helpers and artifacts are named int6 (e.g., quantize_int6_per_row, final_model.int6.ptz), but the PR description/docstring calls this “int5 GPTQ”. Since clip_range=15 yields 31 signed levels (an int5-like scheme), please rename the functions/files/metadata to match the actual quantization format to avoid confusion for readers and future tooling.
Summary
33.6M parameter model quantized to int5 with GPTQ error compensation, fitting under 16MB. First submission to achieve int5 quantization on a 33.6M model within the artifact size limit.
Architecture: 11L, 512d, MHA 8/8, MLP 3.5x (1792), BigramHash 8192, XSA all layers
Quantization: int5 per-row GPTQ (clip_range=15) + Early QAT (threshold 0.5) + EMA 0.997
TTT: Legal score-first AdamW, chunk=131072, last 2 blocks unfrozen
Results
Logs
Seed 1337
Seed 42
Seed 7
Reproduction