[WIP] Looped transformer with LoRA adapters#11
[WIP] Looped transformer with LoRA adapters#11adityawrk wants to merge 2 commits intoopenai:mainfrom
Conversation
3 physical layers x 5 loops = 15 effective depth, dim=640, GQA 10:5, SwiGLU MLP, rank-32 LoRA adapters per loop, U-Net skip connections, multi-token prediction (training only). Int8 quantized + zlib compressed.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 9d6349ebd7
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| shift = i + 2 # predict t+2, t+3, ... | ||
| if shift >= target_ids.size(1): | ||
| continue | ||
| mtp_x = x[:, :-shift, :].reshape(-1, x.size(-1)) | ||
| mtp_targets = target_ids[:, shift:].reshape(-1) |
There was a problem hiding this comment.
Shift MTP targets by one fewer token
When multi-token prediction is enabled, target_ids is already the t+1 sequence, so shift = i + 2 makes the first auxiliary head train on target_ids[:, 2:] (t+3 from x_t) instead of t+2. Every later MTP head is likewise one token too far ahead, so the new auxiliary loss is supervising the wrong future tokens rather than the advertised t+2, t+3, ... targets.
Useful? React with 👍 / 👎.
| ).to(device).bfloat16() | ||
| for module in base_model.modules(): | ||
| if isinstance(module, CastedLinear): | ||
| module.float() | ||
| restore_low_dim_params_to_fp32(base_model) |
There was a problem hiding this comment.
Restore LoRA adapter weights to fp32 before Muon
With the default LORA_RANK>0, the only loop-specific weights are the new LoRA adapters, but after the global .bfloat16() cast this block only promotes CastedLinear modules back to fp32. LoRAAdapter.A/B stay bf16, and Muon.step() allocates its momentum buffer with zeros_like(g), so those adapters are trained with bf16 state/update precision while the rest of the matrix weights stay fp32. That materially weakens the per-loop specialization path this change is adding.
Useful? React with 👍 / 👎.
- MTP shift was off-by-one (target_ids is already t+1, so shift=i+2 predicted t+3 instead of t+2). Changed to shift=i+1. - LoRAAdapter.A/B stayed bf16 after global .bfloat16() cast since only CastedLinear was restored to fp32. Now both types are restored.
Summary
Changes
records/track_10min_16mb/track_a/train_gpt.pyStatus
WIP -- hyperparameter tuning in progress, no validated score yet.