Spaces:
Sleeping
Sleeping
| import logging | |
| from glob import glob | |
| from os import listdir | |
| from os.path import splitext | |
| from typing import Dict, List | |
| import cv2 | |
| import numpy | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from torch.utils.data import Dataset | |
| import utils.utils | |
| from utils.augment import * | |
| r""" | |
| Defines the `BasicSegmentationDataset` and `CoronaryArterySegmentationDatasets`, which extend the `Dataset` and `BasicSegmentationDataset` \ | |
| classes, respectively. Each class defines the specific methods needed for data processing and a method :func:`__getitem__` to return samples. | |
| """ | |
| class BasicSegmentationDataset(Dataset): | |
| r""" | |
| Implements a basic dataset for segmentation tasks, with methods for image and mask scaling and normalization. \ | |
| The filenames of the segmentation ground truths must be equal to the filenames of the images to be segmented, \ | |
| except for a possible suffix. | |
| Args: | |
| imgs_dir (str): path to the directory containing the images to be segmented. | |
| masks_dir (str): path to the directory containing the segmentation ground truths. | |
| scale (float, optional): image scale, between 0 and 1, to be used in the segmentation. | |
| mask_suffix (str, optional): suffix to be added to an image's filename to obtain its | |
| ground truth filename. | |
| """ | |
| def __init__( | |
| self, imgs_dir: str, masks_dir: str, scale: float = 1, mask_suffix: str = "" | |
| ): | |
| self.imgs_dir = imgs_dir | |
| self.masks_dir = masks_dir | |
| self.scale = scale | |
| self.mask_suffix = mask_suffix | |
| assert 0 < scale <= 1, "Scale must be between 0 and 1" | |
| self.ids = [ | |
| splitext(file)[0] for file in listdir(imgs_dir) if not file.startswith(".") | |
| ] | |
| logging.info(f"Creating dataset with {len(self.ids)} examples") | |
| def __len__(self) -> int: | |
| r""" | |
| Returns the size of the dataset. | |
| """ | |
| return len(self.ids) | |
| def preprocess(cls, pil_img: Image, scale: float) -> Image: | |
| r""" | |
| Preprocesses an `Image`, rescaling it and returning it as a NumPy array in | |
| the CHW format. | |
| Args: | |
| pil_imgs (Image): object of class `Image` to be preprocessed. | |
| scale (float): image scale, between 0 and 1. | |
| """ | |
| w, h = pil_img.size | |
| newW, newH = int(scale * w), int(scale * h) | |
| assert newW > 0 and newH > 0, "Scale is too small" | |
| pil_img = pil_img.resize((newW, newH)) | |
| img_nd = numpy.array(pil_img) | |
| if len(img_nd.shape) == 2: | |
| img_nd = numpy.expand_dims(img_nd, axis=2) | |
| # HWC to CHW | |
| img_trans = img_nd.transpose((2, 0, 1)) | |
| if img_trans.max() > 1: | |
| img_trans = img_trans / 255 | |
| return img_trans | |
| def __getitem__(self, i) -> Dict[List[torch.FloatTensor], List[torch.FloatTensor]]: | |
| r""" | |
| Returns two tensors: an image and the corresponding mask. | |
| """ | |
| idx = self.ids[i] | |
| mask_file = glob(self.masks_dir + idx + self.mask_suffix + ".*") | |
| img_file = glob(self.imgs_dir + idx + ".*") | |
| assert ( | |
| len(mask_file) == 1 | |
| ), f"Either no mask or multiple masks found for the ID {idx}: {mask_file}" | |
| assert ( | |
| len(img_file) == 1 | |
| ), f"Either no image or multiple images found for the ID {idx}: {img_file}" | |
| mask = Image.open(mask_file[0]) | |
| img = Image.open(img_file[0]) | |
| assert ( | |
| img.size == mask.size | |
| ), f"Image and mask {idx} should be the same size, but are {img.size} and {mask.size}" | |
| img = self.preprocess(img, self.scale) | |
| mask = self.preprocess(mask, self.scale) | |
| return { | |
| "image": [torch.from_numpy(img).type(torch.FloatTensor)], | |
| "mask": [torch.from_numpy(mask).type(torch.FloatTensor)], | |
| } | |
| class CoronaryDataset(BasicSegmentationDataset): | |
| r""" | |
| Implements a dataset for the Retinal Vessel Segmentation task | |
| Args: | |
| imgs_dir (str): path to the directory containing the images to be segmented. | |
| masks_dir (str): path to the directory containing the segmentation ground truths. | |
| scale (float, optional): image scale, between 0 and 1, to be used in the segmentation. | |
| augmentation_ratio (int, optional): number of augmentations to generate per image. | |
| crop_size (int, optional): size of the square image to be fed to the model. | |
| aug_policy (str, optional): data augmentation policy. | |
| """ | |
| # Number of classes, including the background class | |
| n_classes = 2 | |
| # Maps maks grayscale value to mask class index | |
| gray2class_mapping = {0: 0, 255: 1} | |
| # Maps mask grayscale value to mask RGB value | |
| gray2rgb_mapping = {0: (0, 0, 0), 255: (255, 255, 255)} | |
| rgb2class_mapping = {(0, 0, 0): 0, (255, 255, 255): 1} | |
| def __init__( | |
| self, | |
| imgs_dir: str, | |
| masks_dir: str, | |
| scale: float = 1, | |
| augmentation_ratio: int = 0, | |
| crop_size: int = 512, | |
| aug_policy: str = "retina", | |
| ): | |
| super().__init__(imgs_dir, masks_dir, scale) | |
| self.augmentation_ratio = augmentation_ratio | |
| self.policy = aug_policy | |
| self.crop_size = crop_size | |
| def mask_img2class_mask(cls, pil_mask: Image, scale: float) -> numpy.array: | |
| r""" | |
| Preprocesses a grayscale `Image` containing a segmentation mask, rescaling it, converting its grayscale values \ | |
| to class indices and returning it as a NumPy array in the CHW format. | |
| Args: | |
| pil_imgs (Image): object of class `Image` to be preprocessed. | |
| scale (float): image scale, between 0 and 1. | |
| """ | |
| w, h = pil_mask.size | |
| newW, newH = int(scale * w), int(scale * h) | |
| assert newW > 0 and newH > 0, "Scale is too small" | |
| pil_mask = pil_mask.resize((newW, newH)) | |
| if pil_mask.mode != "L": | |
| pil_mask = pil_mask.convert(mode="L") | |
| mask_nd = numpy.array(pil_mask) | |
| if len(mask_nd.shape) == 2: | |
| mask_nd = numpy.expand_dims(mask_nd, axis=2) | |
| # HWC to CHW | |
| mask = mask_nd.transpose((2, 0, 1)) | |
| mask = mask / 255 | |
| return mask | |
| def one_hot2mask( | |
| cls, one_hot_mask: torch.FloatTensor, shape: str = "CHW" | |
| ) -> numpy.array: | |
| r""" | |
| Returns the one-channel mask (1HW) corresponding to the CHW one-hot encoded one. | |
| """ | |
| # Assuming tensor in CHW shape | |
| if shape == "CHW": | |
| return numpy.argmax(one_hot_mask.detach().numpy(), axis=0) | |
| elif shape == "NCHW": | |
| return numpy.argmax(one_hot_mask.detach().numpy(), axis=1) | |
| return numpy.argmax(one_hot_mask.detach().numpy(), axis=0) | |
| def mask2one_hot( | |
| cls, mask_tensor: torch.FloatTensor, output_shape: str = "NHWC" | |
| ) -> torch.Tensor: | |
| r""" | |
| Returns the received `FloatTensor` in the N1HW shape to a one hot encoded `LongTensor` in the NHWC shape.\ | |
| Can return in NCHW shape is specified. | |
| Args: | |
| mask_tensor (FloatTensor): N1HW FloatTensor to be one-hot encoded. | |
| output_shape (str): NHWC or NCHW. | |
| """ | |
| assert ( | |
| output_shape == "NHWC" or output_shape == "NCHW" | |
| ), "Invalid output shape specified" | |
| # Assuming tensor in NCHW = N1HW shape | |
| if output_shape == "NHWC": | |
| return F.one_hot(mask_tensor, cls.n_classes).squeeze(1) | |
| # Assuming tensor in N1HW shape | |
| elif output_shape == "NCHW": | |
| return torch.transpose( | |
| torch.transpose(F.one_hot(mask_tensor, cls.n_classes), 2, 3), 1, 2 | |
| ) | |
| def class2gray(cls, mask: numpy.array) -> numpy.array: | |
| r""" | |
| Replaces the class labels in a numpy array represented mask by their grayscale values, according to `gray2class_mapping`. | |
| """ | |
| assert ( | |
| len(cls.gray2class_mapping) == cls.n_classes | |
| ), f"Number of class mappings - {len(cls.gray2class_mapping)} - should be the same as the number of classes - {cls.n_classes}" | |
| for color, label in cls.gray2class_mapping.items(): | |
| mask[mask == label] = color | |
| return mask | |
| def gray2rgb(cls, img: Image) -> Image: | |
| r""" | |
| Converts a grayscale image into an RGB one, according to gray2rgb_mapping. | |
| """ | |
| rgb_img = Image.new("RGB", img.size) | |
| for x in range(img.size[0]): | |
| for y in range(img.size[1]): | |
| rgb_img.putpixel((x, y), cls.gray2rgb_mapping[img.getpixel((x, y))]) | |
| return rgb_img | |
| def mask2image(cls, mask: numpy.array) -> Image: | |
| r""" | |
| Converts a one-channel mask (1HW) with class indices into an RGB image, according to gray2class_mapping and gray2rgb_mapping. | |
| """ | |
| return cls.gray2rgb(Image.fromarray(cls.class2gray(mask).astype(numpy.uint8))) | |
| def augment(self, image, mask, policy="retina", augmentation_ratio=0): | |
| """ | |
| Returns a list with the original image and mask and augmented versions of them. | |
| The number of augmented images and masks is equal to the specified augmentation_ratio. | |
| The policy is chosen by the policy argument | |
| """ | |
| tf_imgs = [] | |
| tf_masks = [] | |
| # Data Augmentation | |
| for i in range(augmentation_ratio): | |
| # Select the policy | |
| if policy == "retina": | |
| aug_policy = RetinaPolicy( | |
| crop_dims=[self.crop_size, self.crop_size], brightness=[0.9, 1.1] | |
| ) | |
| # Apply the transformation | |
| tf_image, tf_mask = aug_policy(image, mask) | |
| # Further process the images and masks | |
| tf_image = self.preprocess(tf_image, self.scale) | |
| tf_mask = self.mask_img2class_mask(tf_mask, self.scale) | |
| tf_image = torch.from_numpy(tf_image).type(torch.FloatTensor) | |
| tf_mask = torch.from_numpy(tf_mask).type(torch.FloatTensor) | |
| tf_imgs.append(tf_image) | |
| tf_masks.append(tf_mask) | |
| i, j, h, w = transforms.RandomCrop.get_params( | |
| image, [self.crop_size, self.crop_size] | |
| ) | |
| image = transforms.functional.crop(image, i, j, h, w) | |
| mask = transforms.functional.crop(mask, i, j, h, w) | |
| image = self.preprocess(image, self.scale) | |
| mask = self.mask_img2class_mask(mask, self.scale) | |
| image = torch.from_numpy(image).type(torch.FloatTensor) | |
| mask = torch.from_numpy(mask).type(torch.FloatTensor) | |
| tf_imgs.insert(0, image) | |
| tf_masks.insert(0, mask) | |
| return (tf_imgs, tf_masks) | |
| def __getitem__(self, i) -> Dict[List[torch.FloatTensor], List[torch.FloatTensor]]: | |
| r""" | |
| Returns two tensors: an image, of shape 1HW, and the corresponding mask, of shape CHW. | |
| """ | |
| idx = self.ids[i] | |
| # mask_file = glob(self.masks_dir + idx.replace('training', 'manual1') + '.*') | |
| # img_file = glob(self.imgs_dir + idx + '.*') | |
| mask_file = glob(f"{self.masks_dir}{idx}.*") | |
| img_file = glob(self.imgs_dir + idx + ".*") | |
| # print(img_file, mask_file) | |
| assert ( | |
| len(mask_file) == 1 | |
| ), f"Either no mask or multiple masks found for the ID {idx}: {mask_file}" | |
| assert ( | |
| len(img_file) == 1 | |
| ), f"Either no image or multiple images found for the ID {idx}: {img_file}" | |
| mask = Image.open(mask_file[0]) | |
| image = Image.open(img_file[0]) | |
| # Here we apply any changes to the image that we want for our specfici prediction task | |
| maskArray = numpy.array(mask).astype("uint8") | |
| imageArray = numpy.array(image).astype("uint8") | |
| # ## Get endpoints of skeleton | |
| # endPoints = utils.utils.skelEndpoints(maskArray) | |
| # ## change a channel to show the start and end of centreline | |
| # imageArray[:, :, -1] = endPoints.astype(numpy.uint8)*255 | |
| crudeMask = utils.utils.crudeMaskGenerator(maskArray) | |
| imageArray[:, :, -1] = crudeMask.astype(numpy.uint8) | |
| # print(imageArray.max(), imageArray.min()) | |
| ## Reconvert to PIL image object | |
| image = Image.fromarray(imageArray.astype(numpy.uint8)) | |
| assert ( | |
| image.size == mask.size | |
| ), f"Image and mask {idx} should be the same size, but are {image.size} and {mask.size}" | |
| images, masks = self.augment( | |
| image, mask, policy=self.policy, augmentation_ratio=self.augmentation_ratio | |
| ) | |
| return {"image": images, "mask": masks} | |