Source code for oumi.core.callbacks.bitnet_callback
# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Simple Bitnet model saving callback."""
from pathlib import Path
from typing import Optional, Union
import transformers
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from oumi.core.callbacks.base_trainer_callback import BaseTrainerCallback
from oumi.core.configs import TrainingParams
# Import `onebitllms` utils methods
try:
import onebitllms # type: ignore
from onebitllms import quantize_to_1bit # type: ignore
except ImportError:
onebitllms = None
[docs]
class BitNetCallback(BaseTrainerCallback):
"""BitNet model saving callback.
Simple callback that saves the model into BitNet quantized
format during training.
"""
[docs]
def on_save(
self,
args: Union[transformers.TrainingArguments, TrainingParams],
state: Optional[transformers.TrainerState] = None,
control: Optional[transformers.TrainerControl] = None,
**kwargs,
):
"""Saving callback.
Gets triggered at each saving step to quantize trained models
in 1bit precision.
"""
if onebitllms is None:
raise ValueError(
"""You need `onebitllms` to be installed in order to save
correctly BitNet models - `pip install onebitllms`"""
)
output_dir = Path(args.output_dir) # type: ignore
quantized_subdir = Path(
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}-quantized" # type: ignore
)
output_subdir = Path(f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}") # type: ignore
checkpoint_folder = output_dir / output_subdir
quantized_checkpoint_folder = output_dir / quantized_subdir
quantize_to_1bit(str(checkpoint_folder), str(quantized_checkpoint_folder))