def is_not_too_long(input_ids): | |
input_length = len(input_ids) | |
return input_length < 200 | |
dataset = dataset.filter(is_not_too_long, input_columns=["input_ids"]) | |
len(dataset) | |
8259 | |
Next, create a basic train/test split: | |
dataset = dataset.train_test_split(test_size=0.1) | |
Data collator | |
In order to combine multiple examples into a batch, you need to define a custom data collator. |