This codebase has only been tested for Atari Breakout. It is very hacky, trained on a small number of samples, and is not at all optimized. Use at your own risk!
This projects aims to create a neural playable version of Atari Breakout by learning purely from videos of the game. It's a small replica of what Google's Genie project, where they learned an interactive and playable world models purely through videos of game.
Watch the video
- Implements a Deep Q-Network (DQN) agent for Atari Breakout.
- Supports two exploration strategies:
- Temperature-based (Boltzmann) exploration with Prioritized Experience Replay (PER).
- Random Network Distillation (RND) for intrinsic motivation and improved exploration.
- Modular, efficient, and test-driven codebase.
- Current progress: Agent achieves a score up to 20.
- Next steps: Train for longer, increase max steps per episode to 10,000, and target scores of 200+.
See part1.md for full details, implementation, and next steps.
pip install -r requirements.txt
python -m AutoROM --accept-license- Default (temperature-based exploration):
python train_dqn.py
- RND-based exploration:
python train_dqn.py --exploration_mode rnd
- Additional arguments (see
python train_dqn.py --help) allow you to control episodes, buffer size, and more.
- To record videos of a trained (or random) agent:
python record_videos.py --checkpoint_path <path_to_checkpoint> --output_dir videos/ e.g. python record_videos.py --bulk --total_videos 100 --percent_random 15 --output_dir bulk_videos
- Replace
<path_to_checkpoint>with the path to your saved model checkpoint (seecheckpoints/). - Videos will be saved as MP4 files in the specified output directory.
- Prepare data: Ensure gameplay videos are available in
bulk_videos/(from Part 1). - Train VQ-VAE latent action model:
python train_latent_action.py
- Checkpoints and logs:
- Best model:
checkpoints/latent_action/best.pt - Processed data: see
data/andvqvae_recons/ - Training logs and metrics: Weights & Biases (wandb)
- Best model:
- Evaluation:
- Run
test_latent_action_model.pyfor automated tests and metrics - Visualize reconstructions and codebook usage as described in part2.md
- Run
See part2.md for full details, implementation, and next steps.
- Note: The decoder from Part 2 serves as the world model. No separate training is required unless you wish to experiment with alternative architectures.
- Evaluate world model:
- Use the decoder to predict next frames given current frame and latent action index
- For multi-step prediction and rollout analysis, see evaluation code in
test_latent_action_model.py
- Extract (action, latent_code) pairs using trained VQ-VAE:
python collect_action_latent_pairs.py
- Output:
data/actions/action_latent_pairs.json
- Output:
- Train the action-to-latent MLP:
python train_action_to_latent.py
- Best checkpoint:
checkpoints/latent_action/action_to_latent_best.pt
- Best checkpoint:
- Evaluate mapping accuracy:
- Run
test_latent_action_data_collection.pyandtest_latent_action_model.py - Analyze accuracy and code distributions as described in part4.md
- Run
See part4.md for full details, implementation, and next steps.
- Run a random agent in the neural world model:
python neural_random_game.py
- Output:
data/neural_random_game.gif(video of neural gameplay)
- Output:
- Play Breakout using the neural world model:
python play_neural_breakout.py
- Controls: SPACE (Fire), LEFT/RIGHT ARROW, PERIOD (NOOP), ESC/Q (Quit)
- Requires trained models from Parts 2 and 4
- All inference runs on GPU if available, otherwise MPS/CPU.
- For best performance, ensure torch.compile is enabled and models are on CUDA.
See part5.md for full details, implementation, and next steps.
- debug_color_channels.py: Visualizes and compares color channels between environment frames and PNG files to debug color mismatches.
- debug_first_step_difference.py: Compares predicted vs. ground truth latents and reconstructions for the first step of the neural random game, helping debug model discrepancies.
- neural_random_game.py: Runs a random agent in the neural world model and saves a GIF of the generated gameplay for qualitative evaluation.
- Train DQN agent for longer (by increasing max steps per episode to 10,000 and target scores of 200+)
- Generate 1000s of videos of agent playing with a higher percentage of random agent actions:
python record_videos.py --bulk --total_videos 1000 --percent_random 20 --output_dir bulk_videos
- Train the latent action model for much longer to achieve convergence
- Collect much more data for action → latent code mapping
- Try with different Atari games
