File size: 5,295 Bytes
be2715b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import math
import os
import random
import numpy as np
from torch.utils.data import Dataset
import nibabel
from scipy import ndimage
import glob
from skimage.io import imsave
from utils.dataset_prepare import split_data, save_fileLabel_3D
# input image is the volume
def __itensity_normalize_one_volume__(image):
# normalization following Med3D
top_per = np.percentile(image, 99.5)
bot_per = np.percentile(image, 0.5)
image[image > top_per] = top_per
image[image < bot_per] = bot_per
image = (image - np.mean(image)) / np.std(image)
image = image / 10.0
image[image < 0] = 0.0
image[image > 1] = 1.0
return image
def __training_data_process__(data, label):
# crop data according net input size
data = data.get_fdata()
label = label.get_fdata()
# normalization datas
data = __itensity_normalize_one_volume__(data)
# changing label values
label[label == 205] = 30
label[label == 420] = 60
label[label == 500] = 90
label[label == 550] = 120
label[label == 600] = 150
label[label == 820] = 180
label[label == 850] = 210
return data, label
def preprocess_vol(img_name, label_name):
assert os.path.isfile(img_name)
assert os.path.isfile(label_name)
img = nibabel.load(img_name) # We have transposed the data from WHD format to DHW
assert img is not None
mask = nibabel.load(label_name)
assert mask is not None
img_array, mask_array = __training_data_process__(img, mask)
assert img_array.shape == mask_array.shape, "img shape:{} is not equal to mask shape:{}".format(img_array.shape, mask_array.shape)
return (img_array*255).astype('uint8'), mask_array.astype('uint8')
#if __name__ == '__main__':
def MMWHS_MR_Heart_split():
dataset_name = "MMWHS_MR_Heart"
### Training set
data_dir = './dataset_demo/MMWHS_MR_Heart/Raw/train/'
img_fold_list = os.listdir(data_dir)
dest_dir = './dataset_demo/MMWHS_MR_Heart/train/' # dir for saving train images
dest_dir_label = './dataset_demo/MMWHS_MR_Heart/train_labels/'
if not os.path.exists(dest_dir):
os.makedirs(dest_dir)
if not os.path.exists(dest_dir_label):
os.makedirs(dest_dir_label)
for vol_name in img_fold_list:
if 'label' in vol_name:
continue
mask_name = os.path.join(data_dir, vol_name).replace('image','label')
img_flair, mask = preprocess_vol(os.path.join(data_dir, vol_name), mask_name)
print(img_flair.shape, mask.shape)
# img_array.shape[2] is the length of depth dimension
for depth in range(0, img_flair.shape[2]):
imsave(os.path.join(dest_dir, vol_name.split('.')[0] + '_frame_' + str(depth).zfill(3) + '.png'), img_flair[:, :, depth], check_contrast=False)
imsave(os.path.join(dest_dir_label, vol_name.replace('image','label').split('.')[0] + '_frame_' + str(depth).zfill(3) + '.png'), mask[:, :, depth], check_contrast=False)
### Validation set
data_dir = './dataset_demo/MMWHS_MR_Heart/Raw/valid/'
img_fold_list = os.listdir(data_dir)
dest_dir = './dataset_demo/MMWHS_MR_Heart/valid/'
dest_dir_label = './dataset_demo/MMWHS_MR_Heart/valid_labels/'
if not os.path.exists(dest_dir):
os.makedirs(dest_dir)
if not os.path.exists(dest_dir_label):
os.makedirs(dest_dir_label)
for vol_name in img_fold_list:
if 'label' in vol_name:
continue
mask_name = os.path.join(data_dir, vol_name).replace('image','label')
img_flair, mask = preprocess_vol(os.path.join(data_dir, vol_name), mask_name)
print(img_flair.shape, mask.shape)
# img_array.shape[2] is the length of depth dimension
for depth in range(0, img_flair.shape[2]):
imsave(os.path.join(dest_dir, vol_name.split('.')[0] + '_frame_' + str(depth).zfill(3) + '.png'), img_flair[:, :, depth], check_contrast=False)
imsave(os.path.join(dest_dir_label, vol_name.replace('image','label').split('.')[0] + '_frame_' + str(depth).zfill(3) + '.png'), mask[:, :, depth], check_contrast=False)
### Testing set
data_dir = './dataset_demo/MMWHS_MR_Heart/Raw/test/'
img_fold_list = os.listdir(data_dir)
dest_dir = './dataset_demo/MMWHS_MR_Heart/test/'
dest_dir_label = './dataset_demo/MMWHS_MR_Heart/test_labels/'
if not os.path.exists(dest_dir):
os.makedirs(dest_dir)
if not os.path.exists(dest_dir_label):
os.makedirs(dest_dir_label)
for vol_name in img_fold_list:
if 'label' in vol_name:
continue
mask_name = os.path.join(data_dir, vol_name).replace('image','label')
img_flair, mask = preprocess_vol(os.path.join(data_dir, vol_name), mask_name)
print(img_flair.shape, mask.shape)
# img_array.shape[2] is the length of depth dimension
for depth in range(0, img_flair.shape[2]):
imsave(os.path.join(dest_dir, vol_name.split('.')[0] + '_frame_' + str(depth).zfill(3) + '.png'), img_flair[:, :, depth], check_contrast=False)
imsave(os.path.join(dest_dir_label, vol_name.replace('image','label').split('.')[0] + '_frame_' + str(depth).zfill(3) + '.png'), mask[:, :, depth], check_contrast=False)
save_fileLabel_3D(dataset_name) |