apec-segment / predict.py
FatemehT's picture
style: run pre-commit
8e6512c
import argparse
import logging
import os
import torch
import torch.nn as nn
from PIL import Image
from torch.backends import cudnn
from torchvision import transforms
import segmentation_models_pytorch.segmentation_models_pytorch as smp
from utils.dataset import CoronaryDataset
"""
This uses a pytorch coronary segmentation model (EfficientNetPLusPlus) that has been trained using a freely available dataset of labelled coronary angiograms from: http://personal.cimat.mx:8181/~ivan.cruz/DB_Angiograms.html
The input is a raw angiogram image, and the output is a segmentation mask of all the arteries. This output will be used as the 'first guess' to speed up artery annotation.
"""
def predict_img(net, dataset_class, full_img, device, scale_factor=1, n_classes=3):
# NOTE n_classes is the number of possible values that can be predicted for a given pixel. In a standard binary segmentation task, this will be 2 i.e. black or white
net.eval()
img = torch.from_numpy(dataset_class.preprocess(full_img, scale_factor))
img = img.unsqueeze(0)
img = img.to(device=device, dtype=torch.float32)
with torch.no_grad():
output = net(img)
if n_classes > 1:
probs = torch.softmax(output, dim=1)
else:
probs = torch.sigmoid(output)
probs = probs.squeeze(0)
tf = transforms.Compose(
[
transforms.ToPILImage(),
transforms.Resize(full_img.size[1]),
transforms.ToTensor(),
]
)
full_mask = tf(probs.cpu())
if n_classes > 1:
return dataset_class.one_hot2mask(full_mask)
else:
return full_mask > 0.5
def get_args():
parser = argparse.ArgumentParser(
description="Predict masks from input images",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# parser.add_argument('-d', '--dataset', type=str, help='Specifies the dataset to be used', dest='dataset', required=True)
parser.add_argument(
"--model",
"-m",
default="MODEL.pth",
metavar="FILE",
help="Specify the file in which the model is stored",
)
parser.add_argument(
"--input",
"-i",
metavar="INPUT",
nargs="+",
help="filenames of input images",
required=True,
)
parser.add_argument(
"--output", "-o", metavar="INPUT", nargs="+", help="Filenames of output images"
)
return parser.parse_args()