Skip to content

[WIP] Looped transformer with LoRA adapters#11

Closed
adityawrk wants to merge 2 commits intoopenai:mainfrom
adityawrk:track-a-submission
Closed

[WIP] Looped transformer with LoRA adapters#11
adityawrk wants to merge 2 commits intoopenai:mainfrom
adityawrk:track-a-submission

Conversation

@adityawrk
Copy link
Copy Markdown

Summary

  • 3 physical transformer layers looped 5 times for 15 effective depth
  • Rank-32 LoRA adapters on Q/K/V/O for per-loop specialization
  • SwiGLU MLP (hidden=1664), GQA 10:5, dim=640
  • U-Net skip connections at the loop level
  • Multi-token prediction auxiliary heads (training only, excluded from artifact)
  • Int8 quantized with per-row scales, zlib compressed

Changes

  • records/track_10min_16mb/track_a/train_gpt.py

Status

WIP -- hyperparameter tuning in progress, no validated score yet.

3 physical layers x 5 loops = 15 effective depth, dim=640, GQA 10:5,
SwiGLU MLP, rank-32 LoRA adapters per loop, U-Net skip connections,
multi-token prediction (training only). Int8 quantized + zlib compressed.
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 9d6349ebd7

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +791 to +795
shift = i + 2 # predict t+2, t+3, ...
if shift >= target_ids.size(1):
continue
mtp_x = x[:, :-shift, :].reshape(-1, x.size(-1))
mtp_targets = target_ids[:, shift:].reshape(-1)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Shift MTP targets by one fewer token

When multi-token prediction is enabled, target_ids is already the t+1 sequence, so shift = i + 2 makes the first auxiliary head train on target_ids[:, 2:] (t+3 from x_t) instead of t+2. Every later MTP head is likewise one token too far ahead, so the new auxiliary loss is supervising the wrong future tokens rather than the advertised t+2, t+3, ... targets.

Useful? React with 👍 / 👎.

Comment on lines +918 to +922
).to(device).bfloat16()
for module in base_model.modules():
if isinstance(module, CastedLinear):
module.float()
restore_low_dim_params_to_fp32(base_model)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Restore LoRA adapter weights to fp32 before Muon

With the default LORA_RANK>0, the only loop-specific weights are the new LoRA adapters, but after the global .bfloat16() cast this block only promotes CastedLinear modules back to fp32. LoRAAdapter.A/B stay bf16, and Muon.step() allocates its momentum buffer with zeros_like(g), so those adapters are trained with bf16 state/update precision while the rest of the matrix weights stay fp32. That materially weakens the per-loop specialization path this change is adding.

Useful? React with 👍 / 👎.

- MTP shift was off-by-one (target_ids is already t+1, so shift=i+2
  predicted t+3 instead of t+2). Changed to shift=i+1.
- LoRAAdapter.A/B stayed bf16 after global .bfloat16() cast since only
  CastedLinear was restored to fp32. Now both types are restored.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants