feat: add MoE SwiGLU FFN and Muon optimizer#20
Conversation
Implements two TODO items from the README: 1. **Mixture of Experts in SwiGLU FFN** — New `MoESwiGLUFFN` class wraps N smaller SwiGLU experts with top-k routing and load-balancing auxiliary loss. Enabled via `use_moe: true` in config (dynamics model only). Default: 4 experts, top-2 routing, 0.01 aux loss coefficient. 2. **Muon optimizer** — Newton-Schulz orthogonalized updates for 2D weight matrices, with AdamW handling biases/norms/embeddings. Standalone implementation with no external dependencies. Enabled via `optimizer: "muon"` in config. Both features are opt-in with zero regression to default behavior. All existing configs work without modification. New files: - models/muon.py — Standalone Muon optimizer - utils/optimizer_utils.py — Shared optimizer creation for all stages Modified: - models/st_transformer.py — MoESwiGLUFFN, SwiGLUExpert classes - models/dynamics.py — MoE kwargs passthrough - utils/config.py — MoE + optimizer config fields - configs/training.yaml — Default config with new options - scripts/train_*.py — Multi-optimizer support + MoE aux loss Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…, checkpoint save - Muon: clone gradient before bf16 cast to prevent in-place corruption when param dtype is already bfloat16 (AMP/FSDP mixed precision) - MoE aux_loss: use model device for zero tensor instead of CPU default - Dynamics training: cache aux_loss scalar before backward for accurate W&B logging (avoids double-call and grad_accum inflation) - Checkpoint: save all optimizer/scheduler states when using split optimizers (Muon+AdamW), not just the primary - Document FSDP limitation in Muon docstring (DDP only) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
this looks amazing! excited for it to finish. Updated PR guidelines, main thing is can you visualize inference post-change? also would be great to clean up the AI triple quote function docs, usually not too helpful and add noise |
Strip triple-quote docstrings that add noise per reviewer feedback. Add MoE config fields (use_moe, num_experts, top_k_experts, moe_aux_loss_coeff) to load_dynamics_from_checkpoint so MoE-trained checkpoints restore correctly.
500 steps per stage, compile off, single GPU. Used for generating PR inference visuals.
sounds great! im on it rn :) |
This reverts commit 918a6ac.
- run_inference.py: use os.path.exists instead of os.path.isfile so directory-format checkpoints (model_state_dict.pt + state.pt) are recognized - utils.py: search model_sd directly (not model_sd['model']) when inferring conditioning_dim — state dict keys have no 'model.' prefix
Ground truth vs predicted frames from 500-step smoke training with MoE SwiGLU FFN (4 experts, top-2) + Muon optimizer. MSE: 0.002612 over 12 frames.
This reverts commit 13db565.
|
@AlmondGod removed docstrings and added post change training inference visuals to wandb and shared a link in the PR body, lmk if anything else needs to be changed :) i also can squeeze all commits into one if you want btw |
|
sounds great! looks like all thats left is to complete your checklist (run the longer runs and visualize inference?) |
|
i think we can just do a squash and merge when merging PR! |
Stores fraction_dispatched per MoE block during forward pass and logs
per-expert token fractions (moe/block{i}_expert{j}_frac) to W&B each step.
Enables router weight distribution analysis for the PR checklist.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Longer training results + expert utilization ✅Ran 4 dynamics configurations at 5000 steps on RTX 4090 (PicoDoom dataset), sharing the same VT/LA prerequisites:
Both features independently improve convergence. MoE expert utilization is also now logged per-block to W&B ( All runs on the W&B dashboard. Updated PR body with full results table and checked off remaining items. A note: All 4 longer training runs (5000 steps) completed and are on the W&B dashboard convergence curves and expert utilization all there. The pod got spot preempted right before the 5000-step inference render though, :D so we still have the 2000-step inference visualization in the PR body. Happy to re-run for a 5000-step version if you'd like, but the convergence data tells the full story. Let me know! |
|
amazing results, LGTM! |
thanks a lot man! it was a joy working on this <3 |
Summary
Implements two TODO items from the README:
MoESwiGLUFFNclass wraps N smaller SwiGLU experts with top-k routing and load-balancing auxiliary loss. Enabled viause_moe: truein config (dynamics model only). Default: 4 experts, top-2 routing, 0.01 aux loss coefficient.optimizer: "muon"in config.Both features are opt-in with zero regression to default behavior. All existing configs work without modification.
Inference Visualization
MoE + Muon dynamics on PONG (2000 steps, 4 experts top-2, autoregressive generation without actions):
View inference result (GT vs Predicted frames) on W&B — file:
inference_results_gt_vs_pred_no_actions_20260413_091147.pngAlso fixed two bugs found during inference testing:
run_inference.py:os.path.isfile→os.path.existsto support directory-format checkpointsutils.py:model_sd.get('model', {})→model_sdfor correct conditioning_dim inference from state dict keysW&B Training Runs
Short runs (500 steps, RTX 4090, PicoDoom)
W&B Project Dashboard
Longer runs — AdamW vs Muon convergence comparison (5000 steps, RTX 4090, PicoDoom)
Full dynamics-stage comparison with shared VT/LA prerequisites (2000 steps each):
*Run 4 completed ~4990/5000 steps before spot preemption; W&B has near-complete data.
Key findings:
moe/block{i}_expert{j}_frac) — router distributes tokens across all 4 experts with no expert collapseAll runs visible at: W&B Project Dashboard
Changes
New files
models/muon.py— Standalone Muon optimizer (Newton-Schulz orthogonalization + momentum). DDP-compatible; FSDP limitation documented.utils/optimizer_utils.py— Sharedcreate_optimizer()for all training stages. Handles AdamW (single optimizer) and Muon (split: Muon for 2D weights, AdamW for biases/norms/embeddings).Modified files
models/st_transformer.py—SwiGLUExpert,MoESwiGLUFFNclasses with top-k routing + load-balancing aux loss. Per-expert utilization tracking.STTransformerBlockandSTTransformeraccept MoE kwargs.models/dynamics.py— Passes MoE config through toSTTransformer.utils/config.py— MoE fields (use_moe,num_experts,top_k_experts,moe_aux_loss_coeff) and optimizer fields (optimizer,muon_momentum,muon_backend_steps) added to all config dataclasses.configs/training.yaml— Defaults for new config fields.scripts/train_dynamics.py— Multi-optimizer loop, MoE aux loss integration, per-expert utilization W&B logging, multi-optimizer checkpoint save.scripts/train_video_tokenizer.py— Multi-optimizer support.scripts/train_latent_actions.py— Multi-optimizer support.scripts/run_inference.py— Fix directory-format checkpoint detection.utils/utils.py— Fix conditioning_dim inference from state dict, add MoE config to checkpoint loader.Usage
Known limitations
save_training_state; secondary optimizer state saved in separateall_optimizers.ptfileTest plan
🤖 Generated with Claude Code