Skip to content

paraschopra/atari-pixels

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

67 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Disclaimer

This codebase has only been tested for Atari Breakout. It is very hacky, trained on a small number of samples, and is not at all optimized. Use at your own risk!


Atari Pixels Project

This projects aims to create a neural playable version of Atari Breakout by learning purely from videos of the game. It's a small replica of what Google's Genie project, where they learned an interactive and playable world models purely through videos of game.

Watch the video

Watch the video

Part 1: DQN Training & Exploration

  • Implements a Deep Q-Network (DQN) agent for Atari Breakout.
  • Supports two exploration strategies:
    • Temperature-based (Boltzmann) exploration with Prioritized Experience Replay (PER).
    • Random Network Distillation (RND) for intrinsic motivation and improved exploration.
  • Modular, efficient, and test-driven codebase.
  • Current progress: Agent achieves a score up to 20.
  • Next steps: Train for longer, increase max steps per episode to 10,000, and target scores of 200+.

See part1.md for full details, implementation, and next steps.


Getting Started

1. Install Dependencies and Atari ROMs

pip install -r requirements.txt
python -m AutoROM --accept-license

2. Train the DQN Agent

  • Default (temperature-based exploration):
    python train_dqn.py
  • RND-based exploration:
    python train_dqn.py --exploration_mode rnd
  • Additional arguments (see python train_dqn.py --help) allow you to control episodes, buffer size, and more.

3. Generate Gameplay Videos

  • To record videos of a trained (or random) agent:
    python record_videos.py --checkpoint_path <path_to_checkpoint> --output_dir videos/
    e.g. python record_videos.py --bulk --total_videos 100 --percent_random 15 --output_dir bulk_videos
  • Replace <path_to_checkpoint> with the path to your saved model checkpoint (see checkpoints/).
  • Videos will be saved as MP4 files in the specified output directory.

Part 2: Train Latent Action Prediction Model

  • Prepare data: Ensure gameplay videos are available in bulk_videos/ (from Part 1).
  • Train VQ-VAE latent action model:
    python train_latent_action.py
  • Checkpoints and logs:
    • Best model: checkpoints/latent_action/best.pt
    • Processed data: see data/ and vqvae_recons/
    • Training logs and metrics: Weights & Biases (wandb)
  • Evaluation:
    • Run test_latent_action_model.py for automated tests and metrics
    • Visualize reconstructions and codebook usage as described in part2.md

See part2.md for full details, implementation, and next steps.


Part 3: Next Frame Prediction Model (World Model)

  • Note: The decoder from Part 2 serves as the world model. No separate training is required unless you wish to experiment with alternative architectures.
  • Evaluate world model:
    • Use the decoder to predict next frames given current frame and latent action index
    • For multi-step prediction and rollout analysis, see evaluation code in test_latent_action_model.py

Part 4: Train Action-to-Latent Mapping Model

  • Extract (action, latent_code) pairs using trained VQ-VAE:
    python collect_action_latent_pairs.py
    • Output: data/actions/action_latent_pairs.json
  • Train the action-to-latent MLP:
    python train_action_to_latent.py
    • Best checkpoint: checkpoints/latent_action/action_to_latent_best.pt
  • Evaluate mapping accuracy:
    • Run test_latent_action_data_collection.py and test_latent_action_model.py
    • Analyze accuracy and code distributions as described in part4.md

See part4.md for full details, implementation, and next steps.


Part 5: Playable Neural Breakout (World Model Demo)

  • Run a random agent in the neural world model:
    python neural_random_game.py
    • Output: data/neural_random_game.gif (video of neural gameplay)
  • Play Breakout using the neural world model:
    python play_neural_breakout.py
    • Controls: SPACE (Fire), LEFT/RIGHT ARROW, PERIOD (NOOP), ESC/Q (Quit)
    • Requires trained models from Parts 2 and 4
  • All inference runs on GPU if available, otherwise MPS/CPU.
  • For best performance, ensure torch.compile is enabled and models are on CUDA.

See part5.md for full details, implementation, and next steps.


Utility & Debug Scripts

  • debug_color_channels.py: Visualizes and compares color channels between environment frames and PNG files to debug color mismatches.
  • debug_first_step_difference.py: Compares predicted vs. ground truth latents and reconstructions for the first step of the neural random game, helping debug model discrepancies.
  • neural_random_game.py: Runs a random agent in the neural world model and saves a GIF of the generated gameplay for qualitative evaluation.

Suggested Improvements

  • Train DQN agent for longer (by increasing max steps per episode to 10,000 and target scores of 200+)
  • Generate 1000s of videos of agent playing with a higher percentage of random agent actions:
    python record_videos.py --bulk --total_videos 1000 --percent_random 20 --output_dir bulk_videos
  • Train the latent action model for much longer to achieve convergence
  • Collect much more data for action → latent code mapping
  • Try with different Atari games

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages