|
import math |
|
import os |
|
import random |
|
import numpy as np |
|
from torch.utils.data import Dataset |
|
import nibabel |
|
from scipy import ndimage |
|
import glob |
|
from skimage.io import imsave |
|
from utils.dataset_prepare import split_data, save_fileLabel_3D |
|
|
|
|
|
def __itensity_normalize_one_volume__(image): |
|
|
|
top_per = np.percentile(image, 99.5) |
|
bot_per = np.percentile(image, 0.5) |
|
image[image > top_per] = top_per |
|
image[image < bot_per] = bot_per |
|
image = (image - np.mean(image)) / np.std(image) |
|
image = image / 10.0 |
|
image[image < 0] = 0.0 |
|
image[image > 1] = 1.0 |
|
return image |
|
|
|
|
|
def __training_data_process__(data, label): |
|
|
|
data = data.get_fdata() |
|
label = label.get_fdata() |
|
|
|
|
|
data = __itensity_normalize_one_volume__(data) |
|
|
|
|
|
label[label == 205] = 30 |
|
label[label == 420] = 60 |
|
label[label == 500] = 90 |
|
label[label == 550] = 120 |
|
label[label == 600] = 150 |
|
label[label == 820] = 180 |
|
label[label == 850] = 210 |
|
|
|
return data, label |
|
|
|
|
|
def preprocess_vol(img_name, label_name): |
|
|
|
assert os.path.isfile(img_name) |
|
assert os.path.isfile(label_name) |
|
|
|
img = nibabel.load(img_name) |
|
assert img is not None |
|
mask = nibabel.load(label_name) |
|
assert mask is not None |
|
|
|
img_array, mask_array = __training_data_process__(img, mask) |
|
assert img_array.shape == mask_array.shape, "img shape:{} is not equal to mask shape:{}".format(img_array.shape, mask_array.shape) |
|
|
|
return (img_array*255).astype('uint8'), mask_array.astype('uint8') |
|
|
|
|
|
def MMWHS_CT_Heart_split(): |
|
dataset_name = "MMWHS_CT_Heart" |
|
|
|
data_dir = './dataset_demo/MMWHS_CT_Heart/Raw/train/' |
|
img_fold_list = os.listdir(data_dir) |
|
dest_dir = './dataset_demo/MMWHS_CT_Heart/train/' |
|
dest_dir_label = './dataset_demo/MMWHS_CT_Heart/train_labels/' |
|
if not os.path.exists(dest_dir): |
|
os.makedirs(dest_dir) |
|
if not os.path.exists(dest_dir_label): |
|
os.makedirs(dest_dir_label) |
|
|
|
for vol_name in img_fold_list: |
|
if 'label' in vol_name: |
|
continue |
|
mask_name = os.path.join(data_dir, vol_name).replace('image','label') |
|
img_flair, mask = preprocess_vol(os.path.join(data_dir, vol_name), mask_name) |
|
print(img_flair.shape, mask.shape) |
|
|
|
for depth in range(0, img_flair.shape[2]): |
|
imsave(os.path.join(dest_dir, vol_name.split('.')[0] + '_frame_' + str(depth).zfill(3) + '.png'), img_flair[:, :, depth], check_contrast=False) |
|
imsave(os.path.join(dest_dir_label, vol_name.replace('image','label').split('.')[0] + '_frame_' + str(depth).zfill(3) + '.png'), mask[:, :, depth], check_contrast=False) |
|
|
|
|
|
data_dir = './dataset_demo/MMWHS_CT_Heart/Raw/valid/' |
|
img_fold_list = os.listdir(data_dir) |
|
dest_dir = './dataset_demo/MMWHS_CT_Heart/valid/' |
|
dest_dir_label = './dataset_demo/MMWHS_CT_Heart/valid_labels/' |
|
if not os.path.exists(dest_dir): |
|
os.makedirs(dest_dir) |
|
if not os.path.exists(dest_dir_label): |
|
os.makedirs(dest_dir_label) |
|
|
|
for vol_name in img_fold_list: |
|
if 'label' in vol_name: |
|
continue |
|
mask_name = os.path.join(data_dir, vol_name).replace('image','label') |
|
img_flair, mask = preprocess_vol(os.path.join(data_dir, vol_name), mask_name) |
|
print(img_flair.shape, mask.shape) |
|
|
|
for depth in range(0, img_flair.shape[2]): |
|
imsave(os.path.join(dest_dir, vol_name.split('.')[0] + '_frame_' + str(depth).zfill(3) + '.png'), img_flair[:, :, depth], check_contrast=False) |
|
imsave(os.path.join(dest_dir_label, vol_name.replace('image','label').split('.')[0] + '_frame_' + str(depth).zfill(3) + '.png'), mask[:, :, depth], check_contrast=False) |
|
|
|
|
|
data_dir = './dataset_demo/MMWHS_CT_Heart/Raw/test/' |
|
img_fold_list = os.listdir(data_dir) |
|
dest_dir = './dataset_demo/MMWHS_CT_Heart/test/' |
|
dest_dir_label = './dataset_demo/MMWHS_CT_Heart/test_labels/' |
|
if not os.path.exists(dest_dir): |
|
os.makedirs(dest_dir) |
|
if not os.path.exists(dest_dir_label): |
|
os.makedirs(dest_dir_label) |
|
|
|
for vol_name in img_fold_list: |
|
if 'label' in vol_name: |
|
continue |
|
mask_name = os.path.join(data_dir, vol_name).replace('image','label') |
|
img_flair, mask = preprocess_vol(os.path.join(data_dir, vol_name), mask_name) |
|
print(img_flair.shape, mask.shape) |
|
|
|
for depth in range(0, img_flair.shape[2]): |
|
imsave(os.path.join(dest_dir, vol_name.split('.')[0] + '_frame_' + str(depth).zfill(3) + '.png'), img_flair[:, :, depth], check_contrast=False) |
|
imsave(os.path.join(dest_dir_label, vol_name.replace('image','label').split('.')[0] + '_frame_' + str(depth).zfill(3) + '.png'), mask[:, :, depth], check_contrast=False) |
|
|
|
save_fileLabel_3D(dataset_name) |