|
With that in mind, let's create a function to encode a batch of examples in the dataset: |
|
|
|
def encode_dataset(examples, max_length=512): |
|
questions = examples["question"] |
|
words = examples["words"] |
|
boxes = examples["boxes"] |
|
answers = examples["answer"] |
|
|
|
# encode the batch of examples and initialize the start_positions and end_positions |
|
encoding = tokenizer(questions, words, boxes, max_length=max_length, padding="max_length", truncation=True) |
|
start_positions = [] |
|
end_positions = [] |
|
# loop through the examples in the batch |
|
for i in range(len(questions)): |
|
cls_index = encoding["input_ids"][i].index(tokenizer.cls_token_id) |
|
# find the position of the answer in example's words |
|
words_example = [word.lower() for word in words[i]] |
|
answer = answers[i] |
|
match, word_idx_start, word_idx_end = subfinder(words_example, answer.lower().split()) |
|
if match: |
|
# if match is found, use token_type_ids to find where words start in the encoding |
|
token_type_ids = encoding["token_type_ids"][i] |
|
token_start_index = 0 |
|
while token_type_ids[token_start_index] != 1: |
|
token_start_index += 1 |
|
token_end_index = len(encoding["input_ids"][i]) - 1 |
|
while token_type_ids[token_end_index] != 1: |
|
token_end_index -= 1 |
|
word_ids = encoding.word_ids(i)[token_start_index : token_end_index + 1] |
|
start_position = cls_index |
|
end_position = cls_index |
|
# loop over word_ids and increase token_start_index until it matches the answer position in words |
|
# once it matches, save the token_start_index as the start_position of the answer in the encoding |
|
for id in word_ids: |
|
if id == word_idx_start: |
|
start_position = token_start_index |
|
else: |
|
token_start_index += 1 |
|
# similarly loop over word_ids starting from the end to find the end_position of the answer |
|
for id in word_ids[::-1]: |
|
if id == word_idx_end: |
|
end_position = token_end_index |
|
else: |
|
token_end_index -= 1 |
|
start_positions.append(start_position) |
|
end_positions.append(end_position) |
|
else: |
|
start_positions.append(cls_index) |
|
end_positions.append(cls_index) |
|
encoding["image"] = examples["image"] |
|
encoding["start_positions"] = start_positions |
|
encoding["end_positions"] = end_positions |
|
return encoding |
|
|
|
Now that we have this preprocessing function, we can encode the entire dataset: |
|
|
|
encoded_train_dataset = dataset_with_ocr["train"].map( |
|
encode_dataset, batched=True, batch_size=2, remove_columns=dataset_with_ocr["train"].column_names |
|
) |
|
encoded_test_dataset = dataset_with_ocr["test"].map( |
|
encode_dataset, batched=True, batch_size=2, remove_columns=dataset_with_ocr["test"].column_names |
|
) |
|
|
|
Let's check what the features of the encoded dataset look like: |
|
|
|
encoded_train_dataset.features |
|
{'image': Sequence(feature=Sequence(feature=Sequence(feature=Value(dtype='uint8', id=None), length=-1, id=None), length=-1, id=None), length=-1, id=None), |
|
'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None), |
|
'token_type_ids': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None), |
|
'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None), |
|
'bbox': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None), |
|
'start_positions': Value(dtype='int64', id=None), |
|
'end_positions': Value(dtype='int64', id=None)} |
|
|
|
Evaluation |
|
Evaluation for document question answering requires a significant amount of postprocessing. |