Skip to content

juchengshen/air

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

6 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

AIR β€” Asymmetric Input Recurrence

One model, two roles: emergent specialization in a shared recurrent Transformer

arXiv Blog Python PyTorch CUDA License


"One Model, Two Roles: Emergent Specialization in a Shared Recurrent Transformer"
Jucheng Shen*, Wenyi Su*, Anastasios Kyrillidis
arXiv:2605.17811

Can a shared-weight recurrent Transformer develop two distinct internal roles without being partitioned into separate modules? AIR (Asymmetric Input Recurrence) is a minimal two-state reasoning architecture in which the same Transformer block is reused for both updates, and the only built-in asymmetry is that the encoded input is injected during L-updates but not H-updates:

$$ \textbf{L-update:}\ \ \mathbf{z}_L \leftarrow f\big(\mathbf{z}_L + \mathbf{z}_H + \tilde{\mathbf{x}};,\theta\big), \qquad \textbf{H-update:}\ \ \mathbf{z}_H \leftarrow f\big(\mathbf{z}_H + \mathbf{z}_L;,\theta\big). $$

Across Sudoku-Extreme and Maze-30Γ—30, this single architectural detail causes the shared model to specialize: $\mathbf{z}_H$ behaves like a fully-committed proposal state while $\mathbf{z}_L$ acts as a shifting scratchpad. With half the Transformer parameters of the two-network HRM baseline, AIR matches or exceeds its accuracy on both tasks.

AIR architecture overview

✨ Key Features

  • Shared parameters, distinct roles β€” one Transformer block, two latent states with mechanistically-different functional roles
  • Half the parameters of HRM β€” matches the two-network baseline (Wang et al. 2025) on Sudoku (60.0% vs 55.0%) and Maze (75.6% vs 74.5%)
  • Asymmetric input is the necessary signal β€” symmetric variants collapse to ~52% Sudoku / ~70% Maze; the 8-point gap on Sudoku is the cost of removing the asymmetry
  • Prepend-and-strip level token recovers most of the gap β€” adding a structurally separable state-identity token to the symmetric base lifts Sudoku from 50.9% β†’ 57.5%
  • Mechanistic split in attention β€” L-updates concentrate ~47% more attention mass inside the constraint neighbourhood than H-updates at the deepest layer; on Sudoku, deeper layers additionally route attention to violated cells
  • State coupling, not redundancy β€” freeze experiments collapse final accuracy to 0% on both tasks; the two states are load-bearing in a coupled feedback loop
  • Open-source reproduction β€” every figure, table, and ablation in the paper has a corresponding shell script under experiment_*/

What the two states look like

Decoded zH vs zL on Sudoku β€” zH stays fully committed, zL holds some cells as BLANK and shifts those blanks across sub-steps

Sudoku β€” left columns: zH (fully committed). Right columns: zL (some cells held as BLANK; the held-back set shifts across sub-steps).

Decoded zH vs zL on a 30x30 Maze β€” zH commits to a complete layout, zL leaves regions undecided and rearranges its uncertainty

Maze-30Γ—30 β€” zH commits to a full layout; zL holds regions as PAD and revises them locally as the rollout progresses.

πŸš€ Quick Start

Prerequisites

Python 3.10+. Install Python dependencies:

pip install -r requirements.txt
unzip adam_atan2.zip          # bundles the AdamATan2 optimizer

If CUDA 12.6 and matching PyTorch wheels are not already installed, one tested setup is:

# CUDA 12.6
CUDA_URL=https://developer.download.nvidia.com/compute/cuda/12.6.3/local_installers/cuda_12.6.3_560.35.05_linux.run
wget -O cuda_installer.run "$CUDA_URL"
sudo sh cuda_installer.run --silent --toolkit --override
export CUDA_HOME=/usr/local/cuda-12.6

# PyTorch with CUDA 12.6
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126

# Build helpers for the CUDA extensions
pip install packaging ninja wheel setuptools setuptools-scm

Training scripts log to Weights & Biases:

export WANDB_API_KEY=your_wandb_api_key

Build the datasets

python dataset/build_sudoku_dataset.py \
  --output-dir data/sudoku-extreme-1k-aug-1000 \
  --subsample-size 1000 --num-aug 1000

python dataset/build_maze_dataset.py \
  --output-dir data/maze-30x30-hard-1k

If you already have prebuilt datasets, place them at data/sudoku-extreme-1k-aug-1000 and data/maze-30x30-hard-1k.

Train the default AIR variant (Lx_H, the canonical $(n_L, n_H) = (1, 0)$ asymmetric model)

# Sudoku-Extreme
bash experiment_input-injection-specialization-sudoku/train_sudoku_Lx_H.sh

# Maze-30Γ—30
bash experiment_input-injection-specialization-maze/train_maze_Lx_H.sh

Reproduce the full injection-asymmetry ablation (all 8 variants Γ— 5 seeds)

bash experiment_input-injection-specialization-sudoku/run_all_input_injection_specialization.sh
bash experiment_input-injection-specialization-maze/run_all_input_injection_specialization.sh

If you are not on a Slurm cluster, run individual train_*.sh scripts directly (each is a single-GPU launch). Variant naming follows the paper: L_Hx, Lx_H, L_H2x, L2x_H, Lx_H2x, L2x_Hx, Lx_Hx, L2x_H2x.

Level-token control (Section 4 of the paper)

bash experiment_addition-prepend-strip-no-strip/run_all_addition_prepend_strip_no_strip.sh

Operator-form control (Appendix)

bash experiment_operator-form-control/run_all_operator_form_control.sh

πŸ“ Repository Structure

β”œβ”€β”€ pretrain.py                                # Training entry point (Hydra-configured)
β”œβ”€β”€ evaluate.py                                # Held-out evaluation
β”œβ”€β”€ puzzle_dataset.py                          # Puzzle-token dataset wrapper
β”œβ”€β”€ config/                                    # Default architecture + training configs
β”œβ”€β”€ dataset/                                   # Sudoku-Extreme + Maze-30Γ—30 builders
β”œβ”€β”€ models/
β”‚   └── air/                                   # AIR architecture variants (shared block, two states)
β”œβ”€β”€ experiment_input-injection-specialization-sudoku/  # 8 AIR variants Γ— 5 seeds on Sudoku
β”œβ”€β”€ experiment_input-injection-specialization-maze/    # 8 AIR variants Γ— 5 seeds on Maze
β”œβ”€β”€ experiment_addition-prepend-strip-no-strip/        # Level-token controls (addition / prepend-strip / prepend-no-strip)
β”œβ”€β”€ experiment_operator-form-control/                  # Linear, nonlinear, Hadamard, sign-flip input transforms
β”œβ”€β”€ experiment_visual-sudoku-decoded-freeze/           # Decoded zH/zL rollouts + freeze interventions (Sudoku)
β”œβ”€β”€ experiment_visual-maze-decoded-freeze/             # Decoded zH/zL rollouts + freeze interventions (Maze)
β”œβ”€β”€ experiment_attention-analysis-sudoku/              # Attention contrasts (bar charts + example heatmaps)
β”œβ”€β”€ experiment_attention-analysis-maze/                # Maze counterpart
β”œβ”€β”€ adam_atan2.zip                             # Bundled AdamATan2 optimizer; unzip before training
└── requirements.txt

πŸ’» Hardware

Hardware Use Notes
1 Γ— NVIDIA H200 (80 GB) Single Sudoku training run ~4 hours per run
8 Γ— NVIDIA H200 (80 GB) Single Maze training run ~3 hours per run
Any Ampere or newer GPU Inference / freeze / attention analysis ≀ 16 GB sufficient for batched eval

Total compute reported in the paper (all sweeps + preliminary runs): ~500 GPU-hours.

πŸ“Š Results Highlights

Injection-asymmetry ablation β€” asymmetric matches the two-network baseline at half the parameters

All numbers are mean Β± standard deviation across 5 training seeds, evaluated on the full held-out test sets (422,786 Sudoku puzzles, 1,000 mazes). Bold marks the per-column winner.

Variant $(n_L, n_H)$ $\Delta$ Sudoku (%) Maze (%)
Asymmetric ($\Delta > 0$)
L_Hx (0, 1) 1 58.7 Β± 3.3 75.3 Β± 3.2
Lx_H (default) (1, 0) 1 60.0 Β± 2.0 71.0 Β± 6.3
L_H2x (0, 2) 2 58.6 Β± 1.9 75.6 Β± 1.9
L2x_H (2, 0) 2 59.1 Β± 2.4 71.1 Β± 6.4
Lx_H2x (1, 2) 1 59.6 Β± 0.9 70.9 Β± 2.4
L2x_Hx (2, 1) 1 58.6 Β± 2.9 74.5 Β± 1.6
Group mean (asym) β€” β€” 59.1 73.1
Symmetric ($\Delta = 0$)
Lx_Hx (1, 1) 0 52.1 Β± 1.6 69.4 Β± 2.5
L2x_H2x (2, 2) 0 50.9 Β± 2.9 70.3 Β± 4.2
Group mean (sym) β€” β€” 51.5 69.9
Two-network baseline
HRM (Wang et al. 2025) β€” β€” 55.0 74.5

Headline: the asymmetric group averages ~7.6 points higher than the symmetric group on Sudoku and ~3.2 points higher on Maze, and the best AIR variants match or exceed HRM with half the Transformer parameters.

Level-token control β€” a structurally separable state-identity signal recovers most of the gap

Level-token strategy Mechanism Sudoku (%)
L2x_H2x (symmetric base, no token) β€” 50.9 Β± 2.9
Β  + Addition element-wise add to every token 50.0 Β± 1.9
Β  + Prepend (strip) prepend, attend, then strip 57.5 Β± 1.3
Β  + Prepend (no strip) prepend, persist across cycles 47.8 Β± 1.6
For reference: asymmetric $\Delta = 1$ β€” ~59.0

A level token that occupies its own sequence position (prepend + strip) recovers most of the asymmetric-injection benefit. Mixing the signal into every content token (addition) or letting it accumulate content (no strip) does not work.

Mechanistic split in attention β€” L is consistently more local than H

Layer $\Delta_{\mathrm{nbr}}$ (control) $\Delta_{\mathrm{ent}}$ (control) $\Delta_{\mathrm{viol}}$ (violation-adj.)
0 0.050 Β± 0.001 0.053 Β± 0.000 βˆ’0.018 Β± 0.002
1 0.015 Β± 0.001 0.023 Β± 0.000 0.006 Β± 0.001
2 0.138 Β± 0.001 0.125 Β± 0.001 0.028 Β± 0.001
3 0.244 Β± 0.002 0.182 Β± 0.002 0.037 Β± 0.002

L-updates put ~47% more attention mass inside the constraint neighbourhood than H-updates at the deepest layer (Sudoku, control queries). Violation-specific routing emerges only in the deeper layers.

A single Sudoku puzzle showing the L-update attention concentrating inside the constraint neighbourhood while the H-update spreads attention across the full board

One blank query cell (puzzle p0121, query r2c6, layer 0). The L-update places about 0.81 of its attention mass inside the constraint neighbourhood; the H-update places only 0.24. Same puzzle, same query, same head β€” only the update type differs.

Freeze interventions β€” both states are load-bearing

Task Intervention Total content change Final accuracy
Sudoku normal $\mathbf{z}_L$: 1,235 / $\mathbf{z}_H$: 275 55.1%
Sudoku freeze $\mathbf{z}_H$ β†’ measure $\mathbf{z}_L$ $\mathbf{z}_L$: 323 (↓) 0%
Sudoku freeze $\mathbf{z}_L$ β†’ measure $\mathbf{z}_H$ $\mathbf{z}_H$: 551 (↑) 0%
Maze normal $\mathbf{z}_L$: 1,290 / $\mathbf{z}_H$: 825 71.0%
Maze freeze $\mathbf{z}_H$ β†’ measure $\mathbf{z}_L$ $\mathbf{z}_L$: 2,305 (↑) 0%
Maze freeze $\mathbf{z}_L$ β†’ measure $\mathbf{z}_H$ $\mathbf{z}_H$: 2,880 (↑) 0%

πŸ”¬ Reproducing the headline tables

The experiment folders ship run_all_*.sh scripts that submit the full sweep via sbatch. To run a single variant directly:

# Sudoku β€” default AIR
bash experiment_input-injection-specialization-sudoku/train_sudoku_Lx_H.sh

# Sudoku β€” symmetric control
bash experiment_input-injection-specialization-sudoku/train_sudoku_Lx_Hx.sh

# Level-token (prepend-and-strip) on the symmetric base
bash experiment_addition-prepend-strip-no-strip/train_sudoku_L2x_H2x_input_token_prepend.sh

Visual decoded-state and freeze experiments

Both experiment_visual-*-decoded-freeze/ folders contain decode_*_intermediate_first_10.sh, *_freeze_zH_zL_*5runs.sh, and *_freeze_zH_zL_symmetric.sh. These require trained checkpoints; by default they look under checkpoints/ paths configured in each script. Override via AIR_SUDOKU_CKPT_PATH / AIR_MAZE_CKPT_PATH, or edit the path at the top of the script.

Attention analysis

# Bar-chart data + multi-layer figure
bash experiment_attention-analysis-sudoku/generate_bar_data.sh
bash experiment_attention-analysis-sudoku/multilayer_figure.sh

# Maze counterparts
bash experiment_attention-analysis-maze/generate_bar_data.sh
bash experiment_attention-analysis-maze/multilayer_figure.sh

generate_bar_data.py captures L/H attention maps over 1,000 test puzzles at sub-steps {2, 4, 6, 8, 10, 12, 14, 15} and writes per-layer JSON into bar_data/.

πŸ“ Citation

@article{shen2026air,
  title   = {One Model, Two Roles: Emergent Specialization in a Shared Recurrent Transformer},
  author  = {Shen, Jucheng and Su, Wenyi and Kyrillidis, Anastasios},
  journal = {arXiv preprint arXiv:2605.17811},
  year    = {2026},
  url     = {https://arxiv.org/abs/2605.17811}
}

πŸ“– Blog

  • 🧠 One Model, Two Roles β€” Quanta-style walkthrough on AI-OWLS. Decoded rollouts, the symmetric control, the injection-asymmetry ablation, the level-token recovery, the freeze experiments, and the attention split β€” for a reader outside the subfield.

πŸ‘₯ Authors

Rice University, Department of Computer Science. Jucheng Shen and Wenyi (Barbara) Su contributed equally.

πŸ“„ License

Apache License 2.0.

About

Official PyTorch implementation of paper "One Model, Two Roles: Emergent Specialization in a Shared Recurrent Transformer"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors