Shengjun Zhang, Zhang Zhang, Chensheng Dai, Yueqi Duan
- Release the source code and checkpoints of E-GRPO
- Release the source code and checkpoints of other supported methods
- Release training scripts of other supported methods
E-GRPO (Entropy-Guided GRPO) is a novel reinforcement learning approach for flow-based diffusion models. Our key insight is that high-entropy denoising steps are more critical for policy optimization, and we propose a merging-step strategy that focuses training on these important steps.
- E-GRPO Algorithm: Novel merging-step strategy focusing on high-entropy timesteps
- Multi-Granularity Rewards: Support for computing rewards at multiple sampling granularities
- Flexible Architecture: Modular design supporting multiple RL algorithms and reward models
- Distributed Training: Full support for FSDP and sequence parallelism
This repository not only implements E-GRPO but also provides a unified framework for various GRPO variants:
| Algorithm | Description | Config |
|---|---|---|
| E-GRPO (grpo_merge) | Our method - merging-step strategy | algorithm=grpo_merge |
| DanceGRPO | Basic full-step SDE sampling | algorithm=dance_grpo |
| DanceGRPO LoRA | LoRA fine-tuning variant | algorithm=dance_grpo_lora |
| MixGRPO | Mixed SDE-ODE sampling with sliding window | algorithm=mix_grpo |
| BranchGRPO | Tree-based sampling with split and pruning | algorithm=branch_grpo |
| Model | Description | Config |
|---|---|---|
| HPS v2 | Human Preference Score | reward=hps |
| CLIP Score | Text-image alignment | reward=clip |
| ImageReward | Learned human preference | reward=image_reward |
| PickScore | Pick-a-pic preference model | reward=pick_score |
| Multi-Reward | Combination of multiple rewards | reward=multi_reward |
git clone https://github.com/your-repo/VisualRL.git
cd VisualRL
# Create conda environment
conda create -n e-grpo python=3.10 -y
conda activate e-grpo
# Install dependencies
pip install -e .The environment dependency is compatible with DanceGRPO.
# Download FLUX model from Hugging Face
mkdir -p ckpt/flux
huggingface-cli download black-forest-labs/FLUX.1-schnell --local-dir ckpt/flux# Clone HPSv2 repository
git clone https://github.com/tgxs002/HPSv2.git
# Download HPS checkpoint
mkdir -p ckpt/hps
wget -O ckpt/hps/HPS_v2_compressed.pt https://huggingface.co/spaces/xswu/HPSv2/resolve/main/HPS_v2_compressed.pt
# Download CLIP model
mkdir -p ckpt/CLIP-ViT-H-14-laion2B-s32B-b79K
huggingface-cli download laion/CLIP-ViT-H-14-laion2B-s32B-b79K --local-dir ckpt/CLIP-ViT-H-14-laion2B-s32B-b79Kbash scripts/preprocess/preprocess_prompts.shpython scripts/preprocess/generate_json_index.py \
--embedding_dir data/rl_embeddings \
--output_path data/rl_embeddings/videos2caption.json# Using HPS reward
bash scripts/finetune/train_grpo_hps.sh
# Or with custom settings
torchrun --nproc_per_node=8 fastvideo/train.py \
algorithm=grpo_merge \
reward=hps \
model.pretrained_model_name_or_path=./ckpt/flux \
data.json_path=./data/rl_embeddings/videos2caption.json \
training.max_train_steps=300# DanceGRPO (basic GRPO)
bash scripts/finetune/train_dance_grpo.sh
# MixGRPO (mixed SDE-ODE)
bash scripts/finetune/train_mix_grpo.sh
# BranchGRPO (tree-based sampling)
bash scripts/finetune/train_branch_grpo.sh
# Multi-reward training
bash scripts/finetune/train_multi_reward.shAll configurations are managed by Hydra. You can override any config value from command line:
torchrun --nproc_per_node=8 fastvideo/train.py \
algorithm=grpo_merge \
reward=hps \
training.learning_rate=1e-6 \
training.max_train_steps=500 \
grpo.num_generations=16 \
sampling.height=512 \
sampling.width=512If you find our work helpful for your research, please consider giving a star ⭐ and citation 📝
@article{zhang2025egrpo,
title={E-GRPO: High Entropy Steps Drive Effective Reinforcement Learning for Flow Models},
author={Zhang, Shengjun and Zhang, Zhang and Dai, Chensheng and Duan, Yueqi},
journal={arXiv preprint},
year={2025}
}This codebase is built upon the following excellent repositories:
- DanceGRPO - Basic GRPO implementation
- Flow-GRPO - Flow-based GRPO
- MixGRPO - Mixed sampling strategy
- Granular-GRPO - Multi-granularity rewards
- BranchGRPO - Tree-based sampling
- FastVideo - Distributed training framework
- DDPO - Diffusion policy optimization
This project is licensed under the Apache License 2.0.