📄 Paper | 🌐 Project Page | 💻 Code | 📚 BibTeX
Geometry-aware imitation: GPI treats demonstrations as geometric curves, inducing distance and flow fields that make imitation simple, efficient, flexible, and interpretable.
GPI delivers multimodal behaviors with higher success rates and inference speeds that are 20–100× faster than diffusion policies, while slashing memory by orders of magnitude.
|
|
|
|
gpi/ GPI database, planner, and policy abstractions
pusht/ PushT datasets, environments, download helpers, evaluation utils
scripts/ Entry points for training, policy rollouts, and evaluation
models/ Default location for datasets and checkpoints (auto-populated)
results/ Logs and rollout videos produced by evaluation scripts
environment.yml Python dependencies
Create and activate the recommended conda environment:
conda env create -f environment.yml
conda activate gpiRun the pretrained state-based policy:
python scripts/run_state_policy.py --seed 500 --max-steps 200Run the pretrained vision policy:
python scripts/run_vision_policy.py --seed 500 --max-steps 200Both scripts include extensive CLI options; append --help to inspect defaults and descriptions:
--k-neighbors: number of demonstrations blended per query.--action-horizon: trajectory horizon fetched from the database.--subset-size: random subset size for approximate nearest neighbours.--batch-size: PyTorch batch size for scoring.--device: override automaticcuda/cpuselection.--obs-noise-std,--noise-decay,--min-noise-std,--disable-noise: latent exploration controls.--random-seed: random seed for sampling and policy noise.--memory-length: cap the loop-avoidance buffer.--fixed-lambda1,--fixed-lambda2: manually weight progression vs attraction flows.--action-smoothing: exponential smoothing factor for actions.--no-relative-action: operate in absolute action space.--video-path,--no-save-video: manually set or disable rollout video export.--no-live-render: disable the interactive window for headless runs.--quiet: silence tqdm progress output.
Automate sweeps across random seeds or parameter grids using:
python scripts/run_state_evaluation.py --count 20 --max-steps 200
python scripts/run_vision_evaluation.py --count 20 --max-steps 200Key flags:
--dataset: input PushT replay archive (auto-downloaded if missing).--checkpoint: ResNet18 vision checkpoint path (vision evaluation only).--count: number of runs to generate.--random-seed: deterministic configuration sampling.--video-dir/--no-save-video: control mp4 exports (defaults toresults/).
Each evaluation logs Reward, Inference Time, and Memory to stdout and results/logs/.
Finetune the ResNet18 state predictor to refresh the vision policy backbone:
conda activate gpi
python scripts/train_vision_features.py \
--dataset models/pusht_cchi_v7_replay.zarr.zip \
--output-dataset models/pusht_cchi_v7_replay_imgs_feature_epoch_200.zarr \
--checkpoint-path models/vision_state_predictor_epoch_200.ckptThe training script attaches a lightweight task-specific head that regresses object pose directly from each image frame. During inference we reuse this predicted pose for distance computation, so the vision policy queries the same geometry-aware metric as the state-based planner. However, other visual encoders for latent embeddings can also be used, such as VAEs or pretrained models.
Datasets and checkpoints are auto-downloaded to models/ when absent. Adjust the output names to avoid overwriting existing artifacts.
- Dependency mismatch: Version conflicts among
pygame,pymunk,zarr, orgymcan break the simulator. Verify the environment was created fromenvironment.yml. - Missing assets: If downloads fail, clear the partial files in
models/and re-run the command; the script retries automatically. - Headless rendering: Use
--no-live-renderto avoid opening a window on remote servers. Videos will still be written toresults/.
@misc{GPI,
Author = {Yiming Li and Nael Darwiche and Amirreza Razmjoo and Sichao Liu and Yilun Du and Auke Ijspeert and Sylvain Calinon},
Title = {Geometry-aware Policy Imitation},
Year = {2025},
Eprint = {arXiv:2510.08787},
}



