images = [ | |
train_data_augmentation(convert_to_tf_tensor(image.convert("RGB"))) for image in example_batch["image"] | |
] | |
example_batch["pixel_values"] = [tf.transpose(tf.squeeze(image)) for image in images] | |
return example_batch | |
def preprocess_val(example_batch): | |
"""Apply val_transforms across a batch.""" |