Time-Aware World Model for Adaptive Prediction and Control
Anh N. Nhu*, 1, Sanghyun Son*, 1, Ming C. Lin1
1 University of Maryland, College Park
International Conference on Machine Learning (ICML) 2025
This is the official code for Time-Aware World Model (TAWM), a model-agnostic and more efficient training method for world models.
Time-Aware World Model (TAWM) is a model-agnostic training method that improves dynamics learning by explicitly incorporating time step size Ξt and sampling observations at varying frequencies. This addresses the real-world constraint of varying observation rates, enabling efficient learning across temporal scales and outperforming fixed-Ξt baselines under the same training budget.
Since TAWM's core contribution is the time-aware concepts and training method, which is architecture-agnostic, it can be seamlessly incorporated into any world model training pipeline, including but not limited to TD-MPC2 and Dreamers. In this work, TAWM is built on top of the TD-MPC2 architecture as the basis for the experiments.
You can directly incorporate the Time-Aware concept into your world model training pipeline even without using our code. To incorporate TAWM into any world model architecture:
-
Modify the dynamics or temporal state-space model to condition on the time step
$\Delta t$
Example (Euler integration model):
$$z_{t+\Delta t} = z_t + d_{\theta}(z_t, a_t, \Delta t) \cdot \tau(\Delta t)$$ -
Train using a mixture of time step sizes:
$$\Delta t \sim \text{Log-Uniform}(\Delta t_{\min}, \Delta t_{\max}) \quad \text{(or Uniform sampling)}$$
-
Installation:
Base Conda env (recommended) install Miniconda3 + dependencies:conda env create -f environment.yaml pip3 install gym==0.21.0 pip3 install torch==2.3.1 pip3 install torchvision==0.18.1
Additional dependecies for control tasks:
# install metaworld envs pip3 install git+https://github.com/Farama-Foundation/Metaworld.git@04be337a12305e393c0caf0cbf5ec7755c7c8feb pip3 install gymnasium # install controlgym envs (PDE) cd tdmpc2/envs git clone https://github.com/xiangyuan-zhang/controlgym.git rm -r controlgym/.git/
Additional dependecies for F16 aircraft control tasks (work in progress):
# source: https://github.com/GongXudong/fly-craft/ pip install flycraft cd tawm/tawm/envs git clone https://github.com/GongXudong/fly-craft-examples.git
IMPORTANT NOTE :
If for some reasons,import controlgymcauses the following warning to TERMINATE the program (which it shouldn't):UserWarning: We've integrated functorch into PyTorch. As the final step of the integration, functorch.combine_state_for_ensemble is deprecated as of PyTorch 2.0 and will be deleted in a future version of PyTorch >= 2.3.
Normally, the warning MUST NOT TERMINATE the program (just a warning)
The following fix works for us:- see the error log like this:
~/miniconda3/envs/tdmpc2/lib/python3.9/site-packages/torch/_functorch/deprecated.py:38, in warn_deprecated(api, new_api)on your machine / environment - execute
vim <path to _functorch/deprecated.py>
in our case, it wasvim ~/miniconda3/envs/tdmpc2/lib/python3.9/site-packages/torch/_functorch/deprecated.py - comment out line 38:
# warnings.warn(warning, stacklevel=2)
- see the error log like this:
-
Activate conda env
conda activate tawm
-
TAWM Training Examples:
cd tdmpc2 # TAWM: meta-world environments python train.py task=mw-basketball multi_dt=true steps=1500000 seed=3 # TAWM: pde-control environments: no rendering/video python train.py task=pde-burgers save_video=false multi_dt=true steps=1500000 seed=5 python train.py task=f16 save_video=false multi_dt=true steps=6000000 eval_freq=100000 lr=1e-3 seed=3
Baseline Training Examples:
cd tdmpc2 # baseline: meta-world environments python train.py task=mw-basketball multi_dt=false steps=2000000 seed=3 # baseline: pde-control environments: no rendering/video python train.py task=pde-burgers save_video=false multi_dt=false steps=2000000 seed=5
Every 50,000 steps, the model checkpoints are saved in
<cfg.checkpoint>/<task>/<model_type>-<dt_sampler>-<integrator>/<seed>, where:-
cfg.checkpoint: saving directory of model checkpoints -
task: control task (e.g.mw-assembly,mw-basketball,pde-burgers) -
model_type:multidt(TAWM) orsingledt(baseline) -
dt_sampler:$\Delta t$ sampling method;log-uniform(default) oruniform -
integrator: integration method;euler(default) orrk4
-
The evaluation code evalutate the performance of the world model on specified task with different simulation time steps. For meta-world environment, it also provides success rate for the task (success rate ranges from 0.0 to 1.0).
NOTE: You need to update your local model_path
-
eval_model_multidt.py: test model performance ontaskon various inference-time observation rates# test TAWM (Euler + Log-Uniform) on `mw-basketball` python eval_model_multidt.py task=mw-basketball checkpoint={model_path} seed={seed} dt_sampler=log-uniform multi_dt=true integrator=euler# test non-time-aware baseline on `mw-basketball` python eval_model_multidt.py task=mw-basketball checkpoint={model_path} seed={seed} multi_dt=false eval_steps_adjusted=true -
eval_model_multidt_all: comprehensively test all model performance across all tasks on various inference-time observation rates- Models:
(1a) Baseline (non-time-aware) trained on$\Delta t_{default}$
(1b) Baseline (non-time-aware) trained on$\Delta t \neq \Delta t_{default}$
(2a) TAWM (RK4 + Log-Uniform Sampling)
(2b) TAWM (Euler + Log-Uniform Sampling)
- Use
eval_model_multidt.py
- Models:
-
eval_model_learning_curve: Evaluate intermediate models and save learning curves on each task across seeds- NOTE: You need to have model saved at each step for this evaluation. By default, a model checkpoint is saved every 50,000 steps.
The evaluation results are saved in tdmpc2/logs/<task>/<eval-type>.csv.
task: the control task evaluated oneval_type: evaluated model type (e.g. baseline, TAWM-RK4, TAWM-Euler, etc.)
-
NOTE 1: for the non-time-aware baseline models trained on fixed default
$\Delta t$ , we used the trained weights of the original TD-MPC2 model for each Meta-World control task. The trained weights are available here: https://huggingface.co/nicklashansen/tdmpc2/tree/main/metaworld.
-
NOTE 2: The learning curves of the non-time-aware baseline models evaluated at the default
$\Delta t$ are taken from the original TD-MPC2 model, whose learning curves (at default$\Delta t$ ) are publicly available at: https://github.com/nicklashansen/tdmpc2/tree/main/results/tdmpc2.
The original MTS3 is prediction-only world model and does not support evaluation on control tasks. If you are interested in experimenting with MTS3 as a comparison to our TAWM, please use our modified MTS3+MPC for comparison on control tasks.
-
Offline data collection
The MTS3 model is prediction-only world model, so it does not interact with environments likeMeta-World. Therefore, we need to collect offline dataset for it before training MTS3. -
Offline Data collection:
NOTE: to collect data for an individual task only, specifyspecific_task=<task name>.
a. Collect offline data for Time-Aware World Modelpython collect_offline_dataset.py task_set=mt9 specific_task=mw-basketball num_eps=40000 ep_length=100 multitask=false multi_dt=true data_dir=/fs/nexus-scratch/anhu/mt9_multidt_40k
b. Collect data for baseline world model:
cd tdmpc2/tdmpc2 python collect_offline_dataset.py task_set=mt9 specific_task=mw-basketball num_eps=40000 ep_length=100 multitask=false multi_dt=false data_dir=/fs/nexus-scratch/anhu/mt9_singledt_40k -
MTS3 settings:
H: slow time scale factor for MTS3- access config file in
MTS3/experiments/basketball/conf/model/default_mts3.yaml - set
time_scale_multiplier = <H>
- access config file in
-
Example Training MTS3 for
mw-basketball(assuming you have collected offline dataset for the task):cd MTS3 python MTS3/experiments/basketball/mts3_exp.py
If you have any question or suggestion about our work, please feel free to open an issue or contact us at anhu@umd.edu.
If you find the insights and findings in our work useful, please consider citing our paper with the following BibTeX entry.
@misc{nhu2025timeawareworldmodeladaptive,
title={Time-Aware World Model for Adaptive Prediction and Control},
author={Anh N. Nhu and Sanghyun Son and Ming Lin},
year={2025},
eprint={2506.08441},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2506.08441},
}OR
@InProceedings{pmlr-v267-nhu25a,
title = {Time-Aware World Model for Adaptive Prediction and Control},
author = {Nhu, Anh N and Son, Sanghyun and Lin, Ming},
booktitle = {Proceedings of the 42nd International Conference on Machine Learning},
pages = {46265--46287},
year = {2025},
editor = {Singh, Aarti and Fazel, Maryam and Hsu, Daniel and Lacoste-Julien, Simon and Berkenkamp, Felix and Maharaj, Tegan and Wagstaff, Kiri and Zhu, Jerry},
volume = {267},
series = {Proceedings of Machine Learning Research},
month = {13--19 Jul},
publisher = {PMLR},
pdf = {https://raw.githubusercontent.com/mlresearch/v267/main/assets/nhu25a/nhu25a.pdf},
url = {https://proceedings.mlr.press/v267/nhu25a.html},
}
We appreciate your interest in our work and hope that it is useful to your projects!