Emergent communication framework for 3D structure reconstruction. Two neural agents (Observer and Constructor) learn to communicate about colored 3D LEGO structures through a discrete channel -- no pre-defined language, no supervision on messages.
Paper in preparation.
Observer encodes a 3D structure into a discrete message (L slots, K symbols per slot via Gumbel-Softmax). Constructor reconstructs the structure from the message alone. Entropy sidebar shows per-slot information usage.
3D LEGO structure Discrete message Reconstructed structure
┌─────────────┐ ┌──────────────┐ ┌─────────────┐
│ ■ ■ │ Encoder │ slot 0: sym 7│ Decoder │ ■ ■ │
│ ■ ■ ■ │ ───────► │ slot 1: sym 2│ ───────► │ ■ ■ ■ │
│ ■ │ 3D CNN │ slot 2: sym 14│ ConvT3d │ ■ │
│ │ │ ... │ │ │
└─────────────┘ └──────────────┘ └─────────────┘
4x4x4 grid L x K logits 4x4x4 grid
Gumbel-Softmax
(soft train/hard eval)
Observer (Encoder): 3D CNN with AdaptiveAvgPool -- grid-size agnostic. Outputs message logits (L, K).
Channel: Gumbel-Softmax discretization. Soft (continuous) during training for gradient flow, hard (argmax) during evaluation. Temperature tau controls exploration.
Constructor (Decoder): Transposed 3D convolutions reconstruct the grid from the discrete message. Outputs per-cell class logits.
Evaluation: F1-score decomposed into position accuracy (block placement) and color accuracy (correct color assignment).
ENCODER (Observer)
─────────────────
Input: (B, C, S, S, S) one-hot grid
│
Conv3d(C→32, k=3, pad=1) + BN + ReLU
│
Conv3d(32→64, k=3, pad=1) + BN + ReLU
│
Conv3d(64→128, k=3, stride=2) + ReLU
│
AdaptiveAvgPool3d(1) any grid size → (B, 128)
│
Linear(128 → H) + ReLU
│
Linear(H → L·K)
│
(B, L, K) logits
│
┌─────────┴─────────┐
│ GUMBEL-SOFTMAX │
│ CHANNEL │
│ │
│ train: soft τ │
│ eval: hard │
│ (straight-through│
│ estimator) │
└─────────┬─────────┘
│
(B, L, K) message
│
DECODER (Constructor)
─────────────────────
Linear(K → E) per slot symbol embedding
│
Flatten → (B, L·E)
│
Linear(L·E → H) + ReLU
│
Linear(H → 128) + ReLU → reshape (B, 128, 1, 1, 1)
│
ConvTranspose3d(128→64, k=S) learned upsample
+ BN + ReLU → (B, 64, S, S, S)
│
Conv3d(64→32, k=3, pad=1) + BN + ReLU
│
Conv3d(32→C, k=3, pad=1)
│
Output: (B, C, S, S, S) per-cell class logits
C = num_classes S = grid_size L = message_length
K = vocab_size H = hidden_dim E = embed_dim
python -m venv .venv
source .venv/bin/activate
pip install -e ".[dev]"
python scripts/run_experiment.py configs/smoke.yaml
python -m lego.viz.eval_viz --checkpoint runs/.../checkpoints/best.pt
python -m lego.viz.eval_viz --checkpoint best.pt --mode grid --num 16
python -m lego.viz.eval_viz --checkpoint best.pt --mode animate --output anim.gif
src/lego/
├── data/ # Grid, Block, procedural generation (gravity + connectivity), scoring
├── model/ # Encoder (3D CNN), Channel (Gumbel-Softmax), Decoder (ConvTranspose3d)
├── training/ # Training loop, YAML-driven experiment orchestrator, callbacks
└── viz/ # 3D grid rendering, eval diagnostics
- Python >= 3.10
- PyTorch >= 2.0
- MPS (Apple Silicon), CUDA, or CPU
MIT
Built by @louis49