Source code for oumi.launcher.clouds.frontier_cloud

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
from dataclasses import dataclass
from typing import Optional

from oumi.core.configs import JobConfig
from oumi.core.launcher import BaseCloud, BaseCluster, JobStatus
from oumi.core.registry import register_cloud_builder
from oumi.launcher.clients.slurm_client import SlurmClient
from oumi.launcher.clusters.frontier_cluster import FrontierCluster
from oumi.utils.logging import logger


@dataclass
class _ClusterInfo:
    """Dataclass to hold information about a cluster."""

    queue: str
    user: str

    @property
    def name(self):
        return f"{self.queue}.{self.user}"


_FRONTIER_HOSTNAME = "frontier.olcf.ornl.gov"


[docs] class FrontierCloud(BaseCloud): """A resource pool for managing the OLCF Frontier job queues.""" def __init__(self): """Initializes a new instance of the FrontierCloud class.""" # A mapping from user names to Frontier Clients. self._clients = collections.OrderedDict() # A mapping from cluster names to Frontier Cluster instances. self._clusters = collections.OrderedDict() # Check if any users have open SSH tunnels to Frontier. for user in SlurmClient.get_active_users(_FRONTIER_HOSTNAME): self.initialize_clusters(user) def _parse_cluster_name(self, name: str) -> _ClusterInfo: """Parses the cluster name into queue and user components. Args: name: The name of the cluster. Returns: _ClusterInfo: The parsed cluster information. """ name_splits = name.split(".") if len(name_splits) != 2: raise ValueError( f"Invalid cluster name: {name}. Must be in the format 'queue.user'." ) queue, user = name_splits return _ClusterInfo(queue, user) def _get_or_create_client(self, cluster_info: _ClusterInfo) -> SlurmClient: """Gets the client for the specified user, or creates one if it doesn't exist. Args: cluster_info: The cluster information. Returns: SlurmClient: The client instance. """ if cluster_info.user not in self._clients: self._clients[cluster_info.user] = SlurmClient( cluster_info.user, _FRONTIER_HOSTNAME, cluster_info.name ) return self._clients[cluster_info.user] def _get_or_create_cluster(self, name: str) -> FrontierCluster: """Gets the cluster with the specified name, or creates one if it doesn't exist. Args: name: The name of the cluster. Returns: FrontierCluster: The cluster instance. """ if name not in self._clusters: cluster_info = self._parse_cluster_name(name) self._clusters[name] = FrontierCluster( name, self._get_or_create_client(cluster_info) ) return self._clusters[name]
[docs] def initialize_clusters(self, user) -> list[BaseCluster]: """Initializes clusters for the specified user for all Frontier queues. Args: user: The user to initialize clusters for. Returns: List[FrontierCluster]: The list of initialized clusters. """ clusters = [] for q in sorted({q.value for q in FrontierCluster.SupportedQueues}): name = f"{q}.{user}" cluster = self._get_or_create_cluster(name) clusters.append(cluster) return clusters
[docs] def up_cluster(self, job: JobConfig, name: Optional[str], **kwargs) -> JobStatus: """Creates a cluster and starts the provided Job.""" if not job.user: raise ValueError("User must be provided in the job config.") # The default queue is BATCH. cluster_info = _ClusterInfo( FrontierCluster.SupportedQueues.BATCH.value, job.user ) if name: cluster_info = self._parse_cluster_name(name) if cluster_info.user != job.user: raise ValueError( f"Invalid cluster name: {name}. " "User must match the provided job user." ) else: logger.warning( "No cluster name provided. Using default queue: " f"{FrontierCluster.SupportedQueues.BATCH.value}." ) cluster = self._get_or_create_cluster(cluster_info.name) job_status = cluster.run_job(job) if not job_status: raise RuntimeError("Failed to start job.") return job_status
[docs] def get_cluster(self, name) -> Optional[BaseCluster]: """Gets the cluster with the specified name, or None if not found.""" clusters = self.list_clusters() for cluster in clusters: if cluster.name() == name: return cluster return None
[docs] def list_clusters(self) -> list[BaseCluster]: """Lists the active clusters on this cloud.""" return list(self._clusters.values())
@register_cloud_builder("frontier") def frontier_cloud_builder() -> FrontierCloud: """Builds a FrontierCloud instance.""" return FrontierCloud()