Source code for oumi.core.analyze.dataset_analyzer

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
from dataclasses import asdict, dataclass
from typing import Any, Optional

import pandas as pd
from tqdm import tqdm

from oumi.core.configs import AnalyzeConfig
from oumi.core.datasets import BaseMapDataset
from oumi.core.registry.registry import REGISTRY
from oumi.utils.analysis_utils import load_dataset_from_config
from oumi.utils.logging import logger


[docs] @dataclass class MessageAnalysisResult: """Result of analyzing a single message in a conversation. Attributes: conversation_id: Unique identifier for the conversation conversation_index: Index of the conversation in the dataset message_index: Index of the message within the conversation role: Role of the message sender (e.g., 'user', 'assistant') message_id: Unique identifier for the message text_content: The text content of the message analyzer_metrics: Dictionary of metrics computed by sample analyzers, with keys prefixed by analyzer ID to avoid conflicts """ ANALYZER_METRICS_FIELD = "analyzer_metrics" conversation_id: str conversation_index: int message_index: int role: str message_id: str text_content: str analyzer_metrics: dict[str, Any]
[docs] def to_dict(self) -> dict[str, Any]: """Convert the analysis result to a dictionary with flattened analyzer metrics. Returns: Dictionary representation of the analysis result with analyzer metrics flattened into the main dictionary (prefixed by analyzer ID) """ base_dict = asdict(self) # Flatten analyzer_metrics into the main dict analyzer_metrics = base_dict.pop(self.ANALYZER_METRICS_FIELD, {}) base_dict.update(analyzer_metrics) return base_dict
[docs] @dataclass class DatasetAnalysisResult: """Complete result of dataset analysis. Attributes: dataset_name: Name of the analyzed dataset total_conversations: Total number of conversations in the dataset conversations_analyzed: Number of conversations actually analyzed total_messages: Total number of messages analyzed messages: List of analysis results for each individual message """ dataset_name: str total_conversations: int conversations_analyzed: int total_messages: int messages: list[MessageAnalysisResult]
[docs] def to_dict(self) -> dict[str, Any]: """Convert the analysis result to a dictionary. Returns: Dictionary representation of the analysis result """ return asdict(self)
[docs] def to_dataframe(self) -> pd.DataFrame: """Convert the analysis results to a pandas DataFrame. Returns: DataFrame with flattened analyzer metrics for easy querying. Each row represents one message with all its analysis metrics. """ # Convert each message to dict with flattened metrics message_dicts = [msg.to_dict() for msg in self.messages] return pd.DataFrame(message_dicts)
[docs] class DatasetAnalyzer: """Orchestrates dataset analysis by creating and managing sample analyzers.""" def __init__(self, config: AnalyzeConfig): """Initialize the dataset analyzer with configuration. Args: config: AnalyzeConfig object containing all analysis parameters """ self.config = config self.dataset_name = config.dataset_name self.split = config.split self.tokenizer = config.tokenizer self.dataset = load_dataset_from_config(config) self.sample_analyzers = self._initialize_sample_analyzers() # Initialize analysis results as None self._analysis_results: Optional[DatasetAnalysisResult] = None self._analysis_df: Optional[pd.DataFrame] = None def _initialize_sample_analyzers(self): """Initialize sample analyzer plugins from configuration.""" sample_analyzers = {} for analyzer_params in self.config.analyzers: try: # Get the analyzer class from the registry analyzer_class = REGISTRY.get_sample_analyzer(analyzer_params.id) if analyzer_class is None: raise ValueError( f"Sample analyzer '{analyzer_params.id}' not found in registry" ) # Prepare parameters for analyzer constructor analyzer_kwargs = dict(analyzer_params.params) if self.tokenizer is not None: analyzer_kwargs["tokenizer"] = self.tokenizer # Create analyzer instance with keyword arguments sample_analyzer = analyzer_class(**analyzer_kwargs) sample_analyzers[analyzer_params.id] = sample_analyzer logger.info(f"Initialized sample analyzer: {analyzer_params.id}") except Exception as e: logger.error( f"Failed to initialize sample analyzer {analyzer_params.id}: {e}" ) logger.error(f"Analyzer configuration: {analyzer_params}") return sample_analyzers
[docs] def analyze_dataset(self) -> None: """Analyze the dataset and store results internally. This method performs sample-level analysis using the configured sample analyzers. Each sample analyzer processes individual messages and returns metrics for each message. Results are stored internally and can be accessed via the query() method. Raises: ValueError: If no analyzers are configured for analysis. """ if not self.sample_analyzers: raise ValueError( "No analyzers configured for analysis. Please add at least one " "analyzer to the configuration before calling analyze_dataset()." ) logger.info(f"Starting analysis of dataset: {self.dataset_name}") logger.info( f"Using {len(self.sample_analyzers)} sample analyzers: " f"{list(self.sample_analyzers.keys())}" ) total_conversations = len(self.dataset) conversations_to_analyze = min( total_conversations, self.config.sample_count or total_conversations ) logger.info(f"Analyzing {conversations_to_analyze} conversations") # Step 1: Per-message level analysis logger.info("Step 1: Computing message metrics...") self._compute_message_metrics()
@property def analysis_results(self) -> Optional[DatasetAnalysisResult]: """Get the analysis results if available. Returns: DatasetAnalysisResult if analysis has been run, None otherwise """ return self._analysis_results def _compute_message_metrics(self) -> None: """Compute metrics for all messages in the dataset. Results are stored in self._analysis_results. """ total_conversations = len(self.dataset) # Apply conversation limit if specified max_conversations = self.config.sample_count if max_conversations is not None: if max_conversations <= 0: raise ValueError( f"sample_count must be positive, got {max_conversations}. " "Use None to analyze all conversations." ) conversations_to_analyze = min(total_conversations, max_conversations) logger.info( f"Limiting analysis to first {max_conversations} " f"conversations (dataset has {total_conversations} total)" ) else: conversations_to_analyze = total_conversations logger.info( "Analyzing %d conversations for message-level metrics", conversations_to_analyze, ) # Collect all message analysis results message_results = [] # Use tqdm for progress monitoring for conv_idx in tqdm( range(conversations_to_analyze), desc=f"Analyzing {self.dataset_name}", unit="conv", ): conversation = self.dataset.conversation(conv_idx) for msg_idx, message in enumerate(conversation.messages): message_result = self._compute_per_message_metrics( message, conv_idx, msg_idx, conversation ) message_results.append(message_result) self._analysis_results = DatasetAnalysisResult( dataset_name=self.dataset_name or "", # Config validation ensures this is not None total_conversations=total_conversations, conversations_analyzed=conversations_to_analyze, total_messages=len(message_results), messages=message_results, ) # Convert to DataFrame and save as member variable self._analysis_df = self._analysis_results.to_dataframe() def _compute_per_message_metrics( self, message, conv_idx: int, msg_idx: int, conversation ) -> MessageAnalysisResult: """Compute metrics for a single message. Args: message: The message object to analyze conv_idx: Index of the conversation in the dataset msg_idx: Index of the message within the conversation conversation: The conversation object containing the message Returns: MessageAnalysisResult: Structured result containing message metadata and analyzer metrics for the individual message. """ # Get text content if isinstance(message.content, str): text_content = message.content else: # For multimodal content, extract text only text_content = message.compute_flattened_text_content() # Extract basic message information conversation_id = conversation.conversation_id or f"conv_{conv_idx}" message_id = message.id or f"msg_{conv_idx}_{msg_idx}" role = message.role.value # Compute metrics using all configured analyzers analyzer_metrics: dict[str, Any] = {} for analyzer_id, analyzer in self.sample_analyzers.items(): try: analyzer_metrics_raw = analyzer.analyze_message( text_content, self.tokenizer ) # Prefix metrics with analyzer ID to avoid conflicts for key, value in analyzer_metrics_raw.items(): analyzer_metrics[f"{analyzer_id}_{key}"] = value except Exception as e: logger.warning( f"Analyzer {analyzer_id} failed for message " f"{conv_idx}_{msg_idx}: {e}" ) return MessageAnalysisResult( conversation_id=conversation_id, conversation_index=conv_idx, message_index=msg_idx, role=role, message_id=message_id, text_content=text_content, **{MessageAnalysisResult.ANALYZER_METRICS_FIELD: analyzer_metrics}, )
[docs] def query( self, query_expression: str, ) -> pd.DataFrame: """Query analysis results using pandas query expression. Args: query_expression: Pandas query expression to filter analysis results Please see pandas DataFrame query documentation for more information: https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.query.html Returns: DataFrame with filtered analysis results Examples: # Filter for short messages short_messages = analyzer.query("length_word_count < 10") # Filter for assistant messages assistant_messages = analyzer.query("role == 'assistant'") # Filter for long user messages long_user = analyzer.query("role == 'user' and length_word_count > 100") """ # Run analysis if not already done if self._analysis_df is None: logger.info("Analysis not yet run, starting analysis...") self.analyze_dataset() # After analysis, _analysis_df should be populated assert self._analysis_df is not None # Apply the query filter try: filtered_df = self._analysis_df.query(query_expression) logger.info(f"Query '{query_expression}' returned {len(filtered_df)} rows") except Exception as e: logger.error(f"Query failed: {e}") raise ValueError(f"Invalid query expression '{query_expression}': {e}") return filtered_df
[docs] def filter( self, query_expression: str, ) -> BaseMapDataset: """Filter the original dataset based on analysis results. This method uses analysis results to filter the original dataset, returning a new dataset object containing only the conversations that match the query. Args: query_expression: Pandas query expression to filter analysis results Returns: A new dataset object containing only the filtered conversations Examples:: # Filter for conversations with short messages short_dataset = analyzer.filter("length_word_count < 10") # Filter for conversations with assistant messages assistant_dataset = analyzer.filter("role == 'assistant'") # Filter for conversations with long user messages long_user_dataset = analyzer.filter( "role == 'user' and length_word_count > 100") """ # Get filtered analysis results filtered_df = self.query(query_expression) # Get unique conversation indices from filtered results conversation_indices = filtered_df.conversation_index.unique().tolist() # Create a new dataset with only the filtered conversations filtered_dataset = self._create_filtered_dataset(conversation_indices) logger.info( f"Filtered dataset: {len(conversation_indices)} conversations " f"out of {len(self.dataset)} total" ) return filtered_dataset
def _create_filtered_dataset( self, conversation_indices: list[int] ) -> BaseMapDataset: """Create a new dataset containing only the specified conversations. Args: conversation_indices: List of conversation indices to include Returns: A new dataset object with the same format as the original """ # Deep copy the original dataset to preserve all attributes and methods filtered_dataset = copy.deepcopy(self.dataset) # Filter the DataFrame to only include the specified conversations original_df = self.dataset.data filtered_dataset._data = original_df.iloc[conversation_indices].copy() # Update the dataset name to indicate it's filtered filtered_dataset.dataset_name = f"{self.dataset.dataset_name}_filtered" return filtered_dataset