PyTorch
ssl-aasist
custom_code
ash56's picture
Add files using upload-large-folder tool
6789f6f verified
raw
history blame
1.79 kB
import glob
import os
from typing import List, Optional, Tuple
import logging
import numpy as np
import torchvision.transforms.functional as TF
import PIL
from PIL import Image
from torchvision.datasets import VisionDataset
logger = logging.getLogger(__name__)
class PathDataset(VisionDataset):
def __init__(
self,
root: List[str],
loader: None = None,
transform: Optional[str] = None,
extra_transform: Optional[str] = None,
mean: Optional[List[float]] = None,
std: Optional[List[float]] = None,
):
super().__init__(root=root)
PIL.Image.MAX_IMAGE_PIXELS = 256000001
self.files = []
for folder in self.root:
self.files.extend(
sorted(glob.glob(os.path.join(folder, "**", "*.jpg"), recursive=True))
)
self.files.extend(
sorted(glob.glob(os.path.join(folder, "**", "*.png"), recursive=True))
)
self.transform = transform
self.extra_transform = extra_transform
self.mean = mean
self.std = std
self.loader = loader
logger.info(f"loaded {len(self.files)} samples from {root}")
assert (mean is None) == (std is None)
def __len__(self) -> int:
return len(self.files)
def __getitem__(self, idx) -> Tuple[np.ndarray, np.ndarray]:
path = self.files[idx]
if self.loader is not None:
return self.loader(path), None
img = Image.open(path).convert("RGB")
if self.transform is not None:
img = self.transform(img)
img = TF.to_tensor(img)
if self.mean is not None and self.std is not None:
img = TF.normalize(img, self.mean, self.std)
return img, None