Implementation of STORM: Efficient Stochastic Transformer based World Models for Reinforcement Learning
News: This project is no longer maintained, and we recommend using the new OC-STORM repo, which also contains full code for running STORM on Atari games.
This repo contains an implementation of STORM.
Following the Training and Evaluating Instructions to reproduce the main results presented in our paper. One may also find Additional Useful Information useful when debugging and observing intermediate results. To reproduce the speed metrics mentioned in the paper, please see Reproducing Speed Metrics.
-
Install the necessary dependencies. Note that we conducted our experiments using
python 3.10.pip install -r requirements.txt
Installing
AutoROM.accept-rom-licensemay take several minutes. -
Train the agent.
chmod +x train.sh ./train.sh
The
train.shfile controls the environment and the running name of a training process.env_name=MsPacman python -u train.py \ -n "${env_name}-life_done-wm_2L512D8H-100k-seed1" \ -seed 1 \ -config_path "config_files/STORM.yaml" \ -env_name "ALE/${env_name}-v5" \ -trajectory_path "trajectory/${env_name}.pkl"-
The
env_nameon the first line can be any Atari game, which can be found here. -
-noption is the name for the tensorboard logger and checkpoint folder. You can change it to your preference, but we recommend keeping the environment's name first. The tensorboard logging folder isruns, and the checkpoint folder isckpt. -
The
-seedparameter controls the running seed during the training. We evaluated our method using 5 seeds and report the mean return in Table 1. -
The
-config_pathpoints to a YAML file that controls the model's hyperparameters. The configuration inconfig_files/STORM.yamlis the same as in our paper. -
The
-trajectory_pathis only useful when the optionUseDemonstrationin the YAML file is set toTrue(by default it'sFalse). This corresponds to the ablation studies in Section 5.3. We provide the pre-collected trajectories in theD_TRAJ.7zfile, and you need to decompress it for use.
-
-
Evaluate the agent. The evaluation results will be presented in a CSV file located in the
eval_resultfolder.chmod +x eval.sh ./eval.sh
The
eval.shfile controls the environment and the running name when testing an agent.env_name=MsPacman python -u eval.py \ -env_name "ALE/${env_name}-v5" \ -run_name "${env_name}-life_done-wm_2L512D8H-100k-seed1"\ -config_path "config_files/STORM.yaml"The
-run_nameoption is the same as the-noption intrain.sh. It should be kept the same as in the training script.
You can use Tensorboard to visualize the training curve and the imagination videos:
chmod +x TensorBoard.sh
./TensorBoard.shTo reproduce the speed metrics mentioned in the paper, please consider the following:
- Hardware requirements: NVIDIA GeForce RTX 3090 with a high frequence CPU, we use
11th Gen Intel(R) Core(TM) i9-11900Kin our experiments. Low frequence CPUs may lead to a GPU idle and slow down the traning. To make full use of a powerful GPU, one can traing several agents at the same time on one device. - Software requiements:
PyTorch>=2.0.0is required.
- Our experiments used bfloat16 to accelerate training. To train on devices that do not support bfloat16, such as the NVIDIA V100, you need to change
torch.bfloat16totorch.float16in bothagents.pyandsub_models/world_models.py. Additionally, modify the lineattn = attn.masked_fill(mask == 0, -1e9)toattn = attn.masked_fill(mask == 0, -6e4)to prevent overflow. - On devices like the NVIDIA A100, using bfloat16 may slow down the training. In this case, you can toggle the
self.use_amp = Trueoption in bothagents.pyandsub_models/world_models.py.
We've recently observed if one clones the repo from Powershell and then calls train.sh under WSL shell, then it may throw an error related to arg parse. This may be due to invisible newlines in the files somehow generated when cloning with git. The solution is to download the zip or clone directly inside WSL.
We've referenced several other projects during the development of this code:
- Attention is all you need pytorch For Transformer structure, attention operation, and other building blocks.
- Hugging Face Diffusers For trainable positional encoding.
- DreamerV3 For Symlog loss, layer & kernel configuration in VAE.
@inproceedings{
zhang2023storm,
title={{STORM}: Efficient Stochastic Transformer based World Models for Reinforcement Learning},
author={Weipu Zhang and Gang Wang and Jian Sun and Yetian Yuan and Gao Huang},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
year={2023},
url={https://openreview.net/forum?id=WxnrX42rnS}
}