Astrolabe: Steering Forward-Process Reinforcement Learning for Distilled Autoregressive Video Models
Songchun Zhang1, Zeyue Xue2,3, Siming Fu2, Jie Huang2, Xianghao Kong1, Yue Ma1, Haoyang Huang2, Nan Duan2✉, Anyi Rao1✉
1HKUST 2JD Explore Academy 3HKU ✉Corresponding Authors
Astrolabe is an efficient online Reinforcement Learning (RL) framework designed to align distilled autoregressive (AR) streaming video models with human visual preferences. Without sacrificing real-time inference speed, Astrolabe consistently and robustly improves visual aesthetics and temporal consistency across various baseline models for both short and long video generation.
- 2026-03-23 — Code released!
- 2026-03-18 — Paper released on arXiv!
| Model | Config File |
|---|---|
| LongLive | configs/nft_longlive.py |
| Self-Forcing | configs/nft_self_forcing.py |
| Causal Forcing | configs/nft_casual_forcing.py |
| Krea 14B | configs/nft_krea14b.py |
| Reward | Key in Config | What It Measures | Source |
|---|---|---|---|
| HPSv3 | video_hpsv3_local |
Frame-level aesthetic & visual quality | HPSv3 |
| VideoAlign – VQ | videoalign_vq_score |
Per-frame visual fidelity scored by VideoReward | VideoReward |
| VideoAlign – MQ | videoalign_mq_score |
Temporal smoothness & motion naturalness | VideoReward |
| VideoAlign – TA | videoalign_ta_score |
Prompt–video semantic alignment | VideoReward |
Rewards can be freely combined with per-reward weights, e.g. reward_fn={"video_hpsv3_local": 1.0, "videoalign_mq_score": 1.0}.
Tested configuration: Python 3.10.16, CUDA 12.8, NVIDIA H200 GPUs
# Clone the repository
git clone https://github.com/franklinz233/Astrolabe.git
cd Astrolabe
# Create and activate conda environment
conda create -n astrolabe python=3.10.16
conda activate astrolabe
pip install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu124
# Install other dependencies
pip install -r requirements.txt
# Install Flash Attention (pre-built wheel for CUDA 12 + PyTorch 2.6)
pip install flash-attn==2.7.4.post1 --no-build-isolation
# You can also download the pre-build wheel
wget https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
pip install flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whlWe support four distilled AR video model baselines. Download the base Wan2.1 model and the desired distilled checkpoint(s):
huggingface-cli download Wan-AI/Wan2.1-T2V-1.3B --local-dir wan_models/Wan2.1-T2V-1.3BSelf-Forcing
huggingface-cli download gdhe17/Self-Forcing checkpoints/self_forcing_dmd.pt --local-dir .Causal Forcing
huggingface-cli download zhuhz22/Causal-Forcing chunkwise/causal_forcing.pt --local-dir checkpoints/casualforcing
huggingface-cli download zhuhz22/Causal-Forcing framewise/causal_forcing.pt --local-dir checkpoints/casualforcingLongLive
huggingface-cli download Efficient-Large-Model/LongLive-1.3B --include "models/*" --local-dir checkpoints/longlive_modelsKrea 14B
huggingface-cli download krea/krea-realtime-video \
krea-realtime-video-14b.safetensors \
--local-dir checkpointscheckpoints/
├── casualforcing/
│ ├── chunkwise/
│ │ └── causal_forcing.pt
│ └── framewise/
│ └── causal_forcing.pt
├── krea-realtime-video-14b.safetensors
├── longlive_models/
│ └── models/
│ ├── longlive_base.pt
│ └── lora.pt
└── self_forcing_dmd.pt
Download reward model checkpoints:
mkdir -p reward_ckpts && cd reward_ckpts
# CLIP backbone (required by HPSv2/v3)
wget https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin
# HPSv2.1 checkpoint
wget https://huggingface.co/xswu/HPSv2/resolve/main/HPS_v2.1_compressed.pt
# HPSv3 checkpoint
wget https://huggingface.co/MizzenAI/HPSv3/resolve/main/HPSv3.safetensors
# VideoReward checkpoint
huggingface-cli download KlingTeam/VideoReward --local-dir ./VideorewardDownload the filtered VidProM prompt subset used for training:
huggingface-cli download Franklinzhang/stream_align \
--include "vidprom/*" \
--local-dir ./datasetexport WANDB_API_KEY=<your_key>
export WANDB_ENTITY=<your_entity>GPU presets: Training parameters like
num_image_per_prompt,num_groups, andtest_batch_sizeare auto-configured per GPU scale inGPU_CONFIGSinsideconfigs/_base_clean.py. Add or edit entries there for custom GPU counts.
Multi-node setup: If your cluster cannot resolve
MASTER_ADDRautomatically, use a shared filesystem for node discovery. Run on each node, settingRANK=0on master andRANK=1,2,...on workers:export RANK=<node_rank> # 0 for master, 1, 2, ... for workers export WORLD_SIZE=<num_nodes> # 2→16 GPUs (2×8), 3→24 GPUs (3×8), 6→48 GPUs (6×8) export MASTER_PORT=29500 CONFIG_NAME=<config_name> torchrun --nproc_per_node=8 --nnodes=$WORLD_SIZE \ --node_rank=$RANK \ --master_addr=$MASTER_ADDR \ --master_port=$MASTER_PORT \ scripts/train_nft_wan.py \ --config configs/nft_<model>.py:${CONFIG_NAME}Multi-reward: Replace
hpsv3withmulti_rewardin any config name to enable the full multi-reward objective (HPSv3 + Motion Quality).
💡 LoRA Initialization (recommended for LongLive): LongLive ships a pretrained LoRA adapter (
checkpoints/longlive_models/models/lora.pt) that can be used to warm-start training. Simply append_with_lora_initto any LongLive config name to enable it — the adapter is loaded before RL training begins and typically leads to faster convergence.
Single Node (8× GPU)
# HPSv3 reward
torchrun --nproc_per_node=8 scripts/train_nft_wan.py \
--config configs/nft_longlive.py:longlive_video_hpsv3
# HPSv3 reward — with LoRA init
torchrun --nproc_per_node=8 scripts/train_nft_wan.py \
--config configs/nft_longlive.py:longlive_video_hpsv3_with_lora_init
# Multi-reward (HPSv3 + Motion Quality)
torchrun --nproc_per_node=8 scripts/train_nft_wan.py \
--config configs/nft_longlive.py:longlive_video_multi_rewardMulti-Node (16× / 24× / 48× GPU)
| Scale | HPSv3 Config | HPSv3 + LoRA Init | Multi-Reward Config |
|---|---|---|---|
| 16× GPU | longlive_video_hpsv3_16gpu |
longlive_video_hpsv3_with_lora_init_16gpu |
longlive_video_multi_reward_16gpu |
| 24× GPU | longlive_video_hpsv3_24gpu |
longlive_video_hpsv3_with_lora_init_24gpu |
longlive_video_multi_reward_24gpu |
| 48× GPU | longlive_video_hpsv3_48gpu |
longlive_video_hpsv3_with_lora_init_48gpu |
longlive_video_multi_reward_48gpu |
Single Node (8× GPU)
# HPSv3 reward
torchrun --nproc_per_node=8 scripts/train_nft_wan.py \
--config configs/nft_self_forcing.py:self_forcing_video_hpsv3
# Multi-reward (HPSv3 + Motion Quality)
torchrun --nproc_per_node=8 scripts/train_nft_wan.py \
--config configs/nft_self_forcing.py:self_forcing_video_multi_rewardMulti-Node (16× / 24× / 48× GPU)
| Scale | HPSv3 Config | Multi-Reward Config |
|---|---|---|
| 16× GPU | self_forcing_video_hpsv3_16gpu |
self_forcing_video_multi_reward_16gpu |
| 24× GPU | self_forcing_video_hpsv3_24gpu |
self_forcing_video_multi_reward_24gpu |
| 48× GPU | self_forcing_video_hpsv3_48gpu |
self_forcing_video_multi_reward_48gpu |
Single Node (8× GPU)
# HPSv3 reward
torchrun --nproc_per_node=8 scripts/train_nft_wan.py \
--config configs/nft_casual_forcing.py:casual_forcing_video_hpsv3
# Multi-reward (HPSv3 + Motion Quality)
torchrun --nproc_per_node=8 scripts/train_nft_wan.py \
--config configs/nft_casual_forcing.py:casual_forcing_video_multi_rewardMulti-Node (16× / 24× / 48× GPU)
| Scale | HPSv3 Config | Multi-Reward Config |
|---|---|---|
| 16× GPU | casual_forcing_video_hpsv3_16gpu |
casual_forcing_video_multi_reward_16gpu |
| 24× GPU | casual_forcing_video_hpsv3_24gpu |
— |
| 48× GPU | casual_forcing_video_hpsv3_48gpu |
casual_forcing_video_multi_reward_48gpu |
Single Node (8× GPU)
# HPSv3 reward
torchrun --nproc_per_node=8 scripts/train_nft_wan.py \
--config configs/nft_krea14b.py:krea14b_video_hpsv3
# Multi-reward (HPSv3 + Motion Quality)
torchrun --nproc_per_node=8 scripts/train_nft_wan.py \
--config configs/nft_krea14b.py:krea14b_video_multi_rewardMulti-Node (16× / 24× / 48× GPU)
| Scale | HPSv3 Config | Multi-Reward Config |
|---|---|---|
| 16× GPU | krea14b_video_hpsv3_16gpu |
krea14b_video_multi_reward_16gpu |
| 24× GPU | krea14b_video_hpsv3_24gpu |
krea14b_video_multi_reward_24gpu |
| 48× GPU | krea14b_video_hpsv3_48gpu |
krea14b_video_multi_reward_48gpu |
Training checkpoints are saved under logs/nft/<base_model>/<run_name>_<timestamp>/:
logs/nft/wan_self/<run_name>_<timestamp>/
├── checkpoints/
│ ├── checkpoint-<train_step>/
│ │ └── lora/
│ │ ├── adapter_model.bin # LoRA weights
│ │ └── adapter_config.json # LoRA config
│ └── ...
└── eval_videos/ # Evaluation videos
Run name convention: nft_<base_model>_<config_name>
| Base Model | Directory | Run Name Prefix |
|---|---|---|
| LongLive | wan_longlive |
nft_wan_longlive |
| Self-Forcing | wan_self |
nft_wan_self |
| Causal Forcing | wan_casual_chunk |
nft_wan_casual_chunk |
| Krea 14B | wan_krea14b |
nft_wan_krea14b |
Example:
# After training with config: self_forcing_video_hpsv3
# run_name = nft_wan_self_self_forcing_video_hpsv3
# save_dir = logs/nft/wan_self/self_forcing_video_hpsv3_<timestamp>
torchrun --nproc_per_node=1 scripts/inference_wan.py \
--base_model checkpoints/self_forcing_dmd.pt \
--lora_path logs/nft/wan_self/self_forcing_video_hpsv3_<timestamp>/checkpoints/checkpoint-30 \
--prompt "A cat running in the park" \
--output_dir outputs/test💡 The inference script auto-detects the
lora/subdirectory, so you can pass either the checkpoint directory or thelora/path directly.
Use scripts/inference_wan.py to generate videos with a trained LoRA checkpoint.
torchrun --nproc_per_node=1 scripts/inference_wan.py \
--base_model <base_model_path> \
--lora_path <checkpoint_dir> \
--prompt "A cat running in the park" \
--output_dir outputs/testtorchrun --nproc_per_node=1 scripts/inference_wan.py \
--base_model <base_model_path> \
--lora_path <checkpoint_dir> \
--prompt "A cat running in the park" \
--num_frames 240 \
--output_dir outputs/long_videoProvide a plain text file with one prompt per line:
torchrun --nproc_per_node=8 scripts/inference_wan.py \
--base_model <base_model_path> \
--lora_path <checkpoint_dir> \
--prompt_file prompts/test.txt \
--output_dir outputs/batch💡
--lora_pathaccepts either a checkpoint directory (e.g.,checkpoint-300) or thelora/subdirectory directly — the script auto-detects both.
For long videos, you can switch the text prompt mid-generation to create seamless scene transitions — no extra flags needed. Simply separate scene prompts with | in a single prompt string. The script automatically activates SceneCausalInferencePipeline when it detects the | separator.
Prompt syntax:
| Syntax | Meaning |
|---|---|
prompt1 | prompt2 |
Two scenes, each using the default duration (~10.5 s) |
prompt1[5s] | prompt2[15s] |
Explicit duration per scene in seconds |
prompt1[10s#] | prompt2 |
# marks a hard scene cut at the boundary (KV-cache is flushed) |
Examples:
# Two scenes with default duration
torchrun --nproc_per_node=1 scripts/inference_wan.py \
--base_model <base_model_path> \
--prompt "A cat running in the park | A dog sleeping on the sofa" \
--num_frames 240 \
--output_dir outputs/scene_switch
# Explicit scene durations
torchrun --nproc_per_node=1 scripts/inference_wan.py \
--base_model <base_model_path> \
--prompt "A sunrise over mountains[8s] | A busy city street at noon[12s] | A quiet beach at sunset[10s]" \
--num_frames 480 \
--output_dir outputs/three_scenes
# Hard scene cut (abrupt transition, KV-cache flushed)
torchrun --nproc_per_node=1 scripts/inference_wan.py \
--base_model <base_model_path> \
--prompt "A cat playing with yarn[10s#] | A bird flying over a lake" \
--num_frames 240 \
--output_dir outputs/hard_cut
# Batch: put one scene-switched prompt per line in the file
torchrun --nproc_per_node=8 scripts/inference_wan.py \
--base_model <base_model_path> \
--lora_path <checkpoint_dir> \
--prompt_file prompts/scene_prompts.txt \
--num_frames 240 \
--output_dir outputs/batch_scenes💡 Soft vs. hard transitions: Without
#, the KV-cache is rolled forward so the model retains memory of the previous scene, producing a smooth transition. Adding#flushes the cache for a sharp cut, similar to a film edit.
1. Implement the scorer in astrolabe/rewards.py:
def my_reward(device):
scorer = MyScorer(device=device) # load once at init
def _fn(videos, prompts, metadata=None):
# videos: Tensor [B, F, C, H, W] in [0,1] (video) or [B, C, H, W] (image)
scores = scorer(videos, prompts)
return scores, {} # scores: list or np.array of shape [B]
return _fn2. Register it in the score_functions dict inside multi_score():
score_functions = {
...
"my_reward": my_reward,
}3. Use it in a config: reward_fn={"my_reward": 1.0}.
Create configs/nft_mymodel.py following configs/nft_longlive.py:
import os, imp
_base = imp.load_source("_base_clean", os.path.join(os.path.dirname(__file__), "_base_clean.py"))
def get_config(name):
return globals()[name]()
def _get_config(n_gpus=8, gradient_step_per_epoch=1, dataset="vidprom", reward_fn={}, name=""):
config = _base._make_base_config(dataset)
config.base_model = "wan_longlive" # or wan_self, wan_casual_chunk, wan_casual_frame
_base._apply_wan_common(config)
config.pretrained.model = "./checkpoints/my_model.pt"
config.model_kwargs = {}
gpu_config = _base._get_gpu_config("wan", n_gpus) # or "sd3" / "krea14b"
config.sample.num_image_per_prompt = gpu_config["num_image_per_prompt"]
config.sample.test_batch_size = gpu_config["test_batch_size"]
_base._apply_batch_config(config, n_gpus, gpu_config["bsz"], gpu_config["num_groups"], gradient_step_per_epoch)
_base._apply_common_fields(config, "wan_longlive", dataset, reward_fn, name)
return config
def my_config_hpsv3():
config = _get_config(n_gpus=8, reward_fn={"video_hpsv3_local": 1.0}, name="my_config_hpsv3")
config.beta = 0.1
config.train.learning_rate = 1e-5
return configThen train with:
torchrun --nproc_per_node=8 scripts/train_nft_wan.py --config configs/nft_mymodel.py:my_config_hpsv3Key parameters to tune: config.beta (NFT beta, default 1.0), config.train.beta (KL weight, default 0.0001), config.train.learning_rate (1e-5 for Wan), config.sample.noise_level (default 0.7).
1. Add a wrapper class in utils/ following utils/wan_wrapper.py. You need WanTextEncoder-style text encoder, WanVAEWrapper-style VAE, and a WanDiffusionWrapper-style denoiser.
2. Register the new base_model key in the model-loading block of the training script (e.g. scripts/train_nft_wan.py), alongside existing entries like wan_longlive, wan_self, etc.
3. Add a GPU preset in configs/_base_clean.py under GPU_CONFIGS if the model needs different batch sizes than the existing "wan" preset.
All inference pipelines inherit from CausalInferencePipeline (pipeline/causal_inference.py), which implements block-wise autoregressive generation with KV-cache management. Existing pipelines include SceneCausalInferencePipeline (multi-scene prompt switching) and InteractiveCausalInferencePipeline (interactive prompt switching with KV re-caching).
1. Subclass CausalInferencePipeline in pipeline/ and override inference():
# pipeline/my_pipeline.py
from pipeline.causal_inference import CausalInferencePipeline
class MyCausalInferencePipeline(CausalInferencePipeline):
def inference(self, noise, text_prompts, **kwargs):
# Inherited: self.generator, self.text_encoder, self.vae,
# self.kv_cache1, self.denoising_step_list, etc.
...2. Register in pipeline/__init__.py and import in your inference script.
If you find our work useful, please consider citing:
@article{zhang2026astrolabe,
title = {Astrolabe: Steering Forward-Process Reinforcement Learning for Distilled Autoregressive Video Models},
author = {Zhang, Songchun and Xue, Zeyue and Fu, Siming and Huang, Jie and Kong, Xianghao and Ma, Yue and Huang, Haoyang and Duan, Nan and Rao, Anyi},
journal = {arXiv preprint arXiv:2603.17051},
year = {2026}
}This project builds upon the following outstanding open-source works:
- Self-Forcing — Self-forcing training for autoregressive video diffusion
- LongLive — Long-form video generation via generative extrapolation
- CausVid / Causal Forcing — AR teacher distillation with causal attention
- Wan — Base video diffusion transformer
- DiffusionNFT — Negative-aware fine-tuning for forward-process RL
⭐ If you find Astrolabe useful, please give us a star!



