apec-segment / utils /dataset.py
FatemehT's picture
style: run pre-commit
8e6512c
import logging
from glob import glob
from os import listdir
from os.path import splitext
from typing import Dict, List
import cv2
import numpy
import torch
import torch.nn.functional as F
from PIL import Image
from torch.utils.data import Dataset
import utils.utils
from utils.augment import *
r"""
Defines the `BasicSegmentationDataset` and `CoronaryArterySegmentationDatasets`, which extend the `Dataset` and `BasicSegmentationDataset` \
classes, respectively. Each class defines the specific methods needed for data processing and a method :func:`__getitem__` to return samples.
"""
class BasicSegmentationDataset(Dataset):
r"""
Implements a basic dataset for segmentation tasks, with methods for image and mask scaling and normalization. \
The filenames of the segmentation ground truths must be equal to the filenames of the images to be segmented, \
except for a possible suffix.
Args:
imgs_dir (str): path to the directory containing the images to be segmented.
masks_dir (str): path to the directory containing the segmentation ground truths.
scale (float, optional): image scale, between 0 and 1, to be used in the segmentation.
mask_suffix (str, optional): suffix to be added to an image's filename to obtain its
ground truth filename.
"""
def __init__(
self, imgs_dir: str, masks_dir: str, scale: float = 1, mask_suffix: str = ""
):
self.imgs_dir = imgs_dir
self.masks_dir = masks_dir
self.scale = scale
self.mask_suffix = mask_suffix
assert 0 < scale <= 1, "Scale must be between 0 and 1"
self.ids = [
splitext(file)[0] for file in listdir(imgs_dir) if not file.startswith(".")
]
logging.info(f"Creating dataset with {len(self.ids)} examples")
def __len__(self) -> int:
r"""
Returns the size of the dataset.
"""
return len(self.ids)
@classmethod
def preprocess(cls, pil_img: Image, scale: float) -> Image:
r"""
Preprocesses an `Image`, rescaling it and returning it as a NumPy array in
the CHW format.
Args:
pil_imgs (Image): object of class `Image` to be preprocessed.
scale (float): image scale, between 0 and 1.
"""
w, h = pil_img.size
newW, newH = int(scale * w), int(scale * h)
assert newW > 0 and newH > 0, "Scale is too small"
pil_img = pil_img.resize((newW, newH))
img_nd = numpy.array(pil_img)
if len(img_nd.shape) == 2:
img_nd = numpy.expand_dims(img_nd, axis=2)
# HWC to CHW
img_trans = img_nd.transpose((2, 0, 1))
if img_trans.max() > 1:
img_trans = img_trans / 255
return img_trans
def __getitem__(self, i) -> Dict[List[torch.FloatTensor], List[torch.FloatTensor]]:
r"""
Returns two tensors: an image and the corresponding mask.
"""
idx = self.ids[i]
mask_file = glob(self.masks_dir + idx + self.mask_suffix + ".*")
img_file = glob(self.imgs_dir + idx + ".*")
assert (
len(mask_file) == 1
), f"Either no mask or multiple masks found for the ID {idx}: {mask_file}"
assert (
len(img_file) == 1
), f"Either no image or multiple images found for the ID {idx}: {img_file}"
mask = Image.open(mask_file[0])
img = Image.open(img_file[0])
assert (
img.size == mask.size
), f"Image and mask {idx} should be the same size, but are {img.size} and {mask.size}"
img = self.preprocess(img, self.scale)
mask = self.preprocess(mask, self.scale)
return {
"image": [torch.from_numpy(img).type(torch.FloatTensor)],
"mask": [torch.from_numpy(mask).type(torch.FloatTensor)],
}
class CoronaryDataset(BasicSegmentationDataset):
r"""
Implements a dataset for the Retinal Vessel Segmentation task
Args:
imgs_dir (str): path to the directory containing the images to be segmented.
masks_dir (str): path to the directory containing the segmentation ground truths.
scale (float, optional): image scale, between 0 and 1, to be used in the segmentation.
augmentation_ratio (int, optional): number of augmentations to generate per image.
crop_size (int, optional): size of the square image to be fed to the model.
aug_policy (str, optional): data augmentation policy.
"""
# Number of classes, including the background class
n_classes = 2
# Maps maks grayscale value to mask class index
gray2class_mapping = {0: 0, 255: 1}
# Maps mask grayscale value to mask RGB value
gray2rgb_mapping = {0: (0, 0, 0), 255: (255, 255, 255)}
rgb2class_mapping = {(0, 0, 0): 0, (255, 255, 255): 1}
def __init__(
self,
imgs_dir: str,
masks_dir: str,
scale: float = 1,
augmentation_ratio: int = 0,
crop_size: int = 512,
aug_policy: str = "retina",
):
super().__init__(imgs_dir, masks_dir, scale)
self.augmentation_ratio = augmentation_ratio
self.policy = aug_policy
self.crop_size = crop_size
@classmethod
def mask_img2class_mask(cls, pil_mask: Image, scale: float) -> numpy.array:
r"""
Preprocesses a grayscale `Image` containing a segmentation mask, rescaling it, converting its grayscale values \
to class indices and returning it as a NumPy array in the CHW format.
Args:
pil_imgs (Image): object of class `Image` to be preprocessed.
scale (float): image scale, between 0 and 1.
"""
w, h = pil_mask.size
newW, newH = int(scale * w), int(scale * h)
assert newW > 0 and newH > 0, "Scale is too small"
pil_mask = pil_mask.resize((newW, newH))
if pil_mask.mode != "L":
pil_mask = pil_mask.convert(mode="L")
mask_nd = numpy.array(pil_mask)
if len(mask_nd.shape) == 2:
mask_nd = numpy.expand_dims(mask_nd, axis=2)
# HWC to CHW
mask = mask_nd.transpose((2, 0, 1))
mask = mask / 255
return mask
@classmethod
def one_hot2mask(
cls, one_hot_mask: torch.FloatTensor, shape: str = "CHW"
) -> numpy.array:
r"""
Returns the one-channel mask (1HW) corresponding to the CHW one-hot encoded one.
"""
# Assuming tensor in CHW shape
if shape == "CHW":
return numpy.argmax(one_hot_mask.detach().numpy(), axis=0)
elif shape == "NCHW":
return numpy.argmax(one_hot_mask.detach().numpy(), axis=1)
return numpy.argmax(one_hot_mask.detach().numpy(), axis=0)
@classmethod
def mask2one_hot(
cls, mask_tensor: torch.FloatTensor, output_shape: str = "NHWC"
) -> torch.Tensor:
r"""
Returns the received `FloatTensor` in the N1HW shape to a one hot encoded `LongTensor` in the NHWC shape.\
Can return in NCHW shape is specified.
Args:
mask_tensor (FloatTensor): N1HW FloatTensor to be one-hot encoded.
output_shape (str): NHWC or NCHW.
"""
assert (
output_shape == "NHWC" or output_shape == "NCHW"
), "Invalid output shape specified"
# Assuming tensor in NCHW = N1HW shape
if output_shape == "NHWC":
return F.one_hot(mask_tensor, cls.n_classes).squeeze(1)
# Assuming tensor in N1HW shape
elif output_shape == "NCHW":
return torch.transpose(
torch.transpose(F.one_hot(mask_tensor, cls.n_classes), 2, 3), 1, 2
)
@classmethod
def class2gray(cls, mask: numpy.array) -> numpy.array:
r"""
Replaces the class labels in a numpy array represented mask by their grayscale values, according to `gray2class_mapping`.
"""
assert (
len(cls.gray2class_mapping) == cls.n_classes
), f"Number of class mappings - {len(cls.gray2class_mapping)} - should be the same as the number of classes - {cls.n_classes}"
for color, label in cls.gray2class_mapping.items():
mask[mask == label] = color
return mask
@classmethod
def gray2rgb(cls, img: Image) -> Image:
r"""
Converts a grayscale image into an RGB one, according to gray2rgb_mapping.
"""
rgb_img = Image.new("RGB", img.size)
for x in range(img.size[0]):
for y in range(img.size[1]):
rgb_img.putpixel((x, y), cls.gray2rgb_mapping[img.getpixel((x, y))])
return rgb_img
@classmethod
def mask2image(cls, mask: numpy.array) -> Image:
r"""
Converts a one-channel mask (1HW) with class indices into an RGB image, according to gray2class_mapping and gray2rgb_mapping.
"""
return cls.gray2rgb(Image.fromarray(cls.class2gray(mask).astype(numpy.uint8)))
def augment(self, image, mask, policy="retina", augmentation_ratio=0):
"""
Returns a list with the original image and mask and augmented versions of them.
The number of augmented images and masks is equal to the specified augmentation_ratio.
The policy is chosen by the policy argument
"""
tf_imgs = []
tf_masks = []
# Data Augmentation
for i in range(augmentation_ratio):
# Select the policy
if policy == "retina":
aug_policy = RetinaPolicy(
crop_dims=[self.crop_size, self.crop_size], brightness=[0.9, 1.1]
)
# Apply the transformation
tf_image, tf_mask = aug_policy(image, mask)
# Further process the images and masks
tf_image = self.preprocess(tf_image, self.scale)
tf_mask = self.mask_img2class_mask(tf_mask, self.scale)
tf_image = torch.from_numpy(tf_image).type(torch.FloatTensor)
tf_mask = torch.from_numpy(tf_mask).type(torch.FloatTensor)
tf_imgs.append(tf_image)
tf_masks.append(tf_mask)
i, j, h, w = transforms.RandomCrop.get_params(
image, [self.crop_size, self.crop_size]
)
image = transforms.functional.crop(image, i, j, h, w)
mask = transforms.functional.crop(mask, i, j, h, w)
image = self.preprocess(image, self.scale)
mask = self.mask_img2class_mask(mask, self.scale)
image = torch.from_numpy(image).type(torch.FloatTensor)
mask = torch.from_numpy(mask).type(torch.FloatTensor)
tf_imgs.insert(0, image)
tf_masks.insert(0, mask)
return (tf_imgs, tf_masks)
def __getitem__(self, i) -> Dict[List[torch.FloatTensor], List[torch.FloatTensor]]:
r"""
Returns two tensors: an image, of shape 1HW, and the corresponding mask, of shape CHW.
"""
idx = self.ids[i]
# mask_file = glob(self.masks_dir + idx.replace('training', 'manual1') + '.*')
# img_file = glob(self.imgs_dir + idx + '.*')
mask_file = glob(f"{self.masks_dir}{idx}.*")
img_file = glob(self.imgs_dir + idx + ".*")
# print(img_file, mask_file)
assert (
len(mask_file) == 1
), f"Either no mask or multiple masks found for the ID {idx}: {mask_file}"
assert (
len(img_file) == 1
), f"Either no image or multiple images found for the ID {idx}: {img_file}"
mask = Image.open(mask_file[0])
image = Image.open(img_file[0])
# Here we apply any changes to the image that we want for our specfici prediction task
maskArray = numpy.array(mask).astype("uint8")
imageArray = numpy.array(image).astype("uint8")
# ## Get endpoints of skeleton
# endPoints = utils.utils.skelEndpoints(maskArray)
# ## change a channel to show the start and end of centreline
# imageArray[:, :, -1] = endPoints.astype(numpy.uint8)*255
crudeMask = utils.utils.crudeMaskGenerator(maskArray)
imageArray[:, :, -1] = crudeMask.astype(numpy.uint8)
# print(imageArray.max(), imageArray.min())
## Reconvert to PIL image object
image = Image.fromarray(imageArray.astype(numpy.uint8))
assert (
image.size == mask.size
), f"Image and mask {idx} should be the same size, but are {image.size} and {mask.size}"
images, masks = self.augment(
image, mask, policy=self.policy, augmentation_ratio=self.augmentation_ratio
)
return {"image": images, "mask": masks}