Skip to content

louis49/lego-emcom

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

lego-emcom

Emergent communication framework for 3D structure reconstruction. Two neural agents (Observer and Constructor) learn to communicate about colored 3D LEGO structures through a discrete channel -- no pre-defined language, no supervision on messages.

Paper in preparation.

Evaluation: Original structure, discrete message, and neural reconstruction Observer encodes a 3D structure into a discrete message (L slots, K symbols per slot via Gumbel-Softmax). Constructor reconstructs the structure from the message alone. Entropy sidebar shows per-slot information usage.


How it works

   3D LEGO structure          Discrete message           Reconstructed structure
   ┌─────────────┐           ┌──────────────┐           ┌─────────────┐
   │  ■ ■        │  Encoder  │ slot 0: sym 7│  Decoder  │  ■ ■        │
   │  ■ ■ ■      │ ───────►  │ slot 1: sym 2│ ───────►  │  ■ ■ ■      │
   │  ■          │  3D CNN   │ slot 2: sym 14│ ConvT3d  │  ■          │
   │             │           │ ...          │           │             │
   └─────────────┘           └──────────────┘           └─────────────┘
       4x4x4 grid              L x K logits               4x4x4 grid
                             Gumbel-Softmax
                            (soft train/hard eval)

Observer (Encoder): 3D CNN with AdaptiveAvgPool -- grid-size agnostic. Outputs message logits (L, K).

Channel: Gumbel-Softmax discretization. Soft (continuous) during training for gradient flow, hard (argmax) during evaluation. Temperature tau controls exploration.

Constructor (Decoder): Transposed 3D convolutions reconstruct the grid from the discrete message. Outputs per-cell class logits.

Evaluation: F1-score decomposed into position accuracy (block placement) and color accuracy (correct color assignment).

Model detail

                         ENCODER (Observer)
                         ─────────────────
                    Input: (B, C, S, S, S)    one-hot grid
                              │
                    Conv3d(C→32, k=3, pad=1) + BN + ReLU
                              │
                    Conv3d(32→64, k=3, pad=1) + BN + ReLU
                              │
                    Conv3d(64→128, k=3, stride=2) + ReLU
                              │
                    AdaptiveAvgPool3d(1)       any grid size → (B, 128)
                              │
                    Linear(128 → H) + ReLU
                              │
                    Linear(H → L·K)
                              │
                         (B, L, K) logits
                              │
                    ┌─────────┴─────────┐
                    │   GUMBEL-SOFTMAX  │
                    │   CHANNEL         │
                    │                   │
                    │  train: soft τ    │
                    │  eval:  hard      │
                    │  (straight-through│
                    │   estimator)      │
                    └─────────┬─────────┘
                              │
                         (B, L, K) message
                              │
                         DECODER (Constructor)
                         ─────────────────────
                    Linear(K → E) per slot    symbol embedding
                              │
                    Flatten → (B, L·E)
                              │
                    Linear(L·E → H) + ReLU
                              │
                    Linear(H → 128) + ReLU → reshape (B, 128, 1, 1, 1)
                              │
                    ConvTranspose3d(128→64, k=S)  learned upsample
                    + BN + ReLU               → (B, 64, S, S, S)
                              │
                    Conv3d(64→32, k=3, pad=1) + BN + ReLU
                              │
                    Conv3d(32→C, k=3, pad=1)
                              │
                   Output: (B, C, S, S, S)    per-cell class logits

     C = num_classes   S = grid_size   L = message_length
     K = vocab_size    H = hidden_dim  E = embed_dim

Setup

python -m venv .venv
source .venv/bin/activate
pip install -e ".[dev]"

Quick start

Train a single run

python scripts/run_experiment.py configs/smoke.yaml

Visualize results

python -m lego.viz.eval_viz --checkpoint runs/.../checkpoints/best.pt
python -m lego.viz.eval_viz --checkpoint best.pt --mode grid --num 16
python -m lego.viz.eval_viz --checkpoint best.pt --mode animate --output anim.gif

Architecture

src/lego/
├── data/           # Grid, Block, procedural generation (gravity + connectivity), scoring
├── model/          # Encoder (3D CNN), Channel (Gumbel-Softmax), Decoder (ConvTranspose3d)
├── training/       # Training loop, YAML-driven experiment orchestrator, callbacks
└── viz/            # 3D grid rendering, eval diagnostics

Requirements

  • Python >= 3.10
  • PyTorch >= 2.0
  • MPS (Apple Silicon), CUDA, or CPU

License

MIT


Built by @louis49

About

Emergent communication framework for 3D structure reconstruction

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages