Spaces:
Sleeping
Sleeping
style: run pre-commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +9 -0
- angioPyFunctions.py +70 -61
- normalize_k1.py +3 -1
- predict.py +30 -14
- segmentation_models_pytorch/.github/workflows/tests.yml +2 -2
- segmentation_models_pytorch/.gitignore +1 -1
- segmentation_models_pytorch/HALLOFFAME.md +30 -31
- segmentation_models_pytorch/README.md +8 -8
- segmentation_models_pytorch/__init__.py +1 -1
- segmentation_models_pytorch/docker/Dockerfile +1 -1
- segmentation_models_pytorch/docs/conf.py +35 -26
- segmentation_models_pytorch/docs/insights.rst +8 -8
- segmentation_models_pytorch/docs/install.rst +1 -1
- segmentation_models_pytorch/docs/losses.rst +1 -1
- segmentation_models_pytorch/docs/models.rst +0 -2
- segmentation_models_pytorch/docs/quickstart.rst +1 -1
- segmentation_models_pytorch/docs/requirements.txt +1 -1
- segmentation_models_pytorch/misc/generate_table.py +6 -2
- segmentation_models_pytorch/segmentation_models_pytorch/__init__.py +34 -23
- segmentation_models_pytorch/segmentation_models_pytorch/__version__.py +1 -1
- segmentation_models_pytorch/segmentation_models_pytorch/base/__init__.py +2 -11
- segmentation_models_pytorch/segmentation_models_pytorch/base/heads.py +20 -9
- segmentation_models_pytorch/segmentation_models_pytorch/base/initialization.py +0 -1
- segmentation_models_pytorch/segmentation_models_pytorch/base/model.py +1 -1
- segmentation_models_pytorch/segmentation_models_pytorch/base/modules.py +55 -34
- segmentation_models_pytorch/segmentation_models_pytorch/deeplabv3/__init__.py +1 -1
- segmentation_models_pytorch/segmentation_models_pytorch/deeplabv3/decoder.py +19 -14
- segmentation_models_pytorch/segmentation_models_pytorch/deeplabv3/model.py +40 -45
- segmentation_models_pytorch/segmentation_models_pytorch/efficientunetplusplus/decoder.py +96 -61
- segmentation_models_pytorch/segmentation_models_pytorch/efficientunetplusplus/model.py +21 -19
- segmentation_models_pytorch/segmentation_models_pytorch/encoders/__init__.py +23 -14
- segmentation_models_pytorch/segmentation_models_pytorch/encoders/_base.py +5 -4
- segmentation_models_pytorch/segmentation_models_pytorch/encoders/_preprocessing.py +0 -1
- segmentation_models_pytorch/segmentation_models_pytorch/encoders/densenet.py +20 -11
- segmentation_models_pytorch/segmentation_models_pytorch/encoders/dpn.py +7 -6
- segmentation_models_pytorch/segmentation_models_pytorch/encoders/efficientnet.py +8 -10
- segmentation_models_pytorch/segmentation_models_pytorch/encoders/inceptionresnetv2.py +8 -5
- segmentation_models_pytorch/segmentation_models_pytorch/encoders/inceptionv4.py +13 -9
- segmentation_models_pytorch/segmentation_models_pytorch/encoders/mobilenet.py +1 -2
- segmentation_models_pytorch/segmentation_models_pytorch/encoders/resnet.py +10 -13
- segmentation_models_pytorch/segmentation_models_pytorch/encoders/senet.py +2 -2
- segmentation_models_pytorch/segmentation_models_pytorch/encoders/timm_regnet.py +189 -178
- segmentation_models_pytorch/segmentation_models_pytorch/encoders/timm_res2net.py +82 -81
- segmentation_models_pytorch/segmentation_models_pytorch/encoders/timm_resnest.py +130 -129
- segmentation_models_pytorch/segmentation_models_pytorch/encoders/timm_sknet.py +57 -44
- segmentation_models_pytorch/segmentation_models_pytorch/encoders/vgg.py +5 -4
- segmentation_models_pytorch/segmentation_models_pytorch/encoders/xception.py +31 -18
- segmentation_models_pytorch/segmentation_models_pytorch/fpn/__init__.py +1 -1
- segmentation_models_pytorch/segmentation_models_pytorch/fpn/decoder.py +35 -21
- segmentation_models_pytorch/segmentation_models_pytorch/fpn/model.py +6 -5
README.md
CHANGED
|
@@ -24,3 +24,12 @@ This software allows single arteries to be segmented given a few clicks on a sin
|
|
| 24 |
...a website should pop up in your browser!
|
| 25 |
|
| 26 |
You need to create a /Dicom folder and put some angiography DICOMs in there
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
...a website should pop up in your browser!
|
| 25 |
|
| 26 |
You need to create a /Dicom folder and put some angiography DICOMs in there
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# How to run the project
|
| 30 |
+
## Create virtual environment and activate it
|
| 31 |
+
```bash
|
| 32 |
+
uv venv
|
| 33 |
+
source .venv/bin/activate
|
| 34 |
+
uv pip install -r requirements.txt
|
| 35 |
+
```
|
angioPyFunctions.py
CHANGED
|
@@ -1,90 +1,98 @@
|
|
| 1 |
import os
|
|
|
|
| 2 |
os.environ.setdefault("ASTROPY_SKIP_CONFIG_UPDATE", "1")
|
| 3 |
|
|
|
|
| 4 |
import numpy
|
| 5 |
import scipy.interpolate
|
| 6 |
-
import skimage.filters
|
| 7 |
-
import skimage.morphology
|
| 8 |
import scipy.ndimage
|
| 9 |
import scipy.optimize
|
| 10 |
-
import
|
|
|
|
| 11 |
from PIL import Image
|
| 12 |
-
|
|
|
|
| 13 |
|
| 14 |
if not hasattr(_astro_config, "update_default_config"):
|
|
|
|
| 15 |
def _noop_update_default_config(*args, **kwargs):
|
| 16 |
return None
|
|
|
|
| 17 |
_astro_config.update_default_config = _noop_update_default_config
|
| 18 |
|
| 19 |
-
from fil_finder import FilFinder2D
|
| 20 |
import astropy.units as u
|
| 21 |
-
from tqdm import tqdm
|
| 22 |
-
import pooch
|
| 23 |
-
import utils.dataset
|
| 24 |
import cv2
|
|
|
|
|
|
|
|
|
|
| 25 |
|
|
|
|
| 26 |
|
| 27 |
colourTableHex = {
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
|
| 37 |
colourTableList = {}
|
| 38 |
|
| 39 |
for item in colourTableHex.keys():
|
| 40 |
### WARNING HACK: The colours go in backwards here for some reason perhaps related to RGBA?
|
| 41 |
-
colourTableList[item] = [
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
| 44 |
|
| 45 |
|
| 46 |
def skeletonise(maskArray):
|
| 47 |
# if len(maskArray.shape) == 3:
|
| 48 |
maskArray = cv2.cvtColor(maskArray, cv2.COLOR_BGR2GRAY)
|
| 49 |
|
| 50 |
-
skeleton = skimage.morphology.skeletonize(maskArray.astype(
|
| 51 |
|
| 52 |
# Process the skeleton and find the longest path
|
| 53 |
-
fil = FilFinder2D(
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
fil.preprocess_image(flatten_percent=85)
|
| 56 |
-
fil.create_mask(border_masking=True, verbose=False,
|
| 57 |
-
use_existing_mask=True)
|
| 58 |
fil.medskel(verbose=False)
|
| 59 |
-
fil.analyze_skeletons(
|
| 60 |
-
|
|
|
|
| 61 |
|
| 62 |
# add image arrays dictionary
|
| 63 |
# tifffile.imwrite(os.path.join(arteryFolder, "skel.tif"), fil.skeleton.astype('<u1')*255)
|
| 64 |
|
| 65 |
-
skel = fil.skeleton.astype(
|
| 66 |
|
| 67 |
return skel
|
| 68 |
|
| 69 |
|
| 70 |
def skelEndpoints(skel):
|
| 71 |
-
#skel[skel!=0] = 1
|
| 72 |
-
skel = numpy.uint8(skel>0)
|
| 73 |
|
| 74 |
# Apply the convolution.
|
| 75 |
-
kernel = numpy.uint8([[1,
|
| 76 |
-
[1, 10, 1],
|
| 77 |
-
[1, 1, 1]])
|
| 78 |
src_depth = -1
|
| 79 |
-
filtered = cv2.filter2D(skel,src_depth,kernel)
|
| 80 |
|
| 81 |
# Look through to find the value of 11.
|
| 82 |
# This returns a mask of the endpoints, but if you
|
| 83 |
# just want the coordinates, you could simply
|
| 84 |
# return np.where(filtered==11)
|
| 85 |
out = numpy.zeros_like(skel)
|
| 86 |
-
out[numpy.where(filtered==11)] = 1
|
| 87 |
-
endCoords = numpy.where(filtered==11)
|
| 88 |
endCoords = list(zip(*endCoords))
|
| 89 |
startPoint = endCoords[0]
|
| 90 |
endPoint = endCoords[1]
|
|
@@ -109,16 +117,15 @@ def skelPointsInOrder(skel, startPoint=None):
|
|
| 109 |
skelLength = len(skelPoints)
|
| 110 |
|
| 111 |
# Loop through the skeleton starting with startPoint, deleting the starting point from the skelPoints list, and finding the closest pixel. This is appended to orderedPoints. startPoint now becomes the last point to be appended.
|
| 112 |
-
startPointCopy = startPoint
|
| 113 |
orderedPoints = []
|
| 114 |
|
| 115 |
while len(skelPoints) > 1:
|
| 116 |
-
|
| 117 |
skelPoints.remove(startPointCopy)
|
| 118 |
|
| 119 |
# Calculate the point that is closest to the start point
|
| 120 |
-
diffs = numpy.abs(numpy.array(skelPoints)-numpy.array(startPointCopy))
|
| 121 |
-
dists = numpy.sum(diffs,axis=1)
|
| 122 |
closest_point_index = numpy.argmin(dists)
|
| 123 |
closestPoint = skelPoints[closest_point_index]
|
| 124 |
orderedPoints.append(closestPoint)
|
|
@@ -145,7 +152,7 @@ def skelSplinerWithThickness(skel, EDT, smoothing=50, order=3, decimation=2):
|
|
| 145 |
x = x[::decimation]
|
| 146 |
y = y[::decimation]
|
| 147 |
|
| 148 |
-
#NOTE: Should the EDT be median filtered? I wonder in fact if doing so will reduce the accuracy of the model.
|
| 149 |
# EDT = skimage.filters.median(EDT)
|
| 150 |
|
| 151 |
t = EDT[y, x]
|
|
@@ -156,8 +163,7 @@ def skelSplinerWithThickness(skel, EDT, smoothing=50, order=3, decimation=2):
|
|
| 156 |
|
| 157 |
print(x.shape, y.shape, t.shape)
|
| 158 |
|
| 159 |
-
tcko, uo = scipy.interpolate.splprep(
|
| 160 |
-
[y, x, t], s=smoothing, k=order, per=False)
|
| 161 |
|
| 162 |
return tcko
|
| 163 |
|
|
@@ -192,8 +198,12 @@ def arterySegmentation(inputImage, groundTruthPoints, segmentationModelWeights=N
|
|
| 192 |
)
|
| 193 |
|
| 194 |
if inputImage.shape[0] != 512 and inputImage.shape[1] != 512:
|
| 195 |
-
ratioYX = numpy.array(
|
| 196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
inputImage = scipy.ndimage.zoom(inputImage, ratioYX)
|
| 198 |
points = groundTruthPoints.copy() * ratioYX
|
| 199 |
print(inputImage.shape)
|
|
@@ -202,33 +212,32 @@ def arterySegmentation(inputImage, groundTruthPoints, segmentationModelWeights=N
|
|
| 202 |
|
| 203 |
imageSize = inputImage.shape
|
| 204 |
|
| 205 |
-
n_classes = 2
|
| 206 |
|
| 207 |
net = predict.smp.Unet(
|
| 208 |
-
encoder_name=
|
| 209 |
encoder_weights="imagenet",
|
| 210 |
in_channels=3,
|
| 211 |
-
classes=n_classes
|
| 212 |
)
|
| 213 |
|
| 214 |
net = predict.nn.DataParallel(net)
|
| 215 |
|
| 216 |
-
device = predict.torch.device(
|
|
|
|
|
|
|
| 217 |
net.to(device=device)
|
| 218 |
|
| 219 |
net.load_state_dict(
|
| 220 |
-
predict.torch.load(
|
| 221 |
-
segmentationModelWeights,
|
| 222 |
-
map_location=device
|
| 223 |
-
)
|
| 224 |
)
|
| 225 |
|
| 226 |
orig_image = Image.fromarray(inputImage)
|
| 227 |
|
| 228 |
-
image = predict.Image.new(
|
| 229 |
image.paste(orig_image, (0, 0))
|
| 230 |
|
| 231 |
-
imageArray = numpy.array(image).astype(
|
| 232 |
|
| 233 |
# Clear last channels
|
| 234 |
imageArray[:, :, -1] = 0
|
|
@@ -242,13 +251,13 @@ def arterySegmentation(inputImage, groundTruthPoints, segmentationModelWeights=N
|
|
| 242 |
for y, x in [startPoint, endPoint]:
|
| 243 |
y = int(numpy.round(y))
|
| 244 |
x = int(numpy.round(x))
|
| 245 |
-
imageArray[y-2:y+2, x-2:x+2, 1] = 255
|
| 246 |
|
| 247 |
# All other points on Channel 2
|
| 248 |
for y, x in points[1:-1]:
|
| 249 |
y = int(numpy.round(y))
|
| 250 |
x = int(numpy.round(x))
|
| 251 |
-
imageArray[y-2:y+ 2, x-2:x+2, 2] = 255
|
| 252 |
|
| 253 |
image = Image.fromarray(imageArray.astype(numpy.uint8))
|
| 254 |
|
|
@@ -257,19 +266,19 @@ def arterySegmentation(inputImage, groundTruthPoints, segmentationModelWeights=N
|
|
| 257 |
dataset_class=utils.dataset.CoronaryDataset,
|
| 258 |
full_img=image,
|
| 259 |
scale_factor=1,
|
| 260 |
-
device=device
|
| 261 |
)
|
| 262 |
|
| 263 |
return mask
|
| 264 |
|
| 265 |
|
| 266 |
-
|
| 267 |
def maskOutliner(labelledArtery, outlineThickness=3):
|
| 268 |
-
|
| 269 |
# Compute the boundary of the mask
|
| 270 |
-
contours, _ = cv2.findContours(
|
|
|
|
|
|
|
| 271 |
tmp = numpy.zeros_like(labelledArtery)
|
| 272 |
-
boundary = cv2.drawContours(tmp, contours, -1, (255,255,255), outlineThickness)
|
| 273 |
boundary = boundary > 0
|
| 274 |
|
| 275 |
return boundary
|
|
|
|
| 1 |
import os
|
| 2 |
+
|
| 3 |
os.environ.setdefault("ASTROPY_SKIP_CONFIG_UPDATE", "1")
|
| 4 |
|
| 5 |
+
import astropy.config.configuration as _astro_config
|
| 6 |
import numpy
|
| 7 |
import scipy.interpolate
|
|
|
|
|
|
|
| 8 |
import scipy.ndimage
|
| 9 |
import scipy.optimize
|
| 10 |
+
import skimage.filters
|
| 11 |
+
import skimage.morphology
|
| 12 |
from PIL import Image
|
| 13 |
+
|
| 14 |
+
import predict
|
| 15 |
|
| 16 |
if not hasattr(_astro_config, "update_default_config"):
|
| 17 |
+
|
| 18 |
def _noop_update_default_config(*args, **kwargs):
|
| 19 |
return None
|
| 20 |
+
|
| 21 |
_astro_config.update_default_config = _noop_update_default_config
|
| 22 |
|
|
|
|
| 23 |
import astropy.units as u
|
|
|
|
|
|
|
|
|
|
| 24 |
import cv2
|
| 25 |
+
import pooch
|
| 26 |
+
from fil_finder import FilFinder2D
|
| 27 |
+
from tqdm import tqdm
|
| 28 |
|
| 29 |
+
import utils.dataset
|
| 30 |
|
| 31 |
colourTableHex = {
|
| 32 |
+
"LAD": "#f03b20",
|
| 33 |
+
"D": "#fd8d3c",
|
| 34 |
+
"CX": "#31a354",
|
| 35 |
+
"OM": "#74c476",
|
| 36 |
+
"RCA": "#08519c",
|
| 37 |
+
"AM": "#3182bd",
|
| 38 |
+
"LM": "#984ea3",
|
| 39 |
+
}
|
| 40 |
|
| 41 |
colourTableList = {}
|
| 42 |
|
| 43 |
for item in colourTableHex.keys():
|
| 44 |
### WARNING HACK: The colours go in backwards here for some reason perhaps related to RGBA?
|
| 45 |
+
colourTableList[item] = [
|
| 46 |
+
int(colourTableHex[item][5:7], 16),
|
| 47 |
+
int(colourTableHex[item][3:5], 16),
|
| 48 |
+
int(colourTableHex[item][1:3], 16),
|
| 49 |
+
]
|
| 50 |
|
| 51 |
|
| 52 |
def skeletonise(maskArray):
|
| 53 |
# if len(maskArray.shape) == 3:
|
| 54 |
maskArray = cv2.cvtColor(maskArray, cv2.COLOR_BGR2GRAY)
|
| 55 |
|
| 56 |
+
skeleton = skimage.morphology.skeletonize(maskArray.astype("bool"))
|
| 57 |
|
| 58 |
# Process the skeleton and find the longest path
|
| 59 |
+
fil = FilFinder2D(
|
| 60 |
+
skeleton.astype("uint8"),
|
| 61 |
+
distance=250 * u.pc,
|
| 62 |
+
mask=skeleton,
|
| 63 |
+
beamwidth=10.0 * u.pix,
|
| 64 |
+
)
|
| 65 |
fil.preprocess_image(flatten_percent=85)
|
| 66 |
+
fil.create_mask(border_masking=True, verbose=False, use_existing_mask=True)
|
|
|
|
| 67 |
fil.medskel(verbose=False)
|
| 68 |
+
fil.analyze_skeletons(
|
| 69 |
+
branch_thresh=400 * u.pix, skel_thresh=10 * u.pix, prune_criteria="length"
|
| 70 |
+
)
|
| 71 |
|
| 72 |
# add image arrays dictionary
|
| 73 |
# tifffile.imwrite(os.path.join(arteryFolder, "skel.tif"), fil.skeleton.astype('<u1')*255)
|
| 74 |
|
| 75 |
+
skel = fil.skeleton.astype("<u1") * 255
|
| 76 |
|
| 77 |
return skel
|
| 78 |
|
| 79 |
|
| 80 |
def skelEndpoints(skel):
|
| 81 |
+
# skel[skel!=0] = 1
|
| 82 |
+
skel = numpy.uint8(skel > 0)
|
| 83 |
|
| 84 |
# Apply the convolution.
|
| 85 |
+
kernel = numpy.uint8([[1, 1, 1], [1, 10, 1], [1, 1, 1]])
|
|
|
|
|
|
|
| 86 |
src_depth = -1
|
| 87 |
+
filtered = cv2.filter2D(skel, src_depth, kernel)
|
| 88 |
|
| 89 |
# Look through to find the value of 11.
|
| 90 |
# This returns a mask of the endpoints, but if you
|
| 91 |
# just want the coordinates, you could simply
|
| 92 |
# return np.where(filtered==11)
|
| 93 |
out = numpy.zeros_like(skel)
|
| 94 |
+
out[numpy.where(filtered == 11)] = 1
|
| 95 |
+
endCoords = numpy.where(filtered == 11)
|
| 96 |
endCoords = list(zip(*endCoords))
|
| 97 |
startPoint = endCoords[0]
|
| 98 |
endPoint = endCoords[1]
|
|
|
|
| 117 |
skelLength = len(skelPoints)
|
| 118 |
|
| 119 |
# Loop through the skeleton starting with startPoint, deleting the starting point from the skelPoints list, and finding the closest pixel. This is appended to orderedPoints. startPoint now becomes the last point to be appended.
|
| 120 |
+
startPointCopy = startPoint # copied as we are going to loop and overwrite, but want to also keep the original startPoint
|
| 121 |
orderedPoints = []
|
| 122 |
|
| 123 |
while len(skelPoints) > 1:
|
|
|
|
| 124 |
skelPoints.remove(startPointCopy)
|
| 125 |
|
| 126 |
# Calculate the point that is closest to the start point
|
| 127 |
+
diffs = numpy.abs(numpy.array(skelPoints) - numpy.array(startPointCopy))
|
| 128 |
+
dists = numpy.sum(diffs, axis=1) # l1-distance
|
| 129 |
closest_point_index = numpy.argmin(dists)
|
| 130 |
closestPoint = skelPoints[closest_point_index]
|
| 131 |
orderedPoints.append(closestPoint)
|
|
|
|
| 152 |
x = x[::decimation]
|
| 153 |
y = y[::decimation]
|
| 154 |
|
| 155 |
+
# NOTE: Should the EDT be median filtered? I wonder in fact if doing so will reduce the accuracy of the model.
|
| 156 |
# EDT = skimage.filters.median(EDT)
|
| 157 |
|
| 158 |
t = EDT[y, x]
|
|
|
|
| 163 |
|
| 164 |
print(x.shape, y.shape, t.shape)
|
| 165 |
|
| 166 |
+
tcko, uo = scipy.interpolate.splprep([y, x, t], s=smoothing, k=order, per=False)
|
|
|
|
| 167 |
|
| 168 |
return tcko
|
| 169 |
|
|
|
|
| 198 |
)
|
| 199 |
|
| 200 |
if inputImage.shape[0] != 512 and inputImage.shape[1] != 512:
|
| 201 |
+
ratioYX = numpy.array(
|
| 202 |
+
[512.0 / inputImage.shape[0], 512.0 / inputImage.shape[1]]
|
| 203 |
+
)
|
| 204 |
+
print(
|
| 205 |
+
f"arterySegmentation(): Rescaling image to 512x512 by {ratioYX=}, and also applying this to input points"
|
| 206 |
+
)
|
| 207 |
inputImage = scipy.ndimage.zoom(inputImage, ratioYX)
|
| 208 |
points = groundTruthPoints.copy() * ratioYX
|
| 209 |
print(inputImage.shape)
|
|
|
|
| 212 |
|
| 213 |
imageSize = inputImage.shape
|
| 214 |
|
| 215 |
+
n_classes = 2 # binary output
|
| 216 |
|
| 217 |
net = predict.smp.Unet(
|
| 218 |
+
encoder_name="inceptionresnetv2",
|
| 219 |
encoder_weights="imagenet",
|
| 220 |
in_channels=3,
|
| 221 |
+
classes=n_classes,
|
| 222 |
)
|
| 223 |
|
| 224 |
net = predict.nn.DataParallel(net)
|
| 225 |
|
| 226 |
+
device = predict.torch.device(
|
| 227 |
+
"cuda" if predict.torch.cuda.is_available() else "cpu"
|
| 228 |
+
)
|
| 229 |
net.to(device=device)
|
| 230 |
|
| 231 |
net.load_state_dict(
|
| 232 |
+
predict.torch.load(segmentationModelWeights, map_location=device)
|
|
|
|
|
|
|
|
|
|
| 233 |
)
|
| 234 |
|
| 235 |
orig_image = Image.fromarray(inputImage)
|
| 236 |
|
| 237 |
+
image = predict.Image.new("RGB", imageSize, (0, 0, 0))
|
| 238 |
image.paste(orig_image, (0, 0))
|
| 239 |
|
| 240 |
+
imageArray = numpy.array(image).astype("uint8")
|
| 241 |
|
| 242 |
# Clear last channels
|
| 243 |
imageArray[:, :, -1] = 0
|
|
|
|
| 251 |
for y, x in [startPoint, endPoint]:
|
| 252 |
y = int(numpy.round(y))
|
| 253 |
x = int(numpy.round(x))
|
| 254 |
+
imageArray[y - 2 : y + 2, x - 2 : x + 2, 1] = 255
|
| 255 |
|
| 256 |
# All other points on Channel 2
|
| 257 |
for y, x in points[1:-1]:
|
| 258 |
y = int(numpy.round(y))
|
| 259 |
x = int(numpy.round(x))
|
| 260 |
+
imageArray[y - 2 : y + 2, x - 2 : x + 2, 2] = 255
|
| 261 |
|
| 262 |
image = Image.fromarray(imageArray.astype(numpy.uint8))
|
| 263 |
|
|
|
|
| 266 |
dataset_class=utils.dataset.CoronaryDataset,
|
| 267 |
full_img=image,
|
| 268 |
scale_factor=1,
|
| 269 |
+
device=device,
|
| 270 |
)
|
| 271 |
|
| 272 |
return mask
|
| 273 |
|
| 274 |
|
|
|
|
| 275 |
def maskOutliner(labelledArtery, outlineThickness=3):
|
|
|
|
| 276 |
# Compute the boundary of the mask
|
| 277 |
+
contours, _ = cv2.findContours(
|
| 278 |
+
labelledArtery, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
|
| 279 |
+
)
|
| 280 |
tmp = numpy.zeros_like(labelledArtery)
|
| 281 |
+
boundary = cv2.drawContours(tmp, contours, -1, (255, 255, 255), outlineThickness)
|
| 282 |
boundary = boundary > 0
|
| 283 |
|
| 284 |
return boundary
|
normalize_k1.py
CHANGED
|
@@ -17,7 +17,9 @@ def normalize_image(
|
|
| 17 |
img = img.resize(target_size, Image.Resampling.BICUBIC)
|
| 18 |
|
| 19 |
arr = np.array(img, dtype=np.float32)
|
| 20 |
-
arr = exposure.rescale_intensity(
|
|
|
|
|
|
|
| 21 |
arr = np.clip(arr, png_low, png_high)
|
| 22 |
arr = ((arr - png_low) / (png_high - png_low) * 255.0).astype(np.uint8)
|
| 23 |
|
|
|
|
| 17 |
img = img.resize(target_size, Image.Resampling.BICUBIC)
|
| 18 |
|
| 19 |
arr = np.array(img, dtype=np.float32)
|
| 20 |
+
arr = exposure.rescale_intensity(
|
| 21 |
+
arr, in_range="image", out_range=(png_low, png_high)
|
| 22 |
+
)
|
| 23 |
arr = np.clip(arr, png_low, png_high)
|
| 24 |
arr = ((arr - png_low) / (png_high - png_low) * 255.0).astype(np.uint8)
|
| 25 |
|
predict.py
CHANGED
|
@@ -5,17 +5,17 @@ import os
|
|
| 5 |
import torch
|
| 6 |
import torch.nn as nn
|
| 7 |
from PIL import Image
|
|
|
|
| 8 |
from torchvision import transforms
|
| 9 |
|
| 10 |
-
from utils.dataset import CoronaryDataset
|
| 11 |
import segmentation_models_pytorch.segmentation_models_pytorch as smp
|
|
|
|
| 12 |
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
'''
|
| 16 |
This uses a pytorch coronary segmentation model (EfficientNetPLusPlus) that has been trained using a freely available dataset of labelled coronary angiograms from: http://personal.cimat.mx:8181/~ivan.cruz/DB_Angiograms.html
|
| 17 |
-
The input is a raw angiogram image, and the output is a segmentation mask of all the arteries. This output will be used as the 'first guess' to speed up artery annotation.
|
| 18 |
-
|
|
|
|
| 19 |
|
| 20 |
def predict_img(net, dataset_class, full_img, device, scale_factor=1, n_classes=3):
|
| 21 |
# NOTE n_classes is the number of possible values that can be predicted for a given pixel. In a standard binary segmentation task, this will be 2 i.e. black or white
|
|
@@ -41,11 +41,11 @@ def predict_img(net, dataset_class, full_img, device, scale_factor=1, n_classes=
|
|
| 41 |
[
|
| 42 |
transforms.ToPILImage(),
|
| 43 |
transforms.Resize(full_img.size[1]),
|
| 44 |
-
transforms.ToTensor()
|
| 45 |
]
|
| 46 |
)
|
| 47 |
|
| 48 |
-
full_mask = tf(probs.cpu())
|
| 49 |
|
| 50 |
if n_classes > 1:
|
| 51 |
return dataset_class.one_hot2mask(full_mask)
|
|
@@ -54,12 +54,28 @@ def predict_img(net, dataset_class, full_img, device, scale_factor=1, n_classes=
|
|
| 54 |
|
| 55 |
|
| 56 |
def get_args():
|
| 57 |
-
parser = argparse.ArgumentParser(
|
|
|
|
|
|
|
|
|
|
| 58 |
# parser.add_argument('-d', '--dataset', type=str, help='Specifies the dataset to be used', dest='dataset', required=True)
|
| 59 |
-
parser.add_argument(
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
return parser.parse_args()
|
| 64 |
-
|
| 65 |
-
|
|
|
|
| 5 |
import torch
|
| 6 |
import torch.nn as nn
|
| 7 |
from PIL import Image
|
| 8 |
+
from torch.backends import cudnn
|
| 9 |
from torchvision import transforms
|
| 10 |
|
|
|
|
| 11 |
import segmentation_models_pytorch.segmentation_models_pytorch as smp
|
| 12 |
+
from utils.dataset import CoronaryDataset
|
| 13 |
|
| 14 |
+
"""
|
|
|
|
|
|
|
| 15 |
This uses a pytorch coronary segmentation model (EfficientNetPLusPlus) that has been trained using a freely available dataset of labelled coronary angiograms from: http://personal.cimat.mx:8181/~ivan.cruz/DB_Angiograms.html
|
| 16 |
+
The input is a raw angiogram image, and the output is a segmentation mask of all the arteries. This output will be used as the 'first guess' to speed up artery annotation.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
|
| 20 |
def predict_img(net, dataset_class, full_img, device, scale_factor=1, n_classes=3):
|
| 21 |
# NOTE n_classes is the number of possible values that can be predicted for a given pixel. In a standard binary segmentation task, this will be 2 i.e. black or white
|
|
|
|
| 41 |
[
|
| 42 |
transforms.ToPILImage(),
|
| 43 |
transforms.Resize(full_img.size[1]),
|
| 44 |
+
transforms.ToTensor(),
|
| 45 |
]
|
| 46 |
)
|
| 47 |
|
| 48 |
+
full_mask = tf(probs.cpu())
|
| 49 |
|
| 50 |
if n_classes > 1:
|
| 51 |
return dataset_class.one_hot2mask(full_mask)
|
|
|
|
| 54 |
|
| 55 |
|
| 56 |
def get_args():
|
| 57 |
+
parser = argparse.ArgumentParser(
|
| 58 |
+
description="Predict masks from input images",
|
| 59 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
| 60 |
+
)
|
| 61 |
# parser.add_argument('-d', '--dataset', type=str, help='Specifies the dataset to be used', dest='dataset', required=True)
|
| 62 |
+
parser.add_argument(
|
| 63 |
+
"--model",
|
| 64 |
+
"-m",
|
| 65 |
+
default="MODEL.pth",
|
| 66 |
+
metavar="FILE",
|
| 67 |
+
help="Specify the file in which the model is stored",
|
| 68 |
+
)
|
| 69 |
+
parser.add_argument(
|
| 70 |
+
"--input",
|
| 71 |
+
"-i",
|
| 72 |
+
metavar="INPUT",
|
| 73 |
+
nargs="+",
|
| 74 |
+
help="filenames of input images",
|
| 75 |
+
required=True,
|
| 76 |
+
)
|
| 77 |
+
parser.add_argument(
|
| 78 |
+
"--output", "-o", metavar="INPUT", nargs="+", help="Filenames of output images"
|
| 79 |
+
)
|
| 80 |
|
| 81 |
return parser.parse_args()
|
|
|
|
|
|
segmentation_models_pytorch/.github/workflows/tests.yml
CHANGED
|
@@ -17,12 +17,12 @@ jobs:
|
|
| 17 |
|
| 18 |
steps:
|
| 19 |
- uses: actions/checkout@v2
|
| 20 |
-
|
| 21 |
- name: Set up Python ${{ matrix.python-version }}
|
| 22 |
uses: actions/setup-python@v2
|
| 23 |
with:
|
| 24 |
python-version: 3.6
|
| 25 |
-
|
| 26 |
- name: Install dependencies
|
| 27 |
run: |
|
| 28 |
python -m pip install --upgrade pip
|
|
|
|
| 17 |
|
| 18 |
steps:
|
| 19 |
- uses: actions/checkout@v2
|
| 20 |
+
|
| 21 |
- name: Set up Python ${{ matrix.python-version }}
|
| 22 |
uses: actions/setup-python@v2
|
| 23 |
with:
|
| 24 |
python-version: 3.6
|
| 25 |
+
|
| 26 |
- name: Install dependencies
|
| 27 |
run: |
|
| 28 |
python -m pip install --upgrade pip
|
segmentation_models_pytorch/.gitignore
CHANGED
|
@@ -102,4 +102,4 @@ venv.bak/
|
|
| 102 |
/site
|
| 103 |
|
| 104 |
# mypy
|
| 105 |
-
.mypy_cache/
|
|
|
|
| 102 |
/site
|
| 103 |
|
| 104 |
# mypy
|
| 105 |
+
.mypy_cache/
|
segmentation_models_pytorch/HALLOFFAME.md
CHANGED
|
@@ -5,7 +5,7 @@ Here you can find competitions, names of the winners and links to their solution
|
|
| 5 |
|
| 6 |
Please, follow these rules, when adding a solution to the "Hall of Fame":
|
| 7 |
|
| 8 |
-
1. Solution should be high rated (e.g. for Kaggle gold or silver medal)
|
| 9 |
2. There should be a description of the solution (post at the forum / code / blog post / paper / pre-print)
|
| 10 |
|
| 11 |
|
|
@@ -13,78 +13,77 @@ Please, follow these rules, when adding a solution to the "Hall of Fame":
|
|
| 13 |
|
| 14 |
### [Severstal: Steel Defect Detection](https://www.kaggle.com/c/severstal-steel-defect-detection)
|
| 15 |
|
| 16 |
-
- 1st place.
|
| 17 |
-
[Wuxi Jiangsu](https://www.kaggle.com/rguo97),
|
| 18 |
-
[Hongbo Zhu](https://www.kaggle.com/zhuhongbo),
|
| 19 |
-
[Yizhuo Yu](https://www.kaggle.com/paffpaffyu)
|
| 20 |
[[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114254#latest-675874)]
|
| 21 |
|
| 22 |
-
- 5th place.
|
| 23 |
-
[Guanshuo Xu](https://www.kaggle.com/wowfattie)
|
| 24 |
[[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/117208#latest-675385)]
|
| 25 |
|
| 26 |
-
- 9th place.
|
| 27 |
-
[Jacek Poplawski](https://www.linkedin.com/in/jacekpoplawski/)
|
| 28 |
[[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114297#latest-660842)]
|
| 29 |
|
| 30 |
- 10th place.
|
| 31 |
-
[Alexey Rozhkov](https://www.linkedin.com/in/alexisrozhkov)
|
| 32 |
[[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114465#latest-659615)]
|
| 33 |
|
| 34 |
-
- 12th place.
|
| 35 |
-
[Pavel Yakubovskiy](https://www.linkedin.com/in/pavel-yakubovskiy/),
|
| 36 |
-
[Ilya Dobrynin](https://www.linkedin.com/in/ilya-dobrynin-79a89b106/),
|
| 37 |
-
[Denis Kolpakov](https://www.linkedin.com/in/denis-kolpakov-ab3137197/)
|
| 38 |
[[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114309#latest-661404)]
|
| 39 |
|
| 40 |
-
- 31st place.
|
| 41 |
-
[Insaf Ashrapov](https://www.linkedin.com/in/iashrapov/),
|
| 42 |
-
[Igor Krashenyi](https://www.linkedin.com/in/igor-krashenyi-38b89b98),
|
| 43 |
-
[Pavel Pleskov](https://www.linkedin.com/in/ppleskov),
|
| 44 |
-
[Anton Zakharenkov](https://www.linkedin.com/in/anton-zakharenkov/),
|
| 45 |
-
[Nikolai Popov](https://www.linkedin.com/in/nikolai-popov-b2157370/)
|
| 46 |
[[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114383#latest-658438)]
|
| 47 |
[[code](https://github.com/Diyago/Severstal-Steel-Defect-Detection)]
|
| 48 |
|
| 49 |
-
- 55th place.
|
| 50 |
-
[Karl Hornlund](https://www.linkedin.com/in/karl-hornlund/)
|
| 51 |
[[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114410#latest-672682)]
|
| 52 |
[[code](https://github.com/khornlund/severstal-steel-defect-detection)]
|
| 53 |
|
| 54 |
- Efficiency round 1st place.
|
| 55 |
-
[Stefan Stefanov](https://www.linkedin.com/in/stefan-stefanov-63a77b1)
|
| 56 |
[[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/117486#latest-674229)]
|
| 57 |
|
| 58 |
|
| 59 |
### [Understanding Clouds from Satellite Images](https://www.kaggle.com/c/understanding_cloud_organization)
|
| 60 |
|
| 61 |
- 2nd place.
|
| 62 |
-
[Andrey Kiryasov](https://www.kaggle.com/ekydna)
|
| 63 |
[[description](https://www.kaggle.com/c/understanding_cloud_organization/discussion/118255#latest-678189)]
|
| 64 |
|
| 65 |
- 4th place.
|
| 66 |
-
[Ching-Loong Seow](https://www.linkedin.com/in/clseow/)
|
| 67 |
[[description](https://www.kaggle.com/c/understanding_cloud_organization/discussion/118016#latest-677333)]
|
| 68 |
|
| 69 |
- 34th place.
|
| 70 |
-
[Karl Hornlund](https://www.linkedin.com/in/karl-hornlund/)
|
| 71 |
[[description](https://www.kaggle.com/c/understanding_cloud_organization/discussion/118250#latest-678176)]
|
| 72 |
[[code](https://github.com/khornlund/understanding-cloud-organization)]
|
| 73 |
|
| 74 |
- 55th place.
|
| 75 |
-
[Pavel Yakubovskiy](https://www.linkedin.com/in/pavel-yakubovskiy/)
|
| 76 |
[[description](https://www.kaggle.com/c/understanding_cloud_organization/discussion/118019#latest-678626)]
|
| 77 |
|
| 78 |
## Other platforms
|
| 79 |
|
| 80 |
-
### [MICCAI 2020 TN-SCUI challenge](https://tn-scui2020.grand-challenge.org/Home/)
|
| 81 |
- 1st place.
|
| 82 |
-
[Mingyu Wang](https://github.com/WAMAWAMA)
|
| 83 |
[[description](https://github.com/WAMAWAMA/TNSCUI2020-Seg-Rank1st)]
|
| 84 |
[[code](https://github.com/WAMAWAMA/TNSCUI2020-Seg-Rank1st)]
|
| 85 |
|
| 86 |
### [Open Cities AI Challenge: Segmenting Buildings for Disaster Resilience](https://www.drivendata.org/competitions/60/building-segmentation-disaster-resilience/)
|
| 87 |
- 1st place.
|
| 88 |
-
[Pavel Yakubovskiy](https://www.linkedin.com/in/pavel-yakubovskiy/).
|
| 89 |
[[code and description](https://github.com/qubvel/open-cities-challenge)]
|
| 90 |
-
|
|
|
|
| 5 |
|
| 6 |
Please, follow these rules, when adding a solution to the "Hall of Fame":
|
| 7 |
|
| 8 |
+
1. Solution should be high rated (e.g. for Kaggle gold or silver medal)
|
| 9 |
2. There should be a description of the solution (post at the forum / code / blog post / paper / pre-print)
|
| 10 |
|
| 11 |
|
|
|
|
| 13 |
|
| 14 |
### [Severstal: Steel Defect Detection](https://www.kaggle.com/c/severstal-steel-defect-detection)
|
| 15 |
|
| 16 |
+
- 1st place.
|
| 17 |
+
[Wuxi Jiangsu](https://www.kaggle.com/rguo97),
|
| 18 |
+
[Hongbo Zhu](https://www.kaggle.com/zhuhongbo),
|
| 19 |
+
[Yizhuo Yu](https://www.kaggle.com/paffpaffyu)
|
| 20 |
[[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114254#latest-675874)]
|
| 21 |
|
| 22 |
+
- 5th place.
|
| 23 |
+
[Guanshuo Xu](https://www.kaggle.com/wowfattie)
|
| 24 |
[[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/117208#latest-675385)]
|
| 25 |
|
| 26 |
+
- 9th place.
|
| 27 |
+
[Jacek Poplawski](https://www.linkedin.com/in/jacekpoplawski/)
|
| 28 |
[[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114297#latest-660842)]
|
| 29 |
|
| 30 |
- 10th place.
|
| 31 |
+
[Alexey Rozhkov](https://www.linkedin.com/in/alexisrozhkov)
|
| 32 |
[[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114465#latest-659615)]
|
| 33 |
|
| 34 |
+
- 12th place.
|
| 35 |
+
[Pavel Yakubovskiy](https://www.linkedin.com/in/pavel-yakubovskiy/),
|
| 36 |
+
[Ilya Dobrynin](https://www.linkedin.com/in/ilya-dobrynin-79a89b106/),
|
| 37 |
+
[Denis Kolpakov](https://www.linkedin.com/in/denis-kolpakov-ab3137197/)
|
| 38 |
[[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114309#latest-661404)]
|
| 39 |
|
| 40 |
+
- 31st place.
|
| 41 |
+
[Insaf Ashrapov](https://www.linkedin.com/in/iashrapov/),
|
| 42 |
+
[Igor Krashenyi](https://www.linkedin.com/in/igor-krashenyi-38b89b98),
|
| 43 |
+
[Pavel Pleskov](https://www.linkedin.com/in/ppleskov),
|
| 44 |
+
[Anton Zakharenkov](https://www.linkedin.com/in/anton-zakharenkov/),
|
| 45 |
+
[Nikolai Popov](https://www.linkedin.com/in/nikolai-popov-b2157370/)
|
| 46 |
[[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114383#latest-658438)]
|
| 47 |
[[code](https://github.com/Diyago/Severstal-Steel-Defect-Detection)]
|
| 48 |
|
| 49 |
+
- 55th place.
|
| 50 |
+
[Karl Hornlund](https://www.linkedin.com/in/karl-hornlund/)
|
| 51 |
[[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114410#latest-672682)]
|
| 52 |
[[code](https://github.com/khornlund/severstal-steel-defect-detection)]
|
| 53 |
|
| 54 |
- Efficiency round 1st place.
|
| 55 |
+
[Stefan Stefanov](https://www.linkedin.com/in/stefan-stefanov-63a77b1)
|
| 56 |
[[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/117486#latest-674229)]
|
| 57 |
|
| 58 |
|
| 59 |
### [Understanding Clouds from Satellite Images](https://www.kaggle.com/c/understanding_cloud_organization)
|
| 60 |
|
| 61 |
- 2nd place.
|
| 62 |
+
[Andrey Kiryasov](https://www.kaggle.com/ekydna)
|
| 63 |
[[description](https://www.kaggle.com/c/understanding_cloud_organization/discussion/118255#latest-678189)]
|
| 64 |
|
| 65 |
- 4th place.
|
| 66 |
+
[Ching-Loong Seow](https://www.linkedin.com/in/clseow/)
|
| 67 |
[[description](https://www.kaggle.com/c/understanding_cloud_organization/discussion/118016#latest-677333)]
|
| 68 |
|
| 69 |
- 34th place.
|
| 70 |
+
[Karl Hornlund](https://www.linkedin.com/in/karl-hornlund/)
|
| 71 |
[[description](https://www.kaggle.com/c/understanding_cloud_organization/discussion/118250#latest-678176)]
|
| 72 |
[[code](https://github.com/khornlund/understanding-cloud-organization)]
|
| 73 |
|
| 74 |
- 55th place.
|
| 75 |
+
[Pavel Yakubovskiy](https://www.linkedin.com/in/pavel-yakubovskiy/)
|
| 76 |
[[description](https://www.kaggle.com/c/understanding_cloud_organization/discussion/118019#latest-678626)]
|
| 77 |
|
| 78 |
## Other platforms
|
| 79 |
|
| 80 |
+
### [MICCAI 2020 TN-SCUI challenge](https://tn-scui2020.grand-challenge.org/Home/)
|
| 81 |
- 1st place.
|
| 82 |
+
[Mingyu Wang](https://github.com/WAMAWAMA)
|
| 83 |
[[description](https://github.com/WAMAWAMA/TNSCUI2020-Seg-Rank1st)]
|
| 84 |
[[code](https://github.com/WAMAWAMA/TNSCUI2020-Seg-Rank1st)]
|
| 85 |
|
| 86 |
### [Open Cities AI Challenge: Segmenting Buildings for Disaster Resilience](https://www.drivendata.org/competitions/60/building-segmentation-disaster-resilience/)
|
| 87 |
- 1st place.
|
| 88 |
+
[Pavel Yakubovskiy](https://www.linkedin.com/in/pavel-yakubovskiy/).
|
| 89 |
[[code and description](https://github.com/qubvel/open-cities-challenge)]
|
|
|
segmentation_models_pytorch/README.md
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
<div align="center">
|
| 2 |
-
|
| 3 |
-

|
| 4 |
-
**Python library with Neural Networks for Image
|
| 5 |
-
Segmentation based on [PyTorch](https://pytorch.org/).**
|
| 6 |
|
| 7 |
[](https://segmentation-models-pytorch.readthedocs.io/en/latest/?badge=latest) <br> [](https://shields.io/)
|
| 8 |
|
|
@@ -14,7 +14,7 @@ The main features of this library are:
|
|
| 14 |
- 12 models architectures for binary and multi class segmentation (including legendary Unet)
|
| 15 |
- 104 available encoders
|
| 16 |
- All encoders have pre-trained weights for faster and better convergence
|
| 17 |
-
|
| 18 |
### [📚 Project Documentation 📚](http://smp.readthedocs.io/)
|
| 19 |
|
| 20 |
Visit [Read The Docs Project Page](https://segmentation-models-pytorch.readthedocs.io/en/latest/) or read following README to know more about Segmentation Models Pytorch (SMP for short) library
|
|
@@ -346,11 +346,11 @@ model = smp.FPN('resnet34', in_channels=1)
|
|
| 346 |
mask = model(torch.ones([1, 1, 64, 64]))
|
| 347 |
```
|
| 348 |
|
| 349 |
-
##### Auxiliary classification output
|
| 350 |
-
All models support `aux_params` parameters, which is default set to `None`.
|
| 351 |
If `aux_params = None` then classification auxiliary output is not created, else
|
| 352 |
model produce not only `mask`, but also `label` output with shape `NC`.
|
| 353 |
-
Classification head consists of GlobalPooling->Dropout(optional)->Linear->Activation(optional) layers, which can be
|
| 354 |
configured by `aux_params` as follows:
|
| 355 |
```python
|
| 356 |
aux_params=dict(
|
|
|
|
| 1 |
<div align="center">
|
| 2 |
+
|
| 3 |
+

|
| 4 |
+
**Python library with Neural Networks for Image
|
| 5 |
+
Segmentation based on [PyTorch](https://pytorch.org/).**
|
| 6 |
|
| 7 |
[](https://segmentation-models-pytorch.readthedocs.io/en/latest/?badge=latest) <br> [](https://shields.io/)
|
| 8 |
|
|
|
|
| 14 |
- 12 models architectures for binary and multi class segmentation (including legendary Unet)
|
| 15 |
- 104 available encoders
|
| 16 |
- All encoders have pre-trained weights for faster and better convergence
|
| 17 |
+
|
| 18 |
### [📚 Project Documentation 📚](http://smp.readthedocs.io/)
|
| 19 |
|
| 20 |
Visit [Read The Docs Project Page](https://segmentation-models-pytorch.readthedocs.io/en/latest/) or read following README to know more about Segmentation Models Pytorch (SMP for short) library
|
|
|
|
| 346 |
mask = model(torch.ones([1, 1, 64, 64]))
|
| 347 |
```
|
| 348 |
|
| 349 |
+
##### Auxiliary classification output
|
| 350 |
+
All models support `aux_params` parameters, which is default set to `None`.
|
| 351 |
If `aux_params = None` then classification auxiliary output is not created, else
|
| 352 |
model produce not only `mask`, but also `label` output with shape `NC`.
|
| 353 |
+
Classification head consists of GlobalPooling->Dropout(optional)->Linear->Activation(optional) layers, which can be
|
| 354 |
configured by `aux_params` as follows:
|
| 355 |
```python
|
| 356 |
aux_params=dict(
|
segmentation_models_pytorch/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
from segmentation_models_pytorch import *
|
|
|
|
| 1 |
+
from segmentation_models_pytorch import *
|
segmentation_models_pytorch/docker/Dockerfile
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
FROM anibali/pytorch:cuda-9.0
|
| 2 |
|
| 3 |
-
RUN pip install segmentation-models-pytorch
|
|
|
|
| 1 |
FROM anibali/pytorch:cuda-9.0
|
| 2 |
|
| 3 |
+
RUN pip install segmentation-models-pytorch
|
segmentation_models_pytorch/docs/conf.py
CHANGED
|
@@ -14,24 +14,28 @@
|
|
| 14 |
# import sys
|
| 15 |
# sys.path.insert(0, os.path.abspath('.'))
|
| 16 |
|
|
|
|
| 17 |
import os
|
| 18 |
import re
|
| 19 |
import sys
|
| 20 |
-
|
| 21 |
-
sys.path.append(
|
| 22 |
|
| 23 |
# -- Project information -----------------------------------------------------
|
| 24 |
|
| 25 |
-
project =
|
| 26 |
-
copyright =
|
| 27 |
-
author =
|
|
|
|
| 28 |
|
| 29 |
def get_version():
|
| 30 |
-
sys.path.append(
|
| 31 |
from __version__ import __version__ as version
|
|
|
|
| 32 |
sys.path.pop(-1)
|
| 33 |
return version
|
| 34 |
|
|
|
|
| 35 |
version = get_version()
|
| 36 |
|
| 37 |
# -- General configuration ---------------------------------------------------
|
|
@@ -41,15 +45,15 @@ version = get_version()
|
|
| 41 |
# ones.
|
| 42 |
|
| 43 |
extensions = [
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
]
|
| 50 |
|
| 51 |
# Add any paths that contain templates here, relative to this directory.
|
| 52 |
-
templates_path = [
|
| 53 |
|
| 54 |
# List of patterns, relative to source directory, that match files and
|
| 55 |
# directories to ignore when looking for source files.
|
|
@@ -64,12 +68,14 @@ exclude_patterns = []
|
|
| 64 |
#
|
| 65 |
|
| 66 |
import sphinx_rtd_theme
|
|
|
|
| 67 |
html_theme = "sphinx_rtd_theme"
|
| 68 |
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
|
| 69 |
|
| 70 |
# import karma_sphinx_theme
|
| 71 |
# html_theme = "karma_sphinx_theme"
|
| 72 |
import faculty_sphinx_theme
|
|
|
|
| 73 |
html_theme = "faculty_sphinx_theme"
|
| 74 |
|
| 75 |
# import catalyst_sphinx_theme
|
|
@@ -81,7 +87,7 @@ html_logo = "logo.png"
|
|
| 81 |
# Add any paths that contain custom static files (such as style sheets) here,
|
| 82 |
# relative to this directory. They are copied after the builtin static files,
|
| 83 |
# so a file named "default.css" will overwrite the builtin "default.css".
|
| 84 |
-
html_static_path = [
|
| 85 |
|
| 86 |
# -- Extension configuration -------------------------------------------------
|
| 87 |
|
|
@@ -91,30 +97,33 @@ napoleon_include_init_with_doc = True
|
|
| 91 |
napoleon_numpy_docstring = False
|
| 92 |
|
| 93 |
autodoc_mock_imports = [
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
# 'segmentation_models_pytorch.base',
|
| 104 |
]
|
| 105 |
|
| 106 |
-
autoclass_content =
|
| 107 |
-
autodoc_typehints =
|
| 108 |
|
| 109 |
# --- Work around to make autoclass signatures not (*args, **kwargs) ----------
|
| 110 |
|
| 111 |
-
|
|
|
|
| 112 |
def __getattribute__(self, *args):
|
| 113 |
raise ValueError
|
| 114 |
|
|
|
|
| 115 |
def f(app, obj, bound_method):
|
| 116 |
if "__new__" in obj.__name__:
|
| 117 |
obj.__signature__ = FakeSignature()
|
| 118 |
|
|
|
|
| 119 |
def setup(app):
|
| 120 |
-
app.connect(
|
|
|
|
| 14 |
# import sys
|
| 15 |
# sys.path.insert(0, os.path.abspath('.'))
|
| 16 |
|
| 17 |
+
import datetime
|
| 18 |
import os
|
| 19 |
import re
|
| 20 |
import sys
|
| 21 |
+
|
| 22 |
+
sys.path.append("..")
|
| 23 |
|
| 24 |
# -- Project information -----------------------------------------------------
|
| 25 |
|
| 26 |
+
project = "Segmentation Models"
|
| 27 |
+
copyright = "{}, Pavel Yakubovskiy".format(datetime.datetime.now().year)
|
| 28 |
+
author = "Pavel Yakubovskiy"
|
| 29 |
+
|
| 30 |
|
| 31 |
def get_version():
|
| 32 |
+
sys.path.append("../segmentation_models_pytorch")
|
| 33 |
from __version__ import __version__ as version
|
| 34 |
+
|
| 35 |
sys.path.pop(-1)
|
| 36 |
return version
|
| 37 |
|
| 38 |
+
|
| 39 |
version = get_version()
|
| 40 |
|
| 41 |
# -- General configuration ---------------------------------------------------
|
|
|
|
| 45 |
# ones.
|
| 46 |
|
| 47 |
extensions = [
|
| 48 |
+
"sphinx.ext.autodoc",
|
| 49 |
+
"sphinx.ext.coverage",
|
| 50 |
+
"sphinx.ext.napoleon",
|
| 51 |
+
"sphinx.ext.viewcode",
|
| 52 |
+
"sphinx.ext.mathjax",
|
| 53 |
]
|
| 54 |
|
| 55 |
# Add any paths that contain templates here, relative to this directory.
|
| 56 |
+
templates_path = ["_templates"]
|
| 57 |
|
| 58 |
# List of patterns, relative to source directory, that match files and
|
| 59 |
# directories to ignore when looking for source files.
|
|
|
|
| 68 |
#
|
| 69 |
|
| 70 |
import sphinx_rtd_theme
|
| 71 |
+
|
| 72 |
html_theme = "sphinx_rtd_theme"
|
| 73 |
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
|
| 74 |
|
| 75 |
# import karma_sphinx_theme
|
| 76 |
# html_theme = "karma_sphinx_theme"
|
| 77 |
import faculty_sphinx_theme
|
| 78 |
+
|
| 79 |
html_theme = "faculty_sphinx_theme"
|
| 80 |
|
| 81 |
# import catalyst_sphinx_theme
|
|
|
|
| 87 |
# Add any paths that contain custom static files (such as style sheets) here,
|
| 88 |
# relative to this directory. They are copied after the builtin static files,
|
| 89 |
# so a file named "default.css" will overwrite the builtin "default.css".
|
| 90 |
+
html_static_path = ["_static"]
|
| 91 |
|
| 92 |
# -- Extension configuration -------------------------------------------------
|
| 93 |
|
|
|
|
| 97 |
napoleon_numpy_docstring = False
|
| 98 |
|
| 99 |
autodoc_mock_imports = [
|
| 100 |
+
"torch",
|
| 101 |
+
"tqdm",
|
| 102 |
+
"numpy",
|
| 103 |
+
"timm",
|
| 104 |
+
"pretrainedmodels",
|
| 105 |
+
"torchvision",
|
| 106 |
+
"efficientnet-pytorch",
|
| 107 |
+
"segmentation_models_pytorch.encoders",
|
| 108 |
+
"segmentation_models_pytorch.utils",
|
| 109 |
# 'segmentation_models_pytorch.base',
|
| 110 |
]
|
| 111 |
|
| 112 |
+
autoclass_content = "both"
|
| 113 |
+
autodoc_typehints = "description"
|
| 114 |
|
| 115 |
# --- Work around to make autoclass signatures not (*args, **kwargs) ----------
|
| 116 |
|
| 117 |
+
|
| 118 |
+
class FakeSignature:
|
| 119 |
def __getattribute__(self, *args):
|
| 120 |
raise ValueError
|
| 121 |
|
| 122 |
+
|
| 123 |
def f(app, obj, bound_method):
|
| 124 |
if "__new__" in obj.__name__:
|
| 125 |
obj.__signature__ = FakeSignature()
|
| 126 |
|
| 127 |
+
|
| 128 |
def setup(app):
|
| 129 |
+
app.connect("autodoc-before-process-signature", f)
|
segmentation_models_pytorch/docs/insights.rst
CHANGED
|
@@ -21,20 +21,20 @@ Each encoder should have following attributes and methods and be inherited from
|
|
| 21 |
.. code-block:: python
|
| 22 |
|
| 23 |
class MyEncoder(torch.nn.Module, EncoderMixin):
|
| 24 |
-
|
| 25 |
def __init__(self, **kwargs):
|
| 26 |
super().__init__()
|
| 27 |
-
|
| 28 |
# A number of channels for each encoder feature tensor, list of integers
|
| 29 |
self._out_channels: List[int] = [3, 16, 64, 128, 256, 512]
|
| 30 |
|
| 31 |
# A number of stages in decoder (in other words number of downsampling operations), integer
|
| 32 |
# use in in forward pass to reduce number of returning features
|
| 33 |
-
self._depth: int = 5
|
| 34 |
|
| 35 |
# Default number of input channels in first Conv2d layer for encoder (usually 3)
|
| 36 |
-
self._in_channels: int = 3
|
| 37 |
-
|
| 38 |
# Define encoder modules below
|
| 39 |
...
|
| 40 |
|
|
@@ -90,12 +90,12 @@ For better understanding see more examples of encoder in smp.encoders module.
|
|
| 90 |
3. Aux classification output
|
| 91 |
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 92 |
|
| 93 |
-
All models support ``aux_params`` parameter, which is default set to ``None``.
|
| 94 |
If ``aux_params = None`` than classification auxiliary output is not created, else
|
| 95 |
model produce not only ``mask``, but also ``label`` output with shape ``(N, C)``.
|
| 96 |
|
| 97 |
Classification head consist of following layers:
|
| 98 |
-
|
| 99 |
1. GlobalPooling
|
| 100 |
2. Dropout (optional)
|
| 101 |
3. Linear
|
|
@@ -104,7 +104,7 @@ Classification head consist of following layers:
|
|
| 104 |
Example:
|
| 105 |
|
| 106 |
.. code-block:: python
|
| 107 |
-
|
| 108 |
aux_params=dict(
|
| 109 |
pooling='avg', # one of 'avg', 'max'
|
| 110 |
dropout=0.5, # dropout ratio, default is None
|
|
|
|
| 21 |
.. code-block:: python
|
| 22 |
|
| 23 |
class MyEncoder(torch.nn.Module, EncoderMixin):
|
| 24 |
+
|
| 25 |
def __init__(self, **kwargs):
|
| 26 |
super().__init__()
|
| 27 |
+
|
| 28 |
# A number of channels for each encoder feature tensor, list of integers
|
| 29 |
self._out_channels: List[int] = [3, 16, 64, 128, 256, 512]
|
| 30 |
|
| 31 |
# A number of stages in decoder (in other words number of downsampling operations), integer
|
| 32 |
# use in in forward pass to reduce number of returning features
|
| 33 |
+
self._depth: int = 5
|
| 34 |
|
| 35 |
# Default number of input channels in first Conv2d layer for encoder (usually 3)
|
| 36 |
+
self._in_channels: int = 3
|
| 37 |
+
|
| 38 |
# Define encoder modules below
|
| 39 |
...
|
| 40 |
|
|
|
|
| 90 |
3. Aux classification output
|
| 91 |
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 92 |
|
| 93 |
+
All models support ``aux_params`` parameter, which is default set to ``None``.
|
| 94 |
If ``aux_params = None`` than classification auxiliary output is not created, else
|
| 95 |
model produce not only ``mask``, but also ``label`` output with shape ``(N, C)``.
|
| 96 |
|
| 97 |
Classification head consist of following layers:
|
| 98 |
+
|
| 99 |
1. GlobalPooling
|
| 100 |
2. Dropout (optional)
|
| 101 |
3. Linear
|
|
|
|
| 104 |
Example:
|
| 105 |
|
| 106 |
.. code-block:: python
|
| 107 |
+
|
| 108 |
aux_params=dict(
|
| 109 |
pooling='avg', # one of 'avg', 'max'
|
| 110 |
dropout=0.5, # dropout ratio, default is None
|
segmentation_models_pytorch/docs/install.rst
CHANGED
|
@@ -5,4 +5,4 @@ Latest version from source:
|
|
| 5 |
|
| 6 |
.. code-block:: bash
|
| 7 |
|
| 8 |
-
$ pip install -U git+https://github.com/jlcsilva/segmentation_models.pytorch
|
|
|
|
| 5 |
|
| 6 |
.. code-block:: bash
|
| 7 |
|
| 8 |
+
$ pip install -U git+https://github.com/jlcsilva/segmentation_models.pytorch
|
segmentation_models_pytorch/docs/losses.rst
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
📉 Losses
|
| 2 |
=========
|
| 3 |
|
| 4 |
-
Collection of popular semantic segmentation losses. Adapted from
|
| 5 |
an awesome repo with pytorch utils https://github.com/BloodAxe/pytorch-toolbelt
|
| 6 |
|
| 7 |
Constants
|
|
|
|
| 1 |
📉 Losses
|
| 2 |
=========
|
| 3 |
|
| 4 |
+
Collection of popular semantic segmentation losses. Adapted from
|
| 5 |
an awesome repo with pytorch utils https://github.com/BloodAxe/pytorch-toolbelt
|
| 6 |
|
| 7 |
Constants
|
segmentation_models_pytorch/docs/models.rst
CHANGED
|
@@ -48,5 +48,3 @@ DeepLabV3
|
|
| 48 |
DeepLabV3+
|
| 49 |
~~~~~~~~~~
|
| 50 |
.. autoclass:: segmentation_models_pytorch.DeepLabV3Plus
|
| 51 |
-
|
| 52 |
-
|
|
|
|
| 48 |
DeepLabV3+
|
| 49 |
~~~~~~~~~~
|
| 50 |
.. autoclass:: segmentation_models_pytorch.DeepLabV3Plus
|
|
|
|
|
|
segmentation_models_pytorch/docs/quickstart.rst
CHANGED
|
@@ -6,7 +6,7 @@
|
|
| 6 |
Segmentation model is just a PyTorch nn.Module, which can be created as easy as:
|
| 7 |
|
| 8 |
.. code-block:: python
|
| 9 |
-
|
| 10 |
import segmentation_models_pytorch as smp
|
| 11 |
|
| 12 |
model = smp.Unet(
|
|
|
|
| 6 |
Segmentation model is just a PyTorch nn.Module, which can be created as easy as:
|
| 7 |
|
| 8 |
.. code-block:: python
|
| 9 |
+
|
| 10 |
import segmentation_models_pytorch as smp
|
| 11 |
|
| 12 |
model = smp.Unet(
|
segmentation_models_pytorch/docs/requirements.txt
CHANGED
|
@@ -1,2 +1,2 @@
|
|
| 1 |
faculty-sphinx-theme==0.2.2
|
| 2 |
-
six==1.15.0
|
|
|
|
| 1 |
faculty-sphinx-theme==0.2.2
|
| 2 |
+
six==1.15.0
|
segmentation_models_pytorch/misc/generate_table.py
CHANGED
|
@@ -10,11 +10,15 @@ COLUMNS = [
|
|
| 10 |
"Params, M",
|
| 11 |
]
|
| 12 |
|
|
|
|
| 13 |
def wrap_row(r):
|
| 14 |
return "|{}|".format(r)
|
| 15 |
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
print(wrap_row(header))
|
| 20 |
print(wrap_row(separator))
|
|
|
|
| 10 |
"Params, M",
|
| 11 |
]
|
| 12 |
|
| 13 |
+
|
| 14 |
def wrap_row(r):
|
| 15 |
return "|{}|".format(r)
|
| 16 |
|
| 17 |
+
|
| 18 |
+
header = "|".join([column.ljust(WIDTH, " ") for column in COLUMNS])
|
| 19 |
+
separator = "|".join(
|
| 20 |
+
["-" * WIDTH] + [":" + "-" * (WIDTH - 2) + ":"] * (len(COLUMNS) - 1)
|
| 21 |
+
)
|
| 22 |
|
| 23 |
print(wrap_row(header))
|
| 24 |
print(wrap_row(separator))
|
segmentation_models_pytorch/segmentation_models_pytorch/__init__.py
CHANGED
|
@@ -1,23 +1,20 @@
|
|
| 1 |
-
from
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
from .
|
| 6 |
-
from .
|
| 7 |
from .deeplabv3 import DeepLabV3, DeepLabV3Plus
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
from .pan import PAN
|
|
|
|
| 9 |
from .resunet import ResUnet
|
| 10 |
from .resunetplusplus import ResUnetPlusPlus
|
| 11 |
-
from .
|
| 12 |
-
|
| 13 |
-
from . import encoders
|
| 14 |
-
from . import utils
|
| 15 |
-
from . import losses
|
| 16 |
-
|
| 17 |
-
from .__version__ import __version__
|
| 18 |
-
|
| 19 |
-
from typing import Optional
|
| 20 |
-
import torch
|
| 21 |
|
| 22 |
|
| 23 |
def create_model(
|
|
@@ -28,18 +25,32 @@ def create_model(
|
|
| 28 |
classes: int = 1,
|
| 29 |
**kwargs,
|
| 30 |
) -> torch.nn.Module:
|
| 31 |
-
"""Models wrapper. Allows to create any model just with parametes
|
| 32 |
-
|
| 33 |
-
"""
|
| 34 |
|
| 35 |
-
archs = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
archs_dict = {a.__name__.lower(): a for a in archs}
|
| 37 |
try:
|
| 38 |
model_class = archs_dict[arch.lower()]
|
| 39 |
except KeyError:
|
| 40 |
-
raise KeyError(
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
| 43 |
return model_class(
|
| 44 |
encoder_name=encoder_name,
|
| 45 |
encoder_weights=encoder_weights,
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from . import encoders, losses, utils
|
| 6 |
+
from .__version__ import __version__
|
| 7 |
from .deeplabv3 import DeepLabV3, DeepLabV3Plus
|
| 8 |
+
from .efficientunetplusplus import EfficientUnetPlusPlus
|
| 9 |
+
from .fpn import FPN
|
| 10 |
+
from .linknet import Linknet
|
| 11 |
+
from .manet import MAnet
|
| 12 |
from .pan import PAN
|
| 13 |
+
from .pspnet import PSPNet
|
| 14 |
from .resunet import ResUnet
|
| 15 |
from .resunetplusplus import ResUnetPlusPlus
|
| 16 |
+
from .unet import Unet
|
| 17 |
+
from .unetplusplus import UnetPlusPlus
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
def create_model(
|
|
|
|
| 25 |
classes: int = 1,
|
| 26 |
**kwargs,
|
| 27 |
) -> torch.nn.Module:
|
| 28 |
+
"""Models wrapper. Allows to create any model just with parametes"""
|
|
|
|
|
|
|
| 29 |
|
| 30 |
+
archs = [
|
| 31 |
+
Unet,
|
| 32 |
+
UnetPlusPlus,
|
| 33 |
+
MAnet,
|
| 34 |
+
Linknet,
|
| 35 |
+
FPN,
|
| 36 |
+
PSPNet,
|
| 37 |
+
DeepLabV3,
|
| 38 |
+
DeepLabV3Plus,
|
| 39 |
+
PAN,
|
| 40 |
+
ResUnet,
|
| 41 |
+
EfficientUnetPlusPlus,
|
| 42 |
+
ResUnetPlusPlus,
|
| 43 |
+
]
|
| 44 |
archs_dict = {a.__name__.lower(): a for a in archs}
|
| 45 |
try:
|
| 46 |
model_class = archs_dict[arch.lower()]
|
| 47 |
except KeyError:
|
| 48 |
+
raise KeyError(
|
| 49 |
+
"Wrong architecture type `{}`. Avalibale options are: {}".format(
|
| 50 |
+
arch,
|
| 51 |
+
list(archs_dict.keys()),
|
| 52 |
+
)
|
| 53 |
+
)
|
| 54 |
return model_class(
|
| 55 |
encoder_name=encoder_name,
|
| 56 |
encoder_weights=encoder_weights,
|
segmentation_models_pytorch/segmentation_models_pytorch/__version__.py
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
VERSION = (0, 1, 3)
|
| 2 |
|
| 3 |
-
__version__ =
|
|
|
|
| 1 |
VERSION = (0, 1, 3)
|
| 2 |
|
| 3 |
+
__version__ = ".".join(map(str, VERSION))
|
segmentation_models_pytorch/segmentation_models_pytorch/base/__init__.py
CHANGED
|
@@ -1,12 +1,3 @@
|
|
|
|
|
| 1 |
from .model import SegmentationModel
|
| 2 |
-
|
| 3 |
-
from .modules import (
|
| 4 |
-
PreActivatedConv2dReLU,
|
| 5 |
-
Conv2dReLU,
|
| 6 |
-
Attention,
|
| 7 |
-
)
|
| 8 |
-
|
| 9 |
-
from .heads import (
|
| 10 |
-
SegmentationHead,
|
| 11 |
-
ClassificationHead,
|
| 12 |
-
)
|
|
|
|
| 1 |
+
from .heads import ClassificationHead, SegmentationHead
|
| 2 |
from .model import SegmentationModel
|
| 3 |
+
from .modules import Attention, Conv2dReLU, PreActivatedConv2dReLU
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
segmentation_models_pytorch/segmentation_models_pytorch/base/heads.py
CHANGED
|
@@ -1,22 +1,33 @@
|
|
| 1 |
import torch.nn as nn
|
| 2 |
-
from .modules import Flatten, Activation
|
| 3 |
|
|
|
|
| 4 |
|
| 5 |
-
class SegmentationHead(nn.Sequential):
|
| 6 |
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
activation = Activation(activation)
|
| 11 |
super().__init__(conv2d, upsampling, activation)
|
| 12 |
|
| 13 |
|
| 14 |
class ClassificationHead(nn.Sequential):
|
| 15 |
-
|
| 16 |
-
|
|
|
|
| 17 |
if pooling not in ("max", "avg"):
|
| 18 |
-
raise ValueError(
|
| 19 |
-
|
|
|
|
|
|
|
| 20 |
flatten = Flatten()
|
| 21 |
dropout = nn.Dropout(p=dropout, inplace=True) if dropout else nn.Identity()
|
| 22 |
linear = nn.Linear(in_channels, classes, bias=True)
|
|
|
|
| 1 |
import torch.nn as nn
|
|
|
|
| 2 |
|
| 3 |
+
from .modules import Activation, Flatten
|
| 4 |
|
|
|
|
| 5 |
|
| 6 |
+
class SegmentationHead(nn.Sequential):
|
| 7 |
+
def __init__(
|
| 8 |
+
self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1
|
| 9 |
+
):
|
| 10 |
+
conv2d = nn.Conv2d(
|
| 11 |
+
in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2
|
| 12 |
+
)
|
| 13 |
+
upsampling = (
|
| 14 |
+
nn.UpsamplingBilinear2d(scale_factor=upsampling)
|
| 15 |
+
if upsampling > 1
|
| 16 |
+
else nn.Identity()
|
| 17 |
+
)
|
| 18 |
activation = Activation(activation)
|
| 19 |
super().__init__(conv2d, upsampling, activation)
|
| 20 |
|
| 21 |
|
| 22 |
class ClassificationHead(nn.Sequential):
|
| 23 |
+
def __init__(
|
| 24 |
+
self, in_channels, classes, pooling="avg", dropout=0.2, activation=None
|
| 25 |
+
):
|
| 26 |
if pooling not in ("max", "avg"):
|
| 27 |
+
raise ValueError(
|
| 28 |
+
"Pooling should be one of ('max', 'avg'), got {}.".format(pooling)
|
| 29 |
+
)
|
| 30 |
+
pool = nn.AdaptiveAvgPool2d(1) if pooling == "avg" else nn.AdaptiveMaxPool2d(1)
|
| 31 |
flatten = Flatten()
|
| 32 |
dropout = nn.Dropout(p=dropout, inplace=True) if dropout else nn.Identity()
|
| 33 |
linear = nn.Linear(in_channels, classes, bias=True)
|
segmentation_models_pytorch/segmentation_models_pytorch/base/initialization.py
CHANGED
|
@@ -3,7 +3,6 @@ import torch.nn as nn
|
|
| 3 |
|
| 4 |
def initialize_decoder(module):
|
| 5 |
for m in module.modules():
|
| 6 |
-
|
| 7 |
if isinstance(m, nn.Conv2d):
|
| 8 |
nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu")
|
| 9 |
if m.bias is not None:
|
|
|
|
| 3 |
|
| 4 |
def initialize_decoder(module):
|
| 5 |
for m in module.modules():
|
|
|
|
| 6 |
if isinstance(m, nn.Conv2d):
|
| 7 |
nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu")
|
| 8 |
if m.bias is not None:
|
segmentation_models_pytorch/segmentation_models_pytorch/base/model.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
|
|
|
| 2 |
from . import initialization as init
|
| 3 |
|
| 4 |
|
| 5 |
class SegmentationModel(torch.nn.Module):
|
| 6 |
-
|
| 7 |
def initialize(self):
|
| 8 |
init.initialize_decoder(self.decoder)
|
| 9 |
init.initialize_head(self.segmentation_head)
|
|
|
|
| 1 |
import torch
|
| 2 |
+
|
| 3 |
from . import initialization as init
|
| 4 |
|
| 5 |
|
| 6 |
class SegmentationModel(torch.nn.Module):
|
|
|
|
| 7 |
def initialize(self):
|
| 8 |
init.initialize_decoder(self.decoder)
|
| 9 |
init.initialize_head(self.segmentation_head)
|
segmentation_models_pytorch/segmentation_models_pytorch/base/modules.py
CHANGED
|
@@ -6,22 +6,23 @@ try:
|
|
| 6 |
except ImportError:
|
| 7 |
InPlaceABN = None
|
| 8 |
|
|
|
|
| 9 |
class PreActivatedConv2dReLU(nn.Sequential):
|
| 10 |
"""
|
| 11 |
-
Pre-activated 2D convolution, as proposed in https://arxiv.org/pdf/1603.05027.pdf. Feature maps are processed by a normalization layer,
|
| 12 |
followed by a ReLU activation and a 3x3 convolution.
|
| 13 |
normalization
|
| 14 |
"""
|
|
|
|
| 15 |
def __init__(
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
):
|
| 24 |
-
|
| 25 |
if use_batchnorm == "inplace" and InPlaceABN is None:
|
| 26 |
raise RuntimeError(
|
| 27 |
"In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
|
|
@@ -47,20 +48,21 @@ class PreActivatedConv2dReLU(nn.Sequential):
|
|
| 47 |
)
|
| 48 |
super(PreActivatedConv2dReLU, self).__init__(conv, bn, relu)
|
| 49 |
|
|
|
|
| 50 |
class Conv2dReLU(nn.Sequential):
|
| 51 |
"""
|
| 52 |
Block composed of a 3x3 convolution followed by a normalization layer and ReLU activation.
|
| 53 |
"""
|
|
|
|
| 54 |
def __init__(
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
):
|
| 63 |
-
|
| 64 |
if use_batchnorm == "inplace" and InPlaceABN is None:
|
| 65 |
raise RuntimeError(
|
| 66 |
"In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
|
|
@@ -87,20 +89,33 @@ class Conv2dReLU(nn.Sequential):
|
|
| 87 |
|
| 88 |
super(Conv2dReLU, self).__init__(conv, bn, relu)
|
| 89 |
|
|
|
|
| 90 |
class DepthWiseConv2d(nn.Conv2d):
|
| 91 |
"Depth-wise convolution operation"
|
|
|
|
| 92 |
def __init__(self, channels, kernel_size=3, stride=1):
|
| 93 |
-
super().__init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
class PointWiseConv2d(nn.Conv2d):
|
| 96 |
"Point-wise (1x1) convolution operation"
|
|
|
|
| 97 |
def __init__(self, in_channels, out_channels):
|
| 98 |
super().__init__(in_channels, out_channels, kernel_size=1, stride=1)
|
| 99 |
|
|
|
|
| 100 |
class SEModule(nn.Module):
|
| 101 |
"""
|
| 102 |
Spatial squeeze & channel excitation attention module, as proposed in https://arxiv.org/abs/1709.01507.
|
| 103 |
"""
|
|
|
|
| 104 |
def __init__(self, in_channels, reduction=16):
|
| 105 |
super().__init__()
|
| 106 |
self.cSE = nn.Sequential(
|
|
@@ -114,10 +129,12 @@ class SEModule(nn.Module):
|
|
| 114 |
def forward(self, x):
|
| 115 |
return x * self.cSE(x)
|
| 116 |
|
|
|
|
| 117 |
class sSEModule(nn.Module):
|
| 118 |
"""
|
| 119 |
Channel squeeze & spatial excitation attention module, as proposed in https://arxiv.org/abs/1808.08127.
|
| 120 |
"""
|
|
|
|
| 121 |
def __init__(self, in_channels):
|
| 122 |
super().__init__()
|
| 123 |
self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid())
|
|
@@ -125,10 +142,12 @@ class sSEModule(nn.Module):
|
|
| 125 |
def forward(self, x):
|
| 126 |
return x * self.sSE(x)
|
| 127 |
|
|
|
|
| 128 |
class SCSEModule(nn.Module):
|
| 129 |
"""
|
| 130 |
Concurrent spatial and channel squeeze & excitation attention module, as proposed in https://arxiv.org/pdf/1803.02579.pdf.
|
| 131 |
"""
|
|
|
|
| 132 |
def __init__(self, in_channels, reduction=16):
|
| 133 |
super().__init__()
|
| 134 |
self.cSE = nn.Sequential(
|
|
@@ -143,8 +162,8 @@ class SCSEModule(nn.Module):
|
|
| 143 |
def forward(self, x):
|
| 144 |
return x * self.cSE(x) + x * self.sSE(x)
|
| 145 |
|
| 146 |
-
class ArgMax(nn.Module):
|
| 147 |
|
|
|
|
| 148 |
def __init__(self, dim=None):
|
| 149 |
super().__init__()
|
| 150 |
self.dim = dim
|
|
@@ -154,46 +173,47 @@ class ArgMax(nn.Module):
|
|
| 154 |
|
| 155 |
|
| 156 |
class Activation(nn.Module):
|
| 157 |
-
|
| 158 |
def __init__(self, name, **params):
|
| 159 |
-
|
| 160 |
super().__init__()
|
| 161 |
|
| 162 |
-
if name is None or name ==
|
| 163 |
self.activation = nn.Identity(**params)
|
| 164 |
-
elif name ==
|
| 165 |
self.activation = nn.Sigmoid()
|
| 166 |
-
elif name ==
|
| 167 |
self.activation = nn.Softmax(dim=1, **params)
|
| 168 |
-
elif name ==
|
| 169 |
self.activation = nn.Softmax(**params)
|
| 170 |
-
elif name ==
|
| 171 |
self.activation = nn.LogSoftmax(**params)
|
| 172 |
-
elif name ==
|
| 173 |
self.activation = nn.Tanh()
|
| 174 |
-
elif name ==
|
| 175 |
self.activation = ArgMax(**params)
|
| 176 |
-
elif name ==
|
| 177 |
self.activation = ArgMax(dim=1, **params)
|
| 178 |
elif callable(name):
|
| 179 |
self.activation = name(**params)
|
| 180 |
else:
|
| 181 |
-
raise ValueError(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
def forward(self, x):
|
| 184 |
return self.activation(x)
|
| 185 |
|
| 186 |
|
| 187 |
class Attention(nn.Module):
|
| 188 |
-
|
| 189 |
def __init__(self, name, **params):
|
| 190 |
super().__init__()
|
| 191 |
|
| 192 |
if name is None:
|
| 193 |
self.attention = nn.Identity(**params)
|
| 194 |
-
elif name ==
|
| 195 |
self.attention = SCSEModule(**params)
|
| 196 |
-
elif name ==
|
| 197 |
self.attention = SEModule(**params)
|
| 198 |
else:
|
| 199 |
raise ValueError("Attention {} is not implemented".format(name))
|
|
@@ -201,6 +221,7 @@ class Attention(nn.Module):
|
|
| 201 |
def forward(self, x):
|
| 202 |
return self.attention(x)
|
| 203 |
|
|
|
|
| 204 |
class Flatten(nn.Module):
|
| 205 |
def forward(self, x):
|
| 206 |
-
return x.view(x.shape[0], -1)
|
|
|
|
| 6 |
except ImportError:
|
| 7 |
InPlaceABN = None
|
| 8 |
|
| 9 |
+
|
| 10 |
class PreActivatedConv2dReLU(nn.Sequential):
|
| 11 |
"""
|
| 12 |
+
Pre-activated 2D convolution, as proposed in https://arxiv.org/pdf/1603.05027.pdf. Feature maps are processed by a normalization layer,
|
| 13 |
followed by a ReLU activation and a 3x3 convolution.
|
| 14 |
normalization
|
| 15 |
"""
|
| 16 |
+
|
| 17 |
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
in_channels,
|
| 20 |
+
out_channels,
|
| 21 |
+
kernel_size,
|
| 22 |
+
padding=0,
|
| 23 |
+
stride=1,
|
| 24 |
+
use_batchnorm=True,
|
| 25 |
):
|
|
|
|
| 26 |
if use_batchnorm == "inplace" and InPlaceABN is None:
|
| 27 |
raise RuntimeError(
|
| 28 |
"In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
|
|
|
|
| 48 |
)
|
| 49 |
super(PreActivatedConv2dReLU, self).__init__(conv, bn, relu)
|
| 50 |
|
| 51 |
+
|
| 52 |
class Conv2dReLU(nn.Sequential):
|
| 53 |
"""
|
| 54 |
Block composed of a 3x3 convolution followed by a normalization layer and ReLU activation.
|
| 55 |
"""
|
| 56 |
+
|
| 57 |
def __init__(
|
| 58 |
+
self,
|
| 59 |
+
in_channels,
|
| 60 |
+
out_channels,
|
| 61 |
+
kernel_size,
|
| 62 |
+
padding=0,
|
| 63 |
+
stride=1,
|
| 64 |
+
use_batchnorm=True,
|
| 65 |
):
|
|
|
|
| 66 |
if use_batchnorm == "inplace" and InPlaceABN is None:
|
| 67 |
raise RuntimeError(
|
| 68 |
"In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
|
|
|
|
| 89 |
|
| 90 |
super(Conv2dReLU, self).__init__(conv, bn, relu)
|
| 91 |
|
| 92 |
+
|
| 93 |
class DepthWiseConv2d(nn.Conv2d):
|
| 94 |
"Depth-wise convolution operation"
|
| 95 |
+
|
| 96 |
def __init__(self, channels, kernel_size=3, stride=1):
|
| 97 |
+
super().__init__(
|
| 98 |
+
channels,
|
| 99 |
+
channels,
|
| 100 |
+
kernel_size,
|
| 101 |
+
stride=stride,
|
| 102 |
+
padding=kernel_size // 2,
|
| 103 |
+
groups=channels,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
|
| 107 |
class PointWiseConv2d(nn.Conv2d):
|
| 108 |
"Point-wise (1x1) convolution operation"
|
| 109 |
+
|
| 110 |
def __init__(self, in_channels, out_channels):
|
| 111 |
super().__init__(in_channels, out_channels, kernel_size=1, stride=1)
|
| 112 |
|
| 113 |
+
|
| 114 |
class SEModule(nn.Module):
|
| 115 |
"""
|
| 116 |
Spatial squeeze & channel excitation attention module, as proposed in https://arxiv.org/abs/1709.01507.
|
| 117 |
"""
|
| 118 |
+
|
| 119 |
def __init__(self, in_channels, reduction=16):
|
| 120 |
super().__init__()
|
| 121 |
self.cSE = nn.Sequential(
|
|
|
|
| 129 |
def forward(self, x):
|
| 130 |
return x * self.cSE(x)
|
| 131 |
|
| 132 |
+
|
| 133 |
class sSEModule(nn.Module):
|
| 134 |
"""
|
| 135 |
Channel squeeze & spatial excitation attention module, as proposed in https://arxiv.org/abs/1808.08127.
|
| 136 |
"""
|
| 137 |
+
|
| 138 |
def __init__(self, in_channels):
|
| 139 |
super().__init__()
|
| 140 |
self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid())
|
|
|
|
| 142 |
def forward(self, x):
|
| 143 |
return x * self.sSE(x)
|
| 144 |
|
| 145 |
+
|
| 146 |
class SCSEModule(nn.Module):
|
| 147 |
"""
|
| 148 |
Concurrent spatial and channel squeeze & excitation attention module, as proposed in https://arxiv.org/pdf/1803.02579.pdf.
|
| 149 |
"""
|
| 150 |
+
|
| 151 |
def __init__(self, in_channels, reduction=16):
|
| 152 |
super().__init__()
|
| 153 |
self.cSE = nn.Sequential(
|
|
|
|
| 162 |
def forward(self, x):
|
| 163 |
return x * self.cSE(x) + x * self.sSE(x)
|
| 164 |
|
|
|
|
| 165 |
|
| 166 |
+
class ArgMax(nn.Module):
|
| 167 |
def __init__(self, dim=None):
|
| 168 |
super().__init__()
|
| 169 |
self.dim = dim
|
|
|
|
| 173 |
|
| 174 |
|
| 175 |
class Activation(nn.Module):
|
|
|
|
| 176 |
def __init__(self, name, **params):
|
|
|
|
| 177 |
super().__init__()
|
| 178 |
|
| 179 |
+
if name is None or name == "identity":
|
| 180 |
self.activation = nn.Identity(**params)
|
| 181 |
+
elif name == "sigmoid":
|
| 182 |
self.activation = nn.Sigmoid()
|
| 183 |
+
elif name == "softmax2d":
|
| 184 |
self.activation = nn.Softmax(dim=1, **params)
|
| 185 |
+
elif name == "softmax":
|
| 186 |
self.activation = nn.Softmax(**params)
|
| 187 |
+
elif name == "logsoftmax":
|
| 188 |
self.activation = nn.LogSoftmax(**params)
|
| 189 |
+
elif name == "tanh":
|
| 190 |
self.activation = nn.Tanh()
|
| 191 |
+
elif name == "argmax":
|
| 192 |
self.activation = ArgMax(**params)
|
| 193 |
+
elif name == "argmax2d":
|
| 194 |
self.activation = ArgMax(dim=1, **params)
|
| 195 |
elif callable(name):
|
| 196 |
self.activation = name(**params)
|
| 197 |
else:
|
| 198 |
+
raise ValueError(
|
| 199 |
+
"Activation should be callable/sigmoid/softmax/logsoftmax/tanh/None; got {}".format(
|
| 200 |
+
name
|
| 201 |
+
)
|
| 202 |
+
)
|
| 203 |
|
| 204 |
def forward(self, x):
|
| 205 |
return self.activation(x)
|
| 206 |
|
| 207 |
|
| 208 |
class Attention(nn.Module):
|
|
|
|
| 209 |
def __init__(self, name, **params):
|
| 210 |
super().__init__()
|
| 211 |
|
| 212 |
if name is None:
|
| 213 |
self.attention = nn.Identity(**params)
|
| 214 |
+
elif name == "scse":
|
| 215 |
self.attention = SCSEModule(**params)
|
| 216 |
+
elif name == "se":
|
| 217 |
self.attention = SEModule(**params)
|
| 218 |
else:
|
| 219 |
raise ValueError("Attention {} is not implemented".format(name))
|
|
|
|
| 221 |
def forward(self, x):
|
| 222 |
return self.attention(x)
|
| 223 |
|
| 224 |
+
|
| 225 |
class Flatten(nn.Module):
|
| 226 |
def forward(self, x):
|
| 227 |
+
return x.view(x.shape[0], -1)
|
segmentation_models_pytorch/segmentation_models_pytorch/deeplabv3/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
from .model import DeepLabV3, DeepLabV3Plus
|
|
|
|
| 1 |
+
from .model import DeepLabV3, DeepLabV3Plus
|
segmentation_models_pytorch/segmentation_models_pytorch/deeplabv3/decoder.py
CHANGED
|
@@ -61,14 +61,18 @@ class DeepLabV3PlusDecoder(nn.Module):
|
|
| 61 |
):
|
| 62 |
super().__init__()
|
| 63 |
if output_stride not in {8, 16}:
|
| 64 |
-
raise ValueError(
|
|
|
|
|
|
|
| 65 |
|
| 66 |
self.out_channels = out_channels
|
| 67 |
self.output_stride = output_stride
|
| 68 |
|
| 69 |
self.aspp = nn.Sequential(
|
| 70 |
ASPP(encoder_channels[-1], out_channels, atrous_rates, separable=True),
|
| 71 |
-
SeparableConv2d(
|
|
|
|
|
|
|
| 72 |
nn.BatchNorm2d(out_channels),
|
| 73 |
nn.ReLU(),
|
| 74 |
)
|
|
@@ -77,9 +81,11 @@ class DeepLabV3PlusDecoder(nn.Module):
|
|
| 77 |
self.up = nn.UpsamplingBilinear2d(scale_factor=scale_factor)
|
| 78 |
|
| 79 |
highres_in_channels = encoder_channels[-4]
|
| 80 |
-
highres_out_channels = 48
|
| 81 |
self.block1 = nn.Sequential(
|
| 82 |
-
nn.Conv2d(
|
|
|
|
|
|
|
| 83 |
nn.BatchNorm2d(highres_out_channels),
|
| 84 |
nn.ReLU(),
|
| 85 |
)
|
|
@@ -149,7 +155,7 @@ class ASPPPooling(nn.Sequential):
|
|
| 149 |
size = x.shape[-2:]
|
| 150 |
for mod in self:
|
| 151 |
x = mod(x)
|
| 152 |
-
return F.interpolate(x, size=size, mode=
|
| 153 |
|
| 154 |
|
| 155 |
class ASPP(nn.Module):
|
|
@@ -190,16 +196,15 @@ class ASPP(nn.Module):
|
|
| 190 |
|
| 191 |
|
| 192 |
class SeparableConv2d(nn.Sequential):
|
| 193 |
-
|
| 194 |
def __init__(
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
):
|
| 204 |
dephtwise_conv = nn.Conv2d(
|
| 205 |
in_channels,
|
|
|
|
| 61 |
):
|
| 62 |
super().__init__()
|
| 63 |
if output_stride not in {8, 16}:
|
| 64 |
+
raise ValueError(
|
| 65 |
+
"Output stride should be 8 or 16, got {}.".format(output_stride)
|
| 66 |
+
)
|
| 67 |
|
| 68 |
self.out_channels = out_channels
|
| 69 |
self.output_stride = output_stride
|
| 70 |
|
| 71 |
self.aspp = nn.Sequential(
|
| 72 |
ASPP(encoder_channels[-1], out_channels, atrous_rates, separable=True),
|
| 73 |
+
SeparableConv2d(
|
| 74 |
+
out_channels, out_channels, kernel_size=3, padding=1, bias=False
|
| 75 |
+
),
|
| 76 |
nn.BatchNorm2d(out_channels),
|
| 77 |
nn.ReLU(),
|
| 78 |
)
|
|
|
|
| 81 |
self.up = nn.UpsamplingBilinear2d(scale_factor=scale_factor)
|
| 82 |
|
| 83 |
highres_in_channels = encoder_channels[-4]
|
| 84 |
+
highres_out_channels = 48 # proposed by authors of paper
|
| 85 |
self.block1 = nn.Sequential(
|
| 86 |
+
nn.Conv2d(
|
| 87 |
+
highres_in_channels, highres_out_channels, kernel_size=1, bias=False
|
| 88 |
+
),
|
| 89 |
nn.BatchNorm2d(highres_out_channels),
|
| 90 |
nn.ReLU(),
|
| 91 |
)
|
|
|
|
| 155 |
size = x.shape[-2:]
|
| 156 |
for mod in self:
|
| 157 |
x = mod(x)
|
| 158 |
+
return F.interpolate(x, size=size, mode="bilinear", align_corners=False)
|
| 159 |
|
| 160 |
|
| 161 |
class ASPP(nn.Module):
|
|
|
|
| 196 |
|
| 197 |
|
| 198 |
class SeparableConv2d(nn.Sequential):
|
|
|
|
| 199 |
def __init__(
|
| 200 |
+
self,
|
| 201 |
+
in_channels,
|
| 202 |
+
out_channels,
|
| 203 |
+
kernel_size,
|
| 204 |
+
stride=1,
|
| 205 |
+
padding=0,
|
| 206 |
+
dilation=1,
|
| 207 |
+
bias=True,
|
| 208 |
):
|
| 209 |
dephtwise_conv = nn.Conv2d(
|
| 210 |
in_channels,
|
segmentation_models_pytorch/segmentation_models_pytorch/deeplabv3/model.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
|
|
|
|
|
|
| 1 |
import torch.nn as nn
|
| 2 |
|
| 3 |
-
from
|
| 4 |
-
from .decoder import DeepLabV3Decoder, DeepLabV3PlusDecoder
|
| 5 |
-
from ..base import SegmentationModel, SegmentationHead, ClassificationHead
|
| 6 |
from ..encoders import get_encoder
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
class DeepLabV3(SegmentationModel):
|
|
@@ -12,11 +13,11 @@ class DeepLabV3(SegmentationModel):
|
|
| 12 |
Args:
|
| 13 |
encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
|
| 14 |
to extract features of different spatial resolution
|
| 15 |
-
encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
|
| 16 |
two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
|
| 17 |
with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
|
| 18 |
Default is 5
|
| 19 |
-
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
|
| 20 |
other pretrained weights (see table with available weights for each encoder_name)
|
| 21 |
decoder_channels: A number of convolution filters in ASPP module. Default is 256
|
| 22 |
in_channels: A number of input channels for the model, default is 3 (RGB images)
|
|
@@ -25,7 +26,7 @@ class DeepLabV3(SegmentationModel):
|
|
| 25 |
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, **callable** and **None**.
|
| 26 |
Default is **None**
|
| 27 |
upsampling: Final upsampling factor. Default is 8 to preserve input-output spatial shape identity
|
| 28 |
-
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
|
| 29 |
on top of encoder if **aux_params** is not **None** (default). Supported params:
|
| 30 |
- classes (int): A number of classes
|
| 31 |
- pooling (str): One of "max", "avg". Default is "avg"
|
|
@@ -42,16 +43,16 @@ class DeepLabV3(SegmentationModel):
|
|
| 42 |
"""
|
| 43 |
|
| 44 |
def __init__(
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
):
|
| 56 |
super().__init__()
|
| 57 |
|
|
@@ -61,10 +62,7 @@ class DeepLabV3(SegmentationModel):
|
|
| 61 |
depth=encoder_depth,
|
| 62 |
weights=encoder_weights,
|
| 63 |
)
|
| 64 |
-
self.encoder.make_dilated(
|
| 65 |
-
stage_list=[4, 5],
|
| 66 |
-
dilation_list=[2, 4]
|
| 67 |
-
)
|
| 68 |
|
| 69 |
self.decoder = DeepLabV3Decoder(
|
| 70 |
in_channels=self.encoder.out_channels[-1],
|
|
@@ -90,15 +88,15 @@ class DeepLabV3(SegmentationModel):
|
|
| 90 |
class DeepLabV3Plus(SegmentationModel):
|
| 91 |
"""DeepLabV3+ implementation from "Encoder-Decoder with Atrous Separable
|
| 92 |
Convolution for Semantic Image Segmentation"
|
| 93 |
-
|
| 94 |
Args:
|
| 95 |
encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
|
| 96 |
to extract features of different spatial resolution
|
| 97 |
-
encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
|
| 98 |
two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
|
| 99 |
with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
|
| 100 |
Default is 5
|
| 101 |
-
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
|
| 102 |
other pretrained weights (see table with available weights for each encoder_name)
|
| 103 |
encoder_output_stride: Downsampling factor for last encoder features (see original paper for explanation)
|
| 104 |
decoder_atrous_rates: Dilation rates for ASPP module (should be a tuple of 3 integer values)
|
|
@@ -109,7 +107,7 @@ class DeepLabV3Plus(SegmentationModel):
|
|
| 109 |
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, **callable** and **None**.
|
| 110 |
Default is **None**
|
| 111 |
upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity
|
| 112 |
-
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
|
| 113 |
on top of encoder if **aux_params** is not **None** (default). Supported params:
|
| 114 |
- classes (int): A number of classes
|
| 115 |
- pooling (str): One of "max", "avg". Default is "avg"
|
|
@@ -121,19 +119,20 @@ class DeepLabV3Plus(SegmentationModel):
|
|
| 121 |
Reference:
|
| 122 |
https://arxiv.org/abs/1802.02611v3
|
| 123 |
"""
|
|
|
|
| 124 |
def __init__(
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
):
|
| 138 |
super().__init__()
|
| 139 |
|
|
@@ -145,19 +144,15 @@ class DeepLabV3Plus(SegmentationModel):
|
|
| 145 |
)
|
| 146 |
|
| 147 |
if encoder_output_stride == 8:
|
| 148 |
-
self.encoder.make_dilated(
|
| 149 |
-
stage_list=[4, 5],
|
| 150 |
-
dilation_list=[2, 4]
|
| 151 |
-
)
|
| 152 |
|
| 153 |
elif encoder_output_stride == 16:
|
| 154 |
-
self.encoder.make_dilated(
|
| 155 |
-
stage_list=[5],
|
| 156 |
-
dilation_list=[2]
|
| 157 |
-
)
|
| 158 |
else:
|
| 159 |
raise ValueError(
|
| 160 |
-
"Encoder output stride should be 8 or 16, got {}".format(
|
|
|
|
|
|
|
| 161 |
)
|
| 162 |
|
| 163 |
self.decoder = DeepLabV3PlusDecoder(
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
import torch.nn as nn
|
| 4 |
|
| 5 |
+
from ..base import ClassificationHead, SegmentationHead, SegmentationModel
|
|
|
|
|
|
|
| 6 |
from ..encoders import get_encoder
|
| 7 |
+
from .decoder import DeepLabV3Decoder, DeepLabV3PlusDecoder
|
| 8 |
|
| 9 |
|
| 10 |
class DeepLabV3(SegmentationModel):
|
|
|
|
| 13 |
Args:
|
| 14 |
encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
|
| 15 |
to extract features of different spatial resolution
|
| 16 |
+
encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
|
| 17 |
two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
|
| 18 |
with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
|
| 19 |
Default is 5
|
| 20 |
+
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
|
| 21 |
other pretrained weights (see table with available weights for each encoder_name)
|
| 22 |
decoder_channels: A number of convolution filters in ASPP module. Default is 256
|
| 23 |
in_channels: A number of input channels for the model, default is 3 (RGB images)
|
|
|
|
| 26 |
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, **callable** and **None**.
|
| 27 |
Default is **None**
|
| 28 |
upsampling: Final upsampling factor. Default is 8 to preserve input-output spatial shape identity
|
| 29 |
+
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
|
| 30 |
on top of encoder if **aux_params** is not **None** (default). Supported params:
|
| 31 |
- classes (int): A number of classes
|
| 32 |
- pooling (str): One of "max", "avg". Default is "avg"
|
|
|
|
| 43 |
"""
|
| 44 |
|
| 45 |
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
encoder_name: str = "resnet34",
|
| 48 |
+
encoder_depth: int = 5,
|
| 49 |
+
encoder_weights: Optional[str] = "imagenet",
|
| 50 |
+
decoder_channels: int = 256,
|
| 51 |
+
in_channels: int = 3,
|
| 52 |
+
classes: int = 1,
|
| 53 |
+
activation: Optional[str] = None,
|
| 54 |
+
upsampling: int = 8,
|
| 55 |
+
aux_params: Optional[dict] = None,
|
| 56 |
):
|
| 57 |
super().__init__()
|
| 58 |
|
|
|
|
| 62 |
depth=encoder_depth,
|
| 63 |
weights=encoder_weights,
|
| 64 |
)
|
| 65 |
+
self.encoder.make_dilated(stage_list=[4, 5], dilation_list=[2, 4])
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
self.decoder = DeepLabV3Decoder(
|
| 68 |
in_channels=self.encoder.out_channels[-1],
|
|
|
|
| 88 |
class DeepLabV3Plus(SegmentationModel):
|
| 89 |
"""DeepLabV3+ implementation from "Encoder-Decoder with Atrous Separable
|
| 90 |
Convolution for Semantic Image Segmentation"
|
| 91 |
+
|
| 92 |
Args:
|
| 93 |
encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
|
| 94 |
to extract features of different spatial resolution
|
| 95 |
+
encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
|
| 96 |
two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
|
| 97 |
with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
|
| 98 |
Default is 5
|
| 99 |
+
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
|
| 100 |
other pretrained weights (see table with available weights for each encoder_name)
|
| 101 |
encoder_output_stride: Downsampling factor for last encoder features (see original paper for explanation)
|
| 102 |
decoder_atrous_rates: Dilation rates for ASPP module (should be a tuple of 3 integer values)
|
|
|
|
| 107 |
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, **callable** and **None**.
|
| 108 |
Default is **None**
|
| 109 |
upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity
|
| 110 |
+
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
|
| 111 |
on top of encoder if **aux_params** is not **None** (default). Supported params:
|
| 112 |
- classes (int): A number of classes
|
| 113 |
- pooling (str): One of "max", "avg". Default is "avg"
|
|
|
|
| 119 |
Reference:
|
| 120 |
https://arxiv.org/abs/1802.02611v3
|
| 121 |
"""
|
| 122 |
+
|
| 123 |
def __init__(
|
| 124 |
+
self,
|
| 125 |
+
encoder_name: str = "resnet34",
|
| 126 |
+
encoder_depth: int = 5,
|
| 127 |
+
encoder_weights: Optional[str] = "imagenet",
|
| 128 |
+
encoder_output_stride: int = 16,
|
| 129 |
+
decoder_channels: int = 256,
|
| 130 |
+
decoder_atrous_rates: tuple = (12, 24, 36),
|
| 131 |
+
in_channels: int = 3,
|
| 132 |
+
classes: int = 1,
|
| 133 |
+
activation: Optional[str] = None,
|
| 134 |
+
upsampling: int = 4,
|
| 135 |
+
aux_params: Optional[dict] = None,
|
| 136 |
):
|
| 137 |
super().__init__()
|
| 138 |
|
|
|
|
| 144 |
)
|
| 145 |
|
| 146 |
if encoder_output_stride == 8:
|
| 147 |
+
self.encoder.make_dilated(stage_list=[4, 5], dilation_list=[2, 4])
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
elif encoder_output_stride == 16:
|
| 150 |
+
self.encoder.make_dilated(stage_list=[5], dilation_list=[2])
|
|
|
|
|
|
|
|
|
|
| 151 |
else:
|
| 152 |
raise ValueError(
|
| 153 |
+
"Encoder output stride should be 8 or 16, got {}".format(
|
| 154 |
+
encoder_output_stride
|
| 155 |
+
)
|
| 156 |
)
|
| 157 |
|
| 158 |
self.decoder = DeepLabV3PlusDecoder(
|
segmentation_models_pytorch/segmentation_models_pytorch/efficientunetplusplus/decoder.py
CHANGED
|
@@ -1,76 +1,92 @@
|
|
| 1 |
import torch
|
| 2 |
-
from torch.functional import norm
|
| 3 |
import torch.nn as nn
|
| 4 |
import torch.nn.functional as F
|
|
|
|
| 5 |
|
| 6 |
from ..base import modules as md
|
| 7 |
|
|
|
|
| 8 |
class InvertedResidual(nn.Module):
|
| 9 |
"""
|
| 10 |
-
Inverted bottleneck residual block with an scSE block embedded into the residual layer, after the
|
| 11 |
depthwise convolution. By default, uses batch normalization and Hardswish activation.
|
| 12 |
"""
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
super().__init__()
|
| 16 |
self.same_shape = in_channels == out_channels
|
| 17 |
-
self.mid_channels = expansion_ratio*in_channels
|
| 18 |
self.block = nn.Sequential(
|
| 19 |
md.PointWiseConv2d(in_channels, self.mid_channels),
|
| 20 |
normalization(self.mid_channels),
|
| 21 |
activation,
|
| 22 |
-
md.DepthWiseConv2d(
|
|
|
|
|
|
|
| 23 |
normalization(self.mid_channels),
|
| 24 |
activation,
|
| 25 |
-
#md.sSEModule(self.mid_channels),
|
| 26 |
-
md.SCSEModule(self.mid_channels, reduction
|
| 27 |
-
#md.SEModule(self.mid_channels, reduction = squeeze_ratio),
|
| 28 |
md.PointWiseConv2d(self.mid_channels, out_channels),
|
| 29 |
-
normalization(out_channels)
|
| 30 |
)
|
| 31 |
-
|
| 32 |
if not self.same_shape:
|
| 33 |
-
# 1x1 convolution used to match the number of channels in the skip feature maps with that
|
| 34 |
# of the residual feature maps
|
| 35 |
self.skip_conv = nn.Sequential(
|
| 36 |
-
nn.Conv2d(
|
| 37 |
-
|
|
|
|
|
|
|
| 38 |
)
|
| 39 |
-
|
| 40 |
def forward(self, x):
|
| 41 |
residual = self.block(x)
|
| 42 |
-
|
| 43 |
if not self.same_shape:
|
| 44 |
x = self.skip_conv(x)
|
| 45 |
return x + residual
|
| 46 |
-
|
|
|
|
| 47 |
class DecoderBlock(nn.Module):
|
| 48 |
def __init__(
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
):
|
| 56 |
super().__init__()
|
| 57 |
|
| 58 |
# Inverted Residual block convolutions
|
| 59 |
self.conv1 = InvertedResidual(
|
| 60 |
-
in_channels=in_channels+skip_channels,
|
| 61 |
-
out_channels=out_channels,
|
| 62 |
-
kernel_size=3,
|
| 63 |
-
stride=1,
|
| 64 |
-
expansion_ratio=expansion_ratio,
|
| 65 |
-
squeeze_ratio=squeeze_ratio
|
| 66 |
)
|
| 67 |
self.conv2 = InvertedResidual(
|
| 68 |
-
in_channels=out_channels,
|
| 69 |
-
out_channels=out_channels,
|
| 70 |
-
kernel_size=3,
|
| 71 |
-
stride=1,
|
| 72 |
-
expansion_ratio=expansion_ratio,
|
| 73 |
-
squeeze_ratio=squeeze_ratio
|
| 74 |
)
|
| 75 |
|
| 76 |
def forward(self, x, skip=None):
|
|
@@ -82,14 +98,15 @@ class DecoderBlock(nn.Module):
|
|
| 82 |
x = self.conv2(x)
|
| 83 |
return x
|
| 84 |
|
|
|
|
| 85 |
class EfficientUnetPlusPlusDecoder(nn.Module):
|
| 86 |
def __init__(
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
):
|
| 94 |
super().__init__()
|
| 95 |
if n_blocks != len(decoder_channels):
|
|
@@ -99,8 +116,12 @@ class EfficientUnetPlusPlusDecoder(nn.Module):
|
|
| 99 |
)
|
| 100 |
)
|
| 101 |
|
| 102 |
-
encoder_channels = encoder_channels[
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
# computing blocks input and output channels
|
| 105 |
head_channels = encoder_channels[0]
|
| 106 |
self.in_channels = [head_channels] + list(decoder_channels[:-1])
|
|
@@ -112,37 +133,51 @@ class EfficientUnetPlusPlusDecoder(nn.Module):
|
|
| 112 |
|
| 113 |
blocks = {}
|
| 114 |
for layer_idx in range(len(self.in_channels) - 1):
|
| 115 |
-
for depth_idx in range(layer_idx+1):
|
| 116 |
if depth_idx == 0:
|
| 117 |
in_ch = self.in_channels[layer_idx]
|
| 118 |
-
skip_ch = self.skip_channels[layer_idx] * (layer_idx+1)
|
| 119 |
out_ch = self.out_channels[layer_idx]
|
| 120 |
else:
|
| 121 |
out_ch = self.skip_channels[layer_idx]
|
| 122 |
-
skip_ch = self.skip_channels[layer_idx] * (
|
|
|
|
|
|
|
| 123 |
in_ch = self.skip_channels[layer_idx - 1]
|
| 124 |
-
blocks[f
|
| 125 |
-
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
| 127 |
self.blocks = nn.ModuleDict(blocks)
|
| 128 |
self.depth = len(self.in_channels) - 1
|
| 129 |
|
| 130 |
def forward(self, *features):
|
| 131 |
-
|
| 132 |
-
features = features[1:] # remove first skip with same spatial resolution
|
| 133 |
features = features[::-1] # reverse channels to start from head of encoder
|
| 134 |
# start building dense connections
|
| 135 |
dense_x = {}
|
| 136 |
-
for layer_idx in range(len(self.in_channels)-1):
|
| 137 |
-
for depth_idx in range(self.depth-layer_idx):
|
| 138 |
if layer_idx == 0:
|
| 139 |
-
output = self.blocks[f
|
| 140 |
-
|
|
|
|
|
|
|
| 141 |
else:
|
| 142 |
dense_l_i = depth_idx + layer_idx
|
| 143 |
-
cat_features = [
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
|
|
|
| 2 |
import torch.nn as nn
|
| 3 |
import torch.nn.functional as F
|
| 4 |
+
from torch.functional import norm
|
| 5 |
|
| 6 |
from ..base import modules as md
|
| 7 |
|
| 8 |
+
|
| 9 |
class InvertedResidual(nn.Module):
|
| 10 |
"""
|
| 11 |
+
Inverted bottleneck residual block with an scSE block embedded into the residual layer, after the
|
| 12 |
depthwise convolution. By default, uses batch normalization and Hardswish activation.
|
| 13 |
"""
|
| 14 |
+
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
in_channels,
|
| 18 |
+
out_channels,
|
| 19 |
+
kernel_size=3,
|
| 20 |
+
stride=1,
|
| 21 |
+
expansion_ratio=1,
|
| 22 |
+
squeeze_ratio=1,
|
| 23 |
+
activation=nn.Hardswish(True),
|
| 24 |
+
normalization=nn.BatchNorm2d,
|
| 25 |
+
):
|
| 26 |
super().__init__()
|
| 27 |
self.same_shape = in_channels == out_channels
|
| 28 |
+
self.mid_channels = expansion_ratio * in_channels
|
| 29 |
self.block = nn.Sequential(
|
| 30 |
md.PointWiseConv2d(in_channels, self.mid_channels),
|
| 31 |
normalization(self.mid_channels),
|
| 32 |
activation,
|
| 33 |
+
md.DepthWiseConv2d(
|
| 34 |
+
self.mid_channels, kernel_size=kernel_size, stride=stride
|
| 35 |
+
),
|
| 36 |
normalization(self.mid_channels),
|
| 37 |
activation,
|
| 38 |
+
# md.sSEModule(self.mid_channels),
|
| 39 |
+
md.SCSEModule(self.mid_channels, reduction=squeeze_ratio),
|
| 40 |
+
# md.SEModule(self.mid_channels, reduction = squeeze_ratio),
|
| 41 |
md.PointWiseConv2d(self.mid_channels, out_channels),
|
| 42 |
+
normalization(out_channels),
|
| 43 |
)
|
| 44 |
+
|
| 45 |
if not self.same_shape:
|
| 46 |
+
# 1x1 convolution used to match the number of channels in the skip feature maps with that
|
| 47 |
# of the residual feature maps
|
| 48 |
self.skip_conv = nn.Sequential(
|
| 49 |
+
nn.Conv2d(
|
| 50 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=1
|
| 51 |
+
),
|
| 52 |
+
normalization(out_channels),
|
| 53 |
)
|
| 54 |
+
|
| 55 |
def forward(self, x):
|
| 56 |
residual = self.block(x)
|
| 57 |
+
|
| 58 |
if not self.same_shape:
|
| 59 |
x = self.skip_conv(x)
|
| 60 |
return x + residual
|
| 61 |
+
|
| 62 |
+
|
| 63 |
class DecoderBlock(nn.Module):
|
| 64 |
def __init__(
|
| 65 |
+
self,
|
| 66 |
+
in_channels,
|
| 67 |
+
skip_channels,
|
| 68 |
+
out_channels,
|
| 69 |
+
squeeze_ratio=1,
|
| 70 |
+
expansion_ratio=1,
|
| 71 |
):
|
| 72 |
super().__init__()
|
| 73 |
|
| 74 |
# Inverted Residual block convolutions
|
| 75 |
self.conv1 = InvertedResidual(
|
| 76 |
+
in_channels=in_channels + skip_channels,
|
| 77 |
+
out_channels=out_channels,
|
| 78 |
+
kernel_size=3,
|
| 79 |
+
stride=1,
|
| 80 |
+
expansion_ratio=expansion_ratio,
|
| 81 |
+
squeeze_ratio=squeeze_ratio,
|
| 82 |
)
|
| 83 |
self.conv2 = InvertedResidual(
|
| 84 |
+
in_channels=out_channels,
|
| 85 |
+
out_channels=out_channels,
|
| 86 |
+
kernel_size=3,
|
| 87 |
+
stride=1,
|
| 88 |
+
expansion_ratio=expansion_ratio,
|
| 89 |
+
squeeze_ratio=squeeze_ratio,
|
| 90 |
)
|
| 91 |
|
| 92 |
def forward(self, x, skip=None):
|
|
|
|
| 98 |
x = self.conv2(x)
|
| 99 |
return x
|
| 100 |
|
| 101 |
+
|
| 102 |
class EfficientUnetPlusPlusDecoder(nn.Module):
|
| 103 |
def __init__(
|
| 104 |
+
self,
|
| 105 |
+
encoder_channels,
|
| 106 |
+
decoder_channels,
|
| 107 |
+
n_blocks=5,
|
| 108 |
+
squeeze_ratio=1,
|
| 109 |
+
expansion_ratio=1,
|
| 110 |
):
|
| 111 |
super().__init__()
|
| 112 |
if n_blocks != len(decoder_channels):
|
|
|
|
| 116 |
)
|
| 117 |
)
|
| 118 |
|
| 119 |
+
encoder_channels = encoder_channels[
|
| 120 |
+
1:
|
| 121 |
+
] # remove first skip with same spatial resolution
|
| 122 |
+
encoder_channels = encoder_channels[
|
| 123 |
+
::-1
|
| 124 |
+
] # reverse channels to start from head of encoder
|
| 125 |
# computing blocks input and output channels
|
| 126 |
head_channels = encoder_channels[0]
|
| 127 |
self.in_channels = [head_channels] + list(decoder_channels[:-1])
|
|
|
|
| 133 |
|
| 134 |
blocks = {}
|
| 135 |
for layer_idx in range(len(self.in_channels) - 1):
|
| 136 |
+
for depth_idx in range(layer_idx + 1):
|
| 137 |
if depth_idx == 0:
|
| 138 |
in_ch = self.in_channels[layer_idx]
|
| 139 |
+
skip_ch = self.skip_channels[layer_idx] * (layer_idx + 1)
|
| 140 |
out_ch = self.out_channels[layer_idx]
|
| 141 |
else:
|
| 142 |
out_ch = self.skip_channels[layer_idx]
|
| 143 |
+
skip_ch = self.skip_channels[layer_idx] * (
|
| 144 |
+
layer_idx + 1 - depth_idx
|
| 145 |
+
)
|
| 146 |
in_ch = self.skip_channels[layer_idx - 1]
|
| 147 |
+
blocks[f"x_{depth_idx}_{layer_idx}"] = DecoderBlock(
|
| 148 |
+
in_ch, skip_ch, out_ch, **kwargs
|
| 149 |
+
)
|
| 150 |
+
blocks[f"x_{0}_{len(self.in_channels)-1}"] = DecoderBlock(
|
| 151 |
+
self.in_channels[-1], 0, self.out_channels[-1], **kwargs
|
| 152 |
+
)
|
| 153 |
self.blocks = nn.ModuleDict(blocks)
|
| 154 |
self.depth = len(self.in_channels) - 1
|
| 155 |
|
| 156 |
def forward(self, *features):
|
| 157 |
+
features = features[1:] # remove first skip with same spatial resolution
|
|
|
|
| 158 |
features = features[::-1] # reverse channels to start from head of encoder
|
| 159 |
# start building dense connections
|
| 160 |
dense_x = {}
|
| 161 |
+
for layer_idx in range(len(self.in_channels) - 1):
|
| 162 |
+
for depth_idx in range(self.depth - layer_idx):
|
| 163 |
if layer_idx == 0:
|
| 164 |
+
output = self.blocks[f"x_{depth_idx}_{depth_idx}"](
|
| 165 |
+
features[depth_idx], features[depth_idx + 1]
|
| 166 |
+
)
|
| 167 |
+
dense_x[f"x_{depth_idx}_{depth_idx}"] = output
|
| 168 |
else:
|
| 169 |
dense_l_i = depth_idx + layer_idx
|
| 170 |
+
cat_features = [
|
| 171 |
+
dense_x[f"x_{idx}_{dense_l_i}"]
|
| 172 |
+
for idx in range(depth_idx + 1, dense_l_i + 1)
|
| 173 |
+
]
|
| 174 |
+
cat_features = torch.cat(
|
| 175 |
+
cat_features + [features[dense_l_i + 1]], dim=1
|
| 176 |
+
)
|
| 177 |
+
dense_x[f"x_{depth_idx}_{dense_l_i}"] = self.blocks[
|
| 178 |
+
f"x_{depth_idx}_{dense_l_i}"
|
| 179 |
+
](dense_x[f"x_{depth_idx}_{dense_l_i-1}"], cat_features)
|
| 180 |
+
dense_x[f"x_{0}_{self.depth}"] = self.blocks[f"x_{0}_{self.depth}"](
|
| 181 |
+
dense_x[f"x_{0}_{self.depth-1}"]
|
| 182 |
+
)
|
| 183 |
+
return dense_x[f"x_{0}_{self.depth}"]
|
segmentation_models_pytorch/segmentation_models_pytorch/efficientunetplusplus/model.py
CHANGED
|
@@ -1,28 +1,30 @@
|
|
| 1 |
-
from typing import Optional, Union
|
| 2 |
-
|
| 3 |
-
from ..encoders import get_encoder
|
| 4 |
-
from ..base import SegmentationModel
|
| 5 |
-
from ..base import SegmentationHead, ClassificationHead
|
| 6 |
from torchvision import transforms
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
class EfficientUnetPlusPlus(SegmentationModel):
|
| 9 |
-
"""The EfficientUNet++ is a fully convolutional neural network for ordinary and medical image semantic segmentation.
|
| 10 |
-
Consists of an *encoder* and a *decoder*, connected by *skip connections*. The encoder extracts features of
|
| 11 |
-
different spatial resolutions, which are fed to the decoder through skip connections. The decoder combines its
|
| 12 |
-
own feature maps with the ones from skip connections to produce accurate segmentations masks. The EfficientUNet++
|
| 13 |
-
decoder architecture is based on the UNet++, a model composed of nested U-Net-like decoder sub-networks. To
|
| 14 |
-
increase performance and computational efficiency, the EfficientUNet++ replaces the UNet++'s blocks with
|
| 15 |
inverted residual blocks with depthwise convolutions and embedded spatial and channel attention mechanisms.
|
| 16 |
Synergizes well with EfficientNet encoders. Due to their efficient visual representations (i.e., using few channels
|
| 17 |
to represent extracted features), EfficientNet encoders require few computation from the decoder.
|
| 18 |
|
| 19 |
Args:
|
| 20 |
encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) to extract features
|
| 21 |
-
encoder_depth: Number of stages of the encoder, in range [3 ,5]. Each stage generate features two times smaller,
|
| 22 |
-
in spatial dimensions, than the previous one (e.g., for depth=0 features will haves shapes [(N, C, H, W)]),
|
| 23 |
for depth 1 features will have shapes [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
|
| 24 |
Default is 5
|
| 25 |
-
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
|
| 26 |
other pretrained weights (see table with available weights for each encoder_name)
|
| 27 |
decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in the decoder.
|
| 28 |
Length of the list should be the same as **encoder_depth**
|
|
@@ -31,7 +33,7 @@ class EfficientUnetPlusPlus(SegmentationModel):
|
|
| 31 |
activation: An activation function to apply after the final convolution layer.
|
| 32 |
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, **callable** and **None**.
|
| 33 |
Default is **None**
|
| 34 |
-
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is built
|
| 35 |
on top of encoder if **aux_params** is not **None** (default). Supported params:
|
| 36 |
- classes (int): A number of classes
|
| 37 |
- pooling (str): One of "max", "avg". Default is "avg"
|
|
@@ -71,7 +73,7 @@ class EfficientUnetPlusPlus(SegmentationModel):
|
|
| 71 |
decoder_channels=decoder_channels,
|
| 72 |
n_blocks=encoder_depth,
|
| 73 |
squeeze_ratio=squeeze_ratio,
|
| 74 |
-
expansion_ratio=expansion_ratio
|
| 75 |
)
|
| 76 |
|
| 77 |
self.segmentation_head = SegmentationHead(
|
|
@@ -117,9 +119,9 @@ class EfficientUnetPlusPlus(SegmentationModel):
|
|
| 117 |
[
|
| 118 |
transforms.ToPILImage(),
|
| 119 |
transforms.Resize(x.size[1]),
|
| 120 |
-
transforms.ToTensor()
|
| 121 |
]
|
| 122 |
)
|
| 123 |
-
full_mask = tf(probs.cpu())
|
| 124 |
|
| 125 |
-
return full_mask
|
|
|
|
| 1 |
+
from typing import List, Optional, Union
|
| 2 |
+
|
|
|
|
|
|
|
|
|
|
| 3 |
from torchvision import transforms
|
| 4 |
|
| 5 |
+
from ..base import ClassificationHead, SegmentationHead, SegmentationModel
|
| 6 |
+
from ..encoders import get_encoder
|
| 7 |
+
from .decoder import EfficientUnetPlusPlusDecoder
|
| 8 |
+
|
| 9 |
+
|
| 10 |
class EfficientUnetPlusPlus(SegmentationModel):
|
| 11 |
+
"""The EfficientUNet++ is a fully convolutional neural network for ordinary and medical image semantic segmentation.
|
| 12 |
+
Consists of an *encoder* and a *decoder*, connected by *skip connections*. The encoder extracts features of
|
| 13 |
+
different spatial resolutions, which are fed to the decoder through skip connections. The decoder combines its
|
| 14 |
+
own feature maps with the ones from skip connections to produce accurate segmentations masks. The EfficientUNet++
|
| 15 |
+
decoder architecture is based on the UNet++, a model composed of nested U-Net-like decoder sub-networks. To
|
| 16 |
+
increase performance and computational efficiency, the EfficientUNet++ replaces the UNet++'s blocks with
|
| 17 |
inverted residual blocks with depthwise convolutions and embedded spatial and channel attention mechanisms.
|
| 18 |
Synergizes well with EfficientNet encoders. Due to their efficient visual representations (i.e., using few channels
|
| 19 |
to represent extracted features), EfficientNet encoders require few computation from the decoder.
|
| 20 |
|
| 21 |
Args:
|
| 22 |
encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) to extract features
|
| 23 |
+
encoder_depth: Number of stages of the encoder, in range [3 ,5]. Each stage generate features two times smaller,
|
| 24 |
+
in spatial dimensions, than the previous one (e.g., for depth=0 features will haves shapes [(N, C, H, W)]),
|
| 25 |
for depth 1 features will have shapes [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
|
| 26 |
Default is 5
|
| 27 |
+
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
|
| 28 |
other pretrained weights (see table with available weights for each encoder_name)
|
| 29 |
decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in the decoder.
|
| 30 |
Length of the list should be the same as **encoder_depth**
|
|
|
|
| 33 |
activation: An activation function to apply after the final convolution layer.
|
| 34 |
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, **callable** and **None**.
|
| 35 |
Default is **None**
|
| 36 |
+
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is built
|
| 37 |
on top of encoder if **aux_params** is not **None** (default). Supported params:
|
| 38 |
- classes (int): A number of classes
|
| 39 |
- pooling (str): One of "max", "avg". Default is "avg"
|
|
|
|
| 73 |
decoder_channels=decoder_channels,
|
| 74 |
n_blocks=encoder_depth,
|
| 75 |
squeeze_ratio=squeeze_ratio,
|
| 76 |
+
expansion_ratio=expansion_ratio,
|
| 77 |
)
|
| 78 |
|
| 79 |
self.segmentation_head = SegmentationHead(
|
|
|
|
| 119 |
[
|
| 120 |
transforms.ToPILImage(),
|
| 121 |
transforms.Resize(x.size[1]),
|
| 122 |
+
transforms.ToTensor(),
|
| 123 |
]
|
| 124 |
)
|
| 125 |
+
full_mask = tf(probs.cpu())
|
| 126 |
|
| 127 |
+
return full_mask
|
segmentation_models_pytorch/segmentation_models_pytorch/encoders/__init__.py
CHANGED
|
@@ -1,22 +1,24 @@
|
|
| 1 |
import functools
|
|
|
|
| 2 |
import torch.utils.model_zoo as model_zoo
|
| 3 |
|
| 4 |
-
from .
|
| 5 |
-
from .dpn import dpn_encoders
|
| 6 |
-
from .vgg import vgg_encoders
|
| 7 |
-
from .senet import senet_encoders
|
| 8 |
from .densenet import densenet_encoders
|
|
|
|
|
|
|
| 9 |
from .inceptionresnetv2 import inceptionresnetv2_encoders
|
| 10 |
from .inceptionv4 import inceptionv4_encoders
|
| 11 |
-
from .efficientnet import efficient_net_encoders
|
| 12 |
from .mobilenet import mobilenet_encoders
|
| 13 |
-
from .
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
# from .timm_efficientnet import timm_efficientnet_encoders
|
| 15 |
from .timm_resnest import timm_resnest_encoders
|
| 16 |
-
from .timm_res2net import timm_res2net_encoders
|
| 17 |
-
from .timm_regnet import timm_regnet_encoders
|
| 18 |
from .timm_sknet import timm_sknet_encoders
|
| 19 |
-
from .
|
|
|
|
| 20 |
|
| 21 |
encoders = {}
|
| 22 |
encoders.update(resnet_encoders)
|
|
@@ -37,11 +39,14 @@ encoders.update(timm_sknet_encoders)
|
|
| 37 |
|
| 38 |
|
| 39 |
def get_encoder(name, in_channels=3, depth=5, weights=None):
|
| 40 |
-
|
| 41 |
try:
|
| 42 |
Encoder = encoders[name]["encoder"]
|
| 43 |
except KeyError:
|
| 44 |
-
raise KeyError(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
params = encoders[name]["params"]
|
| 47 |
params.update(depth=depth)
|
|
@@ -51,9 +56,13 @@ def get_encoder(name, in_channels=3, depth=5, weights=None):
|
|
| 51 |
try:
|
| 52 |
settings = encoders[name]["pretrained_settings"][weights]
|
| 53 |
except KeyError:
|
| 54 |
-
raise KeyError(
|
| 55 |
-
weights
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
encoder.load_state_dict(model_zoo.load_url(settings["url"]))
|
| 58 |
|
| 59 |
encoder.set_in_channels(in_channels)
|
|
|
|
| 1 |
import functools
|
| 2 |
+
|
| 3 |
import torch.utils.model_zoo as model_zoo
|
| 4 |
|
| 5 |
+
from ._preprocessing import preprocess_input
|
|
|
|
|
|
|
|
|
|
| 6 |
from .densenet import densenet_encoders
|
| 7 |
+
from .dpn import dpn_encoders
|
| 8 |
+
from .efficientnet import efficient_net_encoders
|
| 9 |
from .inceptionresnetv2 import inceptionresnetv2_encoders
|
| 10 |
from .inceptionv4 import inceptionv4_encoders
|
|
|
|
| 11 |
from .mobilenet import mobilenet_encoders
|
| 12 |
+
from .resnet import resnet_encoders
|
| 13 |
+
from .senet import senet_encoders
|
| 14 |
+
from .timm_regnet import timm_regnet_encoders
|
| 15 |
+
from .timm_res2net import timm_res2net_encoders
|
| 16 |
+
|
| 17 |
# from .timm_efficientnet import timm_efficientnet_encoders
|
| 18 |
from .timm_resnest import timm_resnest_encoders
|
|
|
|
|
|
|
| 19 |
from .timm_sknet import timm_sknet_encoders
|
| 20 |
+
from .vgg import vgg_encoders
|
| 21 |
+
from .xception import xception_encoders
|
| 22 |
|
| 23 |
encoders = {}
|
| 24 |
encoders.update(resnet_encoders)
|
|
|
|
| 39 |
|
| 40 |
|
| 41 |
def get_encoder(name, in_channels=3, depth=5, weights=None):
|
|
|
|
| 42 |
try:
|
| 43 |
Encoder = encoders[name]["encoder"]
|
| 44 |
except KeyError:
|
| 45 |
+
raise KeyError(
|
| 46 |
+
"Wrong encoder name `{}`, supported encoders: {}".format(
|
| 47 |
+
name, list(encoders.keys())
|
| 48 |
+
)
|
| 49 |
+
)
|
| 50 |
|
| 51 |
params = encoders[name]["params"]
|
| 52 |
params.update(depth=depth)
|
|
|
|
| 56 |
try:
|
| 57 |
settings = encoders[name]["pretrained_settings"][weights]
|
| 58 |
except KeyError:
|
| 59 |
+
raise KeyError(
|
| 60 |
+
"Wrong pretrained weights `{}` for encoder `{}`. Available options are: {}".format(
|
| 61 |
+
weights,
|
| 62 |
+
name,
|
| 63 |
+
list(encoders[name]["pretrained_settings"].keys()),
|
| 64 |
+
)
|
| 65 |
+
)
|
| 66 |
encoder.load_state_dict(model_zoo.load_url(settings["url"]))
|
| 67 |
|
| 68 |
encoder.set_in_channels(in_channels)
|
segmentation_models_pytorch/segmentation_models_pytorch/encoders/_base.py
CHANGED
|
@@ -1,15 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
-
from typing import List
|
| 4 |
-
from collections import OrderedDict
|
| 5 |
|
| 6 |
from . import _utils as utils
|
| 7 |
|
| 8 |
|
| 9 |
class EncoderMixin:
|
| 10 |
"""Add encoder functionality such as:
|
| 11 |
-
|
| 12 |
-
|
| 13 |
"""
|
| 14 |
|
| 15 |
@property
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
import torch
|
| 5 |
import torch.nn as nn
|
|
|
|
|
|
|
| 6 |
|
| 7 |
from . import _utils as utils
|
| 8 |
|
| 9 |
|
| 10 |
class EncoderMixin:
|
| 11 |
"""Add encoder functionality such as:
|
| 12 |
+
- output channels specification of feature tensors (produced by encoder)
|
| 13 |
+
- patching first convolution for arbitrary input channels
|
| 14 |
"""
|
| 15 |
|
| 16 |
@property
|
segmentation_models_pytorch/segmentation_models_pytorch/encoders/_preprocessing.py
CHANGED
|
@@ -4,7 +4,6 @@ import numpy as np
|
|
| 4 |
def preprocess_input(
|
| 5 |
x, mean=None, std=None, input_space="RGB", input_range=None, **kwargs
|
| 6 |
):
|
| 7 |
-
|
| 8 |
if input_space == "BGR":
|
| 9 |
x = x[..., ::-1].copy()
|
| 10 |
|
|
|
|
| 4 |
def preprocess_input(
|
| 5 |
x, mean=None, std=None, input_space="RGB", input_range=None, **kwargs
|
| 6 |
):
|
|
|
|
| 7 |
if input_space == "BGR":
|
| 8 |
x = x[..., ::-1].copy()
|
| 9 |
|
segmentation_models_pytorch/segmentation_models_pytorch/encoders/densenet.py
CHANGED
|
@@ -24,8 +24,8 @@ Methods:
|
|
| 24 |
"""
|
| 25 |
|
| 26 |
import re
|
| 27 |
-
import torch.nn as nn
|
| 28 |
|
|
|
|
| 29 |
from pretrainedmodels.models.torchvision_models import pretrained_settings
|
| 30 |
from torchvision.models.densenet import DenseNet
|
| 31 |
|
|
@@ -33,7 +33,6 @@ from ._base import EncoderMixin
|
|
| 33 |
|
| 34 |
|
| 35 |
class TransitionWithSkip(nn.Module):
|
| 36 |
-
|
| 37 |
def __init__(self, module):
|
| 38 |
super().__init__()
|
| 39 |
self.module = module
|
|
@@ -55,22 +54,32 @@ class DenseNetEncoder(DenseNet, EncoderMixin):
|
|
| 55 |
del self.classifier
|
| 56 |
|
| 57 |
def make_dilated(self, stage_list, dilation_list):
|
| 58 |
-
raise ValueError(
|
| 59 |
-
|
|
|
|
|
|
|
| 60 |
|
| 61 |
def get_stages(self):
|
| 62 |
return [
|
| 63 |
nn.Identity(),
|
| 64 |
-
nn.Sequential(
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
nn.Sequential(
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
]
|
| 71 |
|
| 72 |
def forward(self, x):
|
| 73 |
-
|
| 74 |
stages = self.get_stages()
|
| 75 |
|
| 76 |
features = []
|
|
|
|
| 24 |
"""
|
| 25 |
|
| 26 |
import re
|
|
|
|
| 27 |
|
| 28 |
+
import torch.nn as nn
|
| 29 |
from pretrainedmodels.models.torchvision_models import pretrained_settings
|
| 30 |
from torchvision.models.densenet import DenseNet
|
| 31 |
|
|
|
|
| 33 |
|
| 34 |
|
| 35 |
class TransitionWithSkip(nn.Module):
|
|
|
|
| 36 |
def __init__(self, module):
|
| 37 |
super().__init__()
|
| 38 |
self.module = module
|
|
|
|
| 54 |
del self.classifier
|
| 55 |
|
| 56 |
def make_dilated(self, stage_list, dilation_list):
|
| 57 |
+
raise ValueError(
|
| 58 |
+
"DenseNet encoders do not support dilated mode "
|
| 59 |
+
"due to pooling operation for downsampling!"
|
| 60 |
+
)
|
| 61 |
|
| 62 |
def get_stages(self):
|
| 63 |
return [
|
| 64 |
nn.Identity(),
|
| 65 |
+
nn.Sequential(
|
| 66 |
+
self.features.conv0, self.features.norm0, self.features.relu0
|
| 67 |
+
),
|
| 68 |
+
nn.Sequential(
|
| 69 |
+
self.features.pool0,
|
| 70 |
+
self.features.denseblock1,
|
| 71 |
+
TransitionWithSkip(self.features.transition1),
|
| 72 |
+
),
|
| 73 |
+
nn.Sequential(
|
| 74 |
+
self.features.denseblock2, TransitionWithSkip(self.features.transition2)
|
| 75 |
+
),
|
| 76 |
+
nn.Sequential(
|
| 77 |
+
self.features.denseblock3, TransitionWithSkip(self.features.transition3)
|
| 78 |
+
),
|
| 79 |
+
nn.Sequential(self.features.denseblock4, self.features.norm5),
|
| 80 |
]
|
| 81 |
|
| 82 |
def forward(self, x):
|
|
|
|
| 83 |
stages = self.get_stages()
|
| 84 |
|
| 85 |
features = []
|
segmentation_models_pytorch/segmentation_models_pytorch/encoders/dpn.py
CHANGED
|
@@ -26,9 +26,7 @@ Methods:
|
|
| 26 |
import torch
|
| 27 |
import torch.nn as nn
|
| 28 |
import torch.nn.functional as F
|
| 29 |
-
|
| 30 |
-
from pretrainedmodels.models.dpn import DPN
|
| 31 |
-
from pretrainedmodels.models.dpn import pretrained_settings
|
| 32 |
|
| 33 |
from ._base import EncoderMixin
|
| 34 |
|
|
@@ -46,15 +44,18 @@ class DPNEncoder(DPN, EncoderMixin):
|
|
| 46 |
def get_stages(self):
|
| 47 |
return [
|
| 48 |
nn.Identity(),
|
| 49 |
-
nn.Sequential(
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
self.features[self._stage_idxs[0] : self._stage_idxs[1]],
|
| 52 |
self.features[self._stage_idxs[1] : self._stage_idxs[2]],
|
| 53 |
self.features[self._stage_idxs[2] : self._stage_idxs[3]],
|
| 54 |
]
|
| 55 |
|
| 56 |
def forward(self, x):
|
| 57 |
-
|
| 58 |
stages = self.get_stages()
|
| 59 |
|
| 60 |
features = []
|
|
|
|
| 26 |
import torch
|
| 27 |
import torch.nn as nn
|
| 28 |
import torch.nn.functional as F
|
| 29 |
+
from pretrainedmodels.models.dpn import DPN, pretrained_settings
|
|
|
|
|
|
|
| 30 |
|
| 31 |
from ._base import EncoderMixin
|
| 32 |
|
|
|
|
| 44 |
def get_stages(self):
|
| 45 |
return [
|
| 46 |
nn.Identity(),
|
| 47 |
+
nn.Sequential(
|
| 48 |
+
self.features[0].conv, self.features[0].bn, self.features[0].act
|
| 49 |
+
),
|
| 50 |
+
nn.Sequential(
|
| 51 |
+
self.features[0].pool, self.features[1 : self._stage_idxs[0]]
|
| 52 |
+
),
|
| 53 |
self.features[self._stage_idxs[0] : self._stage_idxs[1]],
|
| 54 |
self.features[self._stage_idxs[1] : self._stage_idxs[2]],
|
| 55 |
self.features[self._stage_idxs[2] : self._stage_idxs[3]],
|
| 56 |
]
|
| 57 |
|
| 58 |
def forward(self, x):
|
|
|
|
| 59 |
stages = self.get_stages()
|
| 60 |
|
| 61 |
features = []
|
segmentation_models_pytorch/segmentation_models_pytorch/encoders/efficientnet.py
CHANGED
|
@@ -24,14 +24,13 @@ Methods:
|
|
| 24 |
"""
|
| 25 |
import torch.nn as nn
|
| 26 |
from efficientnet_pytorch import EfficientNet
|
| 27 |
-
from efficientnet_pytorch.utils import url_map, url_map_advprop
|
| 28 |
|
| 29 |
from ._base import EncoderMixin
|
| 30 |
|
| 31 |
|
| 32 |
class EfficientNetEncoder(EfficientNet, EncoderMixin):
|
| 33 |
def __init__(self, stage_idxs, out_channels, model_name, depth=5):
|
| 34 |
-
|
| 35 |
blocks_args, global_params = get_model_params(model_name, override_params=None)
|
| 36 |
super().__init__(blocks_args, global_params)
|
| 37 |
|
|
@@ -46,21 +45,20 @@ class EfficientNetEncoder(EfficientNet, EncoderMixin):
|
|
| 46 |
return [
|
| 47 |
nn.Identity(),
|
| 48 |
nn.Sequential(self._conv_stem, self._bn0, self._swish),
|
| 49 |
-
self._blocks[:self._stage_idxs[0]],
|
| 50 |
-
self._blocks[self._stage_idxs[0]:self._stage_idxs[1]],
|
| 51 |
-
self._blocks[self._stage_idxs[1]:self._stage_idxs[2]],
|
| 52 |
-
self._blocks[self._stage_idxs[2]:],
|
| 53 |
]
|
| 54 |
|
| 55 |
def forward(self, x):
|
| 56 |
stages = self.get_stages()
|
| 57 |
|
| 58 |
-
block_number = 0.
|
| 59 |
drop_connect_rate = self._global_params.drop_connect_rate
|
| 60 |
|
| 61 |
features = []
|
| 62 |
for i in range(self._depth + 1):
|
| 63 |
-
|
| 64 |
# Identity and Sequential stages
|
| 65 |
if i < 2:
|
| 66 |
x = stages[i](x)
|
|
@@ -69,7 +67,7 @@ class EfficientNetEncoder(EfficientNet, EncoderMixin):
|
|
| 69 |
else:
|
| 70 |
for module in stages[i]:
|
| 71 |
drop_connect = drop_connect_rate * block_number / len(self._blocks)
|
| 72 |
-
block_number += 1.
|
| 73 |
x = module(x, drop_connect)
|
| 74 |
|
| 75 |
features.append(x)
|
|
@@ -97,7 +95,7 @@ def _get_pretrained_settings(encoder):
|
|
| 97 |
"url": url_map_advprop[encoder],
|
| 98 |
"input_space": "RGB",
|
| 99 |
"input_range": [0, 1],
|
| 100 |
-
}
|
| 101 |
}
|
| 102 |
return pretrained_settings
|
| 103 |
|
|
|
|
| 24 |
"""
|
| 25 |
import torch.nn as nn
|
| 26 |
from efficientnet_pytorch import EfficientNet
|
| 27 |
+
from efficientnet_pytorch.utils import get_model_params, url_map, url_map_advprop
|
| 28 |
|
| 29 |
from ._base import EncoderMixin
|
| 30 |
|
| 31 |
|
| 32 |
class EfficientNetEncoder(EfficientNet, EncoderMixin):
|
| 33 |
def __init__(self, stage_idxs, out_channels, model_name, depth=5):
|
|
|
|
| 34 |
blocks_args, global_params = get_model_params(model_name, override_params=None)
|
| 35 |
super().__init__(blocks_args, global_params)
|
| 36 |
|
|
|
|
| 45 |
return [
|
| 46 |
nn.Identity(),
|
| 47 |
nn.Sequential(self._conv_stem, self._bn0, self._swish),
|
| 48 |
+
self._blocks[: self._stage_idxs[0]],
|
| 49 |
+
self._blocks[self._stage_idxs[0] : self._stage_idxs[1]],
|
| 50 |
+
self._blocks[self._stage_idxs[1] : self._stage_idxs[2]],
|
| 51 |
+
self._blocks[self._stage_idxs[2] :],
|
| 52 |
]
|
| 53 |
|
| 54 |
def forward(self, x):
|
| 55 |
stages = self.get_stages()
|
| 56 |
|
| 57 |
+
block_number = 0.0
|
| 58 |
drop_connect_rate = self._global_params.drop_connect_rate
|
| 59 |
|
| 60 |
features = []
|
| 61 |
for i in range(self._depth + 1):
|
|
|
|
| 62 |
# Identity and Sequential stages
|
| 63 |
if i < 2:
|
| 64 |
x = stages[i](x)
|
|
|
|
| 67 |
else:
|
| 68 |
for module in stages[i]:
|
| 69 |
drop_connect = drop_connect_rate * block_number / len(self._blocks)
|
| 70 |
+
block_number += 1.0
|
| 71 |
x = module(x, drop_connect)
|
| 72 |
|
| 73 |
features.append(x)
|
|
|
|
| 95 |
"url": url_map_advprop[encoder],
|
| 96 |
"input_space": "RGB",
|
| 97 |
"input_range": [0, 1],
|
| 98 |
+
},
|
| 99 |
}
|
| 100 |
return pretrained_settings
|
| 101 |
|
segmentation_models_pytorch/segmentation_models_pytorch/encoders/inceptionresnetv2.py
CHANGED
|
@@ -24,8 +24,10 @@ Methods:
|
|
| 24 |
"""
|
| 25 |
|
| 26 |
import torch.nn as nn
|
| 27 |
-
from pretrainedmodels.models.inceptionresnetv2 import
|
| 28 |
-
|
|
|
|
|
|
|
| 29 |
|
| 30 |
from ._base import EncoderMixin
|
| 31 |
|
|
@@ -51,8 +53,10 @@ class InceptionResNetV2Encoder(InceptionResNetV2, EncoderMixin):
|
|
| 51 |
del self.last_linear
|
| 52 |
|
| 53 |
def make_dilated(self, stage_list, dilation_list):
|
| 54 |
-
raise ValueError(
|
| 55 |
-
|
|
|
|
|
|
|
| 56 |
|
| 57 |
def get_stages(self):
|
| 58 |
return [
|
|
@@ -65,7 +69,6 @@ class InceptionResNetV2Encoder(InceptionResNetV2, EncoderMixin):
|
|
| 65 |
]
|
| 66 |
|
| 67 |
def forward(self, x):
|
| 68 |
-
|
| 69 |
stages = self.get_stages()
|
| 70 |
|
| 71 |
features = []
|
|
|
|
| 24 |
"""
|
| 25 |
|
| 26 |
import torch.nn as nn
|
| 27 |
+
from pretrainedmodels.models.inceptionresnetv2 import (
|
| 28 |
+
InceptionResNetV2,
|
| 29 |
+
pretrained_settings,
|
| 30 |
+
)
|
| 31 |
|
| 32 |
from ._base import EncoderMixin
|
| 33 |
|
|
|
|
| 53 |
del self.last_linear
|
| 54 |
|
| 55 |
def make_dilated(self, stage_list, dilation_list):
|
| 56 |
+
raise ValueError(
|
| 57 |
+
"InceptionResnetV2 encoder does not support dilated mode "
|
| 58 |
+
"due to pooling operation for downsampling!"
|
| 59 |
+
)
|
| 60 |
|
| 61 |
def get_stages(self):
|
| 62 |
return [
|
|
|
|
| 69 |
]
|
| 70 |
|
| 71 |
def forward(self, x):
|
|
|
|
| 72 |
stages = self.get_stages()
|
| 73 |
|
| 74 |
features = []
|
segmentation_models_pytorch/segmentation_models_pytorch/encoders/inceptionv4.py
CHANGED
|
@@ -24,8 +24,11 @@ Methods:
|
|
| 24 |
"""
|
| 25 |
|
| 26 |
import torch.nn as nn
|
| 27 |
-
from pretrainedmodels.models.inceptionv4 import
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
from ._base import EncoderMixin
|
| 31 |
|
|
@@ -50,21 +53,22 @@ class InceptionV4Encoder(InceptionV4, EncoderMixin):
|
|
| 50 |
del self.last_linear
|
| 51 |
|
| 52 |
def make_dilated(self, stage_list, dilation_list):
|
| 53 |
-
raise ValueError(
|
| 54 |
-
|
|
|
|
|
|
|
| 55 |
|
| 56 |
def get_stages(self):
|
| 57 |
return [
|
| 58 |
nn.Identity(),
|
| 59 |
self.features[: self._stage_idxs[0]],
|
| 60 |
-
self.features[self._stage_idxs[0]: self._stage_idxs[1]],
|
| 61 |
-
self.features[self._stage_idxs[1]: self._stage_idxs[2]],
|
| 62 |
-
self.features[self._stage_idxs[2]: self._stage_idxs[3]],
|
| 63 |
-
self.features[self._stage_idxs[3]:],
|
| 64 |
]
|
| 65 |
|
| 66 |
def forward(self, x):
|
| 67 |
-
|
| 68 |
stages = self.get_stages()
|
| 69 |
|
| 70 |
features = []
|
|
|
|
| 24 |
"""
|
| 25 |
|
| 26 |
import torch.nn as nn
|
| 27 |
+
from pretrainedmodels.models.inceptionv4 import (
|
| 28 |
+
BasicConv2d,
|
| 29 |
+
InceptionV4,
|
| 30 |
+
pretrained_settings,
|
| 31 |
+
)
|
| 32 |
|
| 33 |
from ._base import EncoderMixin
|
| 34 |
|
|
|
|
| 53 |
del self.last_linear
|
| 54 |
|
| 55 |
def make_dilated(self, stage_list, dilation_list):
|
| 56 |
+
raise ValueError(
|
| 57 |
+
"InceptionV4 encoder does not support dilated mode "
|
| 58 |
+
"due to pooling operation for downsampling!"
|
| 59 |
+
)
|
| 60 |
|
| 61 |
def get_stages(self):
|
| 62 |
return [
|
| 63 |
nn.Identity(),
|
| 64 |
self.features[: self._stage_idxs[0]],
|
| 65 |
+
self.features[self._stage_idxs[0] : self._stage_idxs[1]],
|
| 66 |
+
self.features[self._stage_idxs[1] : self._stage_idxs[2]],
|
| 67 |
+
self.features[self._stage_idxs[2] : self._stage_idxs[3]],
|
| 68 |
+
self.features[self._stage_idxs[3] :],
|
| 69 |
]
|
| 70 |
|
| 71 |
def forward(self, x):
|
|
|
|
| 72 |
stages = self.get_stages()
|
| 73 |
|
| 74 |
features = []
|
segmentation_models_pytorch/segmentation_models_pytorch/encoders/mobilenet.py
CHANGED
|
@@ -23,14 +23,13 @@ Methods:
|
|
| 23 |
depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled).
|
| 24 |
"""
|
| 25 |
|
| 26 |
-
import torchvision
|
| 27 |
import torch.nn as nn
|
|
|
|
| 28 |
|
| 29 |
from ._base import EncoderMixin
|
| 30 |
|
| 31 |
|
| 32 |
class MobileNetV2Encoder(torchvision.models.MobileNetV2, EncoderMixin):
|
| 33 |
-
|
| 34 |
def __init__(self, out_channels, depth=5, **kwargs):
|
| 35 |
super().__init__(**kwargs)
|
| 36 |
self._depth = depth
|
|
|
|
| 23 |
depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled).
|
| 24 |
"""
|
| 25 |
|
|
|
|
| 26 |
import torch.nn as nn
|
| 27 |
+
import torchvision
|
| 28 |
|
| 29 |
from ._base import EncoderMixin
|
| 30 |
|
| 31 |
|
| 32 |
class MobileNetV2Encoder(torchvision.models.MobileNetV2, EncoderMixin):
|
|
|
|
| 33 |
def __init__(self, out_channels, depth=5, **kwargs):
|
| 34 |
super().__init__(**kwargs)
|
| 35 |
self._depth = depth
|
segmentation_models_pytorch/segmentation_models_pytorch/encoders/resnet.py
CHANGED
|
@@ -25,11 +25,8 @@ Methods:
|
|
| 25 |
from copy import deepcopy
|
| 26 |
|
| 27 |
import torch.nn as nn
|
| 28 |
-
|
| 29 |
-
from torchvision.models.resnet import ResNet
|
| 30 |
-
from torchvision.models.resnet import BasicBlock
|
| 31 |
-
from torchvision.models.resnet import Bottleneck
|
| 32 |
from pretrainedmodels.models.torchvision_models import pretrained_settings
|
|
|
|
| 33 |
|
| 34 |
from ._base import EncoderMixin
|
| 35 |
|
|
@@ -73,11 +70,11 @@ class ResNetEncoder(ResNet, EncoderMixin):
|
|
| 73 |
new_settings = {
|
| 74 |
"resnet18": {
|
| 75 |
"ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet18-d92f0530.pth",
|
| 76 |
-
"swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet18-118f1556.pth"
|
| 77 |
},
|
| 78 |
"resnet50": {
|
| 79 |
"ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet50-08389792.pth",
|
| 80 |
-
"swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet50-16a12f1b.pth"
|
| 81 |
},
|
| 82 |
"resnext50_32x4d": {
|
| 83 |
"imagenet": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
|
|
@@ -86,7 +83,7 @@ new_settings = {
|
|
| 86 |
},
|
| 87 |
"resnext101_32x4d": {
|
| 88 |
"ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x4-dc43570a.pth",
|
| 89 |
-
"swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x4-3f87e46b.pth"
|
| 90 |
},
|
| 91 |
"resnext101_32x8d": {
|
| 92 |
"imagenet": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
|
|
@@ -104,7 +101,7 @@ new_settings = {
|
|
| 104 |
},
|
| 105 |
"resnext101_32x48d": {
|
| 106 |
"instagram": "https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth",
|
| 107 |
-
}
|
| 108 |
}
|
| 109 |
|
| 110 |
pretrained_settings = deepcopy(pretrained_settings)
|
|
@@ -115,11 +112,11 @@ for model_name, sources in new_settings.items():
|
|
| 115 |
for source_name, source_url in sources.items():
|
| 116 |
pretrained_settings[model_name][source_name] = {
|
| 117 |
"url": source_url,
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
}
|
| 124 |
|
| 125 |
|
|
|
|
| 25 |
from copy import deepcopy
|
| 26 |
|
| 27 |
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
from pretrainedmodels.models.torchvision_models import pretrained_settings
|
| 29 |
+
from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet
|
| 30 |
|
| 31 |
from ._base import EncoderMixin
|
| 32 |
|
|
|
|
| 70 |
new_settings = {
|
| 71 |
"resnet18": {
|
| 72 |
"ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet18-d92f0530.pth",
|
| 73 |
+
"swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet18-118f1556.pth",
|
| 74 |
},
|
| 75 |
"resnet50": {
|
| 76 |
"ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet50-08389792.pth",
|
| 77 |
+
"swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet50-16a12f1b.pth",
|
| 78 |
},
|
| 79 |
"resnext50_32x4d": {
|
| 80 |
"imagenet": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
|
|
|
|
| 83 |
},
|
| 84 |
"resnext101_32x4d": {
|
| 85 |
"ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x4-dc43570a.pth",
|
| 86 |
+
"swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x4-3f87e46b.pth",
|
| 87 |
},
|
| 88 |
"resnext101_32x8d": {
|
| 89 |
"imagenet": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
|
|
|
|
| 101 |
},
|
| 102 |
"resnext101_32x48d": {
|
| 103 |
"instagram": "https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth",
|
| 104 |
+
},
|
| 105 |
}
|
| 106 |
|
| 107 |
pretrained_settings = deepcopy(pretrained_settings)
|
|
|
|
| 112 |
for source_name, source_url in sources.items():
|
| 113 |
pretrained_settings[model_name][source_name] = {
|
| 114 |
"url": source_url,
|
| 115 |
+
"input_size": [3, 224, 224],
|
| 116 |
+
"input_range": [0, 1],
|
| 117 |
+
"mean": [0.485, 0.456, 0.406],
|
| 118 |
+
"std": [0.229, 0.224, 0.225],
|
| 119 |
+
"num_classes": 1000,
|
| 120 |
}
|
| 121 |
|
| 122 |
|
segmentation_models_pytorch/segmentation_models_pytorch/encoders/senet.py
CHANGED
|
@@ -24,14 +24,14 @@ Methods:
|
|
| 24 |
"""
|
| 25 |
|
| 26 |
import torch.nn as nn
|
| 27 |
-
|
| 28 |
from pretrainedmodels.models.senet import (
|
| 29 |
-
SENet,
|
| 30 |
SEBottleneck,
|
|
|
|
| 31 |
SEResNetBottleneck,
|
| 32 |
SEResNeXtBottleneck,
|
| 33 |
pretrained_settings,
|
| 34 |
)
|
|
|
|
| 35 |
from ._base import EncoderMixin
|
| 36 |
|
| 37 |
|
|
|
|
| 24 |
"""
|
| 25 |
|
| 26 |
import torch.nn as nn
|
|
|
|
| 27 |
from pretrainedmodels.models.senet import (
|
|
|
|
| 28 |
SEBottleneck,
|
| 29 |
+
SENet,
|
| 30 |
SEResNetBottleneck,
|
| 31 |
SEResNeXtBottleneck,
|
| 32 |
pretrained_settings,
|
| 33 |
)
|
| 34 |
+
|
| 35 |
from ._base import EncoderMixin
|
| 36 |
|
| 37 |
|
segmentation_models_pytorch/segmentation_models_pytorch/encoders/timm_regnet.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
-
from ._base import EncoderMixin
|
| 2 |
-
from timm.models.regnet import RegNet
|
| 3 |
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
|
| 6 |
class RegNetEncoder(RegNet, EncoderMixin):
|
|
@@ -39,78 +40,78 @@ class RegNetEncoder(RegNet, EncoderMixin):
|
|
| 39 |
|
| 40 |
|
| 41 |
regnet_weights = {
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
| 44 |
},
|
| 45 |
-
|
| 46 |
-
|
| 47 |
},
|
| 48 |
-
|
| 49 |
-
|
| 50 |
},
|
| 51 |
-
|
| 52 |
-
|
| 53 |
},
|
| 54 |
-
|
| 55 |
-
|
| 56 |
},
|
| 57 |
-
|
| 58 |
-
|
| 59 |
},
|
| 60 |
-
|
| 61 |
-
|
| 62 |
},
|
| 63 |
-
|
| 64 |
-
|
| 65 |
},
|
| 66 |
-
|
| 67 |
-
|
| 68 |
},
|
| 69 |
-
|
| 70 |
-
|
| 71 |
},
|
| 72 |
-
|
| 73 |
-
|
| 74 |
},
|
| 75 |
-
|
| 76 |
-
|
| 77 |
},
|
| 78 |
-
|
| 79 |
-
|
| 80 |
},
|
| 81 |
-
|
| 82 |
-
|
| 83 |
},
|
| 84 |
-
|
| 85 |
-
|
| 86 |
},
|
| 87 |
-
|
| 88 |
-
|
| 89 |
},
|
| 90 |
-
|
| 91 |
-
|
| 92 |
},
|
| 93 |
-
|
| 94 |
-
|
| 95 |
},
|
| 96 |
-
|
| 97 |
-
|
| 98 |
},
|
| 99 |
-
|
| 100 |
-
|
| 101 |
},
|
| 102 |
-
|
| 103 |
-
|
| 104 |
},
|
| 105 |
-
|
| 106 |
-
|
| 107 |
},
|
| 108 |
-
|
| 109 |
-
|
| 110 |
},
|
| 111 |
-
'timm-regnety_320': {
|
| 112 |
-
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth'
|
| 113 |
-
}
|
| 114 |
}
|
| 115 |
|
| 116 |
pretrained_settings = {}
|
|
@@ -119,214 +120,224 @@ for model_name, sources in regnet_weights.items():
|
|
| 119 |
for source_name, source_url in sources.items():
|
| 120 |
pretrained_settings[model_name][source_name] = {
|
| 121 |
"url": source_url,
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
}
|
| 128 |
|
| 129 |
# at this point I am too lazy to copy configs, so I just used the same configs from timm's repo
|
| 130 |
|
| 131 |
|
| 132 |
def _mcfg(**kwargs):
|
| 133 |
-
cfg = dict(se_ratio=0
|
| 134 |
cfg.update(**kwargs)
|
| 135 |
return cfg
|
| 136 |
|
| 137 |
|
| 138 |
timm_regnet_encoders = {
|
| 139 |
-
|
| 140 |
-
|
| 141 |
"pretrained_settings": pretrained_settings["timm-regnetx_002"],
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
},
|
| 146 |
},
|
| 147 |
-
|
| 148 |
-
|
| 149 |
"pretrained_settings": pretrained_settings["timm-regnetx_004"],
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
},
|
| 154 |
},
|
| 155 |
-
|
| 156 |
-
|
| 157 |
"pretrained_settings": pretrained_settings["timm-regnetx_006"],
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
},
|
| 162 |
},
|
| 163 |
-
|
| 164 |
-
|
| 165 |
"pretrained_settings": pretrained_settings["timm-regnetx_008"],
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
},
|
| 170 |
},
|
| 171 |
-
|
| 172 |
-
|
| 173 |
"pretrained_settings": pretrained_settings["timm-regnetx_016"],
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
},
|
| 178 |
},
|
| 179 |
-
|
| 180 |
-
|
| 181 |
"pretrained_settings": pretrained_settings["timm-regnetx_032"],
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
},
|
| 186 |
},
|
| 187 |
-
|
| 188 |
-
|
| 189 |
"pretrained_settings": pretrained_settings["timm-regnetx_040"],
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
},
|
| 194 |
},
|
| 195 |
-
|
| 196 |
-
|
| 197 |
"pretrained_settings": pretrained_settings["timm-regnetx_064"],
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
},
|
| 202 |
},
|
| 203 |
-
|
| 204 |
-
|
| 205 |
"pretrained_settings": pretrained_settings["timm-regnetx_080"],
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
},
|
| 210 |
},
|
| 211 |
-
|
| 212 |
-
|
| 213 |
"pretrained_settings": pretrained_settings["timm-regnetx_120"],
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
},
|
| 218 |
},
|
| 219 |
-
|
| 220 |
-
|
| 221 |
"pretrained_settings": pretrained_settings["timm-regnetx_160"],
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
},
|
| 226 |
},
|
| 227 |
-
|
| 228 |
-
|
| 229 |
"pretrained_settings": pretrained_settings["timm-regnetx_320"],
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
},
|
| 234 |
},
|
| 235 |
-
#regnety
|
| 236 |
-
|
| 237 |
-
|
| 238 |
"pretrained_settings": pretrained_settings["timm-regnety_002"],
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
},
|
| 243 |
},
|
| 244 |
-
|
| 245 |
-
|
| 246 |
"pretrained_settings": pretrained_settings["timm-regnety_004"],
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
},
|
| 251 |
},
|
| 252 |
-
|
| 253 |
-
|
| 254 |
"pretrained_settings": pretrained_settings["timm-regnety_006"],
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
},
|
| 259 |
},
|
| 260 |
-
|
| 261 |
-
|
| 262 |
"pretrained_settings": pretrained_settings["timm-regnety_008"],
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
},
|
| 267 |
},
|
| 268 |
-
|
| 269 |
-
|
| 270 |
"pretrained_settings": pretrained_settings["timm-regnety_016"],
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
},
|
| 275 |
},
|
| 276 |
-
|
| 277 |
-
|
| 278 |
"pretrained_settings": pretrained_settings["timm-regnety_032"],
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
},
|
| 283 |
},
|
| 284 |
-
|
| 285 |
-
|
| 286 |
"pretrained_settings": pretrained_settings["timm-regnety_040"],
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
},
|
| 291 |
},
|
| 292 |
-
|
| 293 |
-
|
| 294 |
"pretrained_settings": pretrained_settings["timm-regnety_064"],
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
|
|
|
|
|
|
| 298 |
},
|
| 299 |
},
|
| 300 |
-
|
| 301 |
-
|
| 302 |
"pretrained_settings": pretrained_settings["timm-regnety_080"],
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
|
|
|
|
|
|
| 306 |
},
|
| 307 |
},
|
| 308 |
-
|
| 309 |
-
|
| 310 |
"pretrained_settings": pretrained_settings["timm-regnety_120"],
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
|
|
|
|
|
|
| 314 |
},
|
| 315 |
},
|
| 316 |
-
|
| 317 |
-
|
| 318 |
"pretrained_settings": pretrained_settings["timm-regnety_160"],
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
|
|
|
|
|
|
| 322 |
},
|
| 323 |
},
|
| 324 |
-
|
| 325 |
-
|
| 326 |
"pretrained_settings": pretrained_settings["timm-regnety_320"],
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
|
|
|
|
|
|
| 330 |
},
|
| 331 |
},
|
| 332 |
}
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch.nn as nn
|
| 2 |
+
from timm.models.regnet import RegNet
|
| 3 |
+
|
| 4 |
+
from ._base import EncoderMixin
|
| 5 |
|
| 6 |
|
| 7 |
class RegNetEncoder(RegNet, EncoderMixin):
|
|
|
|
| 40 |
|
| 41 |
|
| 42 |
regnet_weights = {
|
| 43 |
+
"timm-regnetx_002": {
|
| 44 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_002-e7e85e5c.pth",
|
| 45 |
+
},
|
| 46 |
+
"timm-regnetx_004": {
|
| 47 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_004-7d0e9424.pth",
|
| 48 |
},
|
| 49 |
+
"timm-regnetx_006": {
|
| 50 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_006-85ec1baa.pth",
|
| 51 |
},
|
| 52 |
+
"timm-regnetx_008": {
|
| 53 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_008-d8b470eb.pth",
|
| 54 |
},
|
| 55 |
+
"timm-regnetx_016": {
|
| 56 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_016-65ca972a.pth",
|
| 57 |
},
|
| 58 |
+
"timm-regnetx_032": {
|
| 59 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_032-ed0c7f7e.pth",
|
| 60 |
},
|
| 61 |
+
"timm-regnetx_040": {
|
| 62 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_040-73c2a654.pth",
|
| 63 |
},
|
| 64 |
+
"timm-regnetx_064": {
|
| 65 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_064-29278baa.pth",
|
| 66 |
},
|
| 67 |
+
"timm-regnetx_080": {
|
| 68 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_080-7c7fcab1.pth",
|
| 69 |
},
|
| 70 |
+
"timm-regnetx_120": {
|
| 71 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_120-65d5521e.pth",
|
| 72 |
},
|
| 73 |
+
"timm-regnetx_160": {
|
| 74 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_160-c98c4112.pth",
|
| 75 |
},
|
| 76 |
+
"timm-regnetx_320": {
|
| 77 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_320-8ea38b93.pth",
|
| 78 |
},
|
| 79 |
+
"timm-regnety_002": {
|
| 80 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_002-e68ca334.pth",
|
| 81 |
},
|
| 82 |
+
"timm-regnety_004": {
|
| 83 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_004-0db870e6.pth",
|
| 84 |
},
|
| 85 |
+
"timm-regnety_006": {
|
| 86 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_006-c67e57ec.pth",
|
| 87 |
},
|
| 88 |
+
"timm-regnety_008": {
|
| 89 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_008-dc900dbe.pth",
|
| 90 |
},
|
| 91 |
+
"timm-regnety_016": {
|
| 92 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_016-54367f74.pth",
|
| 93 |
},
|
| 94 |
+
"timm-regnety_032": {
|
| 95 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/regnety_032_ra-7f2439f9.pth"
|
| 96 |
},
|
| 97 |
+
"timm-regnety_040": {
|
| 98 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_040-f0d569f9.pth"
|
| 99 |
},
|
| 100 |
+
"timm-regnety_064": {
|
| 101 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_064-0a48325c.pth"
|
| 102 |
},
|
| 103 |
+
"timm-regnety_080": {
|
| 104 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_080-e7f3eb93.pth",
|
| 105 |
},
|
| 106 |
+
"timm-regnety_120": {
|
| 107 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_120-721ba79a.pth",
|
| 108 |
},
|
| 109 |
+
"timm-regnety_160": {
|
| 110 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_160-d64013cd.pth",
|
| 111 |
},
|
| 112 |
+
"timm-regnety_320": {
|
| 113 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth"
|
| 114 |
},
|
|
|
|
|
|
|
|
|
|
| 115 |
}
|
| 116 |
|
| 117 |
pretrained_settings = {}
|
|
|
|
| 120 |
for source_name, source_url in sources.items():
|
| 121 |
pretrained_settings[model_name][source_name] = {
|
| 122 |
"url": source_url,
|
| 123 |
+
"input_size": [3, 224, 224],
|
| 124 |
+
"input_range": [0, 1],
|
| 125 |
+
"mean": [0.485, 0.456, 0.406],
|
| 126 |
+
"std": [0.229, 0.224, 0.225],
|
| 127 |
+
"num_classes": 1000,
|
| 128 |
}
|
| 129 |
|
| 130 |
# at this point I am too lazy to copy configs, so I just used the same configs from timm's repo
|
| 131 |
|
| 132 |
|
| 133 |
def _mcfg(**kwargs):
|
| 134 |
+
cfg = dict(se_ratio=0.0, bottle_ratio=1.0, stem_width=32)
|
| 135 |
cfg.update(**kwargs)
|
| 136 |
return cfg
|
| 137 |
|
| 138 |
|
| 139 |
timm_regnet_encoders = {
|
| 140 |
+
"timm-regnetx_002": {
|
| 141 |
+
"encoder": RegNetEncoder,
|
| 142 |
"pretrained_settings": pretrained_settings["timm-regnetx_002"],
|
| 143 |
+
"params": {
|
| 144 |
+
"out_channels": (3, 32, 24, 56, 152, 368),
|
| 145 |
+
"cfg": _mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13),
|
| 146 |
},
|
| 147 |
},
|
| 148 |
+
"timm-regnetx_004": {
|
| 149 |
+
"encoder": RegNetEncoder,
|
| 150 |
"pretrained_settings": pretrained_settings["timm-regnetx_004"],
|
| 151 |
+
"params": {
|
| 152 |
+
"out_channels": (3, 32, 32, 64, 160, 384),
|
| 153 |
+
"cfg": _mcfg(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22),
|
| 154 |
},
|
| 155 |
},
|
| 156 |
+
"timm-regnetx_006": {
|
| 157 |
+
"encoder": RegNetEncoder,
|
| 158 |
"pretrained_settings": pretrained_settings["timm-regnetx_006"],
|
| 159 |
+
"params": {
|
| 160 |
+
"out_channels": (3, 32, 48, 96, 240, 528),
|
| 161 |
+
"cfg": _mcfg(w0=48, wa=36.97, wm=2.24, group_w=24, depth=16),
|
| 162 |
},
|
| 163 |
},
|
| 164 |
+
"timm-regnetx_008": {
|
| 165 |
+
"encoder": RegNetEncoder,
|
| 166 |
"pretrained_settings": pretrained_settings["timm-regnetx_008"],
|
| 167 |
+
"params": {
|
| 168 |
+
"out_channels": (3, 32, 64, 128, 288, 672),
|
| 169 |
+
"cfg": _mcfg(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16),
|
| 170 |
},
|
| 171 |
},
|
| 172 |
+
"timm-regnetx_016": {
|
| 173 |
+
"encoder": RegNetEncoder,
|
| 174 |
"pretrained_settings": pretrained_settings["timm-regnetx_016"],
|
| 175 |
+
"params": {
|
| 176 |
+
"out_channels": (3, 32, 72, 168, 408, 912),
|
| 177 |
+
"cfg": _mcfg(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18),
|
| 178 |
},
|
| 179 |
},
|
| 180 |
+
"timm-regnetx_032": {
|
| 181 |
+
"encoder": RegNetEncoder,
|
| 182 |
"pretrained_settings": pretrained_settings["timm-regnetx_032"],
|
| 183 |
+
"params": {
|
| 184 |
+
"out_channels": (3, 32, 96, 192, 432, 1008),
|
| 185 |
+
"cfg": _mcfg(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25),
|
| 186 |
},
|
| 187 |
},
|
| 188 |
+
"timm-regnetx_040": {
|
| 189 |
+
"encoder": RegNetEncoder,
|
| 190 |
"pretrained_settings": pretrained_settings["timm-regnetx_040"],
|
| 191 |
+
"params": {
|
| 192 |
+
"out_channels": (3, 32, 80, 240, 560, 1360),
|
| 193 |
+
"cfg": _mcfg(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23),
|
| 194 |
},
|
| 195 |
},
|
| 196 |
+
"timm-regnetx_064": {
|
| 197 |
+
"encoder": RegNetEncoder,
|
| 198 |
"pretrained_settings": pretrained_settings["timm-regnetx_064"],
|
| 199 |
+
"params": {
|
| 200 |
+
"out_channels": (3, 32, 168, 392, 784, 1624),
|
| 201 |
+
"cfg": _mcfg(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17),
|
| 202 |
},
|
| 203 |
},
|
| 204 |
+
"timm-regnetx_080": {
|
| 205 |
+
"encoder": RegNetEncoder,
|
| 206 |
"pretrained_settings": pretrained_settings["timm-regnetx_080"],
|
| 207 |
+
"params": {
|
| 208 |
+
"out_channels": (3, 32, 80, 240, 720, 1920),
|
| 209 |
+
"cfg": _mcfg(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23),
|
| 210 |
},
|
| 211 |
},
|
| 212 |
+
"timm-regnetx_120": {
|
| 213 |
+
"encoder": RegNetEncoder,
|
| 214 |
"pretrained_settings": pretrained_settings["timm-regnetx_120"],
|
| 215 |
+
"params": {
|
| 216 |
+
"out_channels": (3, 32, 224, 448, 896, 2240),
|
| 217 |
+
"cfg": _mcfg(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19),
|
| 218 |
},
|
| 219 |
},
|
| 220 |
+
"timm-regnetx_160": {
|
| 221 |
+
"encoder": RegNetEncoder,
|
| 222 |
"pretrained_settings": pretrained_settings["timm-regnetx_160"],
|
| 223 |
+
"params": {
|
| 224 |
+
"out_channels": (3, 32, 256, 512, 896, 2048),
|
| 225 |
+
"cfg": _mcfg(w0=216, wa=55.59, wm=2.1, group_w=128, depth=22),
|
| 226 |
},
|
| 227 |
},
|
| 228 |
+
"timm-regnetx_320": {
|
| 229 |
+
"encoder": RegNetEncoder,
|
| 230 |
"pretrained_settings": pretrained_settings["timm-regnetx_320"],
|
| 231 |
+
"params": {
|
| 232 |
+
"out_channels": (3, 32, 336, 672, 1344, 2520),
|
| 233 |
+
"cfg": _mcfg(w0=320, wa=69.86, wm=2.0, group_w=168, depth=23),
|
| 234 |
},
|
| 235 |
},
|
| 236 |
+
# regnety
|
| 237 |
+
"timm-regnety_002": {
|
| 238 |
+
"encoder": RegNetEncoder,
|
| 239 |
"pretrained_settings": pretrained_settings["timm-regnety_002"],
|
| 240 |
+
"params": {
|
| 241 |
+
"out_channels": (3, 32, 24, 56, 152, 368),
|
| 242 |
+
"cfg": _mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13, se_ratio=0.25),
|
| 243 |
},
|
| 244 |
},
|
| 245 |
+
"timm-regnety_004": {
|
| 246 |
+
"encoder": RegNetEncoder,
|
| 247 |
"pretrained_settings": pretrained_settings["timm-regnety_004"],
|
| 248 |
+
"params": {
|
| 249 |
+
"out_channels": (3, 32, 48, 104, 208, 440),
|
| 250 |
+
"cfg": _mcfg(w0=48, wa=27.89, wm=2.09, group_w=8, depth=16, se_ratio=0.25),
|
| 251 |
},
|
| 252 |
},
|
| 253 |
+
"timm-regnety_006": {
|
| 254 |
+
"encoder": RegNetEncoder,
|
| 255 |
"pretrained_settings": pretrained_settings["timm-regnety_006"],
|
| 256 |
+
"params": {
|
| 257 |
+
"out_channels": (3, 32, 48, 112, 256, 608),
|
| 258 |
+
"cfg": _mcfg(w0=48, wa=32.54, wm=2.32, group_w=16, depth=15, se_ratio=0.25),
|
| 259 |
},
|
| 260 |
},
|
| 261 |
+
"timm-regnety_008": {
|
| 262 |
+
"encoder": RegNetEncoder,
|
| 263 |
"pretrained_settings": pretrained_settings["timm-regnety_008"],
|
| 264 |
+
"params": {
|
| 265 |
+
"out_channels": (3, 32, 64, 128, 320, 768),
|
| 266 |
+
"cfg": _mcfg(w0=56, wa=38.84, wm=2.4, group_w=16, depth=14, se_ratio=0.25),
|
| 267 |
},
|
| 268 |
},
|
| 269 |
+
"timm-regnety_016": {
|
| 270 |
+
"encoder": RegNetEncoder,
|
| 271 |
"pretrained_settings": pretrained_settings["timm-regnety_016"],
|
| 272 |
+
"params": {
|
| 273 |
+
"out_channels": (3, 32, 48, 120, 336, 888),
|
| 274 |
+
"cfg": _mcfg(w0=48, wa=20.71, wm=2.65, group_w=24, depth=27, se_ratio=0.25),
|
| 275 |
},
|
| 276 |
},
|
| 277 |
+
"timm-regnety_032": {
|
| 278 |
+
"encoder": RegNetEncoder,
|
| 279 |
"pretrained_settings": pretrained_settings["timm-regnety_032"],
|
| 280 |
+
"params": {
|
| 281 |
+
"out_channels": (3, 32, 72, 216, 576, 1512),
|
| 282 |
+
"cfg": _mcfg(w0=80, wa=42.63, wm=2.66, group_w=24, depth=21, se_ratio=0.25),
|
| 283 |
},
|
| 284 |
},
|
| 285 |
+
"timm-regnety_040": {
|
| 286 |
+
"encoder": RegNetEncoder,
|
| 287 |
"pretrained_settings": pretrained_settings["timm-regnety_040"],
|
| 288 |
+
"params": {
|
| 289 |
+
"out_channels": (3, 32, 128, 192, 512, 1088),
|
| 290 |
+
"cfg": _mcfg(w0=96, wa=31.41, wm=2.24, group_w=64, depth=22, se_ratio=0.25),
|
| 291 |
},
|
| 292 |
},
|
| 293 |
+
"timm-regnety_064": {
|
| 294 |
+
"encoder": RegNetEncoder,
|
| 295 |
"pretrained_settings": pretrained_settings["timm-regnety_064"],
|
| 296 |
+
"params": {
|
| 297 |
+
"out_channels": (3, 32, 144, 288, 576, 1296),
|
| 298 |
+
"cfg": _mcfg(
|
| 299 |
+
w0=112, wa=33.22, wm=2.27, group_w=72, depth=25, se_ratio=0.25
|
| 300 |
+
),
|
| 301 |
},
|
| 302 |
},
|
| 303 |
+
"timm-regnety_080": {
|
| 304 |
+
"encoder": RegNetEncoder,
|
| 305 |
"pretrained_settings": pretrained_settings["timm-regnety_080"],
|
| 306 |
+
"params": {
|
| 307 |
+
"out_channels": (3, 32, 168, 448, 896, 2016),
|
| 308 |
+
"cfg": _mcfg(
|
| 309 |
+
w0=192, wa=76.82, wm=2.19, group_w=56, depth=17, se_ratio=0.25
|
| 310 |
+
),
|
| 311 |
},
|
| 312 |
},
|
| 313 |
+
"timm-regnety_120": {
|
| 314 |
+
"encoder": RegNetEncoder,
|
| 315 |
"pretrained_settings": pretrained_settings["timm-regnety_120"],
|
| 316 |
+
"params": {
|
| 317 |
+
"out_channels": (3, 32, 224, 448, 896, 2240),
|
| 318 |
+
"cfg": _mcfg(
|
| 319 |
+
w0=168, wa=73.36, wm=2.37, group_w=112, depth=19, se_ratio=0.25
|
| 320 |
+
),
|
| 321 |
},
|
| 322 |
},
|
| 323 |
+
"timm-regnety_160": {
|
| 324 |
+
"encoder": RegNetEncoder,
|
| 325 |
"pretrained_settings": pretrained_settings["timm-regnety_160"],
|
| 326 |
+
"params": {
|
| 327 |
+
"out_channels": (3, 32, 224, 448, 1232, 3024),
|
| 328 |
+
"cfg": _mcfg(
|
| 329 |
+
w0=200, wa=106.23, wm=2.48, group_w=112, depth=18, se_ratio=0.25
|
| 330 |
+
),
|
| 331 |
},
|
| 332 |
},
|
| 333 |
+
"timm-regnety_320": {
|
| 334 |
+
"encoder": RegNetEncoder,
|
| 335 |
"pretrained_settings": pretrained_settings["timm-regnety_320"],
|
| 336 |
+
"params": {
|
| 337 |
+
"out_channels": (3, 32, 232, 696, 1392, 3712),
|
| 338 |
+
"cfg": _mcfg(
|
| 339 |
+
w0=232, wa=115.89, wm=2.53, group_w=232, depth=20, se_ratio=0.25
|
| 340 |
+
),
|
| 341 |
},
|
| 342 |
},
|
| 343 |
}
|
segmentation_models_pytorch/segmentation_models_pytorch/encoders/timm_res2net.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
-
from ._base import EncoderMixin
|
| 2 |
-
from timm.models.resnet import ResNet
|
| 3 |
-
from timm.models.res2net import Bottle2neck
|
| 4 |
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
class Res2NetEncoder(ResNet, EncoderMixin):
|
|
@@ -44,27 +45,27 @@ class Res2NetEncoder(ResNet, EncoderMixin):
|
|
| 44 |
|
| 45 |
|
| 46 |
res2net_weights = {
|
| 47 |
-
|
| 48 |
-
|
| 49 |
},
|
| 50 |
-
|
| 51 |
-
|
| 52 |
},
|
| 53 |
-
|
| 54 |
-
|
| 55 |
},
|
| 56 |
-
|
| 57 |
-
|
| 58 |
},
|
| 59 |
-
|
| 60 |
-
|
| 61 |
},
|
| 62 |
-
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
| 64 |
},
|
| 65 |
-
'timm-res2next50': {
|
| 66 |
-
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2next50_4s-6ef7e7bf.pth',
|
| 67 |
-
}
|
| 68 |
}
|
| 69 |
|
| 70 |
pretrained_settings = {}
|
|
@@ -73,91 +74,91 @@ for model_name, sources in res2net_weights.items():
|
|
| 73 |
for source_name, source_url in sources.items():
|
| 74 |
pretrained_settings[model_name][source_name] = {
|
| 75 |
"url": source_url,
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
}
|
| 82 |
|
| 83 |
|
| 84 |
timm_res2net_encoders = {
|
| 85 |
-
|
| 86 |
-
|
| 87 |
"pretrained_settings": pretrained_settings["timm-res2net50_26w_4s"],
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
},
|
| 95 |
},
|
| 96 |
-
|
| 97 |
-
|
| 98 |
"pretrained_settings": pretrained_settings["timm-res2net101_26w_4s"],
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
},
|
| 106 |
},
|
| 107 |
-
|
| 108 |
-
|
| 109 |
"pretrained_settings": pretrained_settings["timm-res2net50_26w_6s"],
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
},
|
| 117 |
},
|
| 118 |
-
|
| 119 |
-
|
| 120 |
"pretrained_settings": pretrained_settings["timm-res2net50_26w_8s"],
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
},
|
| 128 |
},
|
| 129 |
-
|
| 130 |
-
|
| 131 |
"pretrained_settings": pretrained_settings["timm-res2net50_48w_2s"],
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
},
|
| 139 |
},
|
| 140 |
-
|
| 141 |
-
|
| 142 |
"pretrained_settings": pretrained_settings["timm-res2net50_14w_8s"],
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
},
|
| 150 |
},
|
| 151 |
-
|
| 152 |
-
|
| 153 |
"pretrained_settings": pretrained_settings["timm-res2next50"],
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
},
|
| 162 |
-
}
|
| 163 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch.nn as nn
|
| 2 |
+
from timm.models.res2net import Bottle2neck
|
| 3 |
+
from timm.models.resnet import ResNet
|
| 4 |
+
|
| 5 |
+
from ._base import EncoderMixin
|
| 6 |
|
| 7 |
|
| 8 |
class Res2NetEncoder(ResNet, EncoderMixin):
|
|
|
|
| 45 |
|
| 46 |
|
| 47 |
res2net_weights = {
|
| 48 |
+
"timm-res2net50_26w_4s": {
|
| 49 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_4s-06e79181.pth"
|
| 50 |
},
|
| 51 |
+
"timm-res2net50_48w_2s": {
|
| 52 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_48w_2s-afed724a.pth"
|
| 53 |
},
|
| 54 |
+
"timm-res2net50_14w_8s": {
|
| 55 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_14w_8s-6527dddc.pth",
|
| 56 |
},
|
| 57 |
+
"timm-res2net50_26w_6s": {
|
| 58 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_6s-19041792.pth",
|
| 59 |
},
|
| 60 |
+
"timm-res2net50_26w_8s": {
|
| 61 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_8s-2c7c9f12.pth",
|
| 62 |
},
|
| 63 |
+
"timm-res2net101_26w_4s": {
|
| 64 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net101_26w_4s-02a759a1.pth",
|
| 65 |
+
},
|
| 66 |
+
"timm-res2next50": {
|
| 67 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2next50_4s-6ef7e7bf.pth",
|
| 68 |
},
|
|
|
|
|
|
|
|
|
|
| 69 |
}
|
| 70 |
|
| 71 |
pretrained_settings = {}
|
|
|
|
| 74 |
for source_name, source_url in sources.items():
|
| 75 |
pretrained_settings[model_name][source_name] = {
|
| 76 |
"url": source_url,
|
| 77 |
+
"input_size": [3, 224, 224],
|
| 78 |
+
"input_range": [0, 1],
|
| 79 |
+
"mean": [0.485, 0.456, 0.406],
|
| 80 |
+
"std": [0.229, 0.224, 0.225],
|
| 81 |
+
"num_classes": 1000,
|
| 82 |
}
|
| 83 |
|
| 84 |
|
| 85 |
timm_res2net_encoders = {
|
| 86 |
+
"timm-res2net50_26w_4s": {
|
| 87 |
+
"encoder": Res2NetEncoder,
|
| 88 |
"pretrained_settings": pretrained_settings["timm-res2net50_26w_4s"],
|
| 89 |
+
"params": {
|
| 90 |
+
"out_channels": (3, 64, 256, 512, 1024, 2048),
|
| 91 |
+
"block": Bottle2neck,
|
| 92 |
+
"layers": [3, 4, 6, 3],
|
| 93 |
+
"base_width": 26,
|
| 94 |
+
"block_args": {"scale": 4},
|
| 95 |
},
|
| 96 |
},
|
| 97 |
+
"timm-res2net101_26w_4s": {
|
| 98 |
+
"encoder": Res2NetEncoder,
|
| 99 |
"pretrained_settings": pretrained_settings["timm-res2net101_26w_4s"],
|
| 100 |
+
"params": {
|
| 101 |
+
"out_channels": (3, 64, 256, 512, 1024, 2048),
|
| 102 |
+
"block": Bottle2neck,
|
| 103 |
+
"layers": [3, 4, 23, 3],
|
| 104 |
+
"base_width": 26,
|
| 105 |
+
"block_args": {"scale": 4},
|
| 106 |
},
|
| 107 |
},
|
| 108 |
+
"timm-res2net50_26w_6s": {
|
| 109 |
+
"encoder": Res2NetEncoder,
|
| 110 |
"pretrained_settings": pretrained_settings["timm-res2net50_26w_6s"],
|
| 111 |
+
"params": {
|
| 112 |
+
"out_channels": (3, 64, 256, 512, 1024, 2048),
|
| 113 |
+
"block": Bottle2neck,
|
| 114 |
+
"layers": [3, 4, 6, 3],
|
| 115 |
+
"base_width": 26,
|
| 116 |
+
"block_args": {"scale": 6},
|
| 117 |
},
|
| 118 |
},
|
| 119 |
+
"timm-res2net50_26w_8s": {
|
| 120 |
+
"encoder": Res2NetEncoder,
|
| 121 |
"pretrained_settings": pretrained_settings["timm-res2net50_26w_8s"],
|
| 122 |
+
"params": {
|
| 123 |
+
"out_channels": (3, 64, 256, 512, 1024, 2048),
|
| 124 |
+
"block": Bottle2neck,
|
| 125 |
+
"layers": [3, 4, 6, 3],
|
| 126 |
+
"base_width": 26,
|
| 127 |
+
"block_args": {"scale": 8},
|
| 128 |
},
|
| 129 |
},
|
| 130 |
+
"timm-res2net50_48w_2s": {
|
| 131 |
+
"encoder": Res2NetEncoder,
|
| 132 |
"pretrained_settings": pretrained_settings["timm-res2net50_48w_2s"],
|
| 133 |
+
"params": {
|
| 134 |
+
"out_channels": (3, 64, 256, 512, 1024, 2048),
|
| 135 |
+
"block": Bottle2neck,
|
| 136 |
+
"layers": [3, 4, 6, 3],
|
| 137 |
+
"base_width": 48,
|
| 138 |
+
"block_args": {"scale": 2},
|
| 139 |
},
|
| 140 |
},
|
| 141 |
+
"timm-res2net50_14w_8s": {
|
| 142 |
+
"encoder": Res2NetEncoder,
|
| 143 |
"pretrained_settings": pretrained_settings["timm-res2net50_14w_8s"],
|
| 144 |
+
"params": {
|
| 145 |
+
"out_channels": (3, 64, 256, 512, 1024, 2048),
|
| 146 |
+
"block": Bottle2neck,
|
| 147 |
+
"layers": [3, 4, 6, 3],
|
| 148 |
+
"base_width": 14,
|
| 149 |
+
"block_args": {"scale": 8},
|
| 150 |
},
|
| 151 |
},
|
| 152 |
+
"timm-res2next50": {
|
| 153 |
+
"encoder": Res2NetEncoder,
|
| 154 |
"pretrained_settings": pretrained_settings["timm-res2next50"],
|
| 155 |
+
"params": {
|
| 156 |
+
"out_channels": (3, 64, 256, 512, 1024, 2048),
|
| 157 |
+
"block": Bottle2neck,
|
| 158 |
+
"layers": [3, 4, 6, 3],
|
| 159 |
+
"base_width": 4,
|
| 160 |
+
"cardinality": 8,
|
| 161 |
+
"block_args": {"scale": 4},
|
| 162 |
},
|
| 163 |
+
},
|
| 164 |
}
|
segmentation_models_pytorch/segmentation_models_pytorch/encoders/timm_resnest.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
-
from ._base import EncoderMixin
|
| 2 |
-
from timm.models.resnet import ResNet
|
| 3 |
-
from timm.models.resnest import ResNestBottleneck
|
| 4 |
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
class ResNestEncoder(ResNet, EncoderMixin):
|
|
@@ -44,30 +45,30 @@ class ResNestEncoder(ResNet, EncoderMixin):
|
|
| 44 |
|
| 45 |
|
| 46 |
resnest_weights = {
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
| 49 |
},
|
| 50 |
-
|
| 51 |
-
|
| 52 |
},
|
| 53 |
-
|
| 54 |
-
|
| 55 |
},
|
| 56 |
-
|
| 57 |
-
|
| 58 |
},
|
| 59 |
-
|
| 60 |
-
|
| 61 |
},
|
| 62 |
-
|
| 63 |
-
|
| 64 |
},
|
| 65 |
-
|
| 66 |
-
|
| 67 |
},
|
| 68 |
-
'timm-resnest50d_1s4x24d': {
|
| 69 |
-
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50_fast_1s4x24d-d4a4f76f.pth',
|
| 70 |
-
}
|
| 71 |
}
|
| 72 |
|
| 73 |
pretrained_settings = {}
|
|
@@ -76,133 +77,133 @@ for model_name, sources in resnest_weights.items():
|
|
| 76 |
for source_name, source_url in sources.items():
|
| 77 |
pretrained_settings[model_name][source_name] = {
|
| 78 |
"url": source_url,
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
}
|
| 85 |
|
| 86 |
|
| 87 |
timm_resnest_encoders = {
|
| 88 |
-
|
| 89 |
-
|
| 90 |
"pretrained_settings": pretrained_settings["timm-resnest14d"],
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
}
|
| 102 |
},
|
| 103 |
-
|
| 104 |
-
|
| 105 |
"pretrained_settings": pretrained_settings["timm-resnest26d"],
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
}
|
| 117 |
},
|
| 118 |
-
|
| 119 |
-
|
| 120 |
"pretrained_settings": pretrained_settings["timm-resnest50d"],
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
}
|
| 132 |
},
|
| 133 |
-
|
| 134 |
-
|
| 135 |
"pretrained_settings": pretrained_settings["timm-resnest101e"],
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
}
|
| 147 |
},
|
| 148 |
-
|
| 149 |
-
|
| 150 |
"pretrained_settings": pretrained_settings["timm-resnest200e"],
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
}
|
| 162 |
},
|
| 163 |
-
|
| 164 |
-
|
| 165 |
"pretrained_settings": pretrained_settings["timm-resnest269e"],
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
},
|
| 177 |
},
|
| 178 |
-
|
| 179 |
-
|
| 180 |
"pretrained_settings": pretrained_settings["timm-resnest50d_4s2x40d"],
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
}
|
| 192 |
},
|
| 193 |
-
|
| 194 |
-
|
| 195 |
"pretrained_settings": pretrained_settings["timm-resnest50d_1s4x24d"],
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
}
|
| 207 |
-
}
|
| 208 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch.nn as nn
|
| 2 |
+
from timm.models.resnest import ResNestBottleneck
|
| 3 |
+
from timm.models.resnet import ResNet
|
| 4 |
+
|
| 5 |
+
from ._base import EncoderMixin
|
| 6 |
|
| 7 |
|
| 8 |
class ResNestEncoder(ResNet, EncoderMixin):
|
|
|
|
| 45 |
|
| 46 |
|
| 47 |
resnest_weights = {
|
| 48 |
+
"timm-resnest14d": {
|
| 49 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest14-9c8fe254.pth"
|
| 50 |
+
},
|
| 51 |
+
"timm-resnest26d": {
|
| 52 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest26-50eb607c.pth"
|
| 53 |
},
|
| 54 |
+
"timm-resnest50d": {
|
| 55 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50-528c19ca.pth",
|
| 56 |
},
|
| 57 |
+
"timm-resnest101e": {
|
| 58 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest101-22405ba7.pth",
|
| 59 |
},
|
| 60 |
+
"timm-resnest200e": {
|
| 61 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest200-75117900.pth",
|
| 62 |
},
|
| 63 |
+
"timm-resnest269e": {
|
| 64 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest269-0cc87c48.pth",
|
| 65 |
},
|
| 66 |
+
"timm-resnest50d_4s2x40d": {
|
| 67 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50_fast_4s2x40d-41d14ed0.pth",
|
| 68 |
},
|
| 69 |
+
"timm-resnest50d_1s4x24d": {
|
| 70 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50_fast_1s4x24d-d4a4f76f.pth",
|
| 71 |
},
|
|
|
|
|
|
|
|
|
|
| 72 |
}
|
| 73 |
|
| 74 |
pretrained_settings = {}
|
|
|
|
| 77 |
for source_name, source_url in sources.items():
|
| 78 |
pretrained_settings[model_name][source_name] = {
|
| 79 |
"url": source_url,
|
| 80 |
+
"input_size": [3, 224, 224],
|
| 81 |
+
"input_range": [0, 1],
|
| 82 |
+
"mean": [0.485, 0.456, 0.406],
|
| 83 |
+
"std": [0.229, 0.224, 0.225],
|
| 84 |
+
"num_classes": 1000,
|
| 85 |
}
|
| 86 |
|
| 87 |
|
| 88 |
timm_resnest_encoders = {
|
| 89 |
+
"timm-resnest14d": {
|
| 90 |
+
"encoder": ResNestEncoder,
|
| 91 |
"pretrained_settings": pretrained_settings["timm-resnest14d"],
|
| 92 |
+
"params": {
|
| 93 |
+
"out_channels": (3, 64, 256, 512, 1024, 2048),
|
| 94 |
+
"block": ResNestBottleneck,
|
| 95 |
+
"layers": [1, 1, 1, 1],
|
| 96 |
+
"stem_type": "deep",
|
| 97 |
+
"stem_width": 32,
|
| 98 |
+
"avg_down": True,
|
| 99 |
+
"base_width": 64,
|
| 100 |
+
"cardinality": 1,
|
| 101 |
+
"block_args": {"radix": 2, "avd": True, "avd_first": False},
|
| 102 |
+
},
|
| 103 |
},
|
| 104 |
+
"timm-resnest26d": {
|
| 105 |
+
"encoder": ResNestEncoder,
|
| 106 |
"pretrained_settings": pretrained_settings["timm-resnest26d"],
|
| 107 |
+
"params": {
|
| 108 |
+
"out_channels": (3, 64, 256, 512, 1024, 2048),
|
| 109 |
+
"block": ResNestBottleneck,
|
| 110 |
+
"layers": [2, 2, 2, 2],
|
| 111 |
+
"stem_type": "deep",
|
| 112 |
+
"stem_width": 32,
|
| 113 |
+
"avg_down": True,
|
| 114 |
+
"base_width": 64,
|
| 115 |
+
"cardinality": 1,
|
| 116 |
+
"block_args": {"radix": 2, "avd": True, "avd_first": False},
|
| 117 |
+
},
|
| 118 |
},
|
| 119 |
+
"timm-resnest50d": {
|
| 120 |
+
"encoder": ResNestEncoder,
|
| 121 |
"pretrained_settings": pretrained_settings["timm-resnest50d"],
|
| 122 |
+
"params": {
|
| 123 |
+
"out_channels": (3, 64, 256, 512, 1024, 2048),
|
| 124 |
+
"block": ResNestBottleneck,
|
| 125 |
+
"layers": [3, 4, 6, 3],
|
| 126 |
+
"stem_type": "deep",
|
| 127 |
+
"stem_width": 32,
|
| 128 |
+
"avg_down": True,
|
| 129 |
+
"base_width": 64,
|
| 130 |
+
"cardinality": 1,
|
| 131 |
+
"block_args": {"radix": 2, "avd": True, "avd_first": False},
|
| 132 |
+
},
|
| 133 |
},
|
| 134 |
+
"timm-resnest101e": {
|
| 135 |
+
"encoder": ResNestEncoder,
|
| 136 |
"pretrained_settings": pretrained_settings["timm-resnest101e"],
|
| 137 |
+
"params": {
|
| 138 |
+
"out_channels": (3, 128, 256, 512, 1024, 2048),
|
| 139 |
+
"block": ResNestBottleneck,
|
| 140 |
+
"layers": [3, 4, 23, 3],
|
| 141 |
+
"stem_type": "deep",
|
| 142 |
+
"stem_width": 64,
|
| 143 |
+
"avg_down": True,
|
| 144 |
+
"base_width": 64,
|
| 145 |
+
"cardinality": 1,
|
| 146 |
+
"block_args": {"radix": 2, "avd": True, "avd_first": False},
|
| 147 |
+
},
|
| 148 |
},
|
| 149 |
+
"timm-resnest200e": {
|
| 150 |
+
"encoder": ResNestEncoder,
|
| 151 |
"pretrained_settings": pretrained_settings["timm-resnest200e"],
|
| 152 |
+
"params": {
|
| 153 |
+
"out_channels": (3, 128, 256, 512, 1024, 2048),
|
| 154 |
+
"block": ResNestBottleneck,
|
| 155 |
+
"layers": [3, 24, 36, 3],
|
| 156 |
+
"stem_type": "deep",
|
| 157 |
+
"stem_width": 64,
|
| 158 |
+
"avg_down": True,
|
| 159 |
+
"base_width": 64,
|
| 160 |
+
"cardinality": 1,
|
| 161 |
+
"block_args": {"radix": 2, "avd": True, "avd_first": False},
|
| 162 |
+
},
|
| 163 |
},
|
| 164 |
+
"timm-resnest269e": {
|
| 165 |
+
"encoder": ResNestEncoder,
|
| 166 |
"pretrained_settings": pretrained_settings["timm-resnest269e"],
|
| 167 |
+
"params": {
|
| 168 |
+
"out_channels": (3, 128, 256, 512, 1024, 2048),
|
| 169 |
+
"block": ResNestBottleneck,
|
| 170 |
+
"layers": [3, 30, 48, 8],
|
| 171 |
+
"stem_type": "deep",
|
| 172 |
+
"stem_width": 64,
|
| 173 |
+
"avg_down": True,
|
| 174 |
+
"base_width": 64,
|
| 175 |
+
"cardinality": 1,
|
| 176 |
+
"block_args": {"radix": 2, "avd": True, "avd_first": False},
|
| 177 |
},
|
| 178 |
},
|
| 179 |
+
"timm-resnest50d_4s2x40d": {
|
| 180 |
+
"encoder": ResNestEncoder,
|
| 181 |
"pretrained_settings": pretrained_settings["timm-resnest50d_4s2x40d"],
|
| 182 |
+
"params": {
|
| 183 |
+
"out_channels": (3, 64, 256, 512, 1024, 2048),
|
| 184 |
+
"block": ResNestBottleneck,
|
| 185 |
+
"layers": [3, 4, 6, 3],
|
| 186 |
+
"stem_type": "deep",
|
| 187 |
+
"stem_width": 32,
|
| 188 |
+
"avg_down": True,
|
| 189 |
+
"base_width": 40,
|
| 190 |
+
"cardinality": 2,
|
| 191 |
+
"block_args": {"radix": 4, "avd": True, "avd_first": True},
|
| 192 |
+
},
|
| 193 |
},
|
| 194 |
+
"timm-resnest50d_1s4x24d": {
|
| 195 |
+
"encoder": ResNestEncoder,
|
| 196 |
"pretrained_settings": pretrained_settings["timm-resnest50d_1s4x24d"],
|
| 197 |
+
"params": {
|
| 198 |
+
"out_channels": (3, 64, 256, 512, 1024, 2048),
|
| 199 |
+
"block": ResNestBottleneck,
|
| 200 |
+
"layers": [3, 4, 6, 3],
|
| 201 |
+
"stem_type": "deep",
|
| 202 |
+
"stem_width": 32,
|
| 203 |
+
"avg_down": True,
|
| 204 |
+
"base_width": 24,
|
| 205 |
+
"cardinality": 4,
|
| 206 |
+
"block_args": {"radix": 1, "avd": True, "avd_first": True},
|
| 207 |
+
},
|
| 208 |
+
},
|
| 209 |
}
|
segmentation_models_pytorch/segmentation_models_pytorch/encoders/timm_sknet.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
-
from ._base import EncoderMixin
|
| 2 |
-
from timm.models.resnet import ResNet
|
| 3 |
-
from timm.models.sknet import SelectiveKernelBottleneck, SelectiveKernelBasic
|
| 4 |
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
class SkNetEncoder(ResNet, EncoderMixin):
|
|
@@ -41,15 +42,15 @@ class SkNetEncoder(ResNet, EncoderMixin):
|
|
| 41 |
|
| 42 |
|
| 43 |
sknet_weights = {
|
| 44 |
-
|
| 45 |
-
|
| 46 |
},
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
| 49 |
},
|
| 50 |
-
'timm-skresnext50_32x4d': {
|
| 51 |
-
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnext50_ra-f40e40bf.pth',
|
| 52 |
-
}
|
| 53 |
}
|
| 54 |
|
| 55 |
pretrained_settings = {}
|
|
@@ -58,46 +59,58 @@ for model_name, sources in sknet_weights.items():
|
|
| 58 |
for source_name, source_url in sources.items():
|
| 59 |
pretrained_settings[model_name][source_name] = {
|
| 60 |
"url": source_url,
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
}
|
| 67 |
|
| 68 |
timm_sknet_encoders = {
|
| 69 |
-
|
| 70 |
-
|
| 71 |
"pretrained_settings": pretrained_settings["timm-skresnet18"],
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
},
|
| 80 |
-
|
| 81 |
-
|
| 82 |
"pretrained_settings": pretrained_settings["timm-skresnet34"],
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
},
|
| 91 |
-
|
| 92 |
-
|
| 93 |
"pretrained_settings": pretrained_settings["timm-skresnext50_32x4d"],
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
}
|
| 102 |
-
}
|
| 103 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch.nn as nn
|
| 2 |
+
from timm.models.resnet import ResNet
|
| 3 |
+
from timm.models.sknet import SelectiveKernelBasic, SelectiveKernelBottleneck
|
| 4 |
+
|
| 5 |
+
from ._base import EncoderMixin
|
| 6 |
|
| 7 |
|
| 8 |
class SkNetEncoder(ResNet, EncoderMixin):
|
|
|
|
| 42 |
|
| 43 |
|
| 44 |
sknet_weights = {
|
| 45 |
+
"timm-skresnet18": {
|
| 46 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet18_ra-4eec2804.pth"
|
| 47 |
},
|
| 48 |
+
"timm-skresnet34": {
|
| 49 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet34_ra-bdc0ccde.pth"
|
| 50 |
+
},
|
| 51 |
+
"timm-skresnext50_32x4d": {
|
| 52 |
+
"imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnext50_ra-f40e40bf.pth",
|
| 53 |
},
|
|
|
|
|
|
|
|
|
|
| 54 |
}
|
| 55 |
|
| 56 |
pretrained_settings = {}
|
|
|
|
| 59 |
for source_name, source_url in sources.items():
|
| 60 |
pretrained_settings[model_name][source_name] = {
|
| 61 |
"url": source_url,
|
| 62 |
+
"input_size": [3, 224, 224],
|
| 63 |
+
"input_range": [0, 1],
|
| 64 |
+
"mean": [0.485, 0.456, 0.406],
|
| 65 |
+
"std": [0.229, 0.224, 0.225],
|
| 66 |
+
"num_classes": 1000,
|
| 67 |
}
|
| 68 |
|
| 69 |
timm_sknet_encoders = {
|
| 70 |
+
"timm-skresnet18": {
|
| 71 |
+
"encoder": SkNetEncoder,
|
| 72 |
"pretrained_settings": pretrained_settings["timm-skresnet18"],
|
| 73 |
+
"params": {
|
| 74 |
+
"out_channels": (3, 64, 64, 128, 256, 512),
|
| 75 |
+
"block": SelectiveKernelBasic,
|
| 76 |
+
"layers": [2, 2, 2, 2],
|
| 77 |
+
"zero_init_last_bn": False,
|
| 78 |
+
"block_args": {
|
| 79 |
+
"sk_kwargs": {
|
| 80 |
+
"min_attn_channels": 16,
|
| 81 |
+
"attn_reduction": 8,
|
| 82 |
+
"split_input": True,
|
| 83 |
+
}
|
| 84 |
+
},
|
| 85 |
+
},
|
| 86 |
},
|
| 87 |
+
"timm-skresnet34": {
|
| 88 |
+
"encoder": SkNetEncoder,
|
| 89 |
"pretrained_settings": pretrained_settings["timm-skresnet34"],
|
| 90 |
+
"params": {
|
| 91 |
+
"out_channels": (3, 64, 64, 128, 256, 512),
|
| 92 |
+
"block": SelectiveKernelBasic,
|
| 93 |
+
"layers": [3, 4, 6, 3],
|
| 94 |
+
"zero_init_last_bn": False,
|
| 95 |
+
"block_args": {
|
| 96 |
+
"sk_kwargs": {
|
| 97 |
+
"min_attn_channels": 16,
|
| 98 |
+
"attn_reduction": 8,
|
| 99 |
+
"split_input": True,
|
| 100 |
+
}
|
| 101 |
+
},
|
| 102 |
+
},
|
| 103 |
},
|
| 104 |
+
"timm-skresnext50_32x4d": {
|
| 105 |
+
"encoder": SkNetEncoder,
|
| 106 |
"pretrained_settings": pretrained_settings["timm-skresnext50_32x4d"],
|
| 107 |
+
"params": {
|
| 108 |
+
"out_channels": (3, 64, 256, 512, 1024, 2048),
|
| 109 |
+
"block": SelectiveKernelBottleneck,
|
| 110 |
+
"layers": [3, 4, 6, 3],
|
| 111 |
+
"zero_init_last_bn": False,
|
| 112 |
+
"cardinality": 32,
|
| 113 |
+
"base_width": 4,
|
| 114 |
+
},
|
| 115 |
+
},
|
| 116 |
}
|
segmentation_models_pytorch/segmentation_models_pytorch/encoders/vgg.py
CHANGED
|
@@ -24,9 +24,8 @@ Methods:
|
|
| 24 |
"""
|
| 25 |
|
| 26 |
import torch.nn as nn
|
| 27 |
-
from torchvision.models.vgg import VGG
|
| 28 |
-
from torchvision.models.vgg import make_layers
|
| 29 |
from pretrainedmodels.models.torchvision_models import pretrained_settings
|
|
|
|
| 30 |
|
| 31 |
from ._base import EncoderMixin
|
| 32 |
|
|
@@ -49,8 +48,10 @@ class VGGEncoder(VGG, EncoderMixin):
|
|
| 49 |
del self.classifier
|
| 50 |
|
| 51 |
def make_dilated(self, stage_list, dilation_list):
|
| 52 |
-
raise ValueError(
|
| 53 |
-
|
|
|
|
|
|
|
| 54 |
|
| 55 |
def get_stages(self):
|
| 56 |
stages = []
|
|
|
|
| 24 |
"""
|
| 25 |
|
| 26 |
import torch.nn as nn
|
|
|
|
|
|
|
| 27 |
from pretrainedmodels.models.torchvision_models import pretrained_settings
|
| 28 |
+
from torchvision.models.vgg import VGG, make_layers
|
| 29 |
|
| 30 |
from ._base import EncoderMixin
|
| 31 |
|
|
|
|
| 48 |
del self.classifier
|
| 49 |
|
| 50 |
def make_dilated(self, stage_list, dilation_list):
|
| 51 |
+
raise ValueError(
|
| 52 |
+
"'VGG' models do not support dilated mode due to Max Pooling"
|
| 53 |
+
" operations for downsampling!"
|
| 54 |
+
)
|
| 55 |
|
| 56 |
def get_stages(self):
|
| 57 |
stages = []
|
segmentation_models_pytorch/segmentation_models_pytorch/encoders/xception.py
CHANGED
|
@@ -1,14 +1,12 @@
|
|
| 1 |
import re
|
| 2 |
-
import torch.nn as nn
|
| 3 |
|
| 4 |
-
|
| 5 |
-
from pretrainedmodels.models.xception import Xception
|
| 6 |
|
| 7 |
from ._base import EncoderMixin
|
| 8 |
|
| 9 |
|
| 10 |
class XceptionEncoder(Xception, EncoderMixin):
|
| 11 |
-
|
| 12 |
def __init__(self, out_channels, *args, depth=5, **kwargs):
|
| 13 |
super().__init__(*args, **kwargs)
|
| 14 |
|
|
@@ -23,18 +21,33 @@ class XceptionEncoder(Xception, EncoderMixin):
|
|
| 23 |
del self.fc
|
| 24 |
|
| 25 |
def make_dilated(self, stage_list, dilation_list):
|
| 26 |
-
raise ValueError(
|
| 27 |
-
|
|
|
|
|
|
|
| 28 |
|
| 29 |
def get_stages(self):
|
| 30 |
return [
|
| 31 |
nn.Identity(),
|
| 32 |
-
nn.Sequential(
|
|
|
|
|
|
|
| 33 |
self.block1,
|
| 34 |
self.block2,
|
| 35 |
-
nn.Sequential(
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
]
|
| 39 |
|
| 40 |
def forward(self, x):
|
|
@@ -49,18 +62,18 @@ class XceptionEncoder(Xception, EncoderMixin):
|
|
| 49 |
|
| 50 |
def load_state_dict(self, state_dict):
|
| 51 |
# remove linear
|
| 52 |
-
state_dict.pop(
|
| 53 |
-
state_dict.pop(
|
| 54 |
|
| 55 |
super().load_state_dict(state_dict)
|
| 56 |
|
| 57 |
|
| 58 |
xception_encoders = {
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
}
|
| 65 |
},
|
| 66 |
}
|
|
|
|
| 1 |
import re
|
|
|
|
| 2 |
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from pretrainedmodels.models.xception import Xception, pretrained_settings
|
| 5 |
|
| 6 |
from ._base import EncoderMixin
|
| 7 |
|
| 8 |
|
| 9 |
class XceptionEncoder(Xception, EncoderMixin):
|
|
|
|
| 10 |
def __init__(self, out_channels, *args, depth=5, **kwargs):
|
| 11 |
super().__init__(*args, **kwargs)
|
| 12 |
|
|
|
|
| 21 |
del self.fc
|
| 22 |
|
| 23 |
def make_dilated(self, stage_list, dilation_list):
|
| 24 |
+
raise ValueError(
|
| 25 |
+
"Xception encoder does not support dilated mode "
|
| 26 |
+
"due to pooling operation for downsampling!"
|
| 27 |
+
)
|
| 28 |
|
| 29 |
def get_stages(self):
|
| 30 |
return [
|
| 31 |
nn.Identity(),
|
| 32 |
+
nn.Sequential(
|
| 33 |
+
self.conv1, self.bn1, self.relu, self.conv2, self.bn2, self.relu
|
| 34 |
+
),
|
| 35 |
self.block1,
|
| 36 |
self.block2,
|
| 37 |
+
nn.Sequential(
|
| 38 |
+
self.block3,
|
| 39 |
+
self.block4,
|
| 40 |
+
self.block5,
|
| 41 |
+
self.block6,
|
| 42 |
+
self.block7,
|
| 43 |
+
self.block8,
|
| 44 |
+
self.block9,
|
| 45 |
+
self.block10,
|
| 46 |
+
self.block11,
|
| 47 |
+
),
|
| 48 |
+
nn.Sequential(
|
| 49 |
+
self.block12, self.conv3, self.bn3, self.relu, self.conv4, self.bn4
|
| 50 |
+
),
|
| 51 |
]
|
| 52 |
|
| 53 |
def forward(self, x):
|
|
|
|
| 62 |
|
| 63 |
def load_state_dict(self, state_dict):
|
| 64 |
# remove linear
|
| 65 |
+
state_dict.pop("fc.bias")
|
| 66 |
+
state_dict.pop("fc.weight")
|
| 67 |
|
| 68 |
super().load_state_dict(state_dict)
|
| 69 |
|
| 70 |
|
| 71 |
xception_encoders = {
|
| 72 |
+
"xception": {
|
| 73 |
+
"encoder": XceptionEncoder,
|
| 74 |
+
"pretrained_settings": pretrained_settings["xception"],
|
| 75 |
+
"params": {
|
| 76 |
+
"out_channels": (3, 64, 128, 256, 728, 2048),
|
| 77 |
+
},
|
| 78 |
},
|
| 79 |
}
|
segmentation_models_pytorch/segmentation_models_pytorch/fpn/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
from .model import FPN
|
|
|
|
| 1 |
+
from .model import FPN
|
segmentation_models_pytorch/segmentation_models_pytorch/fpn/decoder.py
CHANGED
|
@@ -55,51 +55,63 @@ class MergeBlock(nn.Module):
|
|
| 55 |
super().__init__()
|
| 56 |
if policy not in ["add", "cat"]:
|
| 57 |
raise ValueError(
|
| 58 |
-
"`merge_policy` must be one of: ['add', 'cat'], got {}".format(
|
| 59 |
-
policy
|
| 60 |
-
)
|
| 61 |
)
|
| 62 |
self.policy = policy
|
| 63 |
|
| 64 |
def forward(self, x):
|
| 65 |
-
if self.policy ==
|
| 66 |
return sum(x)
|
| 67 |
-
elif self.policy ==
|
| 68 |
return torch.cat(x, dim=1)
|
| 69 |
else:
|
| 70 |
raise ValueError(
|
| 71 |
-
"`merge_policy` must be one of: ['add', 'cat'], got {}".format(
|
|
|
|
|
|
|
| 72 |
)
|
| 73 |
|
| 74 |
|
| 75 |
class FPNDecoder(nn.Module):
|
| 76 |
def __init__(
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
):
|
| 85 |
super().__init__()
|
| 86 |
|
| 87 |
-
self.out_channels =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
if encoder_depth < 3:
|
| 89 |
-
raise ValueError(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
encoder_channels = encoder_channels[::-1]
|
| 92 |
-
encoder_channels = encoder_channels[:encoder_depth + 1]
|
| 93 |
|
| 94 |
self.p5 = nn.Conv2d(encoder_channels[0], pyramid_channels, kernel_size=1)
|
| 95 |
self.p4 = FPNBlock(pyramid_channels, encoder_channels[1])
|
| 96 |
self.p3 = FPNBlock(pyramid_channels, encoder_channels[2])
|
| 97 |
self.p2 = FPNBlock(pyramid_channels, encoder_channels[3])
|
| 98 |
|
| 99 |
-
self.seg_blocks = nn.ModuleList(
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
self.merge = MergeBlock(merge_policy)
|
| 105 |
self.dropout = nn.Dropout2d(p=dropout, inplace=True)
|
|
@@ -112,7 +124,9 @@ class FPNDecoder(nn.Module):
|
|
| 112 |
p3 = self.p3(p4, c3)
|
| 113 |
p2 = self.p2(p3, c2)
|
| 114 |
|
| 115 |
-
feature_pyramid = [
|
|
|
|
|
|
|
| 116 |
x = self.merge(feature_pyramid)
|
| 117 |
x = self.dropout(x)
|
| 118 |
|
|
|
|
| 55 |
super().__init__()
|
| 56 |
if policy not in ["add", "cat"]:
|
| 57 |
raise ValueError(
|
| 58 |
+
"`merge_policy` must be one of: ['add', 'cat'], got {}".format(policy)
|
|
|
|
|
|
|
| 59 |
)
|
| 60 |
self.policy = policy
|
| 61 |
|
| 62 |
def forward(self, x):
|
| 63 |
+
if self.policy == "add":
|
| 64 |
return sum(x)
|
| 65 |
+
elif self.policy == "cat":
|
| 66 |
return torch.cat(x, dim=1)
|
| 67 |
else:
|
| 68 |
raise ValueError(
|
| 69 |
+
"`merge_policy` must be one of: ['add', 'cat'], got {}".format(
|
| 70 |
+
self.policy
|
| 71 |
+
)
|
| 72 |
)
|
| 73 |
|
| 74 |
|
| 75 |
class FPNDecoder(nn.Module):
|
| 76 |
def __init__(
|
| 77 |
+
self,
|
| 78 |
+
encoder_channels,
|
| 79 |
+
encoder_depth=5,
|
| 80 |
+
pyramid_channels=256,
|
| 81 |
+
segmentation_channels=128,
|
| 82 |
+
dropout=0.2,
|
| 83 |
+
merge_policy="add",
|
| 84 |
):
|
| 85 |
super().__init__()
|
| 86 |
|
| 87 |
+
self.out_channels = (
|
| 88 |
+
segmentation_channels
|
| 89 |
+
if merge_policy == "add"
|
| 90 |
+
else segmentation_channels * 4
|
| 91 |
+
)
|
| 92 |
if encoder_depth < 3:
|
| 93 |
+
raise ValueError(
|
| 94 |
+
"Encoder depth for FPN decoder cannot be less than 3, got {}.".format(
|
| 95 |
+
encoder_depth
|
| 96 |
+
)
|
| 97 |
+
)
|
| 98 |
|
| 99 |
encoder_channels = encoder_channels[::-1]
|
| 100 |
+
encoder_channels = encoder_channels[: encoder_depth + 1]
|
| 101 |
|
| 102 |
self.p5 = nn.Conv2d(encoder_channels[0], pyramid_channels, kernel_size=1)
|
| 103 |
self.p4 = FPNBlock(pyramid_channels, encoder_channels[1])
|
| 104 |
self.p3 = FPNBlock(pyramid_channels, encoder_channels[2])
|
| 105 |
self.p2 = FPNBlock(pyramid_channels, encoder_channels[3])
|
| 106 |
|
| 107 |
+
self.seg_blocks = nn.ModuleList(
|
| 108 |
+
[
|
| 109 |
+
SegmentationBlock(
|
| 110 |
+
pyramid_channels, segmentation_channels, n_upsamples=n_upsamples
|
| 111 |
+
)
|
| 112 |
+
for n_upsamples in [3, 2, 1, 0]
|
| 113 |
+
]
|
| 114 |
+
)
|
| 115 |
|
| 116 |
self.merge = MergeBlock(merge_policy)
|
| 117 |
self.dropout = nn.Dropout2d(p=dropout, inplace=True)
|
|
|
|
| 124 |
p3 = self.p3(p4, c3)
|
| 125 |
p2 = self.p2(p3, c2)
|
| 126 |
|
| 127 |
+
feature_pyramid = [
|
| 128 |
+
seg_block(p) for seg_block, p in zip(self.seg_blocks, [p5, p4, p3, p2])
|
| 129 |
+
]
|
| 130 |
x = self.merge(feature_pyramid)
|
| 131 |
x = self.dropout(x)
|
| 132 |
|
segmentation_models_pytorch/segmentation_models_pytorch/fpn/model.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
from typing import Optional, Union
|
| 2 |
-
|
| 3 |
-
from ..base import
|
| 4 |
from ..encoders import get_encoder
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
class FPN(SegmentationModel):
|
|
@@ -10,11 +11,11 @@ class FPN(SegmentationModel):
|
|
| 10 |
Args:
|
| 11 |
encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
|
| 12 |
to extract features of different spatial resolution
|
| 13 |
-
encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
|
| 14 |
two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
|
| 15 |
with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
|
| 16 |
Default is 5
|
| 17 |
-
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
|
| 18 |
other pretrained weights (see table with available weights for each encoder_name)
|
| 19 |
decoder_pyramid_channels: A number of convolution filters in Feature Pyramid of FPN_
|
| 20 |
decoder_segmentation_channels: A number of convolution filters in segmentation blocks of FPN_
|
|
@@ -26,7 +27,7 @@ class FPN(SegmentationModel):
|
|
| 26 |
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, **callable** and **None**.
|
| 27 |
Default is **None**
|
| 28 |
upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity
|
| 29 |
-
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
|
| 30 |
on top of encoder if **aux_params** is not **None** (default). Supported params:
|
| 31 |
- classes (int): A number of classes
|
| 32 |
- pooling (str): One of "max", "avg". Default is "avg"
|
|
|
|
| 1 |
from typing import Optional, Union
|
| 2 |
+
|
| 3 |
+
from ..base import ClassificationHead, SegmentationHead, SegmentationModel
|
| 4 |
from ..encoders import get_encoder
|
| 5 |
+
from .decoder import FPNDecoder
|
| 6 |
|
| 7 |
|
| 8 |
class FPN(SegmentationModel):
|
|
|
|
| 11 |
Args:
|
| 12 |
encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
|
| 13 |
to extract features of different spatial resolution
|
| 14 |
+
encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
|
| 15 |
two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
|
| 16 |
with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
|
| 17 |
Default is 5
|
| 18 |
+
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
|
| 19 |
other pretrained weights (see table with available weights for each encoder_name)
|
| 20 |
decoder_pyramid_channels: A number of convolution filters in Feature Pyramid of FPN_
|
| 21 |
decoder_segmentation_channels: A number of convolution filters in segmentation blocks of FPN_
|
|
|
|
| 27 |
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, **callable** and **None**.
|
| 28 |
Default is **None**
|
| 29 |
upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity
|
| 30 |
+
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
|
| 31 |
on top of encoder if **aux_params** is not **None** (default). Supported params:
|
| 32 |
- classes (int): A number of classes
|
| 33 |
- pooling (str): One of "max", "avg". Default is "avg"
|