Skip to content

anh-nn01/Time-Aware-World-Model

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

54 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

⏱️ Time-Aware World Model 🌎

Time-Aware World Model for Adaptive Prediction and Control
Anh N. Nhu*, 1, Sanghyun Son*, 1, Ming C. Lin1
1 University of Maryland, College Park
International Conference on Machine Learning (ICML) 2025

πŸ“– Introduction

This is the official code for Time-Aware World Model (TAWM), a model-agnostic and more efficient training method for world models.

🎯 TL;DR

Time-Aware World Model (TAWM) is a model-agnostic training method that improves dynamics learning by explicitly incorporating time step size Ξ”t and sampling observations at varying frequencies. This addresses the real-world constraint of varying observation rates, enabling efficient learning across temporal scales and outperforming fixed-Ξ”t baselines under the same training budget.


πŸ”§ Architecture-Agnostic Design

Since TAWM's core contribution is the time-aware concepts and training method, which is architecture-agnostic, it can be seamlessly incorporated into any world model training pipeline, including but not limited to TD-MPC2 and Dreamers. In this work, TAWM is built on top of the TD-MPC2 architecture as the basis for the experiments.


πŸ’‘ Time-Aware Incorporation

You can directly incorporate the Time-Aware concept into your world model training pipeline even without using our code. To incorporate TAWM into any world model architecture:

  1. Modify the dynamics or temporal state-space model to condition on the time step $\Delta t$
    Example (Euler integration model):
    $$z_{t+\Delta t} = z_t + d_{\theta}(z_t, a_t, \Delta t) \cdot \tau(\Delta t)$$

  2. Train using a mixture of time step sizes:
    $$\Delta t \sim \text{Log-Uniform}(\Delta t_{\min}, \Delta t_{\max}) \quad \text{(or Uniform sampling)}$$

1. Dependencies Installations

  1. Installation:
    Base Conda env (recommended) install Miniconda3 + dependencies:

    conda env create -f environment.yaml
    pip3 install gym==0.21.0
    pip3 install torch==2.3.1
    pip3 install torchvision==0.18.1

    Additional dependecies for control tasks:

    # install metaworld envs
    pip3 install git+https://github.com/Farama-Foundation/Metaworld.git@04be337a12305e393c0caf0cbf5ec7755c7c8feb
    pip3 install gymnasium
    
    # install controlgym envs (PDE)
    cd tdmpc2/envs
    git clone https://github.com/xiangyuan-zhang/controlgym.git
    rm -r controlgym/.git/

    Additional dependecies for F16 aircraft control tasks (work in progress):

    # source: https://github.com/GongXudong/fly-craft/
    pip install flycraft
    cd tawm/tawm/envs
    git clone https://github.com/GongXudong/fly-craft-examples.git

    IMPORTANT NOTE :
    If for some reasons, import controlgym causes the following warning to TERMINATE the program (which it shouldn't): UserWarning: We've integrated functorch into PyTorch. As the final step of the integration, functorch.combine_state_for_ensemble is deprecated as of PyTorch 2.0 and will be deleted in a future version of PyTorch >= 2.3.
    Normally, the warning MUST NOT TERMINATE the program (just a warning)

    The following fix works for us:

    • see the error log like this: ~/miniconda3/envs/tdmpc2/lib/python3.9/site-packages/torch/_functorch/deprecated.py:38, in warn_deprecated(api, new_api) on your machine / environment
    • execute vim <path to _functorch/deprecated.py>
      in our case, it was vim ~/miniconda3/envs/tdmpc2/lib/python3.9/site-packages/torch/_functorch/deprecated.py
    • comment out line 38: # warnings.warn(warning, stacklevel=2)

2. TAWM and baseline training

  1. Activate conda env

    conda activate tawm
  2. TAWM Training Examples:

    cd tdmpc2
    # TAWM: meta-world environments
    python train.py task=mw-basketball multi_dt=true steps=1500000 seed=3
    # TAWM: pde-control environments: no rendering/video
    python train.py task=pde-burgers save_video=false multi_dt=true steps=1500000 seed=5
    
    python train.py task=f16 save_video=false multi_dt=true steps=6000000 eval_freq=100000 lr=1e-3 seed=3

    Baseline Training Examples:

    cd tdmpc2
    # baseline: meta-world environments
    python train.py task=mw-basketball multi_dt=false steps=2000000 seed=3
    # baseline: pde-control environments: no rendering/video
    python train.py task=pde-burgers save_video=false multi_dt=false steps=2000000 seed=5

    Every 50,000 steps, the model checkpoints are saved in <cfg.checkpoint>/<task>/<model_type>-<dt_sampler>-<integrator>/<seed>, where:

    • cfg.checkpoint: saving directory of model checkpoints
    • task: control task (e.g. mw-assembly, mw-basketball, pde-burgers)
    • model_type: multidt (TAWM) or singledt (baseline)
    • dt_sampler: $\Delta t$ sampling method; log-uniform (default) or uniform
    • integrator: integration method; euler (default) or rk4

3. Model Evaluation

3a. Evaluation Scripts

The evaluation scripts are provided as reference script for TAWM experiment/deployments.

The evaluation code evalutate the performance of the world model on specified task with different simulation time steps. For meta-world environment, it also provides success rate for the task (success rate ranges from 0.0 to 1.0).

NOTE: You need to update your local model_path

  1. eval_model_multidt.py: test model performance on task on various inference-time observation rates

    # test TAWM (Euler + Log-Uniform) on `mw-basketball`
    python eval_model_multidt.py task=mw-basketball checkpoint={model_path} seed={seed} dt_sampler=log-uniform multi_dt=true integrator=euler
    # test non-time-aware baseline on `mw-basketball`
    python eval_model_multidt.py task=mw-basketball checkpoint={model_path} seed={seed} multi_dt=false eval_steps_adjusted=true
  2. eval_model_multidt_all: comprehensively test all model performance across all tasks on various inference-time observation rates

    • Models:
      (1a) Baseline (non-time-aware) trained on $\Delta t_{default}$
      (1b) Baseline (non-time-aware) trained on $\Delta t \neq \Delta t_{default}$
      (2a) TAWM (RK4 + Log-Uniform Sampling)
      (2b) TAWM (Euler + Log-Uniform Sampling)
    • Use eval_model_multidt.py
  3. eval_model_learning_curve: Evaluate intermediate models and save learning curves on each task across seeds

    • NOTE: You need to have model saved at each step for this evaluation. By default, a model checkpoint is saved every 50,000 steps.

3b. Evaluation Results

The evaluation results are saved in tdmpc2/logs/<task>/<eval-type>.csv.

  • task: the control task evaluated on
  • eval_type: evaluated model type (e.g. baseline, TAWM-RK4, TAWM-Euler, etc.)

(Optional) Experiments with MTS3

The original MTS3 is prediction-only world model and does not support evaluation on control tasks. If you are interested in experimenting with MTS3 as a comparison to our TAWM, please use our modified MTS3+MPC for comparison on control tasks.

  1. Offline data collection
    The MTS3 model is prediction-only world model, so it does not interact with environments like Meta-World. Therefore, we need to collect offline dataset for it before training MTS3.

  2. Offline Data collection:
    NOTE: to collect data for an individual task only, specify specific_task=<task name>.
    a. Collect offline data for Time-Aware World Model

    python collect_offline_dataset.py task_set=mt9 specific_task=mw-basketball num_eps=40000 ep_length=100 multitask=false multi_dt=true data_dir=/fs/nexus-scratch/anhu/mt9_multidt_40k

    b. Collect data for baseline world model:

    cd tdmpc2/tdmpc2
    python collect_offline_dataset.py task_set=mt9 specific_task=mw-basketball num_eps=40000 ep_length=100 multitask=false multi_dt=false data_dir=/fs/nexus-scratch/anhu/mt9_singledt_40k
  3. MTS3 settings: H: slow time scale factor for MTS3

    • access config file in MTS3/experiments/basketball/conf/model/default_mts3.yaml
    • set time_scale_multiplier = <H>
  4. Example Training MTS3 for mw-basketball (assuming you have collected offline dataset for the task):

    cd MTS3
    python MTS3/experiments/basketball/mts3_exp.py
    

πŸ“§ Contact

If you have any question or suggestion about our work, please feel free to open an issue or contact us at anhu@umd.edu.

πŸ“š Citation

If you find the insights and findings in our work useful, please consider citing our paper with the following BibTeX entry.

πŸ“„ ArXiv

@misc{nhu2025timeawareworldmodeladaptive,
   title={Time-Aware World Model for Adaptive Prediction and Control}, 
   author={Anh N. Nhu and Sanghyun Son and Ming Lin},
   year={2025},
   eprint={2506.08441},
   archivePrefix={arXiv},
   primaryClass={cs.LG},
   url={https://arxiv.org/abs/2506.08441}, 
}

OR

πŸŽ“ ICML 2025

@InProceedings{pmlr-v267-nhu25a,
  title = 	 {Time-Aware World Model for Adaptive Prediction and Control},
  author =       {Nhu, Anh N and Son, Sanghyun and Lin, Ming},
  booktitle = 	 {Proceedings of the 42nd International Conference on Machine Learning},
  pages = 	 {46265--46287},
  year = 	 {2025},
  editor = 	 {Singh, Aarti and Fazel, Maryam and Hsu, Daniel and Lacoste-Julien, Simon and Berkenkamp, Felix and Maharaj, Tegan and Wagstaff, Kiri and Zhu, Jerry},
  volume = 	 {267},
  series = 	 {Proceedings of Machine Learning Research},
  month = 	 {13--19 Jul},
  publisher =    {PMLR},
  pdf = 	 {https://raw.githubusercontent.com/mlresearch/v267/main/assets/nhu25a/nhu25a.pdf},
  url = 	 {https://proceedings.mlr.press/v267/nhu25a.html},
}

We appreciate your interest in our work and hope that it is useful to your projects!

About

Official Code for Time-Aware World Model (TAWM) with TD-MPC2 as the baseline architecture in the experiments.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors