File size: 2,505 Bytes
b5cbaa6
 
 
 
 
 
 
8e6512c
b5cbaa6
 
 
8e6512c
b5cbaa6
8e6512c
b5cbaa6
8e6512c
 
 
b5cbaa6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e6512c
b5cbaa6
 
 
8e6512c
b5cbaa6
 
 
 
 
 
 
 
8e6512c
 
 
 
b5cbaa6
8e6512c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5cbaa6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import argparse
import logging
import os

import torch
import torch.nn as nn
from PIL import Image
from torch.backends import cudnn
from torchvision import transforms

import segmentation_models_pytorch.segmentation_models_pytorch as smp
from utils.dataset import CoronaryDataset

"""
This uses a pytorch coronary segmentation model (EfficientNetPLusPlus) that has been trained using a freely available dataset of labelled coronary angiograms from: http://personal.cimat.mx:8181/~ivan.cruz/DB_Angiograms.html
The input is a raw angiogram image, and the output is a segmentation mask of all the arteries. This output will be used as the 'first guess' to speed up artery annotation.
"""


def predict_img(net, dataset_class, full_img, device, scale_factor=1, n_classes=3):
    # NOTE n_classes is the number of possible values that can be predicted for a given pixel. In a standard binary segmentation task, this will be 2 i.e. black or white

    net.eval()

    img = torch.from_numpy(dataset_class.preprocess(full_img, scale_factor))

    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img)

        if n_classes > 1:
            probs = torch.softmax(output, dim=1)
        else:
            probs = torch.sigmoid(output)

        probs = probs.squeeze(0)

        tf = transforms.Compose(
            [
                transforms.ToPILImage(),
                transforms.Resize(full_img.size[1]),
                transforms.ToTensor(),
            ]
        )

        full_mask = tf(probs.cpu())

    if n_classes > 1:
        return dataset_class.one_hot2mask(full_mask)
    else:
        return full_mask > 0.5


def get_args():
    parser = argparse.ArgumentParser(
        description="Predict masks from input images",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    # parser.add_argument('-d', '--dataset', type=str, help='Specifies the dataset to be used', dest='dataset', required=True)
    parser.add_argument(
        "--model",
        "-m",
        default="MODEL.pth",
        metavar="FILE",
        help="Specify the file in which the model is stored",
    )
    parser.add_argument(
        "--input",
        "-i",
        metavar="INPUT",
        nargs="+",
        help="filenames of input images",
        required=True,
    )
    parser.add_argument(
        "--output", "-o", metavar="INPUT", nargs="+", help="Filenames of output images"
    )

    return parser.parse_args()