Implementation of MILO, a model-based, offline imitation learning algorithm.
Link to pdf: https://arxiv.org/abs/2106.03207
After cloning this repository and installing the requirements, please run
cd milo && pip install -e .
cd mjrl && pip install -e .
The experiments are run using MuJoCo physics, which requires a license to install. Please follow the instructions on MuJoCo Website
The milo package contains our imitation learning, model-based environment stack, and boilerplate code. We modified the mjrl package to interface with our cost functions when doing model-based policy gradient. This modification can be seen in mjrl/mjrl/algos/batch_reinforce.py. Note that we currently only support NPG/TRPO as our policy gradient algorithm; however, in principle one could replace this with other algorithms/repositories.
This repository supports 5 modified MuJoCo environments that can be found in milo/milo/gym_env. They are
- Hopper-v4
- Walker2d-v4
- HalfCheetah-v4
- Ant-v4
- Humanoid-v4
If you would like to add an environment, register the environment in /milo/milo/gym_env/__init__.py according to OpenAI Gym instructions.
Please download the datasets from this google drive link. Each environment will have 2 datasets: [ENV]_expert.pt and [ENV]_offline.pt.
In the data directory, place the expert and offline datasets in the data/expert_data and data/offline_data direcotires respectively.
We provide an example run script for Hopper, example_run.sh, that can be modified to be used with any other registered environment. To view all the possible arguments you can run please see the argparse in milo/milo/utils/arguments.py.
To cite this work, please use the following citation. Note that this repository builds upon MJRL so please also cite any references noted in the README here.
@misc{chang2021mitigating,
title={Mitigating Covariate Shift in Imitation Learning via Offline Data Without Great Coverage},
author={Jonathan D. Chang and Masatoshi Uehara and Dhruv Sreenivas and Rahul Kidambi and Wen Sun},
year={2021},
eprint={2106.03207},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
