File size: 782 Bytes
b448ba9
 
 
53e37f7
b448ba9
 
 
 
 
 
 
 
 
 
 
 
3eacf71
 
b448ba9
 
efac340
3b1b05f
b448ba9
 
875b22c
5e2a201
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from transformers import Pipeline


class QAAssessmentPipeline(Pipeline):

    def _sanitize_parameters(self, **kwargs):
        preprocess_kwargs = {}
        if "text" in kwargs:
            preprocess_kwargs["text"] = kwargs["text"]
        return preprocess_kwargs, {}, {}

    def preprocess(self, text, **kwargs):
        # Nothing to preprocess
        return text

    def _forward(self, text, **kwargs):
        predictions = self.model(text)
        return predictions

    def postprocess(self, outputs, **kwargs):
        predictions = outputs
        # print(f"Predictions: {predictions}")

        # Format as JSON-compatible dictionary
        # model_output = {"label": label, "score": round(score, 4)}
        return {"ocr_quality_score": round(predictions[0], 4)}