[Non-Record] JEPA Self-Distillation with EMA Target Encoder for Autoregressive LM | Controlled A/B Shows No Gain Over Vanilla CE (val_bpb: 1.19)#896
Open
MVPandey wants to merge 5 commits intoopenai:mainfrom
Conversation
20L/512d transformer with jepa auxiliary loss: - context encoder (causal, trainable) + target encoder (ema, decay=0.9995) - predictor mlp predicts target encoder repr at next position - vicreg on target encoder output prevents collapse - standard tied-embedding ce head for scoring (jepa is training-only) - only context encoder saved in artifact (~48M params, ~14MB int6+lzma) training: lr warmup 200 steps, cosine warmdown, jepa weight 0.3 annealed to 0 during warmdown, muon for banks + adamw for scalars
submission record with readme, submission.json, train_gpt.py, and cleaned training log. results tbd pending 3 seeds.
JEPA auxiliary loss with EMA target encoder doesn't help autoregressive LM. vanilla CE is 40% faster (no target encoder overhead) and converges to the same or slightly better BPB. includes 11L and 20L comparisons, both with same seed/hardware/wall time.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
[Non-Record] JEPA Self-Distillation for Autoregressive LM | Controlled A/B Shows No Gain Over Vanilla CE | Negative Results
Track: Non-record, unlimited compute, 16MB artifact
Author: Manav Pandey (MVPandey)
val_bpb: 1.1896 (JEPA, 11L) vs 1.1841 (vanilla CE, 11L), no meaningful difference
Artifact: ~16.3MB
TL;DR
JEPA self-distillation with a moving EMA target encoder doesn't help autoregressive language modeling at this small scale. A vanilla CE baseline trains faster and converges to the same (or slightly better) BPB. I'm pushing these results in the spirit of good science and in the hopes that the implementation and findings are useful to someone exploring a similar direction. I'm also actively looking for alternative approaches that might make SSL auxiliary objectives meaningful in this setting.
The Idea
Full JEPA with an EMA target encoder as an auxiliary self-distillation objective. The context encoder is a causal transformer; the target encoder is an EMA copy (decay=0.9995) that provides slowly-evolving prediction targets. A predictor network learns to forecast what the target encoder will represent at the next token position.
My previous experience with JEPA was in constraint satisfaction (Sudoku solving via energy-based inference with Langevin dynamics, github.com/MVPandey/Enso). Adapting it to autoregressive token prediction required rethinking the target/context encoder relationship and figuring out which components actually matter.
Only the context encoder is saved in the artifact. The target encoder, predictor, and projection heads are training-only overhead.
Architecture
graph LR X["x_{1..T}"] --> CE["Context Encoder<br/>(causal transformer, trainable)"] X --> TE["Target Encoder<br/>(same arch, EMA-updated, no grad)"] CE --> h_ctx["h_ctx"] TE --> h_tgt["h_tgt"] h_ctx --> |"ctx_proj"| z_ctx["z_ctx"] h_tgt --> |"tgt_proj"| z_tgt["z_tgt"] z_ctx --> |"predictor MLP"| z_pred["z_pred[t]"] z_tgt --> z_tgt_next["z_tgt[t+1]"] h_ctx --> |"h @ W_emb"| CE_Loss["CE Loss<br/>(standard next-token)"] z_pred --> JEPA_Loss["JEPA Loss<br/>(MSE in latent space)"] z_tgt_next --> JEPA_Loss z_tgt --> VICReg["VICReg<br/>(collapse prevention)"] CE_Loss --> Total["Total Loss"] JEPA_Loss --> |"× 0.3"| Total VICReg --> Total style TE stroke:#888,stroke-dasharray: 5 5 style JEPA_Loss stroke:#c44,stroke-width:2px style VICReg stroke:#c44,stroke-width:2px style CE_Loss stroke:#4a4,stroke-width:2pxBackbone
JEPA Components (~393K params, training-only)
Results: Controlled A/B Comparison
All runs: same seed (42), same hardware (8xH100), same wall time (30 min).
11L/512d (fits 16MB artifact)
20L/512d (over 16MB, for reference)
At 20L, step times were nearly identical because the JEPA path ran uncompiled while vanilla used torch.compile (an unfair advantage for vanilla that we caught via critic analysis). The 11L comparison fixes this (both uncompiled) and shows the real cost: JEPA adds ~40% step time overhead from the target encoder forward pass, resulting in 41% fewer training steps for the same wall time.
Bottom line: vanilla CE wins by 0.005 BPB at 11L while being 40% faster. JEPA is not just noise, it's actively worse when you account for compute.
Core Finding
The JEPA auxiliary loss asks "can you predict what the target encoder's latent representation of the next token will be?" But CE is already asking "can you predict the next token?" With a small, well-structured BPE vocabulary (V=1024), these two objectives produce nearly identical gradient signals. The JEPA loss ends up being a strictly less informative version of what CE already provides. In vision, JEPA helps because pixel-level prediction is wasteful and latent prediction captures semantic structure that raw reconstruction misses. That asymmetry doesn't exist here: tokens are already semantic units.
The Journey
Energy-Based Output Heads (didn't work)
I started by replacing softmax with energy-based scoring in a learned latent space (CLIP-style cosine similarity with learnable temperature). Tried VICReg on the energy head output, deterministic sharpening Langevin at eval time. All produced identical BPB to standard softmax. Softmax is already an energy model with E(v) = -logit(v).
Real JEPA with Target Encoder
Pivoted to actual JEPA with an EMA target encoder. Key things I learned along the way:
The Quantization Bug
First 20L run reported 1.15 BPB but artifact was 39.7MB. The
_clsfunction checked for.attn.and.mlp.but unbanked names used.a.and.m.. Everything fell through to int8 instead of int6. Trained weights at int6 compress to ~0.34 bytes/param with LZMA, not the 0.16 I estimated from a random model test.Other Takeaways
JEPA weight and EMA decay are tightly coupled. High weight + fast EMA means the predictor can't track the target, causing rising loss and gradient competition. Low weight + slow EMA gives a stable but useless auxiliary signal.
Don't test compression ratios on random models. Trained weights have much higher entropy than randomly initialized ones. My random-model test showed 7.8MB for 48M params; the real artifact was 28MB.
Concurrent Work
PR #832 independently explores JEPA for language modeling with a byte-level transformer. My approach uses a full EMA target encoder as the backbone rather than chunk-level prediction on top of a standard transformer.
Reproduction
Test Plan
On chunk-level prediction: An alternative JEPA formulation predicts representations of token chunks rather than individual next tokens, which is closer to how JEPA operates on image patches. PR #832 explores this direction and shows a consistent 0.01 BPB gain. I chose token-level prediction to test whether full JEPA self-distillation (with a moving target encoder) adds value at the most basic level. It doesn't, which suggests the marginal gains from chunk-level prediction come from the chunk aggregation itself rather than the JEPA framework.