|
import glob |
|
import os |
|
from typing import List, Optional, Tuple |
|
|
|
import logging |
|
import numpy as np |
|
import torchvision.transforms.functional as TF |
|
import PIL |
|
from PIL import Image |
|
from torchvision.datasets import VisionDataset |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class PathDataset(VisionDataset): |
|
def __init__( |
|
self, |
|
root: List[str], |
|
loader: None = None, |
|
transform: Optional[str] = None, |
|
extra_transform: Optional[str] = None, |
|
mean: Optional[List[float]] = None, |
|
std: Optional[List[float]] = None, |
|
): |
|
super().__init__(root=root) |
|
|
|
PIL.Image.MAX_IMAGE_PIXELS = 256000001 |
|
|
|
self.files = [] |
|
for folder in self.root: |
|
self.files.extend( |
|
sorted(glob.glob(os.path.join(folder, "**", "*.jpg"), recursive=True)) |
|
) |
|
self.files.extend( |
|
sorted(glob.glob(os.path.join(folder, "**", "*.png"), recursive=True)) |
|
) |
|
|
|
self.transform = transform |
|
self.extra_transform = extra_transform |
|
self.mean = mean |
|
self.std = std |
|
|
|
self.loader = loader |
|
|
|
logger.info(f"loaded {len(self.files)} samples from {root}") |
|
|
|
assert (mean is None) == (std is None) |
|
|
|
def __len__(self) -> int: |
|
return len(self.files) |
|
|
|
def __getitem__(self, idx) -> Tuple[np.ndarray, np.ndarray]: |
|
path = self.files[idx] |
|
|
|
if self.loader is not None: |
|
return self.loader(path), None |
|
|
|
img = Image.open(path).convert("RGB") |
|
if self.transform is not None: |
|
img = self.transform(img) |
|
img = TF.to_tensor(img) |
|
if self.mean is not None and self.std is not None: |
|
img = TF.normalize(img, self.mean, self.std) |
|
return img, None |
|
|