Source code for oumi.core.trainers.verl_grpo_trainer

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Volcano Engine Reinforcement Learning (verl) GRPO Trainer."""

import copy
import os
from pathlib import Path
from pprint import pformat
from typing import Callable, Optional, Union, cast

from datasets import Dataset
from omegaconf import DictConfig, OmegaConf

from oumi.core.types.conversation import Conversation
from oumi.utils.grpo_utils import (
    extract_prompt_images_completion_from_single_turn_conversation,
)

try:
    import ray  # pyright: ignore[reportMissingImports]
    import verl  # pyright: ignore[reportMissingImports]
    from verl.trainer.ppo.ray_trainer import (  # pyright: ignore[reportMissingImports]
        RayPPOTrainer,
        ResourcePoolManager,
        Role,
    )
    from verl.workers.fsdp_workers import (  # pyright: ignore[reportMissingImports]
        ActorRolloutRefWorker,
        CriticWorker,
    )
    from verl.workers.reward_manager import (  # pyright: ignore[reportMissingImports]
        NaiveRewardManager,
    )
except ModuleNotFoundError:
    verl = None
    ray = None


from oumi.core.configs import DatasetSplitParams, TrainingConfig
from oumi.core.processors.base_processor import BaseProcessor
from oumi.core.tokenizers import BaseTokenizer
from oumi.core.trainers.base_trainer import BaseTrainer
from oumi.utils.logging import logger
from oumi.utils.verl_model_merger import FSDPModelMerger, ModelMergerConfig

# Dataset processing function type. This function takes the following arguments:
# 1. a dataset sample.
# 2. index of the sample.
# 3. data source name
# 4. split name (train, validation, etc.)
# Returns an example converted to verl format.
_DatasetProcessFn = Callable[[dict, int, str, str], dict]


[docs] class VerlGrpoTrainer(BaseTrainer): """verl GRPO Trainer. This class wraps verl's RayPPOTrainer. This class' name is misleading as it supports other RL algorithms as well, including GRPO, which we use here. For documentation on the underlying verl RayPPOTrainer, see https://verl.readthedocs.io/en/latest/examples/config.html. """ def __init__( self, processing_class: Optional[BaseTokenizer], config: TrainingConfig, reward_funcs: list[Callable], train_dataset: Dataset, eval_dataset: Dataset, processor: Optional[BaseProcessor] = None, cache_dir: Optional[Union[str, Path]] = None, **kwargs, ): """Initializes the verl trainer. Args: processing_class: The tokenizer for the model. config: Training config. reward_funcs: List of reward functions to use. train_dataset: Training dataset. eval_dataset: Validation dataset. This is required by verl. processor: Optional processor for the dataset. Required for VLM-s. cache_dir: Directory to cache verl Parquet datasets. **kwargs: Additional keyword arguments. """ if verl is None: raise RuntimeError( "verl is not installed. " "Please install it with 'pip install `oumi[gpu]`'." ) logger.warning( "VerlGrpoTrainer is experimental, and the interface is subject to change." ) self._processing_class = processing_class self._oumi_config = copy.deepcopy(config) self._final_output_dir: Optional[Path] = ( Path(self._oumi_config.training.output_dir).absolute().resolve() if self._oumi_config.training.output_dir else None ) self._temp_output_dir: Optional[Path] = ( self._final_output_dir / "verl_output" if self._final_output_dir else None ) if not self._final_output_dir and config.training.save_final_model: raise ValueError( "Output directory must be specified when saving final model is enabled." ) # TODO: OPE-1192 - Support multiple reward functions. if len(reward_funcs) > 1: raise ValueError("We only support up to one reward function.") self._reward_funcs = reward_funcs self._cache_dir: Path = ( Path(cache_dir) if cache_dir else Path.home() / ".cache" / "oumi" / "verl_datasets" ) self._train_dataset = train_dataset self._eval_dataset = eval_dataset # verl trainer uses private methods and properties of `transformers` # processor, so we need to pass the raw processor here. self._processor = processor.raw_processor if processor is not None else None # Detect what dataset post-processing function to use (if any). process_fn = self._detect_dataset_process_fn() # Generate files and set self._train_filepath and self._val_filepath. self._create_dataset_files(process_fn) self._setup_verl_trainer() def _detect_dataset_process_fn( self, ) -> Optional[_DatasetProcessFn]: """Returns a post-processing function to convert data to verl format. Examines dataset samples to determine what post-processing function to use. Returns: A post-processing function to convert data to verl format. If no post-processing is needed, returns `None`. """ first_train_sample = next(iter(self._train_dataset)) first_eval_sample = next(iter(self._eval_dataset)) if not isinstance(first_train_sample, dict): raise ValueError( "Element type of training dataset must be a dictionary. " f"Got {type(first_train_sample)} instead." ) if not isinstance(first_eval_sample, dict): raise ValueError( "Element type of validation dataset must be a dictionary. " f"Got {type(first_eval_sample)} instead." ) # Detect datasets containing Conversation-s. if "conversation_json" in first_train_sample: if "conversation_json" not in first_eval_sample: raise ValueError( "Training and validation datasets must both have the same key: " "'conversation_json'." ) try: # Check if the conversation_json is valid. _ = Conversation.from_json(first_train_sample["conversation_json"]) _ = Conversation.from_json(first_eval_sample["conversation_json"]) except Exception as e: raise ValueError( "Invalid conversation_json in training or validation dataset." ) from e return VerlGrpoTrainer._create_verl_data_entry_from_single_turn_conversation return None @staticmethod def _get_data_source_name(params: DatasetSplitParams) -> str: """Returns the verl data source name.""" dataset_names = list({ds.dataset_name for ds in params.datasets}) if len(dataset_names) != 1: if len(dataset_names) > 1: raise ValueError( f"Multiple dataset names found: {dataset_names}. " f"Please specify a single dataset name." ) else: raise ValueError( "No dataset names found. Please check the dataset split parameters." ) return dataset_names[0] @staticmethod def _extract_question_images_answer_from_single_turn_conversation( example: dict, ) -> tuple[str, list, str]: """Finds question, answer, and optional images in a single-turn conversation. Args: example: A dictionary containing the conversation JSON. Returns: A tuple containing the question, images, and answer. The list of images is empty for text-only conversations. """ prompt, images, answer = ( extract_prompt_images_completion_from_single_turn_conversation(example) ) if len(images) > 0: # TODO: Generalize. This only works for QwenVL 2.5, which is the only # VLM supported by verl as of 2025-05-15. if not prompt.startswith("<image>"): prompt = "<image>" + prompt return (prompt, images, answer) @staticmethod def _create_verl_data_entry_from_single_turn_conversation( example: dict, idx: int, data_source: str, split: str ) -> dict: prompt, images, answer = ( VerlGrpoTrainer._extract_question_images_answer_from_single_turn_conversation( example ) ) data = { "data_source": data_source, "prompt": [ { "role": "user", "content": prompt, } ], "images": images, "ability": "math", "reward_model": {"style": "rule", "ground_truth": answer}, "extra_info": { "split": split, "index": idx, "answer": answer, "question": prompt, # TODO: extract problem }, } return data def _create_dataset_files( self, process_fn: Optional[_DatasetProcessFn] = None ) -> None: """Creates dataset files for verl in Parquet format. The Parquet files are saved to the Oumi cache directory. Args: process_fn: Optional function to convert the dataset samples to verl format. """ train_file = self._cache_dir / "train.parquet" train_dataset = self._train_dataset # Limit the max number of sub-processes to 8 to avoid overloading the system # with too many processes. # TODO: Make this configurable. num_proc = min(8, os.cpu_count() or 1) if process_fn is not None: train_data_source = self._get_data_source_name(self._oumi_config.data.train) train_dataset = train_dataset.map( function=lambda example, idx: process_fn( example, idx, train_data_source, "train", ), with_indices=True, num_proc=num_proc, ) train_dataset.to_parquet(train_file) self._train_filepath = str(train_file) val_file = self._cache_dir / "val.parquet" eval_dataset = self._eval_dataset if process_fn is not None: validation_data_source = self._get_data_source_name( self._oumi_config.data.validation ) eval_dataset = eval_dataset.map( function=lambda example, idx: process_fn( example, idx, validation_data_source, "validation", ), with_indices=True, num_proc=num_proc, ) eval_dataset.to_parquet(val_file) self._val_filepath = str(val_file) def _create_config(self) -> DictConfig: """Creates a verl config.""" model_params = self._oumi_config.model model_name = model_params.model_name # 1. Read verl default dict config from YAML. yaml_path = Path(__file__).parent / "verl_trainer_config.yaml" config = OmegaConf.load(yaml_path) config = cast(DictConfig, config) # 2. Set config values, ex. from Oumi config values config.algorithm.adv_estimator = "grpo" config.data.train_files = self._train_filepath config.data.val_files = self._val_filepath grpo_params = self._oumi_config.training.grpo training_params = self._oumi_config.training config.data.max_response_length = grpo_params.max_completion_length config.actor_rollout_ref.model.path = model_name config.actor_rollout_ref.actor.optim.lr = training_params.learning_rate config.actor_rollout_ref.model.enable_gradient_checkpointing = ( training_params.enable_gradient_checkpointing ) if grpo_params.use_vllm: config.actor_rollout_ref.rollout.name = "vllm" else: config.actor_rollout_ref.rollout.name = "hf" config.actor_rollout_ref.rollout.temperature = grpo_params.temperature config.actor_rollout_ref.rollout.gpu_memory_utilization = ( grpo_params.vllm_gpu_memory_utilization ) # Normally, training steps is determined by the number of epochs. # If max_steps is set, it will override this. config.trainer.total_epochs = training_params.num_train_epochs if training_params.max_steps != -1: config.trainer.total_training_steps = training_params.max_steps if training_params.eval_strategy == "steps": config.trainer.test_freq = training_params.eval_steps if not training_params.save_epoch: config.trainer.save_freq = training_params.save_steps # Specific checkpoint to resume from takes precedence over starting # from last checkpoint. if training_params.resume_from_checkpoint: config.trainer.resume_mode = "resume_path" config.trainer.resume_from_path = training_params.resume_from_checkpoint elif training_params.try_resume_from_last_checkpoint: config.trainer.resume_mode = "auto" config.trainer.logger = [] if training_params.logging_strategy != "no": config.trainer.logger.append("console") if training_params.enable_wandb: config.trainer.logger.append("wandb") config.trainer.project_name = os.environ.get("WANDB_PROJECT", "oumi_verl") config.trainer.experiment_name = training_params.run_name config.trainer.default_local_dir = str(self._temp_output_dir or "") # 3. Apply user overrides overrides_config = OmegaConf.create(training_params.verl_config_overrides) config = cast(DictConfig, OmegaConf.merge(config, overrides_config)) # 4. Finalize and validate config. # Resolves the value of all interpolation fields in the config. # ex. `prompt_length: ${data.max_prompt_length}` # https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#omegaconf-resolve OmegaConf.resolve(config) if ( config.actor_rollout_ref.actor.strategy == "fsdp" and config.actor_rollout_ref.actor.strategy != config.critic.strategy ): raise ValueError( "Actor and critic must use the same strategy when using FSDP." ) return config def _setup_verl_trainer(self): """Sets up verl's RayPPOTrainer.""" if ray is None: raise RuntimeError( "ray is not installed. " "Please install it with 'pip install `oumi[gpu]`'." ) self._verl_config = self._create_config() logger.info(f"verl config: {pformat(self._verl_config)}") tokenizer = self._processing_class role_worker_mapping = { Role.ActorRollout: ray.remote(ActorRolloutRefWorker), Role.Critic: ray.remote(CriticWorker), Role.RefPolicy: ray.remote(ActorRolloutRefWorker), } # Create resource pool manager global_pool_id = "global_pool" resource_pool_spec = { global_pool_id: [self._verl_config.trainer.n_gpus_per_node] * self._verl_config.trainer.nnodes, } mapping = { Role.ActorRollout: global_pool_id, Role.Critic: global_pool_id, Role.RefPolicy: global_pool_id, } resource_pool_manager = ResourcePoolManager( resource_pool_spec=resource_pool_spec, mapping=mapping ) # Create reward function manager compute_score = self._reward_funcs[0] if len(self._reward_funcs) > 0 else None reward_fn = NaiveRewardManager( tokenizer=tokenizer, num_examine=0, compute_score=compute_score ) # num_examine=1 means to print 1 example per batch for analysis. val_reward_fn = NaiveRewardManager( tokenizer=tokenizer, num_examine=1, compute_score=compute_score ) self._verl_trainer = RayPPOTrainer( config=self._verl_config, tokenizer=tokenizer, processor=self._processor, role_worker_mapping=role_worker_mapping, resource_pool_manager=resource_pool_manager, reward_fn=reward_fn, val_reward_fn=val_reward_fn, )
[docs] def train(self) -> None: """Trains the model using verl's RayPPOTrainer.""" logger.info("Initializing verl trainer workers...") self._verl_trainer.init_workers() logger.info("Starting verl training...") self._verl_trainer.fit()
# TODO: OPE-1192 - Implement saving model/trainer state. verl training should # already handle saving models, including the final checkpoint.
[docs] def save_state(self) -> None: """Saves the training state.""" pass
[docs] def save_model(self, config: TrainingConfig, final: bool = True) -> None: """Saves the model. Args: config: The Oumi training config. final: Whether this is the final model being saved during training. """ if final: self._export_hf_model()
def _export_hf_model(self) -> bool: """Exports the tuned model to HF format. This method is called after training is complete. Returns: True if the model is exported successfully, False otherwise. """ if not (self._final_output_dir and self._temp_output_dir): return False final_dir = Path(self._final_output_dir) temp_dir = Path(self._temp_output_dir) all_checkpoint_dirs: list[Path] = [ f.absolute() for f in temp_dir.iterdir() if f.is_dir() and f.name.startswith("global_step_") and (f / "actor").exists() and (f / "actor").is_dir() ] # Find sub-directory named `global_step_NNN` with the largest NNN. latest_checkpoint_step = -1 latest_checkpoint_dir: Optional[Path] = None for d in all_checkpoint_dirs: step_str = str(d.name.removeprefix("global_step_")) try: step = int(step_str) except Exception as e: raise RuntimeError(f"Failed to extract step number from {d}") from e if step > latest_checkpoint_step: latest_checkpoint_dir = d latest_checkpoint_step = step if not latest_checkpoint_dir: logger.warning(f"No checkpoints found under {temp_dir}") return False logger.info( f"Merging and exporting model from '{latest_checkpoint_dir}' " f"to '{final_dir}' ..." ) config = ModelMergerConfig( operation="merge", backend="fsdp", # TODO: Detect if tie-word-embedding is enabled, or add a config parameter. tie_word_embedding=False, local_dir=str(latest_checkpoint_dir / "actor"), hf_model_config_path=str(latest_checkpoint_dir / "actor" / "huggingface"), target_dir=str(final_dir), ) merger = FSDPModelMerger(config) merger.merge_and_save() logger.info(f"Successfully exported model to '{final_dir}'!") return True