Source code for oumi.evaluation.registry.count_letters_task
# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Any, Optional
from oumi.core.configs.params.evaluation_params import EvaluationTaskParams
from oumi.core.inference.base_inference_engine import BaseInferenceEngine
from oumi.core.registry import register_evaluation_function
from oumi.datasets.grpo.letter_count import LetterCountGrpoDataset
from oumi.utils.logging import logger
def _extract_prediction(response: str) -> Optional[int]:
r"""Returns the numeric answer extracted from `\boxed{...}`, or None otherwise."""
regex_result = re.findall(r"\\boxed\{([-+]?\d+)\}", response)
if not regex_result or len(regex_result) != 1:
return None
number_str = regex_result[0]
# Except clause shouldn't trigger because the regex should only find ints.
try:
return int(number_str)
except ValueError:
return None
[docs]
@register_evaluation_function("count_letters")
def count_letters(
task_params: EvaluationTaskParams,
inference_engine: BaseInferenceEngine,
) -> dict[str, Any]:
"""Custom evaluation function registered as `count_letters`."""
dataset = LetterCountGrpoDataset(
dataset="oumi-ai/oumi-letter-count-clean", split="test"
)
# TODO: OPE-1155: Add support for using Oumi dataset code to create the dataset.
# dataset = build_dataset("oumi-ai/oumi-letter-count", tokenizer=None, sample_count=10) # noqa: E501
num_samples = task_params.num_samples
if num_samples is None:
num_samples = len(dataset)
input_conversations = [dataset.conversation(i) for i in range(num_samples)]
conversations = inference_engine.infer(input_conversations)
logger.info(f"Finished inference on {len(conversations)} conversations!")
if len(conversations) > 0:
logger.info(f"Sample conversation: {conversations[0]}")
count = 0 # The number of examples with correct answers extracted.
total = 0 # All examples.
valid_count = 0 # The number of examples with valid answers extracted.
for i, conversation in enumerate(conversations):
total += 1
# Grab the model's response
response = conversation.last_message()
# Ignore cases where model didn't respond or it's a multimodal response.
# For now, we focus on text-only responses.
if not response or not isinstance(response.content, str):
continue
# Count the example as correct if the extracted prediction is correct.
prediction = _extract_prediction(response.content)
if prediction is None:
continue
valid_count += 1
if prediction == conversation.metadata["letter_count_integer"]:
count += 1
return {
# Accuracy across all examples.
"accuracy": count / total if total > 0 else 0,
# Accuracy when only counting examples with properly extracted answers.
"properly_extracted_accuracy": count / valid_count if valid_count > 0 else 0,
"num_samples": num_samples,
# These three values sum up to num_samples.
"num_correct_answers": count,
"num_incorrect_answers": valid_count - count,
"num_invalid_answers": total - valid_count,
}