Edge Align is a Rust CLI that tests how f32 vs f16 vs int8 precision affects attention behavior in Pythia-70M. It identifies which attention heads matter most for next-token prediction, which ones degrade under lower precision, and whether restoring selected heads to full precision improves performance.
- Rust toolchain with Cargo
- Python 3 with
numpy,matplotlib, andscipy - Enough disk space for the Pythia-70M Hugging Face files
- Optional CUDA build support for GPU execution
cargo build --releaseFor CUDA:
cargo build --release --features cudacargo run --releaseWith explicit inputs:
cargo run --release -- --prompts prompts.json --output results/experiment.json --device autoUse --device cpu or --device cuda to force a device. Use --max-repair-k 10 to control the repair curve length.
By default the runner also creates a simulated INT8 comparison path. This is not native INT8 weight inference: the Pythia safetensors checkpoint is loaded with f32 weights, and selected activations are dynamically quantized to signed int8 levels and immediately dequantized during the forward pass. Disable it with --disable-int8. To also run circuit repair on the simulated INT8 model:
cargo run --release -- --repair-int8python scripts/plot.pyThe script reads results/experiment.json and writes:
fig1_importance_vs_drift.pngfig2_zero_vs_mean_ablation.pngfig3_task_specificity.pngfig4_circuit_repair.pngfig5_task_degradation.png
The CLI validates factual-recall and indirect-object-identification prompts against the f32 Pythia-70M run, computes reference mean activations, measures head importance with zero-ablation and mean-ablation, measures f32/f16 head drift, and runs top-k circuit repair by injecting f32 attention-head activations into the f16 forward pass.
The hook is inserted after each attention output projection and before the residual addition. Pythia uses parallel residual blocks, so the MLP reads from the original layer-normalized residual stream rather than from the patched attention output in the same layer.
The public EleutherAI/pythia-70m checkpoint is distributed as half-precision safetensors. The f32 run loads those weights into f32 tensors, so the comparison isolates dtype/activation arithmetic effects available from the public checkpoint rather than recovering unavailable original f32 training weights.
The JSON schema keeps the original f32/f16 fields and adds optional INT8 fields: drift_f32_int8, cosine_sim_f32_int8, clean_logit_diff_int8, and top-level repair_int8. Existing plots still read the original f32/f16 fields.