{ "cells": [ { "cell_type": "markdown", "id": "d0b72877", "metadata": {}, "source": [ "# vqgan-jax-encoding-alamy" ] }, { "cell_type": "markdown", "id": "ba7b31e6", "metadata": {}, "source": [ "Encoding notebook for Alamy dataset." ] }, { "cell_type": "code", "execution_count": 1, "id": "3b59489e", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from tqdm import tqdm\n", "\n", "import torch\n", "import torchvision.transforms as T\n", "import torchvision.transforms.functional as TF\n", "from torchvision.transforms import InterpolationMode\n", "import math\n", "\n", "import webdataset as wds\n", "\n", "import jax\n", "from jax import pmap" ] }, { "cell_type": "markdown", "id": "c7c4c1e6", "metadata": {}, "source": [ "## Dataset and Parameters" ] }, { "cell_type": "code", "execution_count": null, "id": "13c6631b", "metadata": {}, "outputs": [], "source": [ "shards = 'https://s3.us-west-1.wasabisys.com/doodlebot-wasabi/datasets/alamy/webdataset/alamy-{000..895}.tar'\n", "\n", "# Enable curl retries to try to work around temporary network / server errors.\n", "# This shouldn't be necessary when using reliable servers.\n", "shards = f'pipe:curl -s --retry 5 --retry-delay 5 -L {shards} || true'\n", "\n", "length = 44710810 # estimate\n", "\n", "from pathlib import Path\n", "\n", "# Output directory for encoded files\n", "encoded_output = Path.home()/'data'/'alamy'/'encoded'\n", "\n", "batch_size = 128 # Per device\n", "num_workers = 8 # Using larger numbers seemed to be less reliable in this case." ] }, { "cell_type": "code", "execution_count": 3, "id": "3435fb85", "metadata": {}, "outputs": [], "source": [ "bs = batch_size * jax.device_count() # Use a smaller size for testing\n", "batches = math.ceil(length / bs)" ] }, { "cell_type": "code", "execution_count": 4, "id": "669b35df", "metadata": {}, "outputs": [], "source": [ "def center_crop(image, max_size=256):\n", " # Note: we allow upscaling too. We should exclude small images. \n", " image = TF.resize(image, max_size, interpolation=InterpolationMode.LANCZOS)\n", " image = TF.center_crop(image, output_size=2 * [max_size])\n", " return image\n", "\n", "preprocess_image = T.Compose([\n", " center_crop,\n", " T.ToTensor(),\n", " lambda t: t.permute(1, 2, 0) # Reorder, we need dimensions last\n", "])\n", "\n", "# Is there a shortcut for this?\n", "def extract_from_json(item):\n", " item['caption'] = item['json']['caption']\n", " item['url'] = item['json']['url']\n", " return item" ] }, { "cell_type": "code", "execution_count": 7, "id": "369d9719", "metadata": {}, "outputs": [], "source": [ "# Log exceptions to a hardcoded file\n", "def ignore_and_log(exn):\n", " with open('errors.txt', 'a') as f:\n", " f.write(f'{exn}\\n')\n", " return True\n", "\n", "# Or simply use `wds.ignore_and_continue`\n", "exception_handler = ignore_and_log\n", "exception_handler = wds.warn_and_continue" ] }, { "cell_type": "code", "execution_count": 8, "id": "5149b6d5", "metadata": {}, "outputs": [], "source": [ "dataset = wds.WebDataset(shards,\n", " length=batches, # Hint so `len` is implemented\n", " shardshuffle=False, # Keep same order for encoded files for easier bookkeeping\n", " handler=exception_handler, # Ignore read errors instead of failing. See also: `warn_and_continue`\n", ")\n", "\n", "dataset = (dataset \n", " .decode('pil') # decode image with PIL\n", " .map(extract_from_json)\n", " .map_dict(jpg=preprocess_image, handler=exception_handler)\n", " .to_tuple('url', 'jpg', 'caption') # filter to keep only url (for reference), image, caption.\n", " .batched(bs)) # better to batch in the dataset (but we could also do it in the dataloader) - this arg does not affect speed and we could remove it" ] }, { "cell_type": "code", "execution_count": 10, "id": "8cac98cb", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 8min 26s, sys: 12.5 s, total: 8min 38s\n", "Wall time: 14.4 s\n" ] } ], "source": [ "%%time\n", "urls, images, captions = next(iter(dataset))" ] }, { "cell_type": "code", "execution_count": 7, "id": "cd268fbf", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1024, 256, 256, 3])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "images.shape" ] }, { "cell_type": "markdown", "id": "44d50a51", "metadata": {}, "source": [ "### Torch DataLoader" ] }, { "cell_type": "code", "execution_count": 8, "id": "e2df5e13", "metadata": {}, "outputs": [], "source": [ "dl = torch.utils.data.DataLoader(dataset, batch_size=None, num_workers=num_workers)" ] }, { "cell_type": "markdown", "id": "a354472b", "metadata": {}, "source": [ "## VQGAN-JAX model" ] }, { "cell_type": "code", "execution_count": 9, "id": "2fcf01d7", "metadata": {}, "outputs": [], "source": [ "from vqgan_jax.modeling_flax_vqgan import VQModel" ] }, { "cell_type": "markdown", "id": "9daa636d", "metadata": {}, "source": [ "We'll use a VQGAN trained with Taming Transformers and converted to a JAX model." ] }, { "cell_type": "code", "execution_count": 10, "id": "47a8b818", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Working with z of shape (1, 256, 16, 16) = 65536 dimensions.\n" ] } ], "source": [ "model = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")" ] }, { "cell_type": "markdown", "id": "62ad01c3", "metadata": {}, "source": [ "## Encoding" ] }, { "cell_type": "markdown", "id": "20357f74", "metadata": {}, "source": [ "Encoding is really simple using `shard` to automatically distribute \"superbatches\" across devices, and `pmap`. This is all it takes to create our encoding function, that will be jitted on first use." ] }, { "cell_type": "code", "execution_count": 11, "id": "6686b004", "metadata": {}, "outputs": [], "source": [ "from flax.training.common_utils import shard\n", "from functools import partial" ] }, { "cell_type": "code", "execution_count": 12, "id": "322a4619", "metadata": {}, "outputs": [], "source": [ "@partial(jax.pmap, axis_name=\"batch\")\n", "def encode(batch):\n", " # Not sure if we should `replicate` params, does not seem to have any effect\n", " _, indices = model.encode(batch)\n", " return indices" ] }, { "cell_type": "markdown", "id": "14375a41", "metadata": {}, "source": [ "### Encoding loop" ] }, { "cell_type": "code", "execution_count": 13, "id": "ff6c10d4", "metadata": {}, "outputs": [], "source": [ "import os\n", "import pandas as pd\n", "\n", "def encode_captioned_dataset(dataloader, output_dir, save_every=14):\n", " output_dir.mkdir(parents=True, exist_ok=True)\n", "\n", " # Saving strategy:\n", " # - Create a new file every so often to prevent excessive file seeking.\n", " # - Save each batch after processing.\n", " # - Keep the file open until we are done with it.\n", " file = None \n", " for n, (urls, images, captions) in enumerate(tqdm(dataloader)):\n", " if (n % save_every == 0):\n", " if file is not None:\n", " file.close()\n", " split_num = n // save_every\n", " file = open(output_dir/f'split_{split_num:05x}.jsonl', 'w')\n", "\n", " images = shard(images.numpy().squeeze())\n", " encoded = encode(images)\n", " encoded = encoded.reshape(-1, encoded.shape[-1])\n", "\n", " encoded_as_string = list(map(lambda item: np.array2string(item, separator=',', max_line_width=50000, formatter={'int':lambda x: str(x)}), encoded))\n", " batch_df = pd.DataFrame.from_dict({\"url\": urls, \"caption\": captions, \"encoding\": encoded_as_string})\n", " batch_df.to_json(file, orient='records', lines=True)" ] }, { "cell_type": "markdown", "id": "09ff75a3", "metadata": {}, "source": [ "Create a new file every 318 iterations. This should produce splits of ~500 MB each, when using a total batch size of 1024." ] }, { "cell_type": "code", "execution_count": 14, "id": "96222bb4", "metadata": {}, "outputs": [], "source": [ "save_every = 318" ] }, { "cell_type": "code", "execution_count": null, "id": "7704863d", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 2%|█▌ | 1085/43663 [31:58<20:43:42, 1.75s/it]" ] } ], "source": [ "encode_captioned_dataset(dl, encoded_output, save_every=save_every)" ] }, { "cell_type": "markdown", "id": "8953dd84", "metadata": {}, "source": [ "----" ] } ], "metadata": { "interpreter": { "hash": "db471c52d602b4f5f40ecaf278e88ccfef85c29d0a1a07185b0d51fc7acf4e26" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 5 }