def train_transforms(example_batch): images = [jitter(x) for x in example_batch["image"]] labels = [x for x in example_batch["annotation"]] inputs = image_processor(images, labels) return inputs def val_transforms(example_batch): images = [x for x in example_batch["image"]] labels = [x for x in example_batch["annotation"]] inputs = image_processor(images, labels) return inputs To apply the jitter over the entire dataset, use the 🤗 Datasets [~datasets.Dataset.set_transform] function.