Source code for oumi.cli.main

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys

import typer

from oumi.cli.cli_utils import CONSOLE, CONTEXT_ALLOW_EXTRA_ARGS
from oumi.cli.distributed_run import accelerate, torchrun
from oumi.cli.env import env
from oumi.cli.evaluate import evaluate
from oumi.cli.fetch import fetch
from oumi.cli.infer import infer
from oumi.cli.judge import conversations, dataset, model
from oumi.cli.judge_v2 import judge_file
from oumi.cli.launch import cancel, down, status, stop, up, which
from oumi.cli.launch import run as launcher_run
from oumi.cli.train import train

_ASCII_LOGO = r"""
   ____  _    _ __  __ _____
  / __ \| |  | |  \/  |_   _|
 | |  | | |  | | \  / | | |
 | |  | | |  | | |\/| | | |
 | |__| | |__| | |  | |_| |_
  \____/ \____/|_|  |_|_____|
"""


[docs] def experimental_judge_v2_enabled(): """Check if the experimental judge v2 feature is enabled.""" is_enabled = os.environ.get("OUMI_EXPERIMENTAL_JUDGE_V2", "False") return is_enabled.lower() in ("1", "true", "yes", "on")
def _oumi_welcome(ctx: typer.Context): if ctx.invoked_subcommand == "distributed": return # Skip logo for rank>0 for multi-GPU jobs to reduce noise in logs. if int(os.environ.get("RANK", 0)) > 0: return CONSOLE.print(_ASCII_LOGO, style="green", highlight=False)
[docs] def get_app() -> typer.Typer: """Create the Typer CLI app.""" app = typer.Typer(pretty_exceptions_enable=False) app.callback(context_settings={"help_option_names": ["-h", "--help"]})( _oumi_welcome ) app.command( context_settings=CONTEXT_ALLOW_EXTRA_ARGS, help="Evaluate a model.", )(evaluate) app.command()(env) app.command( # Alias for evaluate name="eval", hidden=True, context_settings=CONTEXT_ALLOW_EXTRA_ARGS, help="Evaluate a model.", )(evaluate) app.command( context_settings=CONTEXT_ALLOW_EXTRA_ARGS, help="Run inference on a model.", )(infer) app.command( context_settings=CONTEXT_ALLOW_EXTRA_ARGS, help="Train a model.", )(train) if experimental_judge_v2_enabled(): app.command( name="judge-v2", context_settings=CONTEXT_ALLOW_EXTRA_ARGS, help="Judge a dataset.", )(judge_file) judge_app = typer.Typer(pretty_exceptions_enable=False) judge_app.command(context_settings=CONTEXT_ALLOW_EXTRA_ARGS)(conversations) judge_app.command(context_settings=CONTEXT_ALLOW_EXTRA_ARGS)(dataset) judge_app.command(context_settings=CONTEXT_ALLOW_EXTRA_ARGS)(model) app.add_typer( judge_app, name="judge", help="Judge datasets, models or conversations." ) launch_app = typer.Typer(pretty_exceptions_enable=False) launch_app.command(help="Cancels a job.")(cancel) launch_app.command(help="Turns down a cluster.")(down) launch_app.command( name="run", context_settings=CONTEXT_ALLOW_EXTRA_ARGS, help="Runs a job." )(launcher_run) launch_app.command(help="Prints the status of jobs launched from Oumi.")(status) launch_app.command(help="Stops a cluster.")(stop) launch_app.command( context_settings=CONTEXT_ALLOW_EXTRA_ARGS, help="Launches a job." )(up) launch_app.command(help="Prints the available clouds.")(which) app.add_typer(launch_app, name="launch", help="Launch jobs remotely.") distributed_app = typer.Typer(pretty_exceptions_enable=False) distributed_app.command(context_settings=CONTEXT_ALLOW_EXTRA_ARGS)(accelerate) distributed_app.command(context_settings=CONTEXT_ALLOW_EXTRA_ARGS)(torchrun) app.add_typer( distributed_app, name="distributed", help=( "A wrapper for torchrun/accelerate " "with reasonable default values for distributed training." ), ) app.command( help="Fetch configuration files from the oumi GitHub repository.", )(fetch) return app
[docs] def run(): """The entrypoint for the CLI.""" app = get_app() return app()
if "sphinx" in sys.modules: # Create the CLI app when building the docs to auto-generate the CLI reference. app = get_app()