Source code for oumi.core.synthesis.attribute_formatter

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from oumi.core.configs.params.synthesis_params import GeneralSynthesisParams
from oumi.utils.placeholders import resolve_placeholders


class _AttributeValueInfo:
    """Information about a value of a permutable attribute.

    Used to format the string for a sample.
    """

    def __init__(self, value_name: str, value_description: str):
        """Initialize the attribute value info."""
        self._value_name = value_name
        self.description = value_description

    def __str__(self) -> str:
        return self._value_name


class _AttributeInfo:
    """Information about a permutable attribute.

    Used to format the string for a sample.
    """

    def __init__(
        self,
        attribute_id: str,
        attribute_name: str,
        attribute_description: str,
        value_name: str,
        value_description: str,
    ):
        """Initialize the attribute value info."""
        self.attribute_id = attribute_id
        self._attribute_name = attribute_name
        self.description = attribute_description
        self.value = _AttributeValueInfo(value_name, value_description)

    def __str__(self) -> str:
        return self._attribute_name


[docs] class AttributeFormatter: """Formats a sample using a format string. Integrates information from permutable attributes to support formatting of placeholders in the format string (i.e. {attribute_id.value}). """ def __init__(self, params: GeneralSynthesisParams): """Initialize the formatter.""" self._params = params self._permutable_attribute_map = ( {perm_attr.id: perm_attr for perm_attr in params.permutable_attributes} if params.permutable_attributes else {} ) self._permutable_attribute_info = {} # Pre-compute the attribute info for each possible value for attribute_id, attribute in self._permutable_attribute_map.items(): for value in attribute.possible_values: key = (attribute_id, value.id) self._permutable_attribute_info[key] = _AttributeInfo( attribute_id=attribute_id, attribute_name=attribute.attribute, attribute_description=attribute.description, value_name=value.value, value_description=value.description, )
[docs] def format( self, sample: dict[str, str], format_string: str, missing_values_allowed: bool = False, ) -> str: """Format a sample using a format string. Args: sample: The sample to format. format_string: The format string to use. missing_values_allowed: If True, missing values are allowed in the sample. Returns: The formatted string. """ attr_values = {} for attribute_id, attribute_value in sample.items(): if self._is_permutable_attribute(attribute_id): value_id = attribute_value attr_values[attribute_id] = self._get_permutable_attribute_value_info( attribute_id, value_id ) else: attr_values[attribute_id] = attribute_value formatted_string = resolve_placeholders( format_string, attr_values, missing_values_allowed=missing_values_allowed, ) return formatted_string
def _is_permutable_attribute(self, attribute_id: str) -> bool: """Check if the attribute is a permutable attribute.""" return attribute_id in self._permutable_attribute_map def _get_permutable_attribute_value_info( self, attribute_id: str, attribute_value_id: str ) -> _AttributeInfo: """Get the string representation information for a permutable attribute.""" key = (attribute_id, attribute_value_id) if key in self._permutable_attribute_info: return self._permutable_attribute_info[key] raise ValueError( f"Attribute value {attribute_value_id} not found for " f"attribute {attribute_id}" )