Spaces:
Sleeping
Sleeping
File size: 2,505 Bytes
b5cbaa6 8e6512c b5cbaa6 8e6512c b5cbaa6 8e6512c b5cbaa6 8e6512c b5cbaa6 8e6512c b5cbaa6 8e6512c b5cbaa6 8e6512c b5cbaa6 8e6512c b5cbaa6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import argparse
import logging
import os
import torch
import torch.nn as nn
from PIL import Image
from torch.backends import cudnn
from torchvision import transforms
import segmentation_models_pytorch.segmentation_models_pytorch as smp
from utils.dataset import CoronaryDataset
"""
This uses a pytorch coronary segmentation model (EfficientNetPLusPlus) that has been trained using a freely available dataset of labelled coronary angiograms from: http://personal.cimat.mx:8181/~ivan.cruz/DB_Angiograms.html
The input is a raw angiogram image, and the output is a segmentation mask of all the arteries. This output will be used as the 'first guess' to speed up artery annotation.
"""
def predict_img(net, dataset_class, full_img, device, scale_factor=1, n_classes=3):
# NOTE n_classes is the number of possible values that can be predicted for a given pixel. In a standard binary segmentation task, this will be 2 i.e. black or white
net.eval()
img = torch.from_numpy(dataset_class.preprocess(full_img, scale_factor))
img = img.unsqueeze(0)
img = img.to(device=device, dtype=torch.float32)
with torch.no_grad():
output = net(img)
if n_classes > 1:
probs = torch.softmax(output, dim=1)
else:
probs = torch.sigmoid(output)
probs = probs.squeeze(0)
tf = transforms.Compose(
[
transforms.ToPILImage(),
transforms.Resize(full_img.size[1]),
transforms.ToTensor(),
]
)
full_mask = tf(probs.cpu())
if n_classes > 1:
return dataset_class.one_hot2mask(full_mask)
else:
return full_mask > 0.5
def get_args():
parser = argparse.ArgumentParser(
description="Predict masks from input images",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# parser.add_argument('-d', '--dataset', type=str, help='Specifies the dataset to be used', dest='dataset', required=True)
parser.add_argument(
"--model",
"-m",
default="MODEL.pth",
metavar="FILE",
help="Specify the file in which the model is stored",
)
parser.add_argument(
"--input",
"-i",
metavar="INPUT",
nargs="+",
help="filenames of input images",
required=True,
)
parser.add_argument(
"--output", "-o", metavar="INPUT", nargs="+", help="Filenames of output images"
)
return parser.parse_args()
|