File size: 2,945 Bytes
be2715b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from torch.utils.data import (
    DataLoader
)
from dataloader.dataset import (
    SegmentationDataset,
    AugmentedSegmentationDataset
)

def sam_dataloader(cfg):
    loader_args = dict(num_workers=cfg.base.num_workers, 
                        pin_memory=cfg.base.pin_memory)
    """
    Since the output of SAM's mask decoder is 256 by default (without using a postprocessing step),
    hence, we chose to resize the mask ground truth into 256x256 for a better output (prediction without postprocessing).
    """
    if cfg.base.dataset_name in ["buidnewprocess", "kvasir", "isiconlytrain", "drive"]:
        train_dataset = SegmentationDataset(cfg.base.dataset_name,
                                                    cfg.dataloader.train_dir_img,
                                                    cfg.dataloader.train_dir_mask,
                                                    scale=(1024, 256))
    elif cfg.base.dataset_name in ["bts", "las_mri", "las_ct"]:
        train_dataset = AugmentedSegmentationDataset(cfg.base.dataset_name,
                                                    cfg.dataloader.train_dir_img,
                                                    cfg.dataloader.train_dir_mask,
                                                    scale=(1024, 256))
    else:
        raise NameError(f"[Error] Dataset {cfg.base.dataset_name} is either in wrong format or not yet implemented!")

    val_dataset = SegmentationDataset(cfg.base.dataset_name,
                                                cfg.dataloader.valid_dir_img, 
                                                cfg.dataloader.valid_dir_mask, 
                                                scale=(1024, 256))
    test_dataset = SegmentationDataset(cfg.base.dataset_name,
                                                cfg.dataloader.test_dir_img, 
                                                cfg.dataloader.test_dir_mask, 
                                                scale=(1024, 256))
    train_loader = DataLoader(train_dataset, 
                              shuffle=True, 
                              batch_size=cfg.train.train_batch_size,
                              multiprocessing_context="fork",
                              **loader_args)
    val_loader = DataLoader(val_dataset, 
                            shuffle=False, 
                            drop_last=True, 
                            batch_size=cfg.train.valid_batch_size,  
                            multiprocessing_context="fork",
                            **loader_args)
    test_loader = DataLoader(test_dataset, 
                             shuffle=False, 
                             batch_size=cfg.train.test_batch_size,
                             drop_last=True, 
                             multiprocessing_context="fork",
                             **loader_args)
    
    return train_loader, val_loader, test_loader, val_dataset, test_dataset