oumi.datasets.grpo.rewards#
GRPO reward functions module.
- oumi.datasets.grpo.rewards.compute_letter_count_reward(completion: str, target_count: int) float [source]#
Computes the rewards for counting the letters in a string.
- Parameters:
completion – The completion string from the LLM.
target_count – The target count of letters.
- Returns:
The reward value.
- oumi.datasets.grpo.rewards.compute_sharp_target_token_length_reward(num_tokens: int, *, target_tokens: int)[source]#
Returns maximum reward for inputs that are target_tokens long.
The reward reduces sharply if the actual number of tokens deviates from target_tokens.
The reward is computed as: -|num_tokens - target_tokens|, which penalizes token counts not equal to target_tokens.
- oumi.datasets.grpo.rewards.compute_soft_target_token_length_reward(num_tokens: int, *, target_tokens: int)[source]#
Returns maximum reward for inputs that are target_tokens long.
The reward is in the [0,1] range and reduces smoothly from the maximum value of 1.0 if the actual number of tokens deviates from target_tokens.
The reward is proportional to: x*exp(-x) where x := num_tokens/target_tokens.
- oumi.datasets.grpo.rewards.countdown_reward(data_source: str, solution_str: str, ground_truth: dict[str, Any], extra_info: dict[str, Any], format_score=0.0, score=1.0) float [source]#
Custom reward function for the Countdown task.
Currently, this function only works with the VERL_GRPO trainer.
- Parameters:
data_source – The data source.
solution_str – The response from the LLM.
ground_truth – Dictionary containing target number and available numbers
extra_info – Extra information about the sample.
format_score – The score for correct format but wrong answer.
score – The score for the correct answer.
- Returns:
score if the equation is valid and correct, format_score if the answer was parsed properly but the equation is incorrect, 0 if the answer was not parsed properly.
- oumi.datasets.grpo.rewards.gsm8k_reward(data_source, solution_str, ground_truth, extra_info, method='strict', format_score=0.0, score=1.0)[source]#
The scoring function for GSM8K.
Reference: Trung, Luong, et al. “Reft: Reasoning with reinforced fine-tuning.” Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2024.
- Parameters:
data_source – the data source
solution_str – the solution text
ground_truth – the ground truth
extra_info – extra info
method – the method to extract the solution, choices are ‘strict’ and ‘flexible’
format_score – the score for the format
score – the score for the correct answer