File size: 2,945 Bytes
be2715b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
from torch.utils.data import (
DataLoader
)
from dataloader.dataset import (
SegmentationDataset,
AugmentedSegmentationDataset
)
def sam_dataloader(cfg):
loader_args = dict(num_workers=cfg.base.num_workers,
pin_memory=cfg.base.pin_memory)
"""
Since the output of SAM's mask decoder is 256 by default (without using a postprocessing step),
hence, we chose to resize the mask ground truth into 256x256 for a better output (prediction without postprocessing).
"""
if cfg.base.dataset_name in ["buidnewprocess", "kvasir", "isiconlytrain", "drive"]:
train_dataset = SegmentationDataset(cfg.base.dataset_name,
cfg.dataloader.train_dir_img,
cfg.dataloader.train_dir_mask,
scale=(1024, 256))
elif cfg.base.dataset_name in ["bts", "las_mri", "las_ct"]:
train_dataset = AugmentedSegmentationDataset(cfg.base.dataset_name,
cfg.dataloader.train_dir_img,
cfg.dataloader.train_dir_mask,
scale=(1024, 256))
else:
raise NameError(f"[Error] Dataset {cfg.base.dataset_name} is either in wrong format or not yet implemented!")
val_dataset = SegmentationDataset(cfg.base.dataset_name,
cfg.dataloader.valid_dir_img,
cfg.dataloader.valid_dir_mask,
scale=(1024, 256))
test_dataset = SegmentationDataset(cfg.base.dataset_name,
cfg.dataloader.test_dir_img,
cfg.dataloader.test_dir_mask,
scale=(1024, 256))
train_loader = DataLoader(train_dataset,
shuffle=True,
batch_size=cfg.train.train_batch_size,
multiprocessing_context="fork",
**loader_args)
val_loader = DataLoader(val_dataset,
shuffle=False,
drop_last=True,
batch_size=cfg.train.valid_batch_size,
multiprocessing_context="fork",
**loader_args)
test_loader = DataLoader(test_dataset,
shuffle=False,
batch_size=cfg.train.test_batch_size,
drop_last=True,
multiprocessing_context="fork",
**loader_args)
return train_loader, val_loader, test_loader, val_dataset, test_dataset |