Skip to content

Add opt-in MoD routing, SquareGLU MLP, EMA warmdown distillation, and Grokfast#1

Merged
skoustav35 merged 1 commit intomainfrom
codex/review-and-optimize-project-for-golf-challenge
Apr 1, 2026
Merged

Add opt-in MoD routing, SquareGLU MLP, EMA warmdown distillation, and Grokfast#1
skoustav35 merged 1 commit intomainfrom
codex/review-and-optimize-project-for-golf-challenge

Conversation

@skoustav35
Copy link
Copy Markdown
Owner

Motivation

  • Provide opt-in, high-reward training / inference techniques (Mixture-of-Depth token routing, gated MLP / SquareGLU, EMA self-distillation during warmdown, and Grokfast gradient low-pass) to improve compute-efficiency and generalization within the existing N-gram backoff + legal TTT recipe.
  • Keep reported baseline behavior unchanged by making all new techniques disabled by default and configurable via environment flags so re-runs can opt-in without changing published numbers.

Description

  • Added new hyperparameters to Hyperparameters for toggling and configuring the features: MOD_*, SQUAREGLU_ENABLED, EMA_DISTILL_*, and GROKFAST_*, and documented them in the submission README and submission.json.
  • Implemented Mixture-of-Depth style token routing inside Block with per-block routers and a _make_mod_mask that produces top-k token masks used to selectively scale attention and MLP outputs while remaining legal/score-first during eval.
  • Implemented SquareGLU-style gated MLP as an opt-in path in MLP, added an mlp_gate_bank parameter bank, wired gate bank initialization, included it in bank/optimizer grouping, and passed gate slices into block forward calls.
  • Added EMA warmdown self-distillation by maintaining an EMA snapshot and periodically syncing it as a teacher during warmdown to compute a KL distillation term added to the training loss, and added a Grokfast gradient low-pass augmentation that maintains an EMA of gradients and injects a fraction back into current gradients before optimizer steps.
  • Updated training loop wiring: bank handling, TTT collect logic to include optional mlp_gate_bank, EMA teacher creation, distillation accumulation/logging, Grokfast gradient augmentation before optimizer phases, and small logging additions for the distillation metric.

Testing

  • Performed a syntax/compile check with python -m py_compile records/track_10min_16mb/2026-03-31_LeakyReLU2_LegalTTT_NGramCache_XSA/train_gpt.py which completed successfully and reported no syntax errors.
  • Updated README and submission.json to document the new opt-in flags and keep baseline-reported metrics unchanged; those documentation files were validated for consistency via simple file checks.

Codex Task

@skoustav35 skoustav35 merged commit 37fcb26 into main Apr 1, 2026
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: c577b1cf00

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +2102 to +2106
squareglu=args.squareglu_enabled,
mod_enabled=args.mod_enabled,
mod_attn_keep_ratio=args.mod_attn_keep_ratio,
mod_mlp_keep_ratio=args.mod_mlp_keep_ratio,
mod_min_keep_tokens=args.mod_min_keep_tokens,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Pass optional model flags into eval_model before strict load

These new constructor flags add extra parameters to base_model (for example mlp_gate_bank and blocks.*.mod_*_router.*), but the later eval_model = GPT(...) path still uses default squareglu=False/mod_enabled=False before load_state_dict(..., strict=True). In runs where any of these opt-in features are enabled, final int6 roundtrip evaluation will fail with missing/unexpected keys, so the training run cannot complete its export/eval pipeline.

Useful? React with 👍 / 👎.

Comment on lines +862 to +863
topv = torch.topk(scores, k=keep_k, dim=1).values[:, -1:]
mask = (scores >= topv).to(dtype=x.dtype).unsqueeze(-1)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Make MoD mask differentiable to let router learn

The routing mask is built from a hard threshold (scores >= topv), which is non-differentiable; as a result the router receives no gradient through the mask. With the current zero-initialized router weights/biases, scores start tied and this comparison keeps all tokens, so MOD_* routing is effectively inert even when keep_ratio < 1.

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant