File size: 3,881 Bytes
5fa1a76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
With that in mind, let's create a function to encode a batch of examples in the dataset:

def encode_dataset(examples, max_length=512):
     questions = examples["question"]
     words = examples["words"]
     boxes = examples["boxes"]
     answers = examples["answer"]

     # encode the batch of examples and initialize the start_positions and end_positions
     encoding = tokenizer(questions, words, boxes, max_length=max_length, padding="max_length", truncation=True)
     start_positions = []
     end_positions = []
     # loop through the examples in the batch
     for i in range(len(questions)):
         cls_index = encoding["input_ids"][i].index(tokenizer.cls_token_id)
         # find the position of the answer in example's words
         words_example = [word.lower() for word in words[i]]
         answer = answers[i]
         match, word_idx_start, word_idx_end = subfinder(words_example, answer.lower().split())
         if match:
             # if match is found, use token_type_ids to find where words start in the encoding
             token_type_ids = encoding["token_type_ids"][i]
             token_start_index = 0
             while token_type_ids[token_start_index] != 1:
                 token_start_index += 1
             token_end_index = len(encoding["input_ids"][i]) - 1
             while token_type_ids[token_end_index] != 1:
                 token_end_index -= 1
             word_ids = encoding.word_ids(i)[token_start_index : token_end_index + 1]
             start_position = cls_index
             end_position = cls_index
             # loop over word_ids and increase token_start_index until it matches the answer position in words
             # once it matches, save the token_start_index as the start_position of the answer in the encoding
             for id in word_ids:
                 if id == word_idx_start:
                     start_position = token_start_index
                 else:
                     token_start_index += 1
             # similarly loop over word_ids starting from the end to find the end_position of the answer
             for id in word_ids[::-1]:
                 if id == word_idx_end:
                     end_position = token_end_index
                 else:
                     token_end_index -= 1
             start_positions.append(start_position)
             end_positions.append(end_position)
         else:
             start_positions.append(cls_index)
             end_positions.append(cls_index)
     encoding["image"] = examples["image"]
     encoding["start_positions"] = start_positions
     encoding["end_positions"] = end_positions
     return encoding

Now that we have this preprocessing function, we can encode the entire dataset:

encoded_train_dataset = dataset_with_ocr["train"].map(
     encode_dataset, batched=True, batch_size=2, remove_columns=dataset_with_ocr["train"].column_names
 )
encoded_test_dataset = dataset_with_ocr["test"].map(
     encode_dataset, batched=True, batch_size=2, remove_columns=dataset_with_ocr["test"].column_names
 )

Let's check what the features of the encoded dataset look like:

encoded_train_dataset.features
{'image': Sequence(feature=Sequence(feature=Sequence(feature=Value(dtype='uint8', id=None), length=-1, id=None), length=-1, id=None), length=-1, id=None),
 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None),
 'token_type_ids': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None),
 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None),
 'bbox': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'start_positions': Value(dtype='int64', id=None),
 'end_positions': Value(dtype='int64', id=None)}

Evaluation
Evaluation for document question answering requires a significant amount of postprocessing.