LVM-Med / onnx_model /torch2onnx.py
duynhm's picture
Initial commit
be2715b
raw
history blame
1.55 kB
import torch.onnx
import torch
from torch import nn
from torch.nn import functional as F
import onnx
from transformers import AutoModel
def export_onnx(example_input: torch.Tensor,
model,
onnx_model_name) -> None:
torch.onnx.export(
model,
example_input,
onnx_model_name,
export_params=False,
opset_version=10,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input' : {
0 : 'batch_size'
},
'output' : {
0 : 'batch_size'
}
}
)
if __name__ == "__main__":
"""
Export LVM-Med (RN50 version)
"""
example_input_rn50 = torch.ones(1, 3, 1024, 1024)
lvmmed_rn50 = AutoModel.from_pretrained('ngctnnnn/lvmmed_rn50')
example_output_rn50 = lvmmed_rn50(example_input_rn50)['pooler_output']
print(f"Example output for LVM-Med (RN50)'s shape: {example_output_rn50.shape}")
export_onnx(example_input_rn50, lvmmed_rn50, onnx_model_name="onnx_model/lvmmed_rn50.onnx")
"""
Export LVM-Med (ViT)
"""
example_input_vit = torch.ones(1, 3, 224, 224)
lvmmed_vit = AutoModel.from_pretrained('ngctnnnn/lvmmed_vit')
example_output_vit = lvmmed_vit(example_input_vit)['pooler_output']
print(f"Example output for LVM-Med (RN50)'s shape: {example_output_vit.shape}")
export_onnx(example_input_vit, lvmmed_vit, onnx_model_name="onnx_model/lvmmed_vit.onnx")