verl: Flexible and Efficient RL for LLMs

Yuxuan Tong (童雨轩)

ByteDance Seed & Tsinghua University

2025/05/24

1 Background

1.1 Modelling RL as Dataflow Graph

Reinforcement Learning (RL) for Large Language Models (LLMs) can typically be modeled as a dataflow graph, consisting of:

  1. multiple models: actor, critic, reference, reward model, etc.
  2. multiple stages: generating, preparing experiences, training
  3. multiple workloads: generation, inference, training

1.2 Implementing Dataflow Graph as Execution Pattern

In practice, we should implement the dataflow graph as execution pattern on a GPU cluster with restrictions.

2 Paradigm: HybridFlow

2.1 Background: Single-Controller vs. Multi-Controller

  • Single-Controller (MPMD): A centralized controller manages all the workers.
  • Multi-Controller (SPMD): Each worker computes on its own and communicates with each other.
Figure 1: Single-Controller (Multi-Program-Multi-Data, MPMD) vs. Multi-Controller (Single-Program-Multi-Data, SPMD)
  • Single-Controller (MPMD) is flexible but suffers from communication overhead.
  • Multi-Controller (SPMD) is efficient but is complex for programming
Paradigm Pro Con
Single-Controller Flexible Commnucation Overhead
Multi-Controller Efficient Complex Programming

2.2 Paradigm: Hybrid-Controller

verl introduces hybrid-controller, where a single-controller manages multiple multi-controllers.

2.3 Flexibility: Single-Controller

Single-controller enables verl to implement various RL algorithms by only modifying a few lines, usually only in the fit function.

Listing 1: PPO example code.
for prompts in dataloader:
    # Stage 1: Sampling Trajectories
    batch = actor.generate_sequences(prompts)
    # Stage 2: Preparing Experiences
    batch = reward.compute_reward(batch)
    batch = reference.compute_log_prob(batch)
    batch = critic.compute_values(batch)
    batch = compute_advantage(batch, "gae")
    # Stage 3: Training
    critic.update_critic(batch)
    actor.update_actor(batch)
Listing 2: GRPO example code.
for prompts in dataloader:
    # Stage 1: Sampling Trajectories
    batch = actor.generate_sequences(prompts)
    # Stage 2: Preparing Experiences
    batch = reward.compute_reward(batch)
    batch = reference.compute_log_prob(batch)
    batch = compute_advantage(batch, "grpo")
    # Stage 3: Training
    critic.update_critic(batch)
    actor.update_actor(batch)

With such flexibility, verl has supported diverse RL algorithms including PPOGRPORLOO, ReMaxREINFORCE++PRIMEDAPODr. GRPO, etc.

2.4 Efficiency: Mutli-Controller

2.4.1 Inter-Stage: Hybrid Engine

The optimal execution pattern for different workloads, e.g., training, generation, are usually different.

  • Instead of splitting the devices to deploy different engines separately for different workloads, causing many bubbles,

  • verl implements a hybrid engine that can switch between the different workloads on the same cluster, fully utilizing all the GPUs.

2.4.2 Intra-Stage: Diverse Parallelisms

Hybrid engine allows verl to flexibly switch between parallelism strategies to optimize the performance.

Parallelism Algorithms:

  • Data Parallelism
  • Tensor Parallelism
  • Pipeline Parallelism
  • Context / Sequence Parallelism

Training Backend:

  • FSDP
  • Megatron

Generation Backend:

  • vLLM
  • SGLang

3 Features / Optimizations in verl

3.1 Sequence Packing

  1. Remove padding tokens and packs multiple data sequences into a row
  2. Tweak the attention mask & position IDs to avoid cross-contamination

To enable this, use use_remove_padding.

3.2 DP Balancing

3.2.1 Load Imbalance in DP

  • Parallelism usually needs synchronization between different ranks.
  • Data Parallelism (DP) like ZeRO is the most commonly used parallelism strategy.
  • However, DP performance might be damaged by load imbalance, which is especially severe in long-context training.

3.2.2 Balancing across DP Ranks

  • balance the valid tokens dispatched to each rank
  • by reordering the samples in each batch

To enable this, use balance_batch.

3.2.3 Balancing across Micro Batches

However, in gradient accumulation,

  • it’s not enough to only balance valid tokens in a batch,
  • since DP syncs in the unit of micro batch.

To resolve this, verl supports to

  • balance the valid tokens across micro batches
  • by evenly deviding the data sequences in the batch before packing into micro batches

To enable this, use use_dynamic_bsz.

3.3 Async Engine for Multi-Turn Generation

Multi-turn generation might have different implementations:

  • Naive: wrap the batch generation in a for-loop, synchronizing for each turn, thus causing many bubbles.
  • Efficient: utilize the async engine, managing generation at the request level instead of the batch level.

Specifically, verl integrates:

  • SGLang’s Engine.async_generate (contributed by the SGLang RL team)
  • vLLM-V1’s AsyncLLM (contributed by Xibin Wu from ByteDance)

3.4 Other Features

  1. Multi-Model LLMs’ RL
  2. Full support for RL with AMD (ROCm Kernel) hardwares
  3. Gradient Checkpointing (enable_gradient_checkpointing)
  4. Torch Compile (use_torch_compile)
  5. Liger Kernel (use_liger)

4 Roadmap

For the most timely updates of important features, please keep an eye on verl’s README.

4.1 Efficient RL with Huge MoE like DeepSeek-V3-671B (ETA: Late May’25)

verl is working on supporting efficient RL training for huge MoE like DeepSeek-V3-671B, based on the following features:

  1. MoE models with GPTModel class for actor and critic
  2. Multi-node inference
  3. Parameter sharding manager for Megatron-Core V0.12 + latest version of inference engines

For more details, please check our tracker #708.

4.2 Agentic RL with Diverse Environments & Tools (Planned)

  1. Our ongoing RFC
  2. Integrating MDP
  3. Integrating existing implementations, e.g. the Atropos library from Nous Research

4.3 Other Plans

  1. Partial Rollout
  2. Multi-Token-Prediction (MTP)

Welcome to join the verl community to discuss and contribute!

Thanks for Listening!

Repo: https://github.com/volcengine/verl

Contact:

5 Programming Guide

5.1 Customizing the Dataset

A canonical RL dataset in verl has the following fields:

  • prompt: a list of messages {"role": "...", "content": "..."}
  • data_source: used to choose the reward function
  • reward_model: a dict containing
    • "ground_truth"
    • "style" like "model" or "rule"
  • (Optional) extra_info: a dict containing extra information

For VLM RL, verl expects fields "images" and/or "videos"

For examples, please check the examples/data_preprocess.

You could also customize the field names via config. Please check the data section in config files like ppo_trainer.yaml for more details.

For further customization, verl provides the data.custom_cls config,

Listing 3: Config for custom dataset class.
data:
  custom_cls:
    path: null # path to the `.py` file containing the `class` definition
    name: null # the `class` name

An example CLI config could be:

Listing 4: Example config for custom dataset class.
--data.custom_cls.path=./examples/dataset/custom_dataset.py \
--data.custom_cls.name=CustomDataset

The custom dataset class defined in the .py file is required to accept the following initialization parameters:

Listing 5: Custom dataset class initialization.
class CustomDataset: # You could also inherit from `RLHFDataset`
  def __init__(
      self,
      data_files: Union[str, List[str]],
      tokenizer: PreTrainedTokenizer,
      config: DictConfig,
      processor: Optional[ProcessorMixin] = None,
  ):
      ...

5.2 Customizing the Reward

verl allows to define custom reward function via the custom_reward_function config:

Listing 6: Config for custom reward function.
custom_reward_function:
  path: null # path to the `.py` file containing the function definition
  name: compute_score # the function name after `def`
reward_model:
  reward_manager: naive

An example CLI config could be:

Listing 7: Example config for custom reward function.
--custom_reward_function.path=./examples/reward_fn/custom_reward_fn.py \
--custom_reward_function.name=compute_score \
--reward_model.reward_manager=naive

The function defined in .py should accept the parameters passed from the reward manager __call__ method. Taking NaiveRewardManager as an example:

Listing 8: How a reward function is called in NaiveRewardManager.
class NaiveRewardManager:
    def __call__(self, data: DataProto, return_dict: bool=False):
        # Preprocessing for the input data
        score = self.compute_score(
            data_source=data_source,
            solution_str=solution_str,
            ground_truth=ground_truth,
            extra_info=extra_info,
        )
        # Other processing for the final `reward`

For more complex features, you can also add a new reward manager like PRIMERewardManager or DAPORewardManager.

5.3 Customizing the Loss Function

To modify the loss function, the most convenient way is to

  1. search for the .backward() call
  2. modify functions like compute_policy_loss
  3. or add loss terms like entropy_loss

For example, the DataParallelPPOActor.update_policy method defines the loss function as follows:

Listing 9: Simplified loss function definition in DataParallelPPOActor.
class DataParallelPPOActor(BasePPOActor):
    def update_policy(self, data: DataProto):
        pg_loss = compute_policy_loss(
            old_log_prob=old_log_prob, log_prob=log_prob,
            advantages=advantages, # ...
        )
        entropy_loss = agg_loss(loss_mat=entropy)
        policy_loss = pg_loss - entropy_loss * entropy_coeff
        kld = kl_penalty(
            logprob=log_prob, ref_logprob=ref_log_prob, # ...
        )
        kl_loss = agg_loss(loss_mat=kld)
        policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef
        loss.backward()

5.4 Customizing the Training Logic

As mentioned above, the main training logic is concentrated in the fit function of the trainer classes like RayPPOTrainer.

For example, the DAPORayTrainer class overrides the fit function to implement the “dynamic sampling” feature:

(See the next slide for the code ➡️)

Listing 10: Simplified fit function in DAPORayTrainer, with dynamic sampling highlighted.
class RayDAPOTrainer(RayPPOTrainer):
  def fit(self):
    for epoch in range(self.config.trainer.total_epochs):
      batch = None
      for batch_dict in self.train_dataloader:
        new_batch = DataProto.from_single_dict(batch_dict)
        num_gen_batches += 1
        gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
        new_batch = new_batch.union(gen_batch_output)
        if not self.config.algorithm.filter_groups.enable:
          batch = new_batch
        else:
          # Getting `kept_traj_idxs` ...
          new_batch = new_batch[kept_traj_idxs]
          batch = new_batch if batch is None else DataProto.concat([batch, new_batch])
          prompt_bsz = self.config.data.train_batch_size
          if num_prompt_in_batch < prompt_bsz:
            max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches
            if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches:
                continue
          else:
            traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n
            batch = batch[:traj_bsz]
        # ...