Training Methods#

Introduction#

Oumi supports several training methods to accommodate different use cases.

Here’s a quick comparison:

Method

Use Case

Data Required

Compute

Key Features

Supervised Fine-Tuning (SFT)

Task adaptation

Input-output pairs

Low

Fine-tunes pre-trained models on specific tasks by providing labeled conversations.

Vision-Language SFT

Multimodal tasks

Image-text pairs

Moderate

Extends SFT to handle both images and text, enabling image understanding problems.

Pretraining

Domain adaptation

Raw text

Very High

Trains a language model from scratch or adapts it to a new domain using large amounts of unlabeled text.

Direct Preference Optimization (DPO)

Preference learning

Preference pairs

Low

Trains a model to align with human preferences by providing pairs of preferred and rejected outputs.

Group Relative Policy Optimization (GRPO)

Reasoning

Input-output pairs

Moderate

Trains a model to improve reasoning skills by providing training examples with concrete answers.

Tip

Oumi supports GRPO on Vision-Language Models with the VERL_GRPO trainer.

Supervised Fine-Tuning (SFT)#

Overview#

Supervised Fine-Tuning (SFT) is the most common approach for adapting a pre-trained language model to specific downstream tasks. This involves fine-tuning the model’s parameters on a labeled dataset of input-output pairs, effectively teaching the model to perform the desired task. SFT is effective for a wide range of tasks, including:

  • Question answering: Answering questions based on given context or knowledge. This could be used to build chatbots that can answer questions about a specific domain or provide general knowledge.

  • Agent development: Training language models to act as agents that can interact with their environment and perform tasks autonomously. This involves fine-tuning the model on data that demonstrates how to complete tasks, communicate effectively, and make decisions.

  • Tool use: Fine-tuning models to effectively use external tools (e.g., calculators, APIs, databases) to augment their capabilities. This involves training on data that shows how to call tools, interpret their outputs, and integrate them into problem-solving.

  • Structured data extraction: Training models to extract structured information from unstructured text. This can be used to extract entities, relationships, or key events from documents, enabling automated data analysis and knowledge base construction.

  • Text generation: Generating coherent text, code, scripts, email, etc. based on a prompt.

Data Format#

SFT uses the Conversation format, which represents a conversation between a user and an assistant. Each turn in the conversation is represented by a message with a role (“user” or “assistant”) and content.

{
    "messages": [
        {
            "role": "user",
            "content": "What is machine learning?"
        },
        {
            "role": "assistant",
            "content": "Machine learning is a type of artificial intelligence that allows software applications to become more accurate in predicting outcomes without being explicitly programmed to do so. Machine learning algorithms use historical data as input to predict new output values."
        }
    ]
}

See Supervised Fine-Tuning for available SFT datasets.

Configuration#

The data section in the configuration file specifies the dataset to use for training. The training section defines various training parameters.

data:
  train:
    datasets:
      - dataset_name: "text_sft"
        dataset_path: "/path/to/data"
        split: "train"
    collator_name: "text_with_padding"

training:
  trainer_type: "TRL_SFT"

See the 🔧 Model Finetuning Guide notebook for a complete example.

Vision-Language SFT#

Overview#

Vision-Language SFT extends the concept of Supervised Fine-Tuning to handle both images and text. This enables the model to understand and reason about visual information, opening up a wide range of multimodal applications:

  • Image-based instruction following: Following instructions that involve both text and images. For example, the model could be instructed to analyze and generate a report based on an image of a table.

  • Multimodal Agent Development: Training agents that can perceive and act in the real world through vision and language. This could include tasks like navigating a physical space, interacting with objects, or following complex instructions.

  • Structured Data Extraction from Images: Extracting structured data from images, such as tables, forms, or diagrams. This could be used to automate data entry or to extract information from scanned documents.

Data Format#

Vision-Language SFT uses the Conversation format with additional support for images. The image field contains the path to the image file.

{
  "messages": [
    {
      "role": "user",
      "content": [
        {
          "type": "image_url",
          "content": "https://oumi.ai/the_great_wave_off_kanagawa.jpg"
        },
        {
          "type": "text",
          "content": "What is in this image?"
        }
      ]
    },
    {
      "role": "assistant",
      "content": "The image is a traditional Japanese ukiyo-e print."
    }
  ]
}
from oumi.core.types.conversation import Conversation, ContentItem, Message, Role, Type

Conversation(
    messages=[
        Message(
            role=Role.USER,
            content=[
                ContentItem(
                    type=Type.IMAGE_URL,
                    content="https://oumi.ai/the_great_wave_off_kanagawa.jpg"
                ),
                ContentItem(type=Type.TEXT, content="What is in this image?"),
            ],
        ),
        Message(
            role=Role.ASSISTANT,
            content="The image is a traditional Japanese ukiyo-e print."
        )
    ]
)

See Vision-Language Datasets for available vision-language datasets.

Configuration#

The configuration for Vision-Language SFT is similar to SFT, but with additional parameters for handling images.

model:
  model_name: "meta-llama/Llama-3.2-11B-Vision-Instruct"
  chat_template: "llama3-instruct"
  freeze_layers: ["vision_encoder"]  # Freeze specific layers to only train the language model, and not the vision encoder

data:
  train:
    collator_name: "vision_language_with_padding" # Visual features collator
    collator_kwargs: {}  # Optional: Additional collator parameters
    datasets:
      - dataset_name: "vl_sft_jsonl"
        dataset_path: "/path/to/data"
        trust_remote_code: False # Set to true if needed for model-specific processors
        dataset_kwargs:
          processor_name: "meta-llama/Llama-3.2-11B-Vision-Instruct" # Feature generator

training:
  trainer_type: "TRL_SFT"

Note: You can use collator_kwargs to customize the vision-language collator behavior. See the configuration guide for more details and examples.

See the 🖼️ Oumi Multimodal notebook for a complete example.

Pretraining#

Overview#

The most common pretraining method is Causal Language Modeling (CLM), where the model predicts the next token in a sequence, given the preceding tokens. Pretraining is the process of training a language model from scratch, or continuing training on a pre-trained model, using large amounts of unlabeled text data. This is a computationally expensive process, but it can result in models with strong general language understanding capabilities.

Pretraining is typically used for:

  • Training models from scratch: This involves training a new language model from scratch on a massive text corpus.

  • Continuing training on new data: This involves taking a pre-trained model and continuing to train it on additional data, to improve its performance on existing tasks.

  • Domain adaptation: This involves adapting a pre-trained model to a specific domain, such as scientific literature or legal documents.

Data Format#

Pretraining uses the BasePretrainingDataset format, which simply contains the text to be used for training.

{
    "text": "Document text for pretraining..."
}

See Pre-training section on pretraining datasets.

Configuration#

The configuration for pretraining specifies the dataset and the pretraining approach to use.

data:
  train:
    datasets:
      - dataset_name: "text"
        dataset_path: "/path/to/corpus"
        streaming: true  # Stream data from disk and/or network
    pack: true  # Pack multiple documents into a single sequence
    max_length: 2048  # Maximum sequence length

training:
  trainer_type: "OUMI"

Explanation of Configuration Parameters:

  • streaming: If set to true, the data will be streamed from disk, which is useful for large datasets that don’t fit in memory.

  • pack: If set to true, multiple documents will be packed into a single sequence, which can improve efficiency.

Direct Preference Optimization (DPO)#

Overview#

Direct Preference Optimization (DPO) is a technique for training language models to align with human preferences. It involves presenting the model with pairs of outputs (e.g., two different responses to the same prompt) and training it to prefer the output that humans prefer. DPO offers several advantages:

  • Training with human preferences: DPO allows you to directly incorporate human feedback into the training process, leading to models that generate more desirable outputs.

  • Improving output quality without reward models: Unlike reinforcement learning methods, DPO doesn’t require a separate reward model to evaluate the quality of outputs.

Data Format#

DPO uses the BaseExperimentalDpoDataset format, which includes the prompt, the chosen output, and the rejected output.

{
    "messages": [
        {
            "role": "user",
            "content": "Write a story about a robot"
        }
    ],
    "chosen": {
        "messages": [
            {
                "role": "assistant",
                "content": "In the year 2045, a robot named..."
            }
        ]
    },
    "rejected": {
        "messages": [
            {
                "role": "assistant",
                "content": "There was this robot who..."
            }
        ]
    }
}

See Preference Tuning section on preference datasets.

Configuration#

The configuration for DPO specifies the training parameters and the DPO settings.

data:
  train:
    datasets:
      - dataset_name: "preference_pairs_jsonl"
        dataset_path: "/path/to/data"
    collator_name: "dpo_with_padding"

training:
  trainer_type: "TRL_DPO"  # Use the TRL DPO trainer

Group Relative Policy Optimization (GRPO)#

Overview#

Group Relative Policy Optimization (GRPO) is a technique for training language models using reinforcement learning. A common usage is for training reasoning models on verifiable rewards, i.e. rewards calculated by functions as opposed to a reward model. An example of this is math problems, where there is a correct answer, and correctly-formatted incorrect answers can be given partial credit. While GRPO can be used with reward models, we primarily consider the case of using reward functions here.

Some advantages of GRPO include:

  • No value model: Unlike PPO, where a value aka critic model has to be trained alongside the actor model to estimate long-term reward, GRPO estimates the baseline from group scores, obviating the need for this model. This reduces training complexity and memory usage.

  • Training on verifiable rewards: By having reward functions, a separate reward model doesn’t have to be trained, reducing complexity and memory usage.

  • Does not require labeled preference data: Unlike other algorithms like DPO, GRPO doesn’t require labeled pairwise preference data. Instead, advantages are calculated by comparing multiple generations for a single prompt.

Data Format#

GRPO datasets should inherit from the BaseExperimentalGrpoDataset dataset class. Inside this class, you can implement any custom transformation logic you need.

TRL_GRPO#

For the TRL_GRPO trainer, the only requirement is the dataset includes a "prompt" column containing either the plaintext prompt, or messages in conversational format. The other fields, such as metadata, are optional, but are passed into the custom reward function if present. The following is a single example for LetterCountGrpoDataset, which has prompts asking models to count letters in words:

{
    "prompt": [
        {
            "content": 'Your final answer should be an integer written as digits and formatted as "\\boxed{your_answer}". For example, if the answer is 42, you should output "\\boxed{42}".',
            "role": "system",
        },
        {
            "content": "Could you determine the count of 'l's in 'substantial'?",
            "role": "user",
        },
    ],
    "letter_count": 1,
}

VERL_GRPO#

The VERL_GRPO trainer has a specific format required for its input dataset. Read their documentation for more information. An example for the Countdown dataset is shown below:

{
    "ability": "math",
    "data_source": "countdown",
    "extra_info": {"split": "train"},
    "prompt": [
        {
            "content": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [79, 8], create an equation that equals 87. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.\nAssistant: Let me solve this step by step.\n<think>",
            "role": "user",
        }
    ],
    "reward_model": {
        "ground_truth": {"numbers": [79, 8], "target": 87},
        "style": "rule",
    },
}

Tip

verl requires paths to Parquet files for the training and validation data. Oumi allows you to use HuggingFace Datasets instead by automatically creating the necessary Parquet files before training.

Reward function#

Instead of training a separate reward model which estimates the reward value of a completion, it is common to use reward functions instead. Both the trl and verl frameworks have specific interfaces required for the reward functions used. These are documented in the trl documentation and verl documentation respectively.

Configuration#

TRL_GRPO#

Configuring the TRL_GRPO trainer is similar to most other trl-based trainers in Oumi, like TRL_SFT. Most Oumi config fields will be used, as trl’s GRPO config is built on top of HF’s config. The following configuration highlights some relevant fields for GRPO:

model:
  model_name: "Qwen/Qwen2-0.5B-Instruct"

data:
  train:
    datasets:
      - dataset_name: "trl-lib/tldr"
        split: "train"

training:
  trainer_type: "TRL_GRPO"

  # Specifies the name of a reward function in our reward function registry.
  reward_functions: ["soft_20tokens_completions"]

  grpo:
    use_vllm: True

VERL_GRPO#

verl is an RL training framework created by Alibaba. Many Oumi config fields, which generally correspond to HF config fields, thus are not consumed by verl. The following table shows all Oumi config fields used by the verl trainer, and what fields they map to. An overview of fields in the verl config can be found in their documentation.

Oumi

verl

model.model_name

actor_rollout_ref.model.path

data.train.datasets

data.train_files

data.validation.datasets

data.val_files

training.grpo.max_completion_length

data.max_response_length

training.grpo.use_vllm

actor_rollout_ref.rollout.name

training.grpo.temperature

actor_rollout_ref.rollout.temperature

training.grpo.vllm_gpu_memory_utilization

actor_rollout_ref.rollout.gpu_memory_utilization

training.enable_gradient_checkpointing

actor_rollout_ref.model.enable_gradient_checkpointing

training.learning_rate

actor_rollout_ref.actor.optim.lr

training.num_train_epochs

trainer.total_epochs

training.max_steps

trainer.total_training_steps

training.eval_steps/training.eval_strategy

trainer.test_freq

training.save_steps/training.save_epoch

trainer.save_freq

training.resume_from_checkpoint

trainer.resume_mode/trainer.resume_from_path

training.try_resume_from_last_checkpoint

trainer.resume_mode

training.logging_strategy/training.enable_wandb

trainer.logger

training.run_name

trainer.experiment_name

training.output_dir

trainer.default_local_dir

Tip

The training.verl_config_overrides field can be used to specify any field in the verl config. The values specified in this field will override any values set by the Oumi -> verl mapping above. For example, if you already have your own training/validation Parquet files you want to use, you can directly set data.train_files in the override.

The following shows a bare-bones Oumi VERL_GRPO config.

model:
  model_name: "Qwen/Qwen2-0.5B-Instruct"

data:
  train:
    datasets:
      - dataset_name: "Jiayi-Pan/Countdown-Tasks-3to4"
        split: "train"
  # verl requires a validation set.
  validation:
    datasets:
      - dataset_name: "Jiayi-Pan/Countdown-Tasks-3to4"
        split: "test"

training:
  trainer_type: "VERL_GRPO"
  reward_functions: ["countdown"]

  grpo:
    use_vllm: True

  verl_config_overrides:
    # This sets `data.train_batch_size` to 128 in the verl config.
    data:
      train_batch_size: 128

Next Steps#