BranchGRPO: Stable and Efficient GRPO with Structured Branching in Diffusion Models

Yuming Li1*, Yikai Wang2*, Yuying Zhu3, Zhongyu Zhao1, Ming Lu1, Qi She3, Shanghang Zhang1

1Peking University    2Beijing Normal University    3ByteDance

-->

Abstract

Teaser: Comparison of BranchGRPO and DanceGRPO

Recent progress in aligning image and video generative models with Group Relative Policy Optimization (GRPO) has improved human preference alignment, yet existing approaches still suffer from high computational cost due to sequential rollouts and large numbers of SDE sampling steps, as well as training instability caused by sparse rewards. In this paper, we present BranchGRPO, a method that restructures the rollout process into a branching tree, where shared prefixes amortize computation and pruning removes low-value paths and redundant depths. BranchGRPO introduces three contributions: (1) a branch sampling scheme that reduces rollout cost by reusing common segments; (2) a tree-based advantage estimator that converts sparse terminal rewards into dense, step-level signals; and (3) pruning strategies that accelerate convergence while preserving exploration. On HPDv2.1 image alignment, BranchGRPO improves alignment scores by up to 16% over strong baselines, while reducing per-iteration training time by nearly 55%. On WanX-1.3B video generation, it further achieves higher Video-Align scores with sharper and temporally consistent frames compared to DanceGRPO.

Method

We restructure sequential GRPO rollouts into a branching tree. At selected denoising steps, trajectories split into multiple children that share early prefixes, reducing redundant sampling. Leaf rewards are fused upward using path-probability weighting and normalized per depth to obtain dense, step-wise advantages. Lightweight width and depth pruning limit backpropagation to selected nodes only, preserving forward diversity while cutting compute.

Distribution overlap between sequential and branch rollouts (MMD^2=0.019)

Distribution comparison: branch vs sequential rollouts largely overlap in embedding space (MMD²=0.019).

Pipeline: Branch rollouts, reward fusion, depth-wise normalization and pruning

Pipeline: branching rollouts (left), path-weighted reward fusion (middle), and depth-wise normalization with pruning (right).

Experiments

Efficiency–Quality Comparison

Method NFEπθ_old NFEπθ Iteration Time (s)↓ HPS-v2.1↑ Pick Score↑ ImageReward↑
FLUX---0.3130.2271.112
DanceGRPO (tf=1.0)20206980.3600.2291.189
DanceGRPO (tf=0.6)20124690.3530.2281.219
MixGRPO (20,5)2052890.3590.2281.211
BranchGRPO13.6813.684930.3630.2291.233
BranchGRPO-WidPru13.688.6253140.3640.2301.300
BranchGRPO-DepPru13.688.6253140.3690.2311.319

Note: NFE counts effective denoising steps where the policy is updated; branching methods account for shared prefixes. Speedup is relative to DanceGRPO (tf=1.0).

FLUX Text2Image Results

Qualitative comparison of generations from Flux, DanceGRPO, and BranchGRPO

Qualitative comparison of generations from Flux, DanceGRPO, and BranchGRPO.

Contact

Authors: Yuming Li (Peking University), Yikai Wang (Beijing Normal University), Yuying Zhu (ByteDance), Zhongyu Zhao (Peking University), Ming Lu (Peking University), Qi She (ByteDance), Shanghang Zhang (Peking University)

Emails: fredreic1880@gmail.com