llava-v1.5-7b-gpt4OCR
Collection
Collection of checkpoints for llava-v1.5-7b fine-tuned for OCR tasks.
•
5 items
•
Updated
Finetuned version of liuhaotian/llava-v1.5-7b
on ocr and object detection data using LoRA (adapter config at bottom) to improve OCR captioning abilities and bounding box generation.
The two datasets used for fine tuning are:
We use 10k samples from GRIT where each sample has an image-caption CLIP similarity larger than 0.35 and where the caption does not contain any proper nouns (filtered using spaCy).
Use the code below to get started with the model:
from llama_cpp import Llama
from llama_cpp.llama_chat_format import Llava15ChatHandler
import io
from PIL import Image
import base64
import os
import json
import argparse
import time
from pathlib import Path
chat_handler = Llava15ChatHandler(clip_model_path="your-path-to-mmproj-model-f16.gguf")
llm = Llama(
model_path="your-path-to-llava-v1.5-7b-ocr-pretrain.Q2_K.gguf",
chat_handler=chat_handler,
n_ctx=2048, # n_ctx should be increased to accomodate the image embedding
logits_all=True,# needed to make llava work
n_gpu_layers=-1,
)
def inference(url, question):
start = time.perf_counter()
output = llm.create_chat_completion(
messages = [
{"role": "system", "content": "You are an assistant who answers user questions"},
{
"role": "user",
"content": [
{"type": "image_url",
"image_url": {
"url": url
}
},
{"type" : "text", "text": question}
]
}
],
temperature=1.0,
)
stop = time.perf_counter()
return {
**output,
"completion_time": stop-start,
"tokens-per-second": output["usage"]["completion_tokens"]/(stop-start),
}
def img_base64(path):
ext = path.suffix
path = str(path)
with open(str(path), 'rb') as f:
data = f.read()
return f'data:image/{ext};base64,' + base64.b64encode(data).decode('utf8')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--url', help="url of an image for inference", type=str, default = "https://adquick-public.imgix.net/landing+images/media_formats/billboard-carvana.png?auto=format")
parser.add_argument('--question', '-q', type=str, default="generate a descriptive caption for this image.")
args = parser.parse_args()
url = args.url
# hope this works for local images
if url.startswith('/') or url.startswith('./') or url.startswith("../"):
url = img_base64(Path(url))
# print(url)
outputs = inference(url, args.question)
print(json.dumps(outputs, indent=4))
Output:
The image features an advertisement billboard with a vibrant red car being transported by trucks, set against a clear sky backdrop. Emblazoned across the top is the phrase \"Buy your next CAR from your COUCH.\" Below, the text reads \"CARVANA\" on the left and \"carvana\" on the right side, while a smaller \"Bluegrass Outdoor\" logo is noted at the bottom center.
2-bit