Source code for oumi.builders.training
# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from pprint import pformat
from typing import Callable, Optional, cast
import transformers
import trl
from oumi.core.configs import TrainerType, TrainingParams
from oumi.core.distributed import is_world_process_zero
from oumi.core.processors.base_processor import BaseProcessor
from oumi.core.trainers import BaseTrainer, HuggingFaceTrainer, VerlGrpoTrainer
from oumi.core.trainers import Trainer as OumiTrainer
from oumi.utils.logging import logger
[docs]
def build_trainer(
trainer_type: TrainerType, processor: Optional[BaseProcessor]
) -> Callable[..., BaseTrainer]:
"""Builds a trainer creator functor based on the provided configuration.
Args:
trainer_type (TrainerType): Enum indicating the type of training.
processor: An optional processor.
Returns:
A builder function that can create an appropriate trainer based on the trainer
type specified in the configuration. All function arguments supplied by caller
are forwarded to the trainer's constructor.
Raises:
NotImplementedError: If the trainer type specified in the
configuration is not supported.
"""
def _create_hf_builder_fn(
cls: type[transformers.Trainer],
) -> Callable[..., BaseTrainer]:
def _init_hf_trainer(*args, **kwargs) -> BaseTrainer:
training_args = kwargs.pop("args", None)
callbacks = kwargs.pop("callbacks", [])
if training_args is not None:
# if set, convert to HuggingFace Trainer args format
training_args = cast(TrainingParams, training_args)
training_args.finalize_and_validate()
hf_args = training_args.to_hf()
if is_world_process_zero():
logger.info(pformat(hf_args))
trainer = HuggingFaceTrainer(cls(*args, **kwargs, args=hf_args), processor)
if callbacks:
# TODO(OPE-250): Define generalizable callback abstraction
# Incredibly ugly, but this is the only way to add callbacks that add
# metrics to wandb. Transformers trainer has no public method of
# allowing us to control the order callbacks are called.
training_callbacks = (
[transformers.trainer_callback.DefaultFlowCallback]
+ callbacks
# Skip the first callback, which is the DefaultFlowCallback above.
+ trainer._hf_trainer.callback_handler.callbacks[1:]
)
trainer._hf_trainer.callback_handler.callbacks = []
for c in training_callbacks:
trainer._hf_trainer.add_callback(c)
return trainer
return _init_hf_trainer
def _create_oumi_builder_fn() -> Callable[..., BaseTrainer]:
def _init_oumi_trainer(*args, **kwargs) -> BaseTrainer:
kwargs_processor = kwargs.get("processor", None)
if processor is not None:
if kwargs_processor is None:
kwargs["processor"] = processor
elif id(kwargs_processor) != id(processor):
raise ValueError(
"Different processor instances passed to Oumi trainer, "
"and build_trainer()."
)
return OumiTrainer(*args, **kwargs)
return _init_oumi_trainer
def _create_verl_grpo_builder_fn() -> Callable[..., BaseTrainer]:
def _init_verl_grpo_trainer(*args, **kwargs) -> BaseTrainer:
return VerlGrpoTrainer(*args, **kwargs)
return _init_verl_grpo_trainer
if trainer_type == TrainerType.TRL_SFT:
return _create_hf_builder_fn(trl.SFTTrainer)
elif trainer_type == TrainerType.TRL_DPO:
return _create_hf_builder_fn(trl.DPOTrainer)
elif trainer_type == TrainerType.TRL_GRPO:
return _create_hf_builder_fn(trl.GRPOTrainer)
elif trainer_type == TrainerType.HF:
return _create_hf_builder_fn(transformers.Trainer)
elif trainer_type == TrainerType.OUMI:
warnings.warn(
"OUMI trainer is still in alpha mode. "
"Prefer to use HF trainer when possible."
)
return _create_oumi_builder_fn()
elif trainer_type == TrainerType.VERL_GRPO:
return _create_verl_grpo_builder_fn()
raise NotImplementedError(f"Trainer type {trainer_type} not supported.")