# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import typer
from oumi.cli.cli_utils import CONSOLE, CONTEXT_ALLOW_EXTRA_ARGS
from oumi.cli.distributed_run import accelerate, torchrun
from oumi.cli.env import env
from oumi.cli.evaluate import evaluate
from oumi.cli.fetch import fetch
from oumi.cli.infer import infer
from oumi.cli.judge import conversations, dataset, model
from oumi.cli.judge_v2 import judge_file
from oumi.cli.launch import cancel, down, status, stop, up, which
from oumi.cli.launch import run as launcher_run
from oumi.cli.train import train
_ASCII_LOGO = r"""
____ _ _ __ __ _____
/ __ \| | | | \/ |_ _|
| | | | | | | \ / | | |
| | | | | | | |\/| | | |
| |__| | |__| | | | |_| |_
\____/ \____/|_| |_|_____|
"""
[docs]
def experimental_judge_v2_enabled():
"""Check if the experimental judge v2 feature is enabled."""
is_enabled = os.environ.get("OUMI_EXPERIMENTAL_JUDGE_V2", "False")
return is_enabled.lower() in ("1", "true", "yes", "on")
def _oumi_welcome(ctx: typer.Context):
if ctx.invoked_subcommand == "distributed":
return
# Skip logo for rank>0 for multi-GPU jobs to reduce noise in logs.
if int(os.environ.get("RANK", 0)) > 0:
return
CONSOLE.print(_ASCII_LOGO, style="green", highlight=False)
[docs]
def get_app() -> typer.Typer:
"""Create the Typer CLI app."""
app = typer.Typer(pretty_exceptions_enable=False)
app.callback(context_settings={"help_option_names": ["-h", "--help"]})(
_oumi_welcome
)
app.command(
context_settings=CONTEXT_ALLOW_EXTRA_ARGS,
help="Evaluate a model.",
)(evaluate)
app.command()(env)
app.command( # Alias for evaluate
name="eval",
hidden=True,
context_settings=CONTEXT_ALLOW_EXTRA_ARGS,
help="Evaluate a model.",
)(evaluate)
app.command(
context_settings=CONTEXT_ALLOW_EXTRA_ARGS,
help="Run inference on a model.",
)(infer)
app.command(
context_settings=CONTEXT_ALLOW_EXTRA_ARGS,
help="Train a model.",
)(train)
if experimental_judge_v2_enabled():
app.command(
name="judge-v2",
context_settings=CONTEXT_ALLOW_EXTRA_ARGS,
help="Judge a dataset.",
)(judge_file)
judge_app = typer.Typer(pretty_exceptions_enable=False)
judge_app.command(context_settings=CONTEXT_ALLOW_EXTRA_ARGS)(conversations)
judge_app.command(context_settings=CONTEXT_ALLOW_EXTRA_ARGS)(dataset)
judge_app.command(context_settings=CONTEXT_ALLOW_EXTRA_ARGS)(model)
app.add_typer(
judge_app, name="judge", help="Judge datasets, models or conversations."
)
launch_app = typer.Typer(pretty_exceptions_enable=False)
launch_app.command(help="Cancels a job.")(cancel)
launch_app.command(help="Turns down a cluster.")(down)
launch_app.command(
name="run", context_settings=CONTEXT_ALLOW_EXTRA_ARGS, help="Runs a job."
)(launcher_run)
launch_app.command(help="Prints the status of jobs launched from Oumi.")(status)
launch_app.command(help="Stops a cluster.")(stop)
launch_app.command(
context_settings=CONTEXT_ALLOW_EXTRA_ARGS, help="Launches a job."
)(up)
launch_app.command(help="Prints the available clouds.")(which)
app.add_typer(launch_app, name="launch", help="Launch jobs remotely.")
distributed_app = typer.Typer(pretty_exceptions_enable=False)
distributed_app.command(context_settings=CONTEXT_ALLOW_EXTRA_ARGS)(accelerate)
distributed_app.command(context_settings=CONTEXT_ALLOW_EXTRA_ARGS)(torchrun)
app.add_typer(
distributed_app,
name="distributed",
help=(
"A wrapper for torchrun/accelerate "
"with reasonable default values for distributed training."
),
)
app.command(
help="Fetch configuration files from the oumi GitHub repository.",
)(fetch)
return app
[docs]
def run():
"""The entrypoint for the CLI."""
app = get_app()
return app()
if "sphinx" in sys.modules:
# Create the CLI app when building the docs to auto-generate the CLI reference.
app = get_app()