Skip to content

Shubham-Rasal/grokking-mlx

Repository files navigation

grokking-mlx

Replicating the grokking phenomenon in a tiny Transformer, built with MLX for Apple Silicon.

Grokking is the observation that a neural network can suddenly generalise — long after it has already memorised the training set. This repo reproduces it on modular arithmetic tasks in under a minute on an M-series Mac.


Textbook curve — (a ÷ b) mod 97

Train accuracy hits 100% by epoch ~30, then sits there for 500 epochs while val accuracy flatlines near 0%. Then — suddenly — the model grokks the underlying structure and val accuracy jumps to ~100%.

Division grokking curve


Fast grokking — (a + b) mod 97

Addition is a structurally simpler task. The memorisation and generalisation phases are both shorter, but the sudden jump is still clearly visible.

Addition grokking curve


What you're seeing

Phase Train acc Val acc
Memorisation ~100% ~0%
← long plateau → ~100% ~0%
Grokking 100% jumps to ~100%

Training stops automatically the moment grokking is detected (configurable via --grok-threshold).

Why the sawtooth spikes?

The periodic dips in train accuracy during the memorisation phase are caused by weight decay continuously eroding the memorisation weights until they break, followed by rapid re-memorisation. This isn't a bug — it's the engine behind grokking. Each erosion cycle, weight decay also quietly compresses a more weight-efficient generalising circuit building in the background. When that circuit becomes dominant, grokking fires and the spikes stop.


Quickstart

pip install -r requirements.txt
python main.py

Default: (a ÷ b) mod 97, 50% train split, stops when val acc ≥ 95%.

Textbook curve (large memorisation gap)

python main.py --op / --weight-decay 0.5 --train-fraction 0.3 --epochs 2000

Fast addition grokking

python main.py --op + --weight-decay 2.0 --epochs 1000

All options

Flag Default Description
--op / Operation: + - * /
--p 97 Prime modulus
--train-fraction 0.5 Fraction of pairs used for training
--d-model 128 Embedding dimension
--n-heads 1 Attention heads
--n-layers 2 Transformer layers
--epochs 500 Max epochs (early-stops on grokking)
--lr 1e-3 Peak learning rate
--weight-decay 1.0 AdamW weight decay — key driver of grokking
--grok-threshold 0.95 Val accuracy that triggers early stop
--save-plot grokking.png Output plot path

Model

A decoder-only Transformer (~419K params) with RoPE, RMSNorm, and SiLU FFN. Input sequence: [a, op, b, =]. Prediction: the answer token at the = position.

Runs at ~16 epochs/sec on Apple Silicon.


References

About

Replicating the grokking phenomenon in a tiny MLX Transformer on Apple Silicon

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages