Skip to content

feat: add MoE SwiGLU FFN and Muon optimizer#20

Merged
AlmondGod merged 9 commits into
AlmondGod:mainfrom
eren23:feat/moe-muon
Apr 15, 2026
Merged

feat: add MoE SwiGLU FFN and Muon optimizer#20
AlmondGod merged 9 commits into
AlmondGod:mainfrom
eren23:feat/moe-muon

Conversation

@eren23

@eren23 eren23 commented Apr 12, 2026

Copy link
Copy Markdown
Contributor

Summary

Implements two TODO items from the README:

  • Mixture of Experts in SwiGLU FFN — New MoESwiGLUFFN class wraps N smaller SwiGLU experts with top-k routing and load-balancing auxiliary loss. Enabled via use_moe: true in config (dynamics model only). Default: 4 experts, top-2 routing, 0.01 aux loss coefficient.
  • Muon optimizer — Newton-Schulz orthogonalized updates for 2D weight matrices, with AdamW handling biases/norms/embeddings. Standalone implementation with no external dependencies. Enabled via optimizer: "muon" in config.

Both features are opt-in with zero regression to default behavior. All existing configs work without modification.

Inference Visualization

MoE + Muon dynamics on PONG (2000 steps, 4 experts top-2, autoregressive generation without actions):

View inference result (GT vs Predicted frames) on W&B — file: inference_results_gt_vs_pred_no_actions_20260413_091147.png

  • 12 frames generated autoregressively (2 context + 10 predicted)
  • MSE (GT vs Pred): 0.004465
  • All predicted frames maintain coherent paddle/ball positions

Also fixed two bugs found during inference testing:

  • run_inference.py: os.path.isfileos.path.exists to support directory-format checkpoints
  • utils.py: model_sd.get('model', {})model_sd for correct conditioning_dim inference from state dict keys

W&B Training Runs

Short runs (500 steps, RTX 4090, PicoDoom)

W&B Project Dashboard

Test Config Loss 0 → 250 W&B Run
Video Tokenizer (baseline) AdamW 0.318 → 0.038 00bov8uj
Video Tokenizer Muon 0.337 → 0.047 n529uy9t
Latent Actions Muon 1.262 → 0.149 ypkp3hxk
Dynamics MoE + Muon 7.246 → 2.581 tt48o2q1

Longer runs — AdamW vs Muon convergence comparison (5000 steps, RTX 4090, PicoDoom)

Full dynamics-stage comparison with shared VT/LA prerequisites (2000 steps each):

Run Config Loss @ 0 Loss @ 4000 Δ vs Baseline
Dynamics AdamW (baseline) 7.17 2.198
Dynamics Muon 7.12 2.147 -2.3%
Dynamics AdamW + MoE (4 experts, top-2) 7.18 2.141 -2.6%
Dynamics Muon + MoE 7.12 ~2.1* best

*Run 4 completed ~4990/5000 steps before spot preemption; W&B has near-complete data.

Key findings:

  • Both Muon and MoE independently improve over the AdamW baseline
  • MoE gives slightly larger gains than Muon alone at 5000 steps
  • Per-expert token utilization logged to W&B (moe/block{i}_expert{j}_frac) — router distributes tokens across all 4 experts with no expert collapse

All runs visible at: W&B Project Dashboard

Changes

New files

  • models/muon.py — Standalone Muon optimizer (Newton-Schulz orthogonalization + momentum). DDP-compatible; FSDP limitation documented.
  • utils/optimizer_utils.py — Shared create_optimizer() for all training stages. Handles AdamW (single optimizer) and Muon (split: Muon for 2D weights, AdamW for biases/norms/embeddings).

Modified files

  • models/st_transformer.pySwiGLUExpert, MoESwiGLUFFN classes with top-k routing + load-balancing aux loss. Per-expert utilization tracking. STTransformerBlock and STTransformer accept MoE kwargs.
  • models/dynamics.py — Passes MoE config through to STTransformer.
  • utils/config.py — MoE fields (use_moe, num_experts, top_k_experts, moe_aux_loss_coeff) and optimizer fields (optimizer, muon_momentum, muon_backend_steps) added to all config dataclasses.
  • configs/training.yaml — Defaults for new config fields.
  • scripts/train_dynamics.py — Multi-optimizer loop, MoE aux loss integration, per-expert utilization W&B logging, multi-optimizer checkpoint save.
  • scripts/train_video_tokenizer.py — Multi-optimizer support.
  • scripts/train_latent_actions.py — Multi-optimizer support.
  • scripts/run_inference.py — Fix directory-format checkpoint detection.
  • utils/utils.py — Fix conditioning_dim inference from state dict, add MoE config to checkpoint loader.

Usage

# Enable MoE (dynamics only)
use_moe: true
num_experts: 4
top_k_experts: 2
moe_aux_loss_coeff: 0.01

# Enable Muon (all stages)
optimizer: "muon"
muon_momentum: 0.95
muon_backend_steps: 5

Known limitations

  • Muon only supports DDP, not FSDP (Newton-Schulz requires full weight matrices, not shards)
  • Muon LR/momentum may need per-architecture tuning vs AdamW defaults
  • MoE checkpoint only saves primary optimizer via save_training_state; secondary optimizer state saved in separate all_optimizers.pt file

Test plan

  • Unit tests: MoE gradient flow, device safety, param coverage, eval mode
  • Code review: bf16 grad corruption fix, device-safe aux loss, checkpoint completeness
  • E2E training: 4 runs on RTX 4090 with W&B logging (see table above)
  • Inference visualization: MoE+Muon on PONG (2000 steps, autoregressive)
  • Longer training runs to compare AdamW vs Muon convergence curves (5000 steps, 4 configs)
  • MoE expert utilization analysis (per-expert token fractions logged to W&B)

🤖 Generated with Claude Code

eren23 and others added 2 commits April 12, 2026 13:21
Implements two TODO items from the README:

1. **Mixture of Experts in SwiGLU FFN** — New `MoESwiGLUFFN` class wraps
   N smaller SwiGLU experts with top-k routing and load-balancing auxiliary
   loss. Enabled via `use_moe: true` in config (dynamics model only).
   Default: 4 experts, top-2 routing, 0.01 aux loss coefficient.

2. **Muon optimizer** — Newton-Schulz orthogonalized updates for 2D weight
   matrices, with AdamW handling biases/norms/embeddings. Standalone
   implementation with no external dependencies. Enabled via
   `optimizer: "muon"` in config.

Both features are opt-in with zero regression to default behavior.
All existing configs work without modification.

New files:
- models/muon.py — Standalone Muon optimizer
- utils/optimizer_utils.py — Shared optimizer creation for all stages

Modified:
- models/st_transformer.py — MoESwiGLUFFN, SwiGLUExpert classes
- models/dynamics.py — MoE kwargs passthrough
- utils/config.py — MoE + optimizer config fields
- configs/training.yaml — Default config with new options
- scripts/train_*.py — Multi-optimizer support + MoE aux loss

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…, checkpoint save

- Muon: clone gradient before bf16 cast to prevent in-place corruption
  when param dtype is already bfloat16 (AMP/FSDP mixed precision)
- MoE aux_loss: use model device for zero tensor instead of CPU default
- Dynamics training: cache aux_loss scalar before backward for accurate
  W&B logging (avoids double-call and grad_accum inflation)
- Checkpoint: save all optimizer/scheduler states when using split
  optimizers (Muon+AdamW), not just the primary
- Document FSDP limitation in Muon docstring (DDP only)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@AlmondGod

Copy link
Copy Markdown
Owner

this looks amazing! excited for it to finish.

Updated PR guidelines, main thing is can you visualize inference post-change?

also would be great to clean up the AI triple quote function docs, usually not too helpful and add noise

eren23 added 2 commits April 13, 2026 10:10
Strip triple-quote docstrings that add noise per reviewer feedback.
Add MoE config fields (use_moe, num_experts, top_k_experts,
moe_aux_loss_coeff) to load_dynamics_from_checkpoint so MoE-trained
checkpoints restore correctly.
500 steps per stage, compile off, single GPU.
Used for generating PR inference visuals.
@eren23

eren23 commented Apr 13, 2026

Copy link
Copy Markdown
Contributor Author

this looks amazing! excited for it to finish.

Updated PR guidelines, main thing is can you visualize inference post-change?

also would be great to clean up the AI triple quote function docs, usually not too helpful and add noise

sounds great! im on it rn :)

eren23 added 4 commits April 13, 2026 10:28
- run_inference.py: use os.path.exists instead of os.path.isfile so
  directory-format checkpoints (model_state_dict.pt + state.pt) are
  recognized
- utils.py: search model_sd directly (not model_sd['model']) when
  inferring conditioning_dim — state dict keys have no 'model.' prefix
Ground truth vs predicted frames from 500-step smoke training
with MoE SwiGLU FFN (4 experts, top-2) + Muon optimizer.
MSE: 0.002612 over 12 frames.
@eren23

eren23 commented Apr 13, 2026

Copy link
Copy Markdown
Contributor Author

@AlmondGod removed docstrings and added post change training inference visuals to wandb and shared a link in the PR body, lmk if anything else needs to be changed :)

i also can squeeze all commits into one if you want btw

@AlmondGod

Copy link
Copy Markdown
Owner

sounds great! looks like all thats left is to complete your checklist (run the longer runs and visualize inference?)

@AlmondGod

Copy link
Copy Markdown
Owner

i think we can just do a squash and merge when merging PR!

Stores fraction_dispatched per MoE block during forward pass and logs
per-expert token fractions (moe/block{i}_expert{j}_frac) to W&B each step.
Enables router weight distribution analysis for the PR checklist.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@eren23

eren23 commented Apr 14, 2026

Copy link
Copy Markdown
Contributor Author

Longer training results + expert utilization ✅

Ran 4 dynamics configurations at 5000 steps on RTX 4090 (PicoDoom dataset), sharing the same VT/LA prerequisites:

Config Loss @ 4000 Δ vs Baseline
AdamW (baseline) 2.198
Muon 2.147 -2.3%
AdamW + MoE 2.141 -2.6%
Muon + MoE ~2.1 best

Both features independently improve convergence. MoE expert utilization is also now logged per-block to W&B (moe/block{i}_expert{j}_frac) — no expert collapse observed, all 4 experts receive balanced token fractions.

All runs on the W&B dashboard. Updated PR body with full results table and checked off remaining items.

A note: All 4 longer training runs (5000 steps) completed and are on the W&B dashboard convergence curves and expert utilization all there. The pod got spot preempted right before the 5000-step inference render though, :D so we still have the 2000-step inference visualization in the PR body. Happy to re-run for a 5000-step version if you'd like, but the convergence data tells the full story.

Let me know!

@AlmondGod

Copy link
Copy Markdown
Owner

amazing results, LGTM!

@AlmondGod AlmondGod merged commit ba5c8b0 into AlmondGod:main Apr 15, 2026
@eren23

eren23 commented Apr 16, 2026

Copy link
Copy Markdown
Contributor Author

amazing results, LGTM!

thanks a lot man! it was a joy working on this <3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants