|
tokenizer: PreTrainedTokenizerBase |
|
padding: Union[bool, str, PaddingStrategy] = True |
|
max_length: Optional[int] = None |
|
pad_to_multiple_of: Optional[int] = None |
|
def call(self, features): |
|
label_name = "label" if "label" in features[0].keys() else "labels" |
|
labels = [feature.pop(label_name) for feature in features] |
|
batch_size = len(features) |
|
num_choices = len(features[0]["input_ids"]) |
|
flattened_features = [ |
|
[{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features |
|
] |
|
flattened_features = sum(flattened_features, []) |
|
batch = self.tokenizer.pad( |
|
flattened_features, |
|
padding=self.padding, |
|
max_length=self.max_length, |
|
pad_to_multiple_of=self.pad_to_multiple_of, |
|
return_tensors="pt", |
|
) |
|
batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()} |
|
batch["labels"] = torch.tensor(labels, dtype=torch.int64) |
|
return batch |
|
</pt> |
|
<tf>py |
|
|
|
from dataclasses import dataclass |
|
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy |
|
from typing import Optional, Union |
|
import tensorflow as tf |
|
@dataclass |
|
class DataCollatorForMultipleChoice: |
|
""" |
|
Data collator that will dynamically pad the inputs for multiple choice received. |
|
""" |