# vqgan-jax-encoding-alamy

Encoding notebook for Alamy dataset.

In [1]:
import numpy as np
from tqdm import tqdm

import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from torchvision.transforms import InterpolationMode
import math

import webdataset as wds

import jax
from jax import pmap

## Dataset and Parameters

In [None]:
shards = 'https://s3.us-west-1.wasabisys.com/doodlebot-wasabi/datasets/alamy/webdataset/alamy-{000..895}.tar'

# Enable curl retries to try to work around temporary network / server errors.
# This shouldn't be necessary when using reliable servers.
shards = f'pipe:curl -s --retry 5 --retry-delay 5 -L {shards} || true'

length = 44710810 # estimate

from pathlib import Path

# Output directory for encoded files
encoded_output = Path.home()/'data'/'alamy'/'encoded'

batch_size = 128 # Per device
num_workers = 8 # Using larger numbers seemed to be less reliable in this case.

In [3]:
bs = batch_size * jax.device_count() # Use a smaller size for testing
batches = math.ceil(length / bs)

In [4]:
def center_crop(image, max_size=256):
 # Note: we allow upscaling too. We should exclude small images. 
 image = TF.resize(image, max_size, interpolation=InterpolationMode.LANCZOS)
 image = TF.center_crop(image, output_size=2 * [max_size])
 return image

preprocess_image = T.Compose([
 center_crop,
 T.ToTensor(),
 lambda t: t.permute(1, 2, 0) # Reorder, we need dimensions last
])

# Is there a shortcut for this?
def extract_from_json(item):
 item['caption'] = item['json']['caption']
 item['url'] = item['json']['url']
 return item

In [7]:
# Log exceptions to a hardcoded file
def ignore_and_log(exn):
 with open('errors.txt', 'a') as f:
 f.write(f'{exn}\n')
 return True

# Or simply use `wds.ignore_and_continue`
exception_handler = ignore_and_log
exception_handler = wds.warn_and_continue

In [8]:
dataset = wds.WebDataset(shards,
 length=batches, # Hint so `len` is implemented
 shardshuffle=False, # Keep same order for encoded files for easier bookkeeping
 handler=exception_handler, # Ignore read errors instead of failing. See also: `warn_and_continue`
)

dataset = (dataset 
 .decode('pil') # decode image with PIL
 .map(extract_from_json)
 .map_dict(jpg=preprocess_image, handler=exception_handler)
 .to_tuple('url', 'jpg', 'caption') # filter to keep only url (for reference), image, caption.
 .batched(bs)) # better to batch in the dataset (but we could also do it in the dataloader) - this arg does not affect speed and we could remove it

In [10]:
%%time
urls, images, captions = next(iter(dataset))

CPU times: user 8min 26s, sys: 12.5 s, total: 8min 38s
Wall time: 14.4 s


In [7]:
images.shape

torch.Size([1024, 256, 256, 3])

### Torch DataLoader

In [8]:
dl = torch.utils.data.DataLoader(dataset, batch_size=None, num_workers=num_workers)

## VQGAN-JAX model

In [9]:
from vqgan_jax.modeling_flax_vqgan import VQModel

We'll use a VQGAN trained with Taming Transformers and converted to a JAX model.

In [10]:
model = VQModel.from_pretrained("flax-community/vqgan_f16_16384")

Working with z of shape (1, 256, 16, 16) = 65536 dimensions.


## Encoding

Encoding is really simple using `shard` to automatically distribute "superbatches" across devices, and `pmap`. This is all it takes to create our encoding function, that will be jitted on first use.

In [11]:
from flax.training.common_utils import shard
from functools import partial

In [12]:
@partial(jax.pmap, axis_name="batch")
def encode(batch):
 # Not sure if we should `replicate` params, does not seem to have any effect
 _, indices = model.encode(batch)
 return indices

### Encoding loop

In [13]:
import os
import pandas as pd

def encode_captioned_dataset(dataloader, output_dir, save_every=14):
 output_dir.mkdir(parents=True, exist_ok=True)

 # Saving strategy:
 # - Create a new file every so often to prevent excessive file seeking.
 # - Save each batch after processing.
 # - Keep the file open until we are done with it.
 file = None 
 for n, (urls, images, captions) in enumerate(tqdm(dataloader)):
 if (n % save_every == 0):
 if file is not None:
 file.close()
 split_num = n // save_every
 file = open(output_dir/f'split_{split_num:05x}.jsonl', 'w')

 images = shard(images.numpy().squeeze())
 encoded = encode(images)
 encoded = encoded.reshape(-1, encoded.shape[-1])

 encoded_as_string = list(map(lambda item: np.array2string(item, separator=',', max_line_width=50000, formatter={'int':lambda x: str(x)}), encoded))
 batch_df = pd.DataFrame.from_dict({"url": urls, "caption": captions, "encoding": encoded_as_string})
 batch_df.to_json(file, orient='records', lines=True)

Create a new file every 318 iterations. This should produce splits of ~500 MB each, when using a total batch size of 1024.

In [14]:
save_every = 318

In [None]:
encode_captioned_dataset(dl, encoded_output, save_every=save_every)

 2%|█▌ | 1085/43663 [31:58<20:43:42, 1.75s/it]

----