Source code for oumi.utils.analysis_utils

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from oumi.core.configs import AnalyzeConfig
from oumi.core.datasets import BaseMapDataset
from oumi.utils.logging import logger


[docs] def load_dataset_from_config(config: AnalyzeConfig) -> BaseMapDataset: """Load dataset based on configuration. This function loads datasets directly from the registry for analysis purposes, avoiding the need for tokenizers and other training infrastructure. """ # Delayed import to avoid circular dependency with registry and dataset modules from oumi.core.registry import REGISTRY dataset_name = config.dataset_name split = config.split subset = config.subset if not dataset_name: raise ValueError("Dataset name is required") try: # Load dataset from the REGISTRY dataset_class = REGISTRY.get_dataset(dataset_name, subset=subset) if dataset_class is not None: # Load registered dataset with basic parameters dataset = dataset_class( dataset_name=dataset_name, dataset_path=None, split=split, subset=subset, trust_remote_code=False, ) # Ensure we return a BaseMapDataset if isinstance(dataset, BaseMapDataset): return dataset else: raise NotImplementedError( f"Dataset type {type(dataset)} is not supported for analysis. " "Please use a dataset that inherits from BaseMapDataset." ) else: # TODO: Implement HuggingFace Hub loading raise NotImplementedError( f"Dataset '{dataset_name}' is not registered in the REGISTRY. " "Loading from HuggingFace Hub is not yet implemented." ) except Exception as e: logger.error(f"Failed to load dataset {dataset_name}: {e}") raise