# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""AWQ (Activation-aware Weight Quantization) quantizer implementation."""
import importlib
import importlib.util
import torch
from typing_extensions import override
from oumi.core.configs import QuantizationConfig
from oumi.quantize.base import BaseQuantization, QuantizationResult
from oumi.quantize.utils import format_size, get_directory_size
from oumi.utils.logging import logger
# AWQ configuration defaults
AWQ_DEFAULTS = {
"calibration_dataset": "pileval",
"calibration_split": "train",
"calibration_text_column": "text",
"max_calibration_seq_len": 512,
"duo_scaling": True,
"apply_clip": True,
"n_parallel_calib_samples": None,
}
[docs]
class AwqQuantization(BaseQuantization):
"""AWQ (Activation-aware Weight Quantization) implementation.
This class handles AWQ quantization with support for simulation mode
when AWQ libraries are not available.
"""
supported_methods = ["awq_q4_0", "awq_q4_1", "awq_q8_0", "awq_f16"]
supported_formats = ["safetensors"]
def __init__(self):
"""Initialize AWQ quantizer."""
if importlib.util.find_spec("awq") is not None:
self._awq = importlib.import_module("awq")
else:
self._awq = None
[docs]
@override
def raise_if_requirements_not_met(self):
"""Check if AWQ dependencies are available."""
if self._awq is None:
raise RuntimeError(
"AWQ quantization requires autoawq library.\n"
"Install with: `pip install oumi[quantization]`\n"
)
if not torch.cuda.is_available():
raise RuntimeError(
"AWQ quantization requires a GPU. "
"Please use a machine with at least 1 GPU."
)
[docs]
@override
def quantize(self, config: QuantizationConfig) -> QuantizationResult:
"""Main quantization method for AWQ.
Args:
config: Quantization configuration
Returns:
Dictionary containing quantization results
"""
self.validate_config(config)
logger.info("Starting AWQ quantization pipeline...")
# Step 1: AWQ quantization
model, tokenizer = self._quantize(config)
# Step 2: Save as PyTorch format
logger.info("PyTorch format requested. Saving AWQ model...")
model.save_quantized(config.output_path)
tokenizer.save_pretrained(config.output_path)
awq_size = get_directory_size(config.output_path)
logger.info("✅ AWQ quantization successful! Saved as PyTorch format.")
logger.info(f"📊 Quantized size: {format_size(awq_size)}")
logger.info(
f"💡 Use this model with: "
f"AutoAWQForCausalLM.from_quantized('{config.output_path}')"
)
quantization_result = QuantizationResult(
quantization_method=config.method,
quantized_size_bytes=awq_size,
output_path=config.output_path,
format_type=config.output_format,
)
return quantization_result
def _quantize(self, config: QuantizationConfig):
"""Quantize model using AWQ algorithm with calibration."""
from transformers import AutoTokenizer
logger.info(f"Loading model for AWQ quantization: {config.model.model_name}")
# 1. Load model and tokenizer
logger.info("📥 Loading base model...")
model_kwargs = {
"safetensors": True,
"trust_remote_code": config.model.trust_remote_code,
**(config.model.model_kwargs or {}),
}
model = self._awq.AutoAWQForCausalLM.from_pretrained( # type: ignore
config.model.model_name, **model_kwargs
)
tokenizer = AutoTokenizer.from_pretrained(
config.model.tokenizer_name or config.model.model_name,
trust_remote_code=config.model.trust_remote_code,
**(config.model.tokenizer_kwargs or {}),
)
logger.info("🔧 Configuring AWQ quantization parameters...")
# 2. Prepare quantization config
w_bit_dict = {
"awq_q4_0": 4,
"awq_q4_1": 4,
"awq_q8_0": 8,
"awq_f16": 16,
}
w_bit = w_bit_dict[config.method]
quant_config = {
"zero_point": config.awq_zero_point,
"q_group_size": config.awq_group_size,
"w_bit": w_bit,
"version": config.awq_version,
}
logger.info(f"⚙️ AWQ config: {quant_config}")
logger.info(f"📊 Using {config.calibration_samples} calibration samples")
logger.info("🧮 Starting AWQ calibration and quantization...")
# 3. Perform AWQ quantization with calibration
model.quantize(
tokenizer,
quant_config=quant_config,
calib_data=AWQ_DEFAULTS["calibration_dataset"],
split=AWQ_DEFAULTS["calibration_split"],
text_column=AWQ_DEFAULTS["calibration_text_column"],
max_calib_samples=config.calibration_samples,
max_calib_seq_len=AWQ_DEFAULTS["max_calibration_seq_len"],
duo_scaling=AWQ_DEFAULTS["duo_scaling"],
apply_clip=AWQ_DEFAULTS["apply_clip"],
n_parallel_calib_samples=AWQ_DEFAULTS["n_parallel_calib_samples"],
)
return model, tokenizer