Llama-4-Scout-17B-16E with MXFP4 weights and activations
MXFP4 Quantization
This model is quantized with AMD Quark with the following script:
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer, Llama4ForConditionalGeneration
from quark.torch import ModelQuantizer
from quark.torch import ModelQuantizer, ModelExporter
from quark.torch.export import ExporterConfig, JsonExporterConfig
from quark.torch.quantization import Config, QuantizationConfig, FP8E4M3PerTensorSpec
from customized_configuration import get_global_config
model_dir = MODEL_DIR
export_path = EXPORT_DIR
quant_scheme = 'w_mxfp4_a_mxfp4_sym'
BATCH_SIZE = 1
NUM_CALIBRATION_DATA = 128
MAX_SEQ_LEN = 512
model = Llama4ForConditionalGeneration.from_pretrained(model_dir, torch_dtype=torch.bfloat16, device_map="auto")
model.eval()
print(f"device_map: {model.hf_device_map}")
tokenizer = AutoTokenizer.from_pretrained(model_dir)
# Load the dataset and get calibration data.
dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation")
text_data = dataset["text"][:NUM_CALIBRATION_DATA]
tokenized_outputs = tokenizer(text_data,
return_tensors="pt",
padding=True,
truncation=True,
max_length=MAX_SEQ_LEN).to(model.device)
calib_dataloader = DataLoader(tokenized_outputs['input_ids'],
batch_size=BATCH_SIZE,
drop_last=True)
global_quant_config = get_global_config(quant_scheme,
group_size=32,
scale_format="e8m0",
scale_calculation_mode="even")
# vLLM hardcodes quant_config=None for multi_modal_projector, vision_model and routers.
exclude_layers = []
for name, module in model.named_modules():
if name.endswith(".router"):
exclude_layers.append(name)
elif name.startswith("multi_modal_projector.") or name.startswith(
"vision_model."):
exclude_layers.append(name)
# Define quantization config for kv-cache layers, output tensors apply FP8_PER_TENSOR_SPEC.
KV_CACHE_SPEC = FP8E4M3PerTensorSpec(
observer_method="min_max", is_dynamic=False).to_quantization_spec()
kv_cache_layer_names_for_llama = ["*k_proj", "*v_proj"]
kv_cache_quant_config = {
name:
QuantizationConfig(input_tensors=global_quant_config.input_tensors,
weight=global_quant_config.weight,
output_tensors=KV_CACHE_SPEC)
for name in kv_cache_layer_names_for_llama
}
layer_quant_config = kv_cache_quant_config.copy()
quant_config = Config(
global_quant_config=global_quant_config,
layer_quant_config=layer_quant_config,
exclude=exclude_layers + ["language_model.lm_head"],
)
print(quant_config)
quantizer = ModelQuantizer(quant_config, multi_device=True)
qmodel = quantizer.quantize_model(model, calib_dataloader)
# Freeze quantized model to export.
freezed_model = quantizer.freeze(qmodel)
# Define export config.
export_config = ExporterConfig(json_export_config=JsonExporterConfig())
export_config.json_export_config.kv_cache_group = kv_cache_layer_names_for_llama
exporter = ModelExporter(config=export_config, export_dir=export_path)
with torch.no_grad():
exporter.export_safetensors_model(freezed_model,
quant_config=quant_config,
tokenizer=AutoTokenizer.from_pretrained(model_dir))
- Downloads last month
- 0
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support
Model tree for hann-wang/Llama-4-Scout-17B-16E-WMXFP4-AMXFP4-KVFP8
Base model
meta-llama/Llama-4-Scout-17B-16E