Source code for oumi.core.configs.analyze_config
# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Any, Optional
from omegaconf import MISSING
from oumi.core.configs.base_config import BaseConfig
from oumi.core.configs.params.base_params import BaseParams
[docs]
@dataclass
class SampleAnalyzerParams(BaseParams):
"""Params for a single sample analyzer plugin."""
id: str = MISSING
"""Unique identifier for the analyzer."""
config: dict[str, Any] = field(default_factory=dict)
"""Analyzer-specific configuration parameters."""
[docs]
@dataclass
class AnalyzeConfig(BaseConfig):
"""Configuration for dataset analysis and aggregation."""
# Simple fields for common use cases
dataset_name: Optional[str] = None
"""Dataset name."""
split: str = "train"
"""The split of the dataset to load.
This is typically one of "train", "test", or "validation". Defaults to "train".
"""
subset: Optional[str] = None
"""The subset of the dataset to load. If None, uses the base dataset."""
sample_count: Optional[int] = None
"""The number of examples to sample from the dataset.
If None, uses the full dataset. If specified, must be non-negative.
"""
output_path: str = "."
"""Directory path where output files will be saved.
Defaults to current directory ('.').
"""
analyzers: list[SampleAnalyzerParams] = field(default_factory=list)
"""List of analyzer configurations (plugin-style)."""
[docs]
def __post_init__(self):
"""Validates the configuration parameters."""
if not self.dataset_name:
raise ValueError("'dataset_name' must be provided")
# Validate analyzer configurations
analyzer_ids = set()
for analyzer in self.analyzers:
# Validate analyzer ID
if not analyzer.id:
raise ValueError("Analyzer 'id' must be provided")
if analyzer.id in analyzer_ids:
raise ValueError(f"Duplicate analyzer ID found: '{analyzer.id}'")
analyzer_ids.add(analyzer.id)