# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
import warnings
from pathlib import Path
from typing import Optional, Union
[docs]
def get_logger(
name: str,
level: str = "info",
log_dir: Optional[Union[str, Path]] = None,
) -> logging.Logger:
"""Gets a logger instance with the specified name and log level.
Args:
name : The name of the logger.
level (optional): The log level to set for the logger. Defaults to "info".
log_dir (optional): Directory to store log files. Defaults to None.
Returns:
logging.Logger: The logger instance.
"""
if name not in logging.Logger.manager.loggerDict:
configure_logger(name, level=level, log_dir=log_dir)
logger = logging.getLogger(name)
return logger
def _detect_rank() -> int:
"""Detects rank.
Reading the rank from the environment variables instead of
get_device_rank_info to avoid circular imports.
"""
for var_name in (
"RANK",
"SKYPILOT_NODE_RANK", # SkyPilot
"PMI_RANK", # HPC
):
rank = os.environ.get(var_name, None)
if rank is not None:
rank = int(rank)
if rank < 0:
raise ValueError(f"Negative rank: {rank} specified in '{var_name}'!")
return rank
return 0
def _should_use_rich_logging() -> bool:
"""Determines if rich logging should be used based on environment variables.
Note: Rich logging is experimental, and may be removed in the future.
Currently it is disabled by default.
"""
# Check if explicitly disabled
if os.environ.get("OUMI_ENABLE_RICH_LOGGING", "").lower() in (
"1",
"yes",
"on",
"true",
"y",
):
return sys.stdout.isatty() # is in a terminal
return False
def _configure_rich_handler(
device_rank: int,
level: str,
) -> logging.Handler:
"""Configures a rich logging handler."""
try:
from rich.console import Console
from rich.logging import RichHandler
from rich.traceback import install
except ImportError:
raise ImportError(
"Rich logging is not installed. Please install it with `pip install rich`."
)
use_detailed_logging = level.upper() == "DEBUG"
if use_detailed_logging:
# Add extra logging for debugging
install(show_locals=True, suppress=[])
console = Console()
console_handler = RichHandler(
console=console,
show_time=True,
show_level=True,
show_path=True,
enable_link_path=True,
markup=False,
rich_tracebacks=use_detailed_logging,
tracebacks_show_locals=use_detailed_logging,
locals_max_length=20,
locals_max_string=80,
)
if use_detailed_logging:
rich_formatter = logging.Formatter(
f"[rank-{device_rank}][pid-%(process)d][%(threadName)s] %(message)s"
)
else:
rich_formatter = logging.Formatter(f"[rank-{device_rank}] %(message)s")
console_handler.setFormatter(rich_formatter)
return console_handler
[docs]
def update_logger_level(name: str, level: str = "info") -> None:
"""Updates the log level of the logger.
Args:
name (str): The logger instance to update.
level (str, optional): The log level to set for the logger. Defaults to "info".
"""
logger = get_logger(name, level=level)
logger.setLevel(level.upper())
for handler in logger.handlers:
handler.setLevel(level.upper())
# Default logger for the package
logger = get_logger("oumi")