Multicentury-HTR-Demo / onnx_text_recognition.py
MikkoLipsanen's picture
Speed up crop_line function
37b60ea verified
raw
history blame
5.25 kB
from optimum.onnxruntime import ORTModelForVision2Seq
from transformers import TrOCRProcessor
import numpy as np
import onnxruntime
import math
import cv2
import os
class TextRecognition:
def __init__(self,
processor_path,
model_path,
device = 'cuda:0',
half_precision = False,
line_threshold = 10):
self.device = device
self.half_precision = half_precision
self.line_threshold = line_threshold
self.processor_path = processor_path
self.model_path = model_path
self.processor = self.init_processor()
self.recognition_model = self.init_recognition_model()
def init_processor(self):
"""Function for initializing the processor."""
try:
processor = TrOCRProcessor.from_pretrained(self.processor_path, token=True)
return processor
except Exception as e:
print('Failed to initialize processor: %s' % e)
def init_recognition_model(self):
"""Function for initializing the text detection model."""
sess_options = onnxruntime.SessionOptions()
sess_options.intra_op_num_threads = 3
sess_options.inter_op_num_threads = 3
try:
recognition_model = ORTModelForVision2Seq.from_pretrained(self.model_path, token=True, session_options=sess_options, provider="CUDAExecutionProvider")
return recognition_model
except Exception as e:
print('Failed to load the text recognition model: %s' % e)
def crop_line(self, image, polygon, height, width):
"""Crops predicted text line based on the polygon coordinates
and returns binarised text line image."""
poly = np.array([[int(lst[0]), int(lst[1])] for lst in polygon])
rect = cv2.boundingRect(polygon)
cropped_image = image[rect[1]: rect[1] + rect[3], rect[0]: rect[0] + rect[2]]
mask = np.zeros([rect[3], rect[2]], dtype=np.uint8)
cv2.drawContours(mask, [polygon- np.array([[rect[0],rect[1]]])], -1, (255, 255, 255), -1, cv2.LINE_AA)
res = cv2.bitwise_and(cropped_image, cropped_image, mask = mask)
wbg = np.ones_like(cropped_image, np.uint8)*255
cv2.bitwise_not(wbg,wbg, mask=mask)
# Overlap the resulted cropped image on the white background
dst = wbg+res
return dst
def crop_lines(self, polygons, image, height, width):
"""Returns a list of line images cropped following the detected polygon coordinates."""
cropped_lines = []
for i, polygon in enumerate(polygons):
cropped_line = self.crop_line(image, polygon, height, width)
cropped_lines.append(cropped_line)
return cropped_lines
def get_scores(self, lgscores):
"""Get exponent of log scores."""
scores = []
for lgscore in lgscores:
score = math.exp(lgscore)
scores.append(score)
return scores
def predict_text(self, cropped_lines):
"""Functions for predicting text content from the cropped line images."""
pixel_values = self.processor(cropped_lines, return_tensors="pt").pixel_values
generated_dict = self.recognition_model.generate(pixel_values.to(self.device), max_new_tokens=128, return_dict_in_generate=True, output_scores=True)
generated_ids, lgscores = generated_dict['sequences'], generated_dict['sequences_scores']
scores = self.get_scores(lgscores.tolist())
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
return scores, generated_text
def get_text_lines(self, cropped_lines):
scores, generated_text = [], []
if len(cropped_lines) <= self.line_threshold:
scores, generated_text = self.predict_text(cropped_lines)
else:
n = math.ceil(len(cropped_lines) / self.line_threshold)
for i in range(n):
print(i)
start = int(i * self.line_threshold)
end = int(min(start + self.line_threshold, len(cropped_lines)))
sc, gt = self.predict_text(cropped_lines[start:end])
scores += sc
print(gt)
generated_text += gt
return scores, generated_text
def get_res_dict(self, polygons, generated_text, height, width, image_name, line_confs, scores):
"""Combines the results in a dictionary form."""
line_dicts = []
for i in range(len(generated_text)):
line_dict = {'polygon': polygons[i], 'text': generated_text[i], 'conf': line_confs[i], 'text_conf':scores[i]}
line_dicts.append(line_dict)
lines_dict = {'img_name': image_name, 'height': height, 'width': width, 'text_lines': line_dicts}
return lines_dict
def process_lines(self, polygons, image, height, width):
# Crop line images
print('starting text generation')
cropped_lines = self.crop_lines(polygons, image, height, width)
print('cropped lines')
# Get text predictions
scores, generated_text = self.get_text_lines(cropped_lines)
return generated_text