Llama-4-Scout-17B-16E with MXFP4 weights and activations

MXFP4 Quantization

This model is quantized with AMD Quark with the following script:

import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer, Llama4ForConditionalGeneration
from quark.torch import ModelQuantizer
from quark.torch import ModelQuantizer, ModelExporter
from quark.torch.export import ExporterConfig, JsonExporterConfig
from quark.torch.quantization import Config, QuantizationConfig, FP8E4M3PerTensorSpec
from customized_configuration import get_global_config
model_dir = MODEL_DIR
export_path = EXPORT_DIR
quant_scheme = 'w_mxfp4_a_mxfp4_sym'
BATCH_SIZE = 1
NUM_CALIBRATION_DATA = 128
MAX_SEQ_LEN = 512
model = Llama4ForConditionalGeneration.from_pretrained(model_dir, torch_dtype=torch.bfloat16, device_map="auto")
model.eval()
print(f"device_map: {model.hf_device_map}")
tokenizer = AutoTokenizer.from_pretrained(model_dir)
# Load the dataset and get calibration data.
dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation")
text_data = dataset["text"][:NUM_CALIBRATION_DATA]
tokenized_outputs = tokenizer(text_data,
                              return_tensors="pt",
                              padding=True,
                              truncation=True,
                              max_length=MAX_SEQ_LEN).to(model.device)
calib_dataloader = DataLoader(tokenized_outputs['input_ids'],
                              batch_size=BATCH_SIZE,
                              drop_last=True)
global_quant_config = get_global_config(quant_scheme,
                                        group_size=32,
                                        scale_format="e8m0",
                                        scale_calculation_mode="even")
# vLLM hardcodes quant_config=None for multi_modal_projector, vision_model and routers.
exclude_layers = []
for name, module in model.named_modules():
    if name.endswith(".router"):
        exclude_layers.append(name)
    elif name.startswith("multi_modal_projector.") or name.startswith(
            "vision_model."):
        exclude_layers.append(name)
# Define quantization config for kv-cache layers, output tensors apply FP8_PER_TENSOR_SPEC.
KV_CACHE_SPEC = FP8E4M3PerTensorSpec(
    observer_method="min_max", is_dynamic=False).to_quantization_spec()
kv_cache_layer_names_for_llama = ["*k_proj", "*v_proj"]
kv_cache_quant_config = {
    name:
    QuantizationConfig(input_tensors=global_quant_config.input_tensors,
                       weight=global_quant_config.weight,
                       output_tensors=KV_CACHE_SPEC)
    for name in kv_cache_layer_names_for_llama
}
layer_quant_config = kv_cache_quant_config.copy()
quant_config = Config(
    global_quant_config=global_quant_config,
    layer_quant_config=layer_quant_config,
    exclude=exclude_layers + ["language_model.lm_head"],
)
print(quant_config)
quantizer = ModelQuantizer(quant_config, multi_device=True)
qmodel = quantizer.quantize_model(model, calib_dataloader)
# Freeze quantized model to export.
freezed_model = quantizer.freeze(qmodel)
# Define export config.
export_config = ExporterConfig(json_export_config=JsonExporterConfig())
export_config.json_export_config.kv_cache_group = kv_cache_layer_names_for_llama
exporter = ModelExporter(config=export_config, export_dir=export_path)
with torch.no_grad():
    exporter.export_safetensors_model(freezed_model,
                                      quant_config=quant_config,
                                      tokenizer=AutoTokenizer.from_pretrained(model_dir))
Downloads last month
0
Safetensors
Model size
59.1B params
Tensor type
BF16
·
U8
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for hann-wang/Llama-4-Scout-17B-16E-WMXFP4-AMXFP4-KVFP8

Quantized
(12)
this model

Dataset used to train hann-wang/Llama-4-Scout-17B-16E-WMXFP4-AMXFP4-KVFP8