bge-m3-onnx-int8 / export_onnx_int8.py
gpahal's picture
Upload folder using huggingface_hub
2b34e84 verified
import argparse
import copy
import logging
import os
from collections import OrderedDict
import torch
from huggingface_hub import snapshot_download
from optimum.exporters.onnx import onnx_export_from_model
from optimum.exporters.onnx.model_configs import XLMRobertaOnnxConfig
from optimum.exporters.tasks import TasksManager
from optimum.onnxruntime import ORTQuantizer
from optimum.onnxruntime.configuration import AutoQuantizationConfig
from torch import Tensor, nn
from transformers import AutoConfig, AutoModel
logger = logging.getLogger(__name__)
class BGEM3InferenceModel(nn.Module):
def __init__(
self,
model_name: str = "BAAI/bge-m3",
colbert_dim: int = -1,
) -> None:
super().__init__()
model_name = snapshot_download(
repo_id=model_name,
allow_patterns=[
"model.safetensors",
"colbert_linear.pt",
"sparse_linear.pt",
"config.json",
],
)
self.config = AutoConfig.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
self.colbert_linear = torch.nn.Linear(
in_features=self.model.config.hidden_size,
out_features=(
self.model.config.hidden_size if colbert_dim == -1 else colbert_dim
),
)
self.sparse_linear = torch.nn.Linear(
in_features=self.model.config.hidden_size, out_features=1
)
colbert_state_dict = torch.load(
os.path.join(model_name, "colbert_linear.pt"), map_location="cpu"
)
sparse_state_dict = torch.load(
os.path.join(model_name, "sparse_linear.pt"), map_location="cpu"
)
self.colbert_linear.load_state_dict(colbert_state_dict)
self.sparse_linear.load_state_dict(sparse_state_dict)
def dense_embedding(self, last_hidden_state: Tensor) -> Tensor:
return last_hidden_state[:, 0]
def sparse_embedding(self, last_hidden_state: Tensor) -> Tensor:
with torch.no_grad():
return torch.relu(self.sparse_linear(last_hidden_state))
def colbert_embedding(
self, last_hidden_state: Tensor, attention_mask: Tensor
) -> Tensor:
with torch.no_grad():
colbert_vecs = self.colbert_linear(last_hidden_state[:, 1:])
colbert_vecs = colbert_vecs * attention_mask[:, 1:][:, :, None].float()
return colbert_vecs
def forward(self, input_ids: Tensor, attention_mask: Tensor) -> dict[str, Tensor]:
with torch.no_grad():
last_hidden_state = self.model(
input_ids=input_ids, attention_mask=attention_mask, return_dict=True
).last_hidden_state
output = {}
dense_vecs = self.dense_embedding(last_hidden_state)
output["dense_vecs"] = torch.nn.functional.normalize(dense_vecs, dim=-1)
sparse_vecs = self.sparse_embedding(last_hidden_state)
output["sparse_vecs"] = sparse_vecs
colbert_vecs = self.colbert_embedding(last_hidden_state, attention_mask)
output["colbert_vecs"] = torch.nn.functional.normalize(colbert_vecs, dim=-1)
return output
class BGEM3OnnxConfig(XLMRobertaOnnxConfig):
@property
def outputs(self) -> dict[str, dict[int, str]]:
"""
Dict containing the axis definition of the output tensors to provide to the model.
Returns:
`Dict[str, Dict[int, str]]`: A mapping of each output name to a mapping of axis position to the axes symbolic name.
"""
return copy.deepcopy(
OrderedDict(
{
"dense_vecs": {0: "batch_size", 1: "embedding"},
"sparse_vecs": {0: "batch_size", 1: "token", 2: "weight"},
"colbert_vecs": {0: "batch_size", 1: "token", 2: "embedding"},
}
)
)
def main(output: str, opset: int, device: str, optimize: str, atol: str):
model = BGEM3InferenceModel()
bgem3_onnx_config = BGEM3OnnxConfig(model.config)
# Export to ONNX first
print("Exporting to ONNX...")
# Monkey-patch the library inference to return 'transformers'
original_infer = TasksManager.infer_library_from_model
TasksManager.infer_library_from_model = lambda model: "transformers"
try:
onnx_export_from_model(
model, # Use the full custom model
output=output,
task="feature-extraction",
custom_onnx_configs={"model": bgem3_onnx_config},
opset=opset,
optimize=optimize,
atol=atol,
device=device,
)
finally:
# Restore original function
TasksManager.infer_library_from_model = original_infer
print(f"ONNX model saved to: {output}")
# Apply quantization
print("Quantizing model...")
quantizer = ORTQuantizer.from_pretrained(output)
qconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=False)
print("Applying dynamic int8 quantization...")
quantized_path = f"{output}_int8"
quantizer.quantize(
save_dir=quantized_path,
quantization_config=qconfig
)
print(f"Quantized model saved to: {quantized_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--output",
type=str,
default="onnx_model",
help="Path indicating the directory where to store the generated ONNX model.",
)
parser.add_argument(
"--opset",
type=int,
default=None,
help="If specified, ONNX opset version to export the model with. Otherwise, the default opset for the given model architecture will be used.",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
help='The device to use to do the export. Defaults to "cpu".',
)
parser.add_argument(
"--optimize",
type=str,
default=None,
choices=["O1", "O2", "O3", "O4"],
help=(
"Allows to run ONNX Runtime optimizations directly during the export. Some of these optimizations are specific to ONNX Runtime, and the resulting ONNX will not be usable with other runtime as OpenVINO or TensorRT. Possible options:\n"
" - O1: Basic general optimizations\n"
" - O2: Basic and extended general optimizations, transformers-specific fusions\n"
" - O3: Same as O2 with GELU approximation\n"
" - O4: Same as O3 with mixed precision (fp16, GPU-only, requires `--device cuda`)"
),
)
parser.add_argument(
"--atol",
type=float,
default=None,
help="If specified, the absolute difference tolerance when validating the model. Otherwise, the default atol for the model will be used.",
)
args = parser.parse_args()
main(args.output, args.opset, args.device, args.optimize, args.atol)