Replicating the grokking phenomenon in a tiny Transformer, built with MLX for Apple Silicon.
Grokking is the observation that a neural network can suddenly generalise — long after it has already memorised the training set. This repo reproduces it on modular arithmetic tasks in under a minute on an M-series Mac.
Train accuracy hits 100% by epoch ~30, then sits there for 500 epochs while val accuracy flatlines near 0%. Then — suddenly — the model grokks the underlying structure and val accuracy jumps to ~100%.
Addition is a structurally simpler task. The memorisation and generalisation phases are both shorter, but the sudden jump is still clearly visible.
| Phase | Train acc | Val acc |
|---|---|---|
| Memorisation | ~100% | ~0% |
| ← long plateau → | ~100% | ~0% |
| Grokking | 100% | jumps to ~100% |
Training stops automatically the moment grokking is detected (configurable via --grok-threshold).
The periodic dips in train accuracy during the memorisation phase are caused by weight decay continuously eroding the memorisation weights until they break, followed by rapid re-memorisation. This isn't a bug — it's the engine behind grokking. Each erosion cycle, weight decay also quietly compresses a more weight-efficient generalising circuit building in the background. When that circuit becomes dominant, grokking fires and the spikes stop.
pip install -r requirements.txt
python main.pyDefault: (a ÷ b) mod 97, 50% train split, stops when val acc ≥ 95%.
python main.py --op / --weight-decay 0.5 --train-fraction 0.3 --epochs 2000python main.py --op + --weight-decay 2.0 --epochs 1000| Flag | Default | Description |
|---|---|---|
--op |
/ |
Operation: + - * / |
--p |
97 |
Prime modulus |
--train-fraction |
0.5 |
Fraction of pairs used for training |
--d-model |
128 |
Embedding dimension |
--n-heads |
1 |
Attention heads |
--n-layers |
2 |
Transformer layers |
--epochs |
500 |
Max epochs (early-stops on grokking) |
--lr |
1e-3 |
Peak learning rate |
--weight-decay |
1.0 |
AdamW weight decay — key driver of grokking |
--grok-threshold |
0.95 |
Val accuracy that triggers early stop |
--save-plot |
grokking.png |
Output plot path |
A decoder-only Transformer (~419K params) with RoPE, RMSNorm, and SiLU FFN. Input sequence: [a, op, b, =]. Prediction: the answer token at the = position.
Runs at ~16 epochs/sec on Apple Silicon.

