def train_transforms(example_batch): | |
images = [jitter(x) for x in example_batch["image"]] | |
labels = [x for x in example_batch["annotation"]] | |
inputs = image_processor(images, labels) | |
return inputs | |
def val_transforms(example_batch): | |
images = [x for x in example_batch["image"]] | |
labels = [x for x in example_batch["annotation"]] | |
inputs = image_processor(images, labels) | |
return inputs | |
To apply the jitter over the entire dataset, use the 🤗 Datasets [~datasets.Dataset.set_transform] function. |