Recent progress in aligning image and video generative models with Group Relative Policy Optimization (GRPO) has improved human preference alignment, yet existing approaches still suffer from high computational cost due to sequential rollouts and large numbers of SDE sampling steps, as well as training instability caused by sparse rewards. In this paper, we present BranchGRPO, a method that restructures the rollout process into a branching tree, where shared prefixes amortize computation and pruning removes low-value paths and redundant depths. BranchGRPO introduces three contributions: (1) a branch sampling scheme that reduces rollout cost by reusing common segments; (2) a tree-based advantage estimator that converts sparse terminal rewards into dense, step-level signals; and (3) pruning strategies that accelerate convergence while preserving exploration. On HPDv2.1 image alignment, BranchGRPO improves alignment scores by up to 16% over strong baselines, while reducing per-iteration training time by nearly 55%. On WanX-1.3B video generation, it further achieves higher Video-Align scores with sharper and temporally consistent frames compared to DanceGRPO.
We restructure sequential GRPO rollouts into a branching tree. At selected denoising steps, trajectories split into multiple children that share early prefixes, reducing redundant sampling. Leaf rewards are fused upward using path-probability weighting and normalized per depth to obtain dense, step-wise advantages. Lightweight width and depth pruning limit backpropagation to selected nodes only, preserving forward diversity while cutting compute.
Distribution comparison: branch vs sequential rollouts largely overlap in embedding space (MMD²=0.019).
Pipeline: branching rollouts (left), path-weighted reward fusion (middle), and depth-wise normalization with pruning (right).
Method | NFEπθ_old | NFEπθ | Iteration Time (s)↓ | HPS-v2.1↑ | Pick Score↑ | ImageReward↑ |
---|---|---|---|---|---|---|
FLUX | - | - | - | 0.313 | 0.227 | 1.112 |
DanceGRPO (tf=1.0) | 20 | 20 | 698 | 0.360 | 0.229 | 1.189 |
DanceGRPO (tf=0.6) | 20 | 12 | 469 | 0.353 | 0.228 | 1.219 |
MixGRPO (20,5) | 20 | 5 | 289 | 0.359 | 0.228 | 1.211 |
BranchGRPO | 13.68 | 13.68 | 493 | 0.363 | 0.229 | 1.233 |
BranchGRPO-WidPru | 13.68 | 8.625 | 314 | 0.364 | 0.230 | 1.300 |
BranchGRPO-DepPru | 13.68 | 8.625 | 314 | 0.369 | 0.231 | 1.319 |
Note: NFE counts effective denoising steps where the policy is updated; branching methods account for shared prefixes. Speedup is relative to DanceGRPO (tf=1.0).
Qualitative comparison of generations from Flux, DanceGRPO, and BranchGRPO.
Authors: Yuming Li (Peking University), Yikai Wang (Beijing Normal University), Yuying Zhu (ByteDance), Zhongyu Zhao (Peking University), Ming Lu (Peking University), Qi She (ByteDance), Shanghang Zhang (Peking University)
Emails: fredreic1880@gmail.com