import torch.onnx import torch from torch import nn from torch.nn import functional as F import onnx from transformers import AutoModel def export_onnx(example_input: torch.Tensor, model, onnx_model_name) -> None: torch.onnx.export( model, example_input, onnx_model_name, export_params=False, opset_version=10, do_constant_folding=True, input_names=['input'], output_names=['output'], dynamic_axes={ 'input' : { 0 : 'batch_size' }, 'output' : { 0 : 'batch_size' } } ) if __name__ == "__main__": """ Export LVM-Med (RN50 version) """ example_input_rn50 = torch.ones(1, 3, 1024, 1024) lvmmed_rn50 = AutoModel.from_pretrained('ngctnnnn/lvmmed_rn50') example_output_rn50 = lvmmed_rn50(example_input_rn50)['pooler_output'] print(f"Example output for LVM-Med (RN50)'s shape: {example_output_rn50.shape}") export_onnx(example_input_rn50, lvmmed_rn50, onnx_model_name="onnx_model/lvmmed_rn50.onnx") """ Export LVM-Med (ViT) """ example_input_vit = torch.ones(1, 3, 224, 224) lvmmed_vit = AutoModel.from_pretrained('ngctnnnn/lvmmed_vit') example_output_vit = lvmmed_vit(example_input_vit)['pooler_output'] print(f"Example output for LVM-Med (RN50)'s shape: {example_output_vit.shape}") export_onnx(example_input_vit, lvmmed_vit, onnx_model_name="onnx_model/lvmmed_vit.onnx")