Spaces:
Sleeping
Sleeping
| import argparse | |
| import logging | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| from PIL import Image | |
| from torch.backends import cudnn | |
| from torchvision import transforms | |
| import segmentation_models_pytorch.segmentation_models_pytorch as smp | |
| from utils.dataset import CoronaryDataset | |
| """ | |
| This uses a pytorch coronary segmentation model (EfficientNetPLusPlus) that has been trained using a freely available dataset of labelled coronary angiograms from: http://personal.cimat.mx:8181/~ivan.cruz/DB_Angiograms.html | |
| The input is a raw angiogram image, and the output is a segmentation mask of all the arteries. This output will be used as the 'first guess' to speed up artery annotation. | |
| """ | |
| def predict_img(net, dataset_class, full_img, device, scale_factor=1, n_classes=3): | |
| # NOTE n_classes is the number of possible values that can be predicted for a given pixel. In a standard binary segmentation task, this will be 2 i.e. black or white | |
| net.eval() | |
| img = torch.from_numpy(dataset_class.preprocess(full_img, scale_factor)) | |
| img = img.unsqueeze(0) | |
| img = img.to(device=device, dtype=torch.float32) | |
| with torch.no_grad(): | |
| output = net(img) | |
| if n_classes > 1: | |
| probs = torch.softmax(output, dim=1) | |
| else: | |
| probs = torch.sigmoid(output) | |
| probs = probs.squeeze(0) | |
| tf = transforms.Compose( | |
| [ | |
| transforms.ToPILImage(), | |
| transforms.Resize(full_img.size[1]), | |
| transforms.ToTensor(), | |
| ] | |
| ) | |
| full_mask = tf(probs.cpu()) | |
| if n_classes > 1: | |
| return dataset_class.one_hot2mask(full_mask) | |
| else: | |
| return full_mask > 0.5 | |
| def get_args(): | |
| parser = argparse.ArgumentParser( | |
| description="Predict masks from input images", | |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
| ) | |
| # parser.add_argument('-d', '--dataset', type=str, help='Specifies the dataset to be used', dest='dataset', required=True) | |
| parser.add_argument( | |
| "--model", | |
| "-m", | |
| default="MODEL.pth", | |
| metavar="FILE", | |
| help="Specify the file in which the model is stored", | |
| ) | |
| parser.add_argument( | |
| "--input", | |
| "-i", | |
| metavar="INPUT", | |
| nargs="+", | |
| help="filenames of input images", | |
| required=True, | |
| ) | |
| parser.add_argument( | |
| "--output", "-o", metavar="INPUT", nargs="+", help="Filenames of output images" | |
| ) | |
| return parser.parse_args() | |