Source code for oumi.quantize.bnb_quantizer
# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BitsAndBytes quantization implementation."""
import importlib.util
from pathlib import Path
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing_extensions import override
from oumi.core.configs import QuantizationConfig
from oumi.quantize.base import BaseQuantization, QuantizationResult
from oumi.quantize.utils import format_size, get_directory_size
from oumi.utils.logging import logger
[docs]
class BitsAndBytesQuantization(BaseQuantization):
"""BitsAndBytes quantization implementation.
This class handles quantization using the BitsAndBytes library,
supporting both 4-bit and 8-bit quantization methods.
"""
supported_methods = ["bnb_4bit", "bnb_8bit"]
supported_formats = ["safetensors"]
def __init__(self):
"""Initialize BitsAndBytes quantizer."""
self._bitsandbytes = importlib.util.find_spec("bitsandbytes")
[docs]
@override
def raise_if_requirements_not_met(self) -> None:
"""Check if BitsAndBytes dependencies are available.
Raises:
RuntimeError: If BitsAndBytes dependencies are not available.
"""
if self._bitsandbytes is None:
raise RuntimeError(
"BitsAndBytes quantization requires bitsandbytes library.\n"
"Install with: pip install bitsandbytes"
)
# Import to get version info
try:
import bitsandbytes # type: ignore
logger.info(f"BitsAndBytes library found: {bitsandbytes.__version__}")
except (ImportError, AttributeError):
logger.info("BitsAndBytes library found (version unknown)")
[docs]
@override
def quantize(self, config: QuantizationConfig) -> QuantizationResult:
"""Main quantization method for BitsAndBytes.
Args:
config: Quantization configuration
Returns:
QuantizationResult containing quantization results
"""
# Validate configuration for this quantizer
self.validate_config(config)
# Check requirements
self.raise_if_requirements_not_met()
logger.info("Starting BitsAndBytes quantization pipeline...")
# Perform quantization
model, tokenizer = self._quantize_model(config)
# Save model based on output format
output_path = self._save_model(model, tokenizer, config)
quantized_size = get_directory_size(output_path)
logger.info("✅ BitsAndBytes quantization successful!")
logger.info(f"📊 Quantized size: {format_size(quantized_size)}")
logger.info(f"💡 Model saved to: {output_path}")
return QuantizationResult(
quantization_method=config.method,
quantized_size_bytes=quantized_size,
output_path=output_path,
format_type=config.output_format,
)
def _quantize_model(self, config: QuantizationConfig):
"""Quantize model using BitsAndBytes."""
logger.info(
f"Loading model for BitsAndBytes quantization: {config.model.model_name}"
)
logger.info("📥 Loading base model...")
# Configure quantization based on method
quantization_config = self._get_quantization_config(config.method)
logger.info(f"🔧 Using {config.method} quantization")
# Load and quantize model
torch_dtype = config.model.torch_dtype
if torch_dtype == torch.float32:
torch_dtype = torch.float16
model = AutoModelForCausalLM.from_pretrained(
config.model.model_name,
quantization_config=quantization_config,
device_map=config.model.device_map,
torch_dtype=torch_dtype,
trust_remote_code=config.model.trust_remote_code,
**(config.model.model_kwargs or {}),
)
tokenizer = AutoTokenizer.from_pretrained(
config.model.tokenizer_name or config.model.model_name,
trust_remote_code=config.model.trust_remote_code,
**(config.model.tokenizer_kwargs or {}),
)
return model, tokenizer
def _get_quantization_config(self, method: str):
"""Get BitsAndBytes quantization config based on method."""
from transformers import BitsAndBytesConfig
if method == "bnb_4bit":
return BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
elif method == "bnb_8bit":
return BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
)
else:
raise ValueError(f"Unsupported BitsAndBytes method: {method}")
def _save_model(self, model, tokenizer, config: QuantizationConfig) -> str:
"""Save quantized model based on output format."""
# Ensure output directory exists
output_path = Path(config.output_path)
if output_path.suffix:
# If output_path has an extension, treat parent as directory
output_dir = output_path.parent
else:
# If no extension, treat as directory
output_dir = output_path
output_dir.mkdir(parents=True, exist_ok=True)
# Save based on format
logger.info(f"Saving quantized model to: {output_dir}")
model.save_pretrained(
str(output_dir),
safe_serialization=True, # use safetensors
)
tokenizer.save_pretrained(str(output_dir))
return str(output_dir)