FatemehT commited on
Commit
8e6512c
·
1 Parent(s): a8ab7ac

style: run pre-commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +9 -0
  2. angioPyFunctions.py +70 -61
  3. normalize_k1.py +3 -1
  4. predict.py +30 -14
  5. segmentation_models_pytorch/.github/workflows/tests.yml +2 -2
  6. segmentation_models_pytorch/.gitignore +1 -1
  7. segmentation_models_pytorch/HALLOFFAME.md +30 -31
  8. segmentation_models_pytorch/README.md +8 -8
  9. segmentation_models_pytorch/__init__.py +1 -1
  10. segmentation_models_pytorch/docker/Dockerfile +1 -1
  11. segmentation_models_pytorch/docs/conf.py +35 -26
  12. segmentation_models_pytorch/docs/insights.rst +8 -8
  13. segmentation_models_pytorch/docs/install.rst +1 -1
  14. segmentation_models_pytorch/docs/losses.rst +1 -1
  15. segmentation_models_pytorch/docs/models.rst +0 -2
  16. segmentation_models_pytorch/docs/quickstart.rst +1 -1
  17. segmentation_models_pytorch/docs/requirements.txt +1 -1
  18. segmentation_models_pytorch/misc/generate_table.py +6 -2
  19. segmentation_models_pytorch/segmentation_models_pytorch/__init__.py +34 -23
  20. segmentation_models_pytorch/segmentation_models_pytorch/__version__.py +1 -1
  21. segmentation_models_pytorch/segmentation_models_pytorch/base/__init__.py +2 -11
  22. segmentation_models_pytorch/segmentation_models_pytorch/base/heads.py +20 -9
  23. segmentation_models_pytorch/segmentation_models_pytorch/base/initialization.py +0 -1
  24. segmentation_models_pytorch/segmentation_models_pytorch/base/model.py +1 -1
  25. segmentation_models_pytorch/segmentation_models_pytorch/base/modules.py +55 -34
  26. segmentation_models_pytorch/segmentation_models_pytorch/deeplabv3/__init__.py +1 -1
  27. segmentation_models_pytorch/segmentation_models_pytorch/deeplabv3/decoder.py +19 -14
  28. segmentation_models_pytorch/segmentation_models_pytorch/deeplabv3/model.py +40 -45
  29. segmentation_models_pytorch/segmentation_models_pytorch/efficientunetplusplus/decoder.py +96 -61
  30. segmentation_models_pytorch/segmentation_models_pytorch/efficientunetplusplus/model.py +21 -19
  31. segmentation_models_pytorch/segmentation_models_pytorch/encoders/__init__.py +23 -14
  32. segmentation_models_pytorch/segmentation_models_pytorch/encoders/_base.py +5 -4
  33. segmentation_models_pytorch/segmentation_models_pytorch/encoders/_preprocessing.py +0 -1
  34. segmentation_models_pytorch/segmentation_models_pytorch/encoders/densenet.py +20 -11
  35. segmentation_models_pytorch/segmentation_models_pytorch/encoders/dpn.py +7 -6
  36. segmentation_models_pytorch/segmentation_models_pytorch/encoders/efficientnet.py +8 -10
  37. segmentation_models_pytorch/segmentation_models_pytorch/encoders/inceptionresnetv2.py +8 -5
  38. segmentation_models_pytorch/segmentation_models_pytorch/encoders/inceptionv4.py +13 -9
  39. segmentation_models_pytorch/segmentation_models_pytorch/encoders/mobilenet.py +1 -2
  40. segmentation_models_pytorch/segmentation_models_pytorch/encoders/resnet.py +10 -13
  41. segmentation_models_pytorch/segmentation_models_pytorch/encoders/senet.py +2 -2
  42. segmentation_models_pytorch/segmentation_models_pytorch/encoders/timm_regnet.py +189 -178
  43. segmentation_models_pytorch/segmentation_models_pytorch/encoders/timm_res2net.py +82 -81
  44. segmentation_models_pytorch/segmentation_models_pytorch/encoders/timm_resnest.py +130 -129
  45. segmentation_models_pytorch/segmentation_models_pytorch/encoders/timm_sknet.py +57 -44
  46. segmentation_models_pytorch/segmentation_models_pytorch/encoders/vgg.py +5 -4
  47. segmentation_models_pytorch/segmentation_models_pytorch/encoders/xception.py +31 -18
  48. segmentation_models_pytorch/segmentation_models_pytorch/fpn/__init__.py +1 -1
  49. segmentation_models_pytorch/segmentation_models_pytorch/fpn/decoder.py +35 -21
  50. segmentation_models_pytorch/segmentation_models_pytorch/fpn/model.py +6 -5
README.md CHANGED
@@ -24,3 +24,12 @@ This software allows single arteries to be segmented given a few clicks on a sin
24
  ...a website should pop up in your browser!
25
 
26
  You need to create a /Dicom folder and put some angiography DICOMs in there
 
 
 
 
 
 
 
 
 
 
24
  ...a website should pop up in your browser!
25
 
26
  You need to create a /Dicom folder and put some angiography DICOMs in there
27
+
28
+
29
+ # How to run the project
30
+ ## Create virtual environment and activate it
31
+ ```bash
32
+ uv venv
33
+ source .venv/bin/activate
34
+ uv pip install -r requirements.txt
35
+ ```
angioPyFunctions.py CHANGED
@@ -1,90 +1,98 @@
1
  import os
 
2
  os.environ.setdefault("ASTROPY_SKIP_CONFIG_UPDATE", "1")
3
 
 
4
  import numpy
5
  import scipy.interpolate
6
- import skimage.filters
7
- import skimage.morphology
8
  import scipy.ndimage
9
  import scipy.optimize
10
- import predict
 
11
  from PIL import Image
12
- import astropy.config.configuration as _astro_config
 
13
 
14
  if not hasattr(_astro_config, "update_default_config"):
 
15
  def _noop_update_default_config(*args, **kwargs):
16
  return None
 
17
  _astro_config.update_default_config = _noop_update_default_config
18
 
19
- from fil_finder import FilFinder2D
20
  import astropy.units as u
21
- from tqdm import tqdm
22
- import pooch
23
- import utils.dataset
24
  import cv2
 
 
 
25
 
 
26
 
27
  colourTableHex = {
28
- 'LAD': "#f03b20",
29
- 'D': "#fd8d3c",
30
- 'CX': "#31a354",
31
- 'OM': "#74c476",
32
- 'RCA': "#08519c",
33
- 'AM': "#3182bd",
34
- 'LM': "#984ea3",
35
- }
36
 
37
  colourTableList = {}
38
 
39
  for item in colourTableHex.keys():
40
  ### WARNING HACK: The colours go in backwards here for some reason perhaps related to RGBA?
41
- colourTableList[item] = [int(colourTableHex[item][5:7], 16),
42
- int(colourTableHex[item][3:5], 16),
43
- int(colourTableHex[item][1:3], 16)]
 
 
44
 
45
 
46
  def skeletonise(maskArray):
47
  # if len(maskArray.shape) == 3:
48
  maskArray = cv2.cvtColor(maskArray, cv2.COLOR_BGR2GRAY)
49
 
50
- skeleton = skimage.morphology.skeletonize(maskArray.astype('bool'))
51
 
52
  # Process the skeleton and find the longest path
53
- fil = FilFinder2D(skeleton.astype('uint8'),
54
- distance=250 * u.pc, mask=skeleton, beamwidth=10.0*u.pix)
 
 
 
 
55
  fil.preprocess_image(flatten_percent=85)
56
- fil.create_mask(border_masking=True, verbose=False,
57
- use_existing_mask=True)
58
  fil.medskel(verbose=False)
59
- fil.analyze_skeletons(branch_thresh=400 * u.pix,
60
- skel_thresh=10 * u.pix, prune_criteria='length')
 
61
 
62
  # add image arrays dictionary
63
  # tifffile.imwrite(os.path.join(arteryFolder, "skel.tif"), fil.skeleton.astype('<u1')*255)
64
 
65
- skel = fil.skeleton.astype('<u1')*255
66
 
67
  return skel
68
 
69
 
70
  def skelEndpoints(skel):
71
- #skel[skel!=0] = 1
72
- skel = numpy.uint8(skel>0)
73
 
74
  # Apply the convolution.
75
- kernel = numpy.uint8([[1, 1, 1],
76
- [1, 10, 1],
77
- [1, 1, 1]])
78
  src_depth = -1
79
- filtered = cv2.filter2D(skel,src_depth,kernel)
80
 
81
  # Look through to find the value of 11.
82
  # This returns a mask of the endpoints, but if you
83
  # just want the coordinates, you could simply
84
  # return np.where(filtered==11)
85
  out = numpy.zeros_like(skel)
86
- out[numpy.where(filtered==11)] = 1
87
- endCoords = numpy.where(filtered==11)
88
  endCoords = list(zip(*endCoords))
89
  startPoint = endCoords[0]
90
  endPoint = endCoords[1]
@@ -109,16 +117,15 @@ def skelPointsInOrder(skel, startPoint=None):
109
  skelLength = len(skelPoints)
110
 
111
  # Loop through the skeleton starting with startPoint, deleting the starting point from the skelPoints list, and finding the closest pixel. This is appended to orderedPoints. startPoint now becomes the last point to be appended.
112
- startPointCopy = startPoint # copied as we are going to loop and overwrite, but want to also keep the original startPoint
113
  orderedPoints = []
114
 
115
  while len(skelPoints) > 1:
116
-
117
  skelPoints.remove(startPointCopy)
118
 
119
  # Calculate the point that is closest to the start point
120
- diffs = numpy.abs(numpy.array(skelPoints)-numpy.array(startPointCopy))
121
- dists = numpy.sum(diffs,axis=1) #l1-distance
122
  closest_point_index = numpy.argmin(dists)
123
  closestPoint = skelPoints[closest_point_index]
124
  orderedPoints.append(closestPoint)
@@ -145,7 +152,7 @@ def skelSplinerWithThickness(skel, EDT, smoothing=50, order=3, decimation=2):
145
  x = x[::decimation]
146
  y = y[::decimation]
147
 
148
- #NOTE: Should the EDT be median filtered? I wonder in fact if doing so will reduce the accuracy of the model.
149
  # EDT = skimage.filters.median(EDT)
150
 
151
  t = EDT[y, x]
@@ -156,8 +163,7 @@ def skelSplinerWithThickness(skel, EDT, smoothing=50, order=3, decimation=2):
156
 
157
  print(x.shape, y.shape, t.shape)
158
 
159
- tcko, uo = scipy.interpolate.splprep(
160
- [y, x, t], s=smoothing, k=order, per=False)
161
 
162
  return tcko
163
 
@@ -192,8 +198,12 @@ def arterySegmentation(inputImage, groundTruthPoints, segmentationModelWeights=N
192
  )
193
 
194
  if inputImage.shape[0] != 512 and inputImage.shape[1] != 512:
195
- ratioYX = numpy.array([512./inputImage.shape[0], 512./inputImage.shape[1]])
196
- print(f"arterySegmentation(): Rescaling image to 512x512 by {ratioYX=}, and also applying this to input points")
 
 
 
 
197
  inputImage = scipy.ndimage.zoom(inputImage, ratioYX)
198
  points = groundTruthPoints.copy() * ratioYX
199
  print(inputImage.shape)
@@ -202,33 +212,32 @@ def arterySegmentation(inputImage, groundTruthPoints, segmentationModelWeights=N
202
 
203
  imageSize = inputImage.shape
204
 
205
- n_classes = 2 # binary output
206
 
207
  net = predict.smp.Unet(
208
- encoder_name='inceptionresnetv2',
209
  encoder_weights="imagenet",
210
  in_channels=3,
211
- classes=n_classes
212
  )
213
 
214
  net = predict.nn.DataParallel(net)
215
 
216
- device = predict.torch.device('cuda' if predict.torch.cuda.is_available() else 'cpu')
 
 
217
  net.to(device=device)
218
 
219
  net.load_state_dict(
220
- predict.torch.load(
221
- segmentationModelWeights,
222
- map_location=device
223
- )
224
  )
225
 
226
  orig_image = Image.fromarray(inputImage)
227
 
228
- image = predict.Image.new('RGB', imageSize, (0, 0, 0))
229
  image.paste(orig_image, (0, 0))
230
 
231
- imageArray = numpy.array(image).astype('uint8')
232
 
233
  # Clear last channels
234
  imageArray[:, :, -1] = 0
@@ -242,13 +251,13 @@ def arterySegmentation(inputImage, groundTruthPoints, segmentationModelWeights=N
242
  for y, x in [startPoint, endPoint]:
243
  y = int(numpy.round(y))
244
  x = int(numpy.round(x))
245
- imageArray[y-2:y+2, x-2:x+2, 1] = 255
246
 
247
  # All other points on Channel 2
248
  for y, x in points[1:-1]:
249
  y = int(numpy.round(y))
250
  x = int(numpy.round(x))
251
- imageArray[y-2:y+ 2, x-2:x+2, 2] = 255
252
 
253
  image = Image.fromarray(imageArray.astype(numpy.uint8))
254
 
@@ -257,19 +266,19 @@ def arterySegmentation(inputImage, groundTruthPoints, segmentationModelWeights=N
257
  dataset_class=utils.dataset.CoronaryDataset,
258
  full_img=image,
259
  scale_factor=1,
260
- device=device
261
  )
262
 
263
  return mask
264
 
265
 
266
-
267
  def maskOutliner(labelledArtery, outlineThickness=3):
268
-
269
  # Compute the boundary of the mask
270
- contours, _ = cv2.findContours(labelledArtery, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
 
 
271
  tmp = numpy.zeros_like(labelledArtery)
272
- boundary = cv2.drawContours(tmp, contours, -1, (255,255,255), outlineThickness)
273
  boundary = boundary > 0
274
 
275
  return boundary
 
1
  import os
2
+
3
  os.environ.setdefault("ASTROPY_SKIP_CONFIG_UPDATE", "1")
4
 
5
+ import astropy.config.configuration as _astro_config
6
  import numpy
7
  import scipy.interpolate
 
 
8
  import scipy.ndimage
9
  import scipy.optimize
10
+ import skimage.filters
11
+ import skimage.morphology
12
  from PIL import Image
13
+
14
+ import predict
15
 
16
  if not hasattr(_astro_config, "update_default_config"):
17
+
18
  def _noop_update_default_config(*args, **kwargs):
19
  return None
20
+
21
  _astro_config.update_default_config = _noop_update_default_config
22
 
 
23
  import astropy.units as u
 
 
 
24
  import cv2
25
+ import pooch
26
+ from fil_finder import FilFinder2D
27
+ from tqdm import tqdm
28
 
29
+ import utils.dataset
30
 
31
  colourTableHex = {
32
+ "LAD": "#f03b20",
33
+ "D": "#fd8d3c",
34
+ "CX": "#31a354",
35
+ "OM": "#74c476",
36
+ "RCA": "#08519c",
37
+ "AM": "#3182bd",
38
+ "LM": "#984ea3",
39
+ }
40
 
41
  colourTableList = {}
42
 
43
  for item in colourTableHex.keys():
44
  ### WARNING HACK: The colours go in backwards here for some reason perhaps related to RGBA?
45
+ colourTableList[item] = [
46
+ int(colourTableHex[item][5:7], 16),
47
+ int(colourTableHex[item][3:5], 16),
48
+ int(colourTableHex[item][1:3], 16),
49
+ ]
50
 
51
 
52
  def skeletonise(maskArray):
53
  # if len(maskArray.shape) == 3:
54
  maskArray = cv2.cvtColor(maskArray, cv2.COLOR_BGR2GRAY)
55
 
56
+ skeleton = skimage.morphology.skeletonize(maskArray.astype("bool"))
57
 
58
  # Process the skeleton and find the longest path
59
+ fil = FilFinder2D(
60
+ skeleton.astype("uint8"),
61
+ distance=250 * u.pc,
62
+ mask=skeleton,
63
+ beamwidth=10.0 * u.pix,
64
+ )
65
  fil.preprocess_image(flatten_percent=85)
66
+ fil.create_mask(border_masking=True, verbose=False, use_existing_mask=True)
 
67
  fil.medskel(verbose=False)
68
+ fil.analyze_skeletons(
69
+ branch_thresh=400 * u.pix, skel_thresh=10 * u.pix, prune_criteria="length"
70
+ )
71
 
72
  # add image arrays dictionary
73
  # tifffile.imwrite(os.path.join(arteryFolder, "skel.tif"), fil.skeleton.astype('<u1')*255)
74
 
75
+ skel = fil.skeleton.astype("<u1") * 255
76
 
77
  return skel
78
 
79
 
80
  def skelEndpoints(skel):
81
+ # skel[skel!=0] = 1
82
+ skel = numpy.uint8(skel > 0)
83
 
84
  # Apply the convolution.
85
+ kernel = numpy.uint8([[1, 1, 1], [1, 10, 1], [1, 1, 1]])
 
 
86
  src_depth = -1
87
+ filtered = cv2.filter2D(skel, src_depth, kernel)
88
 
89
  # Look through to find the value of 11.
90
  # This returns a mask of the endpoints, but if you
91
  # just want the coordinates, you could simply
92
  # return np.where(filtered==11)
93
  out = numpy.zeros_like(skel)
94
+ out[numpy.where(filtered == 11)] = 1
95
+ endCoords = numpy.where(filtered == 11)
96
  endCoords = list(zip(*endCoords))
97
  startPoint = endCoords[0]
98
  endPoint = endCoords[1]
 
117
  skelLength = len(skelPoints)
118
 
119
  # Loop through the skeleton starting with startPoint, deleting the starting point from the skelPoints list, and finding the closest pixel. This is appended to orderedPoints. startPoint now becomes the last point to be appended.
120
+ startPointCopy = startPoint # copied as we are going to loop and overwrite, but want to also keep the original startPoint
121
  orderedPoints = []
122
 
123
  while len(skelPoints) > 1:
 
124
  skelPoints.remove(startPointCopy)
125
 
126
  # Calculate the point that is closest to the start point
127
+ diffs = numpy.abs(numpy.array(skelPoints) - numpy.array(startPointCopy))
128
+ dists = numpy.sum(diffs, axis=1) # l1-distance
129
  closest_point_index = numpy.argmin(dists)
130
  closestPoint = skelPoints[closest_point_index]
131
  orderedPoints.append(closestPoint)
 
152
  x = x[::decimation]
153
  y = y[::decimation]
154
 
155
+ # NOTE: Should the EDT be median filtered? I wonder in fact if doing so will reduce the accuracy of the model.
156
  # EDT = skimage.filters.median(EDT)
157
 
158
  t = EDT[y, x]
 
163
 
164
  print(x.shape, y.shape, t.shape)
165
 
166
+ tcko, uo = scipy.interpolate.splprep([y, x, t], s=smoothing, k=order, per=False)
 
167
 
168
  return tcko
169
 
 
198
  )
199
 
200
  if inputImage.shape[0] != 512 and inputImage.shape[1] != 512:
201
+ ratioYX = numpy.array(
202
+ [512.0 / inputImage.shape[0], 512.0 / inputImage.shape[1]]
203
+ )
204
+ print(
205
+ f"arterySegmentation(): Rescaling image to 512x512 by {ratioYX=}, and also applying this to input points"
206
+ )
207
  inputImage = scipy.ndimage.zoom(inputImage, ratioYX)
208
  points = groundTruthPoints.copy() * ratioYX
209
  print(inputImage.shape)
 
212
 
213
  imageSize = inputImage.shape
214
 
215
+ n_classes = 2 # binary output
216
 
217
  net = predict.smp.Unet(
218
+ encoder_name="inceptionresnetv2",
219
  encoder_weights="imagenet",
220
  in_channels=3,
221
+ classes=n_classes,
222
  )
223
 
224
  net = predict.nn.DataParallel(net)
225
 
226
+ device = predict.torch.device(
227
+ "cuda" if predict.torch.cuda.is_available() else "cpu"
228
+ )
229
  net.to(device=device)
230
 
231
  net.load_state_dict(
232
+ predict.torch.load(segmentationModelWeights, map_location=device)
 
 
 
233
  )
234
 
235
  orig_image = Image.fromarray(inputImage)
236
 
237
+ image = predict.Image.new("RGB", imageSize, (0, 0, 0))
238
  image.paste(orig_image, (0, 0))
239
 
240
+ imageArray = numpy.array(image).astype("uint8")
241
 
242
  # Clear last channels
243
  imageArray[:, :, -1] = 0
 
251
  for y, x in [startPoint, endPoint]:
252
  y = int(numpy.round(y))
253
  x = int(numpy.round(x))
254
+ imageArray[y - 2 : y + 2, x - 2 : x + 2, 1] = 255
255
 
256
  # All other points on Channel 2
257
  for y, x in points[1:-1]:
258
  y = int(numpy.round(y))
259
  x = int(numpy.round(x))
260
+ imageArray[y - 2 : y + 2, x - 2 : x + 2, 2] = 255
261
 
262
  image = Image.fromarray(imageArray.astype(numpy.uint8))
263
 
 
266
  dataset_class=utils.dataset.CoronaryDataset,
267
  full_img=image,
268
  scale_factor=1,
269
+ device=device,
270
  )
271
 
272
  return mask
273
 
274
 
 
275
  def maskOutliner(labelledArtery, outlineThickness=3):
 
276
  # Compute the boundary of the mask
277
+ contours, _ = cv2.findContours(
278
+ labelledArtery, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
279
+ )
280
  tmp = numpy.zeros_like(labelledArtery)
281
+ boundary = cv2.drawContours(tmp, contours, -1, (255, 255, 255), outlineThickness)
282
  boundary = boundary > 0
283
 
284
  return boundary
normalize_k1.py CHANGED
@@ -17,7 +17,9 @@ def normalize_image(
17
  img = img.resize(target_size, Image.Resampling.BICUBIC)
18
 
19
  arr = np.array(img, dtype=np.float32)
20
- arr = exposure.rescale_intensity(arr, in_range="image", out_range=(png_low, png_high))
 
 
21
  arr = np.clip(arr, png_low, png_high)
22
  arr = ((arr - png_low) / (png_high - png_low) * 255.0).astype(np.uint8)
23
 
 
17
  img = img.resize(target_size, Image.Resampling.BICUBIC)
18
 
19
  arr = np.array(img, dtype=np.float32)
20
+ arr = exposure.rescale_intensity(
21
+ arr, in_range="image", out_range=(png_low, png_high)
22
+ )
23
  arr = np.clip(arr, png_low, png_high)
24
  arr = ((arr - png_low) / (png_high - png_low) * 255.0).astype(np.uint8)
25
 
predict.py CHANGED
@@ -5,17 +5,17 @@ import os
5
  import torch
6
  import torch.nn as nn
7
  from PIL import Image
 
8
  from torchvision import transforms
9
 
10
- from utils.dataset import CoronaryDataset
11
  import segmentation_models_pytorch.segmentation_models_pytorch as smp
 
12
 
13
- from torch.backends import cudnn
14
-
15
- '''
16
  This uses a pytorch coronary segmentation model (EfficientNetPLusPlus) that has been trained using a freely available dataset of labelled coronary angiograms from: http://personal.cimat.mx:8181/~ivan.cruz/DB_Angiograms.html
17
- The input is a raw angiogram image, and the output is a segmentation mask of all the arteries. This output will be used as the 'first guess' to speed up artery annotation.
18
- '''
 
19
 
20
  def predict_img(net, dataset_class, full_img, device, scale_factor=1, n_classes=3):
21
  # NOTE n_classes is the number of possible values that can be predicted for a given pixel. In a standard binary segmentation task, this will be 2 i.e. black or white
@@ -41,11 +41,11 @@ def predict_img(net, dataset_class, full_img, device, scale_factor=1, n_classes=
41
  [
42
  transforms.ToPILImage(),
43
  transforms.Resize(full_img.size[1]),
44
- transforms.ToTensor()
45
  ]
46
  )
47
 
48
- full_mask = tf(probs.cpu())
49
 
50
  if n_classes > 1:
51
  return dataset_class.one_hot2mask(full_mask)
@@ -54,12 +54,28 @@ def predict_img(net, dataset_class, full_img, device, scale_factor=1, n_classes=
54
 
55
 
56
  def get_args():
57
- parser = argparse.ArgumentParser(description='Predict masks from input images', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
 
 
 
58
  # parser.add_argument('-d', '--dataset', type=str, help='Specifies the dataset to be used', dest='dataset', required=True)
59
- parser.add_argument('--model', '-m', default='MODEL.pth', metavar='FILE', help="Specify the file in which the model is stored")
60
- parser.add_argument('--input', '-i', metavar='INPUT', nargs='+', help='filenames of input images', required=True)
61
- parser.add_argument('--output', '-o', metavar='INPUT', nargs='+', help='Filenames of output images')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  return parser.parse_args()
64
-
65
-
 
5
  import torch
6
  import torch.nn as nn
7
  from PIL import Image
8
+ from torch.backends import cudnn
9
  from torchvision import transforms
10
 
 
11
  import segmentation_models_pytorch.segmentation_models_pytorch as smp
12
+ from utils.dataset import CoronaryDataset
13
 
14
+ """
 
 
15
  This uses a pytorch coronary segmentation model (EfficientNetPLusPlus) that has been trained using a freely available dataset of labelled coronary angiograms from: http://personal.cimat.mx:8181/~ivan.cruz/DB_Angiograms.html
16
+ The input is a raw angiogram image, and the output is a segmentation mask of all the arteries. This output will be used as the 'first guess' to speed up artery annotation.
17
+ """
18
+
19
 
20
  def predict_img(net, dataset_class, full_img, device, scale_factor=1, n_classes=3):
21
  # NOTE n_classes is the number of possible values that can be predicted for a given pixel. In a standard binary segmentation task, this will be 2 i.e. black or white
 
41
  [
42
  transforms.ToPILImage(),
43
  transforms.Resize(full_img.size[1]),
44
+ transforms.ToTensor(),
45
  ]
46
  )
47
 
48
+ full_mask = tf(probs.cpu())
49
 
50
  if n_classes > 1:
51
  return dataset_class.one_hot2mask(full_mask)
 
54
 
55
 
56
  def get_args():
57
+ parser = argparse.ArgumentParser(
58
+ description="Predict masks from input images",
59
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
60
+ )
61
  # parser.add_argument('-d', '--dataset', type=str, help='Specifies the dataset to be used', dest='dataset', required=True)
62
+ parser.add_argument(
63
+ "--model",
64
+ "-m",
65
+ default="MODEL.pth",
66
+ metavar="FILE",
67
+ help="Specify the file in which the model is stored",
68
+ )
69
+ parser.add_argument(
70
+ "--input",
71
+ "-i",
72
+ metavar="INPUT",
73
+ nargs="+",
74
+ help="filenames of input images",
75
+ required=True,
76
+ )
77
+ parser.add_argument(
78
+ "--output", "-o", metavar="INPUT", nargs="+", help="Filenames of output images"
79
+ )
80
 
81
  return parser.parse_args()
 
 
segmentation_models_pytorch/.github/workflows/tests.yml CHANGED
@@ -17,12 +17,12 @@ jobs:
17
 
18
  steps:
19
  - uses: actions/checkout@v2
20
-
21
  - name: Set up Python ${{ matrix.python-version }}
22
  uses: actions/setup-python@v2
23
  with:
24
  python-version: 3.6
25
-
26
  - name: Install dependencies
27
  run: |
28
  python -m pip install --upgrade pip
 
17
 
18
  steps:
19
  - uses: actions/checkout@v2
20
+
21
  - name: Set up Python ${{ matrix.python-version }}
22
  uses: actions/setup-python@v2
23
  with:
24
  python-version: 3.6
25
+
26
  - name: Install dependencies
27
  run: |
28
  python -m pip install --upgrade pip
segmentation_models_pytorch/.gitignore CHANGED
@@ -102,4 +102,4 @@ venv.bak/
102
  /site
103
 
104
  # mypy
105
- .mypy_cache/
 
102
  /site
103
 
104
  # mypy
105
+ .mypy_cache/
segmentation_models_pytorch/HALLOFFAME.md CHANGED
@@ -5,7 +5,7 @@ Here you can find competitions, names of the winners and links to their solution
5
 
6
  Please, follow these rules, when adding a solution to the "Hall of Fame":
7
 
8
- 1. Solution should be high rated (e.g. for Kaggle gold or silver medal)
9
  2. There should be a description of the solution (post at the forum / code / blog post / paper / pre-print)
10
 
11
 
@@ -13,78 +13,77 @@ Please, follow these rules, when adding a solution to the "Hall of Fame":
13
 
14
  ### [Severstal: Steel Defect Detection](https://www.kaggle.com/c/severstal-steel-defect-detection)
15
 
16
- - 1st place.
17
- [Wuxi Jiangsu](https://www.kaggle.com/rguo97),
18
- [Hongbo Zhu](https://www.kaggle.com/zhuhongbo),
19
- [Yizhuo Yu](https://www.kaggle.com/paffpaffyu)
20
  [[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114254#latest-675874)]
21
 
22
- - 5th place.
23
- [Guanshuo Xu](https://www.kaggle.com/wowfattie)
24
  [[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/117208#latest-675385)]
25
 
26
- - 9th place.
27
- [Jacek Poplawski](https://www.linkedin.com/in/jacekpoplawski/)
28
  [[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114297#latest-660842)]
29
 
30
  - 10th place.
31
- [Alexey Rozhkov](https://www.linkedin.com/in/alexisrozhkov)
32
  [[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114465#latest-659615)]
33
 
34
- - 12th place.
35
- [Pavel Yakubovskiy](https://www.linkedin.com/in/pavel-yakubovskiy/),
36
- [Ilya Dobrynin](https://www.linkedin.com/in/ilya-dobrynin-79a89b106/),
37
- [Denis Kolpakov](https://www.linkedin.com/in/denis-kolpakov-ab3137197/)
38
  [[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114309#latest-661404)]
39
 
40
- - 31st place.
41
- [Insaf Ashrapov](https://www.linkedin.com/in/iashrapov/),
42
- [Igor Krashenyi](https://www.linkedin.com/in/igor-krashenyi-38b89b98),
43
- [Pavel Pleskov](https://www.linkedin.com/in/ppleskov),
44
- [Anton Zakharenkov](https://www.linkedin.com/in/anton-zakharenkov/),
45
- [Nikolai Popov](https://www.linkedin.com/in/nikolai-popov-b2157370/)
46
  [[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114383#latest-658438)]
47
  [[code](https://github.com/Diyago/Severstal-Steel-Defect-Detection)]
48
 
49
- - 55th place.
50
- [Karl Hornlund](https://www.linkedin.com/in/karl-hornlund/)
51
  [[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114410#latest-672682)]
52
  [[code](https://github.com/khornlund/severstal-steel-defect-detection)]
53
 
54
  - Efficiency round 1st place.
55
- [Stefan Stefanov](https://www.linkedin.com/in/stefan-stefanov-63a77b1)
56
  [[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/117486#latest-674229)]
57
 
58
 
59
  ### [Understanding Clouds from Satellite Images](https://www.kaggle.com/c/understanding_cloud_organization)
60
 
61
  - 2nd place.
62
- [Andrey Kiryasov](https://www.kaggle.com/ekydna)
63
  [[description](https://www.kaggle.com/c/understanding_cloud_organization/discussion/118255#latest-678189)]
64
 
65
  - 4th place.
66
- [Ching-Loong Seow](https://www.linkedin.com/in/clseow/)
67
  [[description](https://www.kaggle.com/c/understanding_cloud_organization/discussion/118016#latest-677333)]
68
 
69
  - 34th place.
70
- [Karl Hornlund](https://www.linkedin.com/in/karl-hornlund/)
71
  [[description](https://www.kaggle.com/c/understanding_cloud_organization/discussion/118250#latest-678176)]
72
  [[code](https://github.com/khornlund/understanding-cloud-organization)]
73
 
74
  - 55th place.
75
- [Pavel Yakubovskiy](https://www.linkedin.com/in/pavel-yakubovskiy/)
76
  [[description](https://www.kaggle.com/c/understanding_cloud_organization/discussion/118019#latest-678626)]
77
 
78
  ## Other platforms
79
 
80
- ### [MICCAI 2020 TN-SCUI challenge](https://tn-scui2020.grand-challenge.org/Home/)
81
  - 1st place.
82
- [Mingyu Wang](https://github.com/WAMAWAMA)
83
  [[description](https://github.com/WAMAWAMA/TNSCUI2020-Seg-Rank1st)]
84
  [[code](https://github.com/WAMAWAMA/TNSCUI2020-Seg-Rank1st)]
85
 
86
  ### [Open Cities AI Challenge: Segmenting Buildings for Disaster Resilience](https://www.drivendata.org/competitions/60/building-segmentation-disaster-resilience/)
87
  - 1st place.
88
- [Pavel Yakubovskiy](https://www.linkedin.com/in/pavel-yakubovskiy/).
89
  [[code and description](https://github.com/qubvel/open-cities-challenge)]
90
-
 
5
 
6
  Please, follow these rules, when adding a solution to the "Hall of Fame":
7
 
8
+ 1. Solution should be high rated (e.g. for Kaggle gold or silver medal)
9
  2. There should be a description of the solution (post at the forum / code / blog post / paper / pre-print)
10
 
11
 
 
13
 
14
  ### [Severstal: Steel Defect Detection](https://www.kaggle.com/c/severstal-steel-defect-detection)
15
 
16
+ - 1st place.
17
+ [Wuxi Jiangsu](https://www.kaggle.com/rguo97),
18
+ [Hongbo Zhu](https://www.kaggle.com/zhuhongbo),
19
+ [Yizhuo Yu](https://www.kaggle.com/paffpaffyu)
20
  [[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114254#latest-675874)]
21
 
22
+ - 5th place.
23
+ [Guanshuo Xu](https://www.kaggle.com/wowfattie)
24
  [[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/117208#latest-675385)]
25
 
26
+ - 9th place.
27
+ [Jacek Poplawski](https://www.linkedin.com/in/jacekpoplawski/)
28
  [[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114297#latest-660842)]
29
 
30
  - 10th place.
31
+ [Alexey Rozhkov](https://www.linkedin.com/in/alexisrozhkov)
32
  [[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114465#latest-659615)]
33
 
34
+ - 12th place.
35
+ [Pavel Yakubovskiy](https://www.linkedin.com/in/pavel-yakubovskiy/),
36
+ [Ilya Dobrynin](https://www.linkedin.com/in/ilya-dobrynin-79a89b106/),
37
+ [Denis Kolpakov](https://www.linkedin.com/in/denis-kolpakov-ab3137197/)
38
  [[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114309#latest-661404)]
39
 
40
+ - 31st place.
41
+ [Insaf Ashrapov](https://www.linkedin.com/in/iashrapov/),
42
+ [Igor Krashenyi](https://www.linkedin.com/in/igor-krashenyi-38b89b98),
43
+ [Pavel Pleskov](https://www.linkedin.com/in/ppleskov),
44
+ [Anton Zakharenkov](https://www.linkedin.com/in/anton-zakharenkov/),
45
+ [Nikolai Popov](https://www.linkedin.com/in/nikolai-popov-b2157370/)
46
  [[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114383#latest-658438)]
47
  [[code](https://github.com/Diyago/Severstal-Steel-Defect-Detection)]
48
 
49
+ - 55th place.
50
+ [Karl Hornlund](https://www.linkedin.com/in/karl-hornlund/)
51
  [[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114410#latest-672682)]
52
  [[code](https://github.com/khornlund/severstal-steel-defect-detection)]
53
 
54
  - Efficiency round 1st place.
55
+ [Stefan Stefanov](https://www.linkedin.com/in/stefan-stefanov-63a77b1)
56
  [[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/117486#latest-674229)]
57
 
58
 
59
  ### [Understanding Clouds from Satellite Images](https://www.kaggle.com/c/understanding_cloud_organization)
60
 
61
  - 2nd place.
62
+ [Andrey Kiryasov](https://www.kaggle.com/ekydna)
63
  [[description](https://www.kaggle.com/c/understanding_cloud_organization/discussion/118255#latest-678189)]
64
 
65
  - 4th place.
66
+ [Ching-Loong Seow](https://www.linkedin.com/in/clseow/)
67
  [[description](https://www.kaggle.com/c/understanding_cloud_organization/discussion/118016#latest-677333)]
68
 
69
  - 34th place.
70
+ [Karl Hornlund](https://www.linkedin.com/in/karl-hornlund/)
71
  [[description](https://www.kaggle.com/c/understanding_cloud_organization/discussion/118250#latest-678176)]
72
  [[code](https://github.com/khornlund/understanding-cloud-organization)]
73
 
74
  - 55th place.
75
+ [Pavel Yakubovskiy](https://www.linkedin.com/in/pavel-yakubovskiy/)
76
  [[description](https://www.kaggle.com/c/understanding_cloud_organization/discussion/118019#latest-678626)]
77
 
78
  ## Other platforms
79
 
80
+ ### [MICCAI 2020 TN-SCUI challenge](https://tn-scui2020.grand-challenge.org/Home/)
81
  - 1st place.
82
+ [Mingyu Wang](https://github.com/WAMAWAMA)
83
  [[description](https://github.com/WAMAWAMA/TNSCUI2020-Seg-Rank1st)]
84
  [[code](https://github.com/WAMAWAMA/TNSCUI2020-Seg-Rank1st)]
85
 
86
  ### [Open Cities AI Challenge: Segmenting Buildings for Disaster Resilience](https://www.drivendata.org/competitions/60/building-segmentation-disaster-resilience/)
87
  - 1st place.
88
+ [Pavel Yakubovskiy](https://www.linkedin.com/in/pavel-yakubovskiy/).
89
  [[code and description](https://github.com/qubvel/open-cities-challenge)]
 
segmentation_models_pytorch/README.md CHANGED
@@ -1,8 +1,8 @@
1
  <div align="center">
2
-
3
- ![logo](https://i.ibb.co/dc1XdhT/Segmentation-Models-V2-Side-1-1.png)
4
- **Python library with Neural Networks for Image
5
- Segmentation based on [PyTorch](https://pytorch.org/).**
6
 
7
  [![Documentation Status](https://readthedocs.org/projects/smp/badge/?version=latest)](https://segmentation-models-pytorch.readthedocs.io/en/latest/?badge=latest) <br> [![Generic badge](https://img.shields.io/badge/License-MIT-<COLOR>.svg)](https://shields.io/)
8
 
@@ -14,7 +14,7 @@ The main features of this library are:
14
  - 12 models architectures for binary and multi class segmentation (including legendary Unet)
15
  - 104 available encoders
16
  - All encoders have pre-trained weights for faster and better convergence
17
-
18
  ### [📚 Project Documentation 📚](http://smp.readthedocs.io/)
19
 
20
  Visit [Read The Docs Project Page](https://segmentation-models-pytorch.readthedocs.io/en/latest/) or read following README to know more about Segmentation Models Pytorch (SMP for short) library
@@ -346,11 +346,11 @@ model = smp.FPN('resnet34', in_channels=1)
346
  mask = model(torch.ones([1, 1, 64, 64]))
347
  ```
348
 
349
- ##### Auxiliary classification output
350
- All models support `aux_params` parameters, which is default set to `None`.
351
  If `aux_params = None` then classification auxiliary output is not created, else
352
  model produce not only `mask`, but also `label` output with shape `NC`.
353
- Classification head consists of GlobalPooling->Dropout(optional)->Linear->Activation(optional) layers, which can be
354
  configured by `aux_params` as follows:
355
  ```python
356
  aux_params=dict(
 
1
  <div align="center">
2
+
3
+ ![logo](https://i.ibb.co/dc1XdhT/Segmentation-Models-V2-Side-1-1.png)
4
+ **Python library with Neural Networks for Image
5
+ Segmentation based on [PyTorch](https://pytorch.org/).**
6
 
7
  [![Documentation Status](https://readthedocs.org/projects/smp/badge/?version=latest)](https://segmentation-models-pytorch.readthedocs.io/en/latest/?badge=latest) <br> [![Generic badge](https://img.shields.io/badge/License-MIT-<COLOR>.svg)](https://shields.io/)
8
 
 
14
  - 12 models architectures for binary and multi class segmentation (including legendary Unet)
15
  - 104 available encoders
16
  - All encoders have pre-trained weights for faster and better convergence
17
+
18
  ### [📚 Project Documentation 📚](http://smp.readthedocs.io/)
19
 
20
  Visit [Read The Docs Project Page](https://segmentation-models-pytorch.readthedocs.io/en/latest/) or read following README to know more about Segmentation Models Pytorch (SMP for short) library
 
346
  mask = model(torch.ones([1, 1, 64, 64]))
347
  ```
348
 
349
+ ##### Auxiliary classification output
350
+ All models support `aux_params` parameters, which is default set to `None`.
351
  If `aux_params = None` then classification auxiliary output is not created, else
352
  model produce not only `mask`, but also `label` output with shape `NC`.
353
+ Classification head consists of GlobalPooling->Dropout(optional)->Linear->Activation(optional) layers, which can be
354
  configured by `aux_params` as follows:
355
  ```python
356
  aux_params=dict(
segmentation_models_pytorch/__init__.py CHANGED
@@ -1 +1 @@
1
- from segmentation_models_pytorch import *
 
1
+ from segmentation_models_pytorch import *
segmentation_models_pytorch/docker/Dockerfile CHANGED
@@ -1,3 +1,3 @@
1
  FROM anibali/pytorch:cuda-9.0
2
 
3
- RUN pip install segmentation-models-pytorch
 
1
  FROM anibali/pytorch:cuda-9.0
2
 
3
+ RUN pip install segmentation-models-pytorch
segmentation_models_pytorch/docs/conf.py CHANGED
@@ -14,24 +14,28 @@
14
  # import sys
15
  # sys.path.insert(0, os.path.abspath('.'))
16
 
 
17
  import os
18
  import re
19
  import sys
20
- import datetime
21
- sys.path.append('..')
22
 
23
  # -- Project information -----------------------------------------------------
24
 
25
- project = 'Segmentation Models'
26
- copyright = '{}, Pavel Yakubovskiy'.format(datetime.datetime.now().year)
27
- author = 'Pavel Yakubovskiy'
 
28
 
29
  def get_version():
30
- sys.path.append('../segmentation_models_pytorch')
31
  from __version__ import __version__ as version
 
32
  sys.path.pop(-1)
33
  return version
34
 
 
35
  version = get_version()
36
 
37
  # -- General configuration ---------------------------------------------------
@@ -41,15 +45,15 @@ version = get_version()
41
  # ones.
42
 
43
  extensions = [
44
- 'sphinx.ext.autodoc',
45
- 'sphinx.ext.coverage',
46
- 'sphinx.ext.napoleon',
47
- 'sphinx.ext.viewcode',
48
- 'sphinx.ext.mathjax',
49
  ]
50
 
51
  # Add any paths that contain templates here, relative to this directory.
52
- templates_path = ['_templates']
53
 
54
  # List of patterns, relative to source directory, that match files and
55
  # directories to ignore when looking for source files.
@@ -64,12 +68,14 @@ exclude_patterns = []
64
  #
65
 
66
  import sphinx_rtd_theme
 
67
  html_theme = "sphinx_rtd_theme"
68
  html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
69
 
70
  # import karma_sphinx_theme
71
  # html_theme = "karma_sphinx_theme"
72
  import faculty_sphinx_theme
 
73
  html_theme = "faculty_sphinx_theme"
74
 
75
  # import catalyst_sphinx_theme
@@ -81,7 +87,7 @@ html_logo = "logo.png"
81
  # Add any paths that contain custom static files (such as style sheets) here,
82
  # relative to this directory. They are copied after the builtin static files,
83
  # so a file named "default.css" will overwrite the builtin "default.css".
84
- html_static_path = ['_static']
85
 
86
  # -- Extension configuration -------------------------------------------------
87
 
@@ -91,30 +97,33 @@ napoleon_include_init_with_doc = True
91
  napoleon_numpy_docstring = False
92
 
93
  autodoc_mock_imports = [
94
- 'torch',
95
- 'tqdm',
96
- 'numpy',
97
- 'timm',
98
- 'pretrainedmodels',
99
- 'torchvision',
100
- 'efficientnet-pytorch',
101
- 'segmentation_models_pytorch.encoders',
102
- 'segmentation_models_pytorch.utils',
103
  # 'segmentation_models_pytorch.base',
104
  ]
105
 
106
- autoclass_content = 'both'
107
- autodoc_typehints = 'description'
108
 
109
  # --- Work around to make autoclass signatures not (*args, **kwargs) ----------
110
 
111
- class FakeSignature():
 
112
  def __getattribute__(self, *args):
113
  raise ValueError
114
 
 
115
  def f(app, obj, bound_method):
116
  if "__new__" in obj.__name__:
117
  obj.__signature__ = FakeSignature()
118
 
 
119
  def setup(app):
120
- app.connect('autodoc-before-process-signature', f)
 
14
  # import sys
15
  # sys.path.insert(0, os.path.abspath('.'))
16
 
17
+ import datetime
18
  import os
19
  import re
20
  import sys
21
+
22
+ sys.path.append("..")
23
 
24
  # -- Project information -----------------------------------------------------
25
 
26
+ project = "Segmentation Models"
27
+ copyright = "{}, Pavel Yakubovskiy".format(datetime.datetime.now().year)
28
+ author = "Pavel Yakubovskiy"
29
+
30
 
31
  def get_version():
32
+ sys.path.append("../segmentation_models_pytorch")
33
  from __version__ import __version__ as version
34
+
35
  sys.path.pop(-1)
36
  return version
37
 
38
+
39
  version = get_version()
40
 
41
  # -- General configuration ---------------------------------------------------
 
45
  # ones.
46
 
47
  extensions = [
48
+ "sphinx.ext.autodoc",
49
+ "sphinx.ext.coverage",
50
+ "sphinx.ext.napoleon",
51
+ "sphinx.ext.viewcode",
52
+ "sphinx.ext.mathjax",
53
  ]
54
 
55
  # Add any paths that contain templates here, relative to this directory.
56
+ templates_path = ["_templates"]
57
 
58
  # List of patterns, relative to source directory, that match files and
59
  # directories to ignore when looking for source files.
 
68
  #
69
 
70
  import sphinx_rtd_theme
71
+
72
  html_theme = "sphinx_rtd_theme"
73
  html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
74
 
75
  # import karma_sphinx_theme
76
  # html_theme = "karma_sphinx_theme"
77
  import faculty_sphinx_theme
78
+
79
  html_theme = "faculty_sphinx_theme"
80
 
81
  # import catalyst_sphinx_theme
 
87
  # Add any paths that contain custom static files (such as style sheets) here,
88
  # relative to this directory. They are copied after the builtin static files,
89
  # so a file named "default.css" will overwrite the builtin "default.css".
90
+ html_static_path = ["_static"]
91
 
92
  # -- Extension configuration -------------------------------------------------
93
 
 
97
  napoleon_numpy_docstring = False
98
 
99
  autodoc_mock_imports = [
100
+ "torch",
101
+ "tqdm",
102
+ "numpy",
103
+ "timm",
104
+ "pretrainedmodels",
105
+ "torchvision",
106
+ "efficientnet-pytorch",
107
+ "segmentation_models_pytorch.encoders",
108
+ "segmentation_models_pytorch.utils",
109
  # 'segmentation_models_pytorch.base',
110
  ]
111
 
112
+ autoclass_content = "both"
113
+ autodoc_typehints = "description"
114
 
115
  # --- Work around to make autoclass signatures not (*args, **kwargs) ----------
116
 
117
+
118
+ class FakeSignature:
119
  def __getattribute__(self, *args):
120
  raise ValueError
121
 
122
+
123
  def f(app, obj, bound_method):
124
  if "__new__" in obj.__name__:
125
  obj.__signature__ = FakeSignature()
126
 
127
+
128
  def setup(app):
129
+ app.connect("autodoc-before-process-signature", f)
segmentation_models_pytorch/docs/insights.rst CHANGED
@@ -21,20 +21,20 @@ Each encoder should have following attributes and methods and be inherited from
21
  .. code-block:: python
22
 
23
  class MyEncoder(torch.nn.Module, EncoderMixin):
24
-
25
  def __init__(self, **kwargs):
26
  super().__init__()
27
-
28
  # A number of channels for each encoder feature tensor, list of integers
29
  self._out_channels: List[int] = [3, 16, 64, 128, 256, 512]
30
 
31
  # A number of stages in decoder (in other words number of downsampling operations), integer
32
  # use in in forward pass to reduce number of returning features
33
- self._depth: int = 5
34
 
35
  # Default number of input channels in first Conv2d layer for encoder (usually 3)
36
- self._in_channels: int = 3
37
-
38
  # Define encoder modules below
39
  ...
40
 
@@ -90,12 +90,12 @@ For better understanding see more examples of encoder in smp.encoders module.
90
  3. Aux classification output
91
  ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
92
 
93
- All models support ``aux_params`` parameter, which is default set to ``None``.
94
  If ``aux_params = None`` than classification auxiliary output is not created, else
95
  model produce not only ``mask``, but also ``label`` output with shape ``(N, C)``.
96
 
97
  Classification head consist of following layers:
98
-
99
  1. GlobalPooling
100
  2. Dropout (optional)
101
  3. Linear
@@ -104,7 +104,7 @@ Classification head consist of following layers:
104
  Example:
105
 
106
  .. code-block:: python
107
-
108
  aux_params=dict(
109
  pooling='avg', # one of 'avg', 'max'
110
  dropout=0.5, # dropout ratio, default is None
 
21
  .. code-block:: python
22
 
23
  class MyEncoder(torch.nn.Module, EncoderMixin):
24
+
25
  def __init__(self, **kwargs):
26
  super().__init__()
27
+
28
  # A number of channels for each encoder feature tensor, list of integers
29
  self._out_channels: List[int] = [3, 16, 64, 128, 256, 512]
30
 
31
  # A number of stages in decoder (in other words number of downsampling operations), integer
32
  # use in in forward pass to reduce number of returning features
33
+ self._depth: int = 5
34
 
35
  # Default number of input channels in first Conv2d layer for encoder (usually 3)
36
+ self._in_channels: int = 3
37
+
38
  # Define encoder modules below
39
  ...
40
 
 
90
  3. Aux classification output
91
  ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
92
 
93
+ All models support ``aux_params`` parameter, which is default set to ``None``.
94
  If ``aux_params = None`` than classification auxiliary output is not created, else
95
  model produce not only ``mask``, but also ``label`` output with shape ``(N, C)``.
96
 
97
  Classification head consist of following layers:
98
+
99
  1. GlobalPooling
100
  2. Dropout (optional)
101
  3. Linear
 
104
  Example:
105
 
106
  .. code-block:: python
107
+
108
  aux_params=dict(
109
  pooling='avg', # one of 'avg', 'max'
110
  dropout=0.5, # dropout ratio, default is None
segmentation_models_pytorch/docs/install.rst CHANGED
@@ -5,4 +5,4 @@ Latest version from source:
5
 
6
  .. code-block:: bash
7
 
8
- $ pip install -U git+https://github.com/jlcsilva/segmentation_models.pytorch
 
5
 
6
  .. code-block:: bash
7
 
8
+ $ pip install -U git+https://github.com/jlcsilva/segmentation_models.pytorch
segmentation_models_pytorch/docs/losses.rst CHANGED
@@ -1,7 +1,7 @@
1
  📉 Losses
2
  =========
3
 
4
- Collection of popular semantic segmentation losses. Adapted from
5
  an awesome repo with pytorch utils https://github.com/BloodAxe/pytorch-toolbelt
6
 
7
  Constants
 
1
  📉 Losses
2
  =========
3
 
4
+ Collection of popular semantic segmentation losses. Adapted from
5
  an awesome repo with pytorch utils https://github.com/BloodAxe/pytorch-toolbelt
6
 
7
  Constants
segmentation_models_pytorch/docs/models.rst CHANGED
@@ -48,5 +48,3 @@ DeepLabV3
48
  DeepLabV3+
49
  ~~~~~~~~~~
50
  .. autoclass:: segmentation_models_pytorch.DeepLabV3Plus
51
-
52
-
 
48
  DeepLabV3+
49
  ~~~~~~~~~~
50
  .. autoclass:: segmentation_models_pytorch.DeepLabV3Plus
 
 
segmentation_models_pytorch/docs/quickstart.rst CHANGED
@@ -6,7 +6,7 @@
6
  Segmentation model is just a PyTorch nn.Module, which can be created as easy as:
7
 
8
  .. code-block:: python
9
-
10
  import segmentation_models_pytorch as smp
11
 
12
  model = smp.Unet(
 
6
  Segmentation model is just a PyTorch nn.Module, which can be created as easy as:
7
 
8
  .. code-block:: python
9
+
10
  import segmentation_models_pytorch as smp
11
 
12
  model = smp.Unet(
segmentation_models_pytorch/docs/requirements.txt CHANGED
@@ -1,2 +1,2 @@
1
  faculty-sphinx-theme==0.2.2
2
- six==1.15.0
 
1
  faculty-sphinx-theme==0.2.2
2
+ six==1.15.0
segmentation_models_pytorch/misc/generate_table.py CHANGED
@@ -10,11 +10,15 @@ COLUMNS = [
10
  "Params, M",
11
  ]
12
 
 
13
  def wrap_row(r):
14
  return "|{}|".format(r)
15
 
16
- header = "|".join([column.ljust(WIDTH, ' ') for column in COLUMNS])
17
- separator = "|".join(["-" * WIDTH] + [":" + "-" * (WIDTH - 2) + ":"] * (len(COLUMNS) - 1))
 
 
 
18
 
19
  print(wrap_row(header))
20
  print(wrap_row(separator))
 
10
  "Params, M",
11
  ]
12
 
13
+
14
  def wrap_row(r):
15
  return "|{}|".format(r)
16
 
17
+
18
+ header = "|".join([column.ljust(WIDTH, " ") for column in COLUMNS])
19
+ separator = "|".join(
20
+ ["-" * WIDTH] + [":" + "-" * (WIDTH - 2) + ":"] * (len(COLUMNS) - 1)
21
+ )
22
 
23
  print(wrap_row(header))
24
  print(wrap_row(separator))
segmentation_models_pytorch/segmentation_models_pytorch/__init__.py CHANGED
@@ -1,23 +1,20 @@
1
- from .unet import Unet
2
- from .unetplusplus import UnetPlusPlus
3
- from .manet import MAnet
4
- from .linknet import Linknet
5
- from .fpn import FPN
6
- from .pspnet import PSPNet
7
  from .deeplabv3 import DeepLabV3, DeepLabV3Plus
 
 
 
 
8
  from .pan import PAN
 
9
  from .resunet import ResUnet
10
  from .resunetplusplus import ResUnetPlusPlus
11
- from .efficientunetplusplus import EfficientUnetPlusPlus
12
-
13
- from . import encoders
14
- from . import utils
15
- from . import losses
16
-
17
- from .__version__ import __version__
18
-
19
- from typing import Optional
20
- import torch
21
 
22
 
23
  def create_model(
@@ -28,18 +25,32 @@ def create_model(
28
  classes: int = 1,
29
  **kwargs,
30
  ) -> torch.nn.Module:
31
- """Models wrapper. Allows to create any model just with parametes
32
-
33
- """
34
 
35
- archs = [Unet, UnetPlusPlus, MAnet, Linknet, FPN, PSPNet, DeepLabV3, DeepLabV3Plus, PAN, ResUnet, EfficientUnetPlusPlus, ResUnetPlusPlus]
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  archs_dict = {a.__name__.lower(): a for a in archs}
37
  try:
38
  model_class = archs_dict[arch.lower()]
39
  except KeyError:
40
- raise KeyError("Wrong architecture type `{}`. Avalibale options are: {}".format(
41
- arch, list(archs_dict.keys()),
42
- ))
 
 
 
43
  return model_class(
44
  encoder_name=encoder_name,
45
  encoder_weights=encoder_weights,
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ from . import encoders, losses, utils
6
+ from .__version__ import __version__
7
  from .deeplabv3 import DeepLabV3, DeepLabV3Plus
8
+ from .efficientunetplusplus import EfficientUnetPlusPlus
9
+ from .fpn import FPN
10
+ from .linknet import Linknet
11
+ from .manet import MAnet
12
  from .pan import PAN
13
+ from .pspnet import PSPNet
14
  from .resunet import ResUnet
15
  from .resunetplusplus import ResUnetPlusPlus
16
+ from .unet import Unet
17
+ from .unetplusplus import UnetPlusPlus
 
 
 
 
 
 
 
 
18
 
19
 
20
  def create_model(
 
25
  classes: int = 1,
26
  **kwargs,
27
  ) -> torch.nn.Module:
28
+ """Models wrapper. Allows to create any model just with parametes"""
 
 
29
 
30
+ archs = [
31
+ Unet,
32
+ UnetPlusPlus,
33
+ MAnet,
34
+ Linknet,
35
+ FPN,
36
+ PSPNet,
37
+ DeepLabV3,
38
+ DeepLabV3Plus,
39
+ PAN,
40
+ ResUnet,
41
+ EfficientUnetPlusPlus,
42
+ ResUnetPlusPlus,
43
+ ]
44
  archs_dict = {a.__name__.lower(): a for a in archs}
45
  try:
46
  model_class = archs_dict[arch.lower()]
47
  except KeyError:
48
+ raise KeyError(
49
+ "Wrong architecture type `{}`. Avalibale options are: {}".format(
50
+ arch,
51
+ list(archs_dict.keys()),
52
+ )
53
+ )
54
  return model_class(
55
  encoder_name=encoder_name,
56
  encoder_weights=encoder_weights,
segmentation_models_pytorch/segmentation_models_pytorch/__version__.py CHANGED
@@ -1,3 +1,3 @@
1
  VERSION = (0, 1, 3)
2
 
3
- __version__ = '.'.join(map(str, VERSION))
 
1
  VERSION = (0, 1, 3)
2
 
3
+ __version__ = ".".join(map(str, VERSION))
segmentation_models_pytorch/segmentation_models_pytorch/base/__init__.py CHANGED
@@ -1,12 +1,3 @@
 
1
  from .model import SegmentationModel
2
-
3
- from .modules import (
4
- PreActivatedConv2dReLU,
5
- Conv2dReLU,
6
- Attention,
7
- )
8
-
9
- from .heads import (
10
- SegmentationHead,
11
- ClassificationHead,
12
- )
 
1
+ from .heads import ClassificationHead, SegmentationHead
2
  from .model import SegmentationModel
3
+ from .modules import Attention, Conv2dReLU, PreActivatedConv2dReLU
 
 
 
 
 
 
 
 
 
 
segmentation_models_pytorch/segmentation_models_pytorch/base/heads.py CHANGED
@@ -1,22 +1,33 @@
1
  import torch.nn as nn
2
- from .modules import Flatten, Activation
3
 
 
4
 
5
- class SegmentationHead(nn.Sequential):
6
 
7
- def __init__(self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1):
8
- conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
9
- upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
 
 
 
 
 
 
 
 
 
10
  activation = Activation(activation)
11
  super().__init__(conv2d, upsampling, activation)
12
 
13
 
14
  class ClassificationHead(nn.Sequential):
15
-
16
- def __init__(self, in_channels, classes, pooling="avg", dropout=0.2, activation=None):
 
17
  if pooling not in ("max", "avg"):
18
- raise ValueError("Pooling should be one of ('max', 'avg'), got {}.".format(pooling))
19
- pool = nn.AdaptiveAvgPool2d(1) if pooling == 'avg' else nn.AdaptiveMaxPool2d(1)
 
 
20
  flatten = Flatten()
21
  dropout = nn.Dropout(p=dropout, inplace=True) if dropout else nn.Identity()
22
  linear = nn.Linear(in_channels, classes, bias=True)
 
1
  import torch.nn as nn
 
2
 
3
+ from .modules import Activation, Flatten
4
 
 
5
 
6
+ class SegmentationHead(nn.Sequential):
7
+ def __init__(
8
+ self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1
9
+ ):
10
+ conv2d = nn.Conv2d(
11
+ in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2
12
+ )
13
+ upsampling = (
14
+ nn.UpsamplingBilinear2d(scale_factor=upsampling)
15
+ if upsampling > 1
16
+ else nn.Identity()
17
+ )
18
  activation = Activation(activation)
19
  super().__init__(conv2d, upsampling, activation)
20
 
21
 
22
  class ClassificationHead(nn.Sequential):
23
+ def __init__(
24
+ self, in_channels, classes, pooling="avg", dropout=0.2, activation=None
25
+ ):
26
  if pooling not in ("max", "avg"):
27
+ raise ValueError(
28
+ "Pooling should be one of ('max', 'avg'), got {}.".format(pooling)
29
+ )
30
+ pool = nn.AdaptiveAvgPool2d(1) if pooling == "avg" else nn.AdaptiveMaxPool2d(1)
31
  flatten = Flatten()
32
  dropout = nn.Dropout(p=dropout, inplace=True) if dropout else nn.Identity()
33
  linear = nn.Linear(in_channels, classes, bias=True)
segmentation_models_pytorch/segmentation_models_pytorch/base/initialization.py CHANGED
@@ -3,7 +3,6 @@ import torch.nn as nn
3
 
4
  def initialize_decoder(module):
5
  for m in module.modules():
6
-
7
  if isinstance(m, nn.Conv2d):
8
  nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu")
9
  if m.bias is not None:
 
3
 
4
  def initialize_decoder(module):
5
  for m in module.modules():
 
6
  if isinstance(m, nn.Conv2d):
7
  nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu")
8
  if m.bias is not None:
segmentation_models_pytorch/segmentation_models_pytorch/base/model.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
 
2
  from . import initialization as init
3
 
4
 
5
  class SegmentationModel(torch.nn.Module):
6
-
7
  def initialize(self):
8
  init.initialize_decoder(self.decoder)
9
  init.initialize_head(self.segmentation_head)
 
1
  import torch
2
+
3
  from . import initialization as init
4
 
5
 
6
  class SegmentationModel(torch.nn.Module):
 
7
  def initialize(self):
8
  init.initialize_decoder(self.decoder)
9
  init.initialize_head(self.segmentation_head)
segmentation_models_pytorch/segmentation_models_pytorch/base/modules.py CHANGED
@@ -6,22 +6,23 @@ try:
6
  except ImportError:
7
  InPlaceABN = None
8
 
 
9
  class PreActivatedConv2dReLU(nn.Sequential):
10
  """
11
- Pre-activated 2D convolution, as proposed in https://arxiv.org/pdf/1603.05027.pdf. Feature maps are processed by a normalization layer,
12
  followed by a ReLU activation and a 3x3 convolution.
13
  normalization
14
  """
 
15
  def __init__(
16
- self,
17
- in_channels,
18
- out_channels,
19
- kernel_size,
20
- padding=0,
21
- stride=1,
22
- use_batchnorm=True,
23
  ):
24
-
25
  if use_batchnorm == "inplace" and InPlaceABN is None:
26
  raise RuntimeError(
27
  "In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
@@ -47,20 +48,21 @@ class PreActivatedConv2dReLU(nn.Sequential):
47
  )
48
  super(PreActivatedConv2dReLU, self).__init__(conv, bn, relu)
49
 
 
50
  class Conv2dReLU(nn.Sequential):
51
  """
52
  Block composed of a 3x3 convolution followed by a normalization layer and ReLU activation.
53
  """
 
54
  def __init__(
55
- self,
56
- in_channels,
57
- out_channels,
58
- kernel_size,
59
- padding=0,
60
- stride=1,
61
- use_batchnorm=True,
62
  ):
63
-
64
  if use_batchnorm == "inplace" and InPlaceABN is None:
65
  raise RuntimeError(
66
  "In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
@@ -87,20 +89,33 @@ class Conv2dReLU(nn.Sequential):
87
 
88
  super(Conv2dReLU, self).__init__(conv, bn, relu)
89
 
 
90
  class DepthWiseConv2d(nn.Conv2d):
91
  "Depth-wise convolution operation"
 
92
  def __init__(self, channels, kernel_size=3, stride=1):
93
- super().__init__(channels, channels, kernel_size, stride=stride, padding=kernel_size//2, groups=channels)
 
 
 
 
 
 
 
 
94
 
95
  class PointWiseConv2d(nn.Conv2d):
96
  "Point-wise (1x1) convolution operation"
 
97
  def __init__(self, in_channels, out_channels):
98
  super().__init__(in_channels, out_channels, kernel_size=1, stride=1)
99
 
 
100
  class SEModule(nn.Module):
101
  """
102
  Spatial squeeze & channel excitation attention module, as proposed in https://arxiv.org/abs/1709.01507.
103
  """
 
104
  def __init__(self, in_channels, reduction=16):
105
  super().__init__()
106
  self.cSE = nn.Sequential(
@@ -114,10 +129,12 @@ class SEModule(nn.Module):
114
  def forward(self, x):
115
  return x * self.cSE(x)
116
 
 
117
  class sSEModule(nn.Module):
118
  """
119
  Channel squeeze & spatial excitation attention module, as proposed in https://arxiv.org/abs/1808.08127.
120
  """
 
121
  def __init__(self, in_channels):
122
  super().__init__()
123
  self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid())
@@ -125,10 +142,12 @@ class sSEModule(nn.Module):
125
  def forward(self, x):
126
  return x * self.sSE(x)
127
 
 
128
  class SCSEModule(nn.Module):
129
  """
130
  Concurrent spatial and channel squeeze & excitation attention module, as proposed in https://arxiv.org/pdf/1803.02579.pdf.
131
  """
 
132
  def __init__(self, in_channels, reduction=16):
133
  super().__init__()
134
  self.cSE = nn.Sequential(
@@ -143,8 +162,8 @@ class SCSEModule(nn.Module):
143
  def forward(self, x):
144
  return x * self.cSE(x) + x * self.sSE(x)
145
 
146
- class ArgMax(nn.Module):
147
 
 
148
  def __init__(self, dim=None):
149
  super().__init__()
150
  self.dim = dim
@@ -154,46 +173,47 @@ class ArgMax(nn.Module):
154
 
155
 
156
  class Activation(nn.Module):
157
-
158
  def __init__(self, name, **params):
159
-
160
  super().__init__()
161
 
162
- if name is None or name == 'identity':
163
  self.activation = nn.Identity(**params)
164
- elif name == 'sigmoid':
165
  self.activation = nn.Sigmoid()
166
- elif name == 'softmax2d':
167
  self.activation = nn.Softmax(dim=1, **params)
168
- elif name == 'softmax':
169
  self.activation = nn.Softmax(**params)
170
- elif name == 'logsoftmax':
171
  self.activation = nn.LogSoftmax(**params)
172
- elif name == 'tanh':
173
  self.activation = nn.Tanh()
174
- elif name == 'argmax':
175
  self.activation = ArgMax(**params)
176
- elif name == 'argmax2d':
177
  self.activation = ArgMax(dim=1, **params)
178
  elif callable(name):
179
  self.activation = name(**params)
180
  else:
181
- raise ValueError('Activation should be callable/sigmoid/softmax/logsoftmax/tanh/None; got {}'.format(name))
 
 
 
 
182
 
183
  def forward(self, x):
184
  return self.activation(x)
185
 
186
 
187
  class Attention(nn.Module):
188
-
189
  def __init__(self, name, **params):
190
  super().__init__()
191
 
192
  if name is None:
193
  self.attention = nn.Identity(**params)
194
- elif name == 'scse':
195
  self.attention = SCSEModule(**params)
196
- elif name == 'se':
197
  self.attention = SEModule(**params)
198
  else:
199
  raise ValueError("Attention {} is not implemented".format(name))
@@ -201,6 +221,7 @@ class Attention(nn.Module):
201
  def forward(self, x):
202
  return self.attention(x)
203
 
 
204
  class Flatten(nn.Module):
205
  def forward(self, x):
206
- return x.view(x.shape[0], -1)
 
6
  except ImportError:
7
  InPlaceABN = None
8
 
9
+
10
  class PreActivatedConv2dReLU(nn.Sequential):
11
  """
12
+ Pre-activated 2D convolution, as proposed in https://arxiv.org/pdf/1603.05027.pdf. Feature maps are processed by a normalization layer,
13
  followed by a ReLU activation and a 3x3 convolution.
14
  normalization
15
  """
16
+
17
  def __init__(
18
+ self,
19
+ in_channels,
20
+ out_channels,
21
+ kernel_size,
22
+ padding=0,
23
+ stride=1,
24
+ use_batchnorm=True,
25
  ):
 
26
  if use_batchnorm == "inplace" and InPlaceABN is None:
27
  raise RuntimeError(
28
  "In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
 
48
  )
49
  super(PreActivatedConv2dReLU, self).__init__(conv, bn, relu)
50
 
51
+
52
  class Conv2dReLU(nn.Sequential):
53
  """
54
  Block composed of a 3x3 convolution followed by a normalization layer and ReLU activation.
55
  """
56
+
57
  def __init__(
58
+ self,
59
+ in_channels,
60
+ out_channels,
61
+ kernel_size,
62
+ padding=0,
63
+ stride=1,
64
+ use_batchnorm=True,
65
  ):
 
66
  if use_batchnorm == "inplace" and InPlaceABN is None:
67
  raise RuntimeError(
68
  "In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
 
89
 
90
  super(Conv2dReLU, self).__init__(conv, bn, relu)
91
 
92
+
93
  class DepthWiseConv2d(nn.Conv2d):
94
  "Depth-wise convolution operation"
95
+
96
  def __init__(self, channels, kernel_size=3, stride=1):
97
+ super().__init__(
98
+ channels,
99
+ channels,
100
+ kernel_size,
101
+ stride=stride,
102
+ padding=kernel_size // 2,
103
+ groups=channels,
104
+ )
105
+
106
 
107
  class PointWiseConv2d(nn.Conv2d):
108
  "Point-wise (1x1) convolution operation"
109
+
110
  def __init__(self, in_channels, out_channels):
111
  super().__init__(in_channels, out_channels, kernel_size=1, stride=1)
112
 
113
+
114
  class SEModule(nn.Module):
115
  """
116
  Spatial squeeze & channel excitation attention module, as proposed in https://arxiv.org/abs/1709.01507.
117
  """
118
+
119
  def __init__(self, in_channels, reduction=16):
120
  super().__init__()
121
  self.cSE = nn.Sequential(
 
129
  def forward(self, x):
130
  return x * self.cSE(x)
131
 
132
+
133
  class sSEModule(nn.Module):
134
  """
135
  Channel squeeze & spatial excitation attention module, as proposed in https://arxiv.org/abs/1808.08127.
136
  """
137
+
138
  def __init__(self, in_channels):
139
  super().__init__()
140
  self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid())
 
142
  def forward(self, x):
143
  return x * self.sSE(x)
144
 
145
+
146
  class SCSEModule(nn.Module):
147
  """
148
  Concurrent spatial and channel squeeze & excitation attention module, as proposed in https://arxiv.org/pdf/1803.02579.pdf.
149
  """
150
+
151
  def __init__(self, in_channels, reduction=16):
152
  super().__init__()
153
  self.cSE = nn.Sequential(
 
162
  def forward(self, x):
163
  return x * self.cSE(x) + x * self.sSE(x)
164
 
 
165
 
166
+ class ArgMax(nn.Module):
167
  def __init__(self, dim=None):
168
  super().__init__()
169
  self.dim = dim
 
173
 
174
 
175
  class Activation(nn.Module):
 
176
  def __init__(self, name, **params):
 
177
  super().__init__()
178
 
179
+ if name is None or name == "identity":
180
  self.activation = nn.Identity(**params)
181
+ elif name == "sigmoid":
182
  self.activation = nn.Sigmoid()
183
+ elif name == "softmax2d":
184
  self.activation = nn.Softmax(dim=1, **params)
185
+ elif name == "softmax":
186
  self.activation = nn.Softmax(**params)
187
+ elif name == "logsoftmax":
188
  self.activation = nn.LogSoftmax(**params)
189
+ elif name == "tanh":
190
  self.activation = nn.Tanh()
191
+ elif name == "argmax":
192
  self.activation = ArgMax(**params)
193
+ elif name == "argmax2d":
194
  self.activation = ArgMax(dim=1, **params)
195
  elif callable(name):
196
  self.activation = name(**params)
197
  else:
198
+ raise ValueError(
199
+ "Activation should be callable/sigmoid/softmax/logsoftmax/tanh/None; got {}".format(
200
+ name
201
+ )
202
+ )
203
 
204
  def forward(self, x):
205
  return self.activation(x)
206
 
207
 
208
  class Attention(nn.Module):
 
209
  def __init__(self, name, **params):
210
  super().__init__()
211
 
212
  if name is None:
213
  self.attention = nn.Identity(**params)
214
+ elif name == "scse":
215
  self.attention = SCSEModule(**params)
216
+ elif name == "se":
217
  self.attention = SEModule(**params)
218
  else:
219
  raise ValueError("Attention {} is not implemented".format(name))
 
221
  def forward(self, x):
222
  return self.attention(x)
223
 
224
+
225
  class Flatten(nn.Module):
226
  def forward(self, x):
227
+ return x.view(x.shape[0], -1)
segmentation_models_pytorch/segmentation_models_pytorch/deeplabv3/__init__.py CHANGED
@@ -1 +1 @@
1
- from .model import DeepLabV3, DeepLabV3Plus
 
1
+ from .model import DeepLabV3, DeepLabV3Plus
segmentation_models_pytorch/segmentation_models_pytorch/deeplabv3/decoder.py CHANGED
@@ -61,14 +61,18 @@ class DeepLabV3PlusDecoder(nn.Module):
61
  ):
62
  super().__init__()
63
  if output_stride not in {8, 16}:
64
- raise ValueError("Output stride should be 8 or 16, got {}.".format(output_stride))
 
 
65
 
66
  self.out_channels = out_channels
67
  self.output_stride = output_stride
68
 
69
  self.aspp = nn.Sequential(
70
  ASPP(encoder_channels[-1], out_channels, atrous_rates, separable=True),
71
- SeparableConv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
 
 
72
  nn.BatchNorm2d(out_channels),
73
  nn.ReLU(),
74
  )
@@ -77,9 +81,11 @@ class DeepLabV3PlusDecoder(nn.Module):
77
  self.up = nn.UpsamplingBilinear2d(scale_factor=scale_factor)
78
 
79
  highres_in_channels = encoder_channels[-4]
80
- highres_out_channels = 48 # proposed by authors of paper
81
  self.block1 = nn.Sequential(
82
- nn.Conv2d(highres_in_channels, highres_out_channels, kernel_size=1, bias=False),
 
 
83
  nn.BatchNorm2d(highres_out_channels),
84
  nn.ReLU(),
85
  )
@@ -149,7 +155,7 @@ class ASPPPooling(nn.Sequential):
149
  size = x.shape[-2:]
150
  for mod in self:
151
  x = mod(x)
152
- return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
153
 
154
 
155
  class ASPP(nn.Module):
@@ -190,16 +196,15 @@ class ASPP(nn.Module):
190
 
191
 
192
  class SeparableConv2d(nn.Sequential):
193
-
194
  def __init__(
195
- self,
196
- in_channels,
197
- out_channels,
198
- kernel_size,
199
- stride=1,
200
- padding=0,
201
- dilation=1,
202
- bias=True,
203
  ):
204
  dephtwise_conv = nn.Conv2d(
205
  in_channels,
 
61
  ):
62
  super().__init__()
63
  if output_stride not in {8, 16}:
64
+ raise ValueError(
65
+ "Output stride should be 8 or 16, got {}.".format(output_stride)
66
+ )
67
 
68
  self.out_channels = out_channels
69
  self.output_stride = output_stride
70
 
71
  self.aspp = nn.Sequential(
72
  ASPP(encoder_channels[-1], out_channels, atrous_rates, separable=True),
73
+ SeparableConv2d(
74
+ out_channels, out_channels, kernel_size=3, padding=1, bias=False
75
+ ),
76
  nn.BatchNorm2d(out_channels),
77
  nn.ReLU(),
78
  )
 
81
  self.up = nn.UpsamplingBilinear2d(scale_factor=scale_factor)
82
 
83
  highres_in_channels = encoder_channels[-4]
84
+ highres_out_channels = 48 # proposed by authors of paper
85
  self.block1 = nn.Sequential(
86
+ nn.Conv2d(
87
+ highres_in_channels, highres_out_channels, kernel_size=1, bias=False
88
+ ),
89
  nn.BatchNorm2d(highres_out_channels),
90
  nn.ReLU(),
91
  )
 
155
  size = x.shape[-2:]
156
  for mod in self:
157
  x = mod(x)
158
+ return F.interpolate(x, size=size, mode="bilinear", align_corners=False)
159
 
160
 
161
  class ASPP(nn.Module):
 
196
 
197
 
198
  class SeparableConv2d(nn.Sequential):
 
199
  def __init__(
200
+ self,
201
+ in_channels,
202
+ out_channels,
203
+ kernel_size,
204
+ stride=1,
205
+ padding=0,
206
+ dilation=1,
207
+ bias=True,
208
  ):
209
  dephtwise_conv = nn.Conv2d(
210
  in_channels,
segmentation_models_pytorch/segmentation_models_pytorch/deeplabv3/model.py CHANGED
@@ -1,9 +1,10 @@
 
 
1
  import torch.nn as nn
2
 
3
- from typing import Optional
4
- from .decoder import DeepLabV3Decoder, DeepLabV3PlusDecoder
5
- from ..base import SegmentationModel, SegmentationHead, ClassificationHead
6
  from ..encoders import get_encoder
 
7
 
8
 
9
  class DeepLabV3(SegmentationModel):
@@ -12,11 +13,11 @@ class DeepLabV3(SegmentationModel):
12
  Args:
13
  encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
14
  to extract features of different spatial resolution
15
- encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
16
  two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
17
  with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
18
  Default is 5
19
- encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
20
  other pretrained weights (see table with available weights for each encoder_name)
21
  decoder_channels: A number of convolution filters in ASPP module. Default is 256
22
  in_channels: A number of input channels for the model, default is 3 (RGB images)
@@ -25,7 +26,7 @@ class DeepLabV3(SegmentationModel):
25
  Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, **callable** and **None**.
26
  Default is **None**
27
  upsampling: Final upsampling factor. Default is 8 to preserve input-output spatial shape identity
28
- aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
29
  on top of encoder if **aux_params** is not **None** (default). Supported params:
30
  - classes (int): A number of classes
31
  - pooling (str): One of "max", "avg". Default is "avg"
@@ -42,16 +43,16 @@ class DeepLabV3(SegmentationModel):
42
  """
43
 
44
  def __init__(
45
- self,
46
- encoder_name: str = "resnet34",
47
- encoder_depth: int = 5,
48
- encoder_weights: Optional[str] = "imagenet",
49
- decoder_channels: int = 256,
50
- in_channels: int = 3,
51
- classes: int = 1,
52
- activation: Optional[str] = None,
53
- upsampling: int = 8,
54
- aux_params: Optional[dict] = None,
55
  ):
56
  super().__init__()
57
 
@@ -61,10 +62,7 @@ class DeepLabV3(SegmentationModel):
61
  depth=encoder_depth,
62
  weights=encoder_weights,
63
  )
64
- self.encoder.make_dilated(
65
- stage_list=[4, 5],
66
- dilation_list=[2, 4]
67
- )
68
 
69
  self.decoder = DeepLabV3Decoder(
70
  in_channels=self.encoder.out_channels[-1],
@@ -90,15 +88,15 @@ class DeepLabV3(SegmentationModel):
90
  class DeepLabV3Plus(SegmentationModel):
91
  """DeepLabV3+ implementation from "Encoder-Decoder with Atrous Separable
92
  Convolution for Semantic Image Segmentation"
93
-
94
  Args:
95
  encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
96
  to extract features of different spatial resolution
97
- encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
98
  two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
99
  with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
100
  Default is 5
101
- encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
102
  other pretrained weights (see table with available weights for each encoder_name)
103
  encoder_output_stride: Downsampling factor for last encoder features (see original paper for explanation)
104
  decoder_atrous_rates: Dilation rates for ASPP module (should be a tuple of 3 integer values)
@@ -109,7 +107,7 @@ class DeepLabV3Plus(SegmentationModel):
109
  Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, **callable** and **None**.
110
  Default is **None**
111
  upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity
112
- aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
113
  on top of encoder if **aux_params** is not **None** (default). Supported params:
114
  - classes (int): A number of classes
115
  - pooling (str): One of "max", "avg". Default is "avg"
@@ -121,19 +119,20 @@ class DeepLabV3Plus(SegmentationModel):
121
  Reference:
122
  https://arxiv.org/abs/1802.02611v3
123
  """
 
124
  def __init__(
125
- self,
126
- encoder_name: str = "resnet34",
127
- encoder_depth: int = 5,
128
- encoder_weights: Optional[str] = "imagenet",
129
- encoder_output_stride: int = 16,
130
- decoder_channels: int = 256,
131
- decoder_atrous_rates: tuple = (12, 24, 36),
132
- in_channels: int = 3,
133
- classes: int = 1,
134
- activation: Optional[str] = None,
135
- upsampling: int = 4,
136
- aux_params: Optional[dict] = None,
137
  ):
138
  super().__init__()
139
 
@@ -145,19 +144,15 @@ class DeepLabV3Plus(SegmentationModel):
145
  )
146
 
147
  if encoder_output_stride == 8:
148
- self.encoder.make_dilated(
149
- stage_list=[4, 5],
150
- dilation_list=[2, 4]
151
- )
152
 
153
  elif encoder_output_stride == 16:
154
- self.encoder.make_dilated(
155
- stage_list=[5],
156
- dilation_list=[2]
157
- )
158
  else:
159
  raise ValueError(
160
- "Encoder output stride should be 8 or 16, got {}".format(encoder_output_stride)
 
 
161
  )
162
 
163
  self.decoder = DeepLabV3PlusDecoder(
 
1
+ from typing import Optional
2
+
3
  import torch.nn as nn
4
 
5
+ from ..base import ClassificationHead, SegmentationHead, SegmentationModel
 
 
6
  from ..encoders import get_encoder
7
+ from .decoder import DeepLabV3Decoder, DeepLabV3PlusDecoder
8
 
9
 
10
  class DeepLabV3(SegmentationModel):
 
13
  Args:
14
  encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
15
  to extract features of different spatial resolution
16
+ encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
17
  two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
18
  with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
19
  Default is 5
20
+ encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
21
  other pretrained weights (see table with available weights for each encoder_name)
22
  decoder_channels: A number of convolution filters in ASPP module. Default is 256
23
  in_channels: A number of input channels for the model, default is 3 (RGB images)
 
26
  Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, **callable** and **None**.
27
  Default is **None**
28
  upsampling: Final upsampling factor. Default is 8 to preserve input-output spatial shape identity
29
+ aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
30
  on top of encoder if **aux_params** is not **None** (default). Supported params:
31
  - classes (int): A number of classes
32
  - pooling (str): One of "max", "avg". Default is "avg"
 
43
  """
44
 
45
  def __init__(
46
+ self,
47
+ encoder_name: str = "resnet34",
48
+ encoder_depth: int = 5,
49
+ encoder_weights: Optional[str] = "imagenet",
50
+ decoder_channels: int = 256,
51
+ in_channels: int = 3,
52
+ classes: int = 1,
53
+ activation: Optional[str] = None,
54
+ upsampling: int = 8,
55
+ aux_params: Optional[dict] = None,
56
  ):
57
  super().__init__()
58
 
 
62
  depth=encoder_depth,
63
  weights=encoder_weights,
64
  )
65
+ self.encoder.make_dilated(stage_list=[4, 5], dilation_list=[2, 4])
 
 
 
66
 
67
  self.decoder = DeepLabV3Decoder(
68
  in_channels=self.encoder.out_channels[-1],
 
88
  class DeepLabV3Plus(SegmentationModel):
89
  """DeepLabV3+ implementation from "Encoder-Decoder with Atrous Separable
90
  Convolution for Semantic Image Segmentation"
91
+
92
  Args:
93
  encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
94
  to extract features of different spatial resolution
95
+ encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
96
  two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
97
  with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
98
  Default is 5
99
+ encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
100
  other pretrained weights (see table with available weights for each encoder_name)
101
  encoder_output_stride: Downsampling factor for last encoder features (see original paper for explanation)
102
  decoder_atrous_rates: Dilation rates for ASPP module (should be a tuple of 3 integer values)
 
107
  Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, **callable** and **None**.
108
  Default is **None**
109
  upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity
110
+ aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
111
  on top of encoder if **aux_params** is not **None** (default). Supported params:
112
  - classes (int): A number of classes
113
  - pooling (str): One of "max", "avg". Default is "avg"
 
119
  Reference:
120
  https://arxiv.org/abs/1802.02611v3
121
  """
122
+
123
  def __init__(
124
+ self,
125
+ encoder_name: str = "resnet34",
126
+ encoder_depth: int = 5,
127
+ encoder_weights: Optional[str] = "imagenet",
128
+ encoder_output_stride: int = 16,
129
+ decoder_channels: int = 256,
130
+ decoder_atrous_rates: tuple = (12, 24, 36),
131
+ in_channels: int = 3,
132
+ classes: int = 1,
133
+ activation: Optional[str] = None,
134
+ upsampling: int = 4,
135
+ aux_params: Optional[dict] = None,
136
  ):
137
  super().__init__()
138
 
 
144
  )
145
 
146
  if encoder_output_stride == 8:
147
+ self.encoder.make_dilated(stage_list=[4, 5], dilation_list=[2, 4])
 
 
 
148
 
149
  elif encoder_output_stride == 16:
150
+ self.encoder.make_dilated(stage_list=[5], dilation_list=[2])
 
 
 
151
  else:
152
  raise ValueError(
153
+ "Encoder output stride should be 8 or 16, got {}".format(
154
+ encoder_output_stride
155
+ )
156
  )
157
 
158
  self.decoder = DeepLabV3PlusDecoder(
segmentation_models_pytorch/segmentation_models_pytorch/efficientunetplusplus/decoder.py CHANGED
@@ -1,76 +1,92 @@
1
  import torch
2
- from torch.functional import norm
3
  import torch.nn as nn
4
  import torch.nn.functional as F
 
5
 
6
  from ..base import modules as md
7
 
 
8
  class InvertedResidual(nn.Module):
9
  """
10
- Inverted bottleneck residual block with an scSE block embedded into the residual layer, after the
11
  depthwise convolution. By default, uses batch normalization and Hardswish activation.
12
  """
13
- def __init__(self, in_channels, out_channels, kernel_size = 3, stride = 1, expansion_ratio = 1, squeeze_ratio = 1, \
14
- activation = nn.Hardswish(True), normalization = nn.BatchNorm2d):
 
 
 
 
 
 
 
 
 
 
15
  super().__init__()
16
  self.same_shape = in_channels == out_channels
17
- self.mid_channels = expansion_ratio*in_channels
18
  self.block = nn.Sequential(
19
  md.PointWiseConv2d(in_channels, self.mid_channels),
20
  normalization(self.mid_channels),
21
  activation,
22
- md.DepthWiseConv2d(self.mid_channels, kernel_size=kernel_size, stride=stride),
 
 
23
  normalization(self.mid_channels),
24
  activation,
25
- #md.sSEModule(self.mid_channels),
26
- md.SCSEModule(self.mid_channels, reduction = squeeze_ratio),
27
- #md.SEModule(self.mid_channels, reduction = squeeze_ratio),
28
  md.PointWiseConv2d(self.mid_channels, out_channels),
29
- normalization(out_channels)
30
  )
31
-
32
  if not self.same_shape:
33
- # 1x1 convolution used to match the number of channels in the skip feature maps with that
34
  # of the residual feature maps
35
  self.skip_conv = nn.Sequential(
36
- nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1),
37
- normalization(out_channels)
 
 
38
  )
39
-
40
  def forward(self, x):
41
  residual = self.block(x)
42
-
43
  if not self.same_shape:
44
  x = self.skip_conv(x)
45
  return x + residual
46
-
 
47
  class DecoderBlock(nn.Module):
48
  def __init__(
49
- self,
50
- in_channels,
51
- skip_channels,
52
- out_channels,
53
- squeeze_ratio=1,
54
- expansion_ratio=1
55
  ):
56
  super().__init__()
57
 
58
  # Inverted Residual block convolutions
59
  self.conv1 = InvertedResidual(
60
- in_channels=in_channels+skip_channels,
61
- out_channels=out_channels,
62
- kernel_size=3,
63
- stride=1,
64
- expansion_ratio=expansion_ratio,
65
- squeeze_ratio=squeeze_ratio
66
  )
67
  self.conv2 = InvertedResidual(
68
- in_channels=out_channels,
69
- out_channels=out_channels,
70
- kernel_size=3,
71
- stride=1,
72
- expansion_ratio=expansion_ratio,
73
- squeeze_ratio=squeeze_ratio
74
  )
75
 
76
  def forward(self, x, skip=None):
@@ -82,14 +98,15 @@ class DecoderBlock(nn.Module):
82
  x = self.conv2(x)
83
  return x
84
 
 
85
  class EfficientUnetPlusPlusDecoder(nn.Module):
86
  def __init__(
87
- self,
88
- encoder_channels,
89
- decoder_channels,
90
- n_blocks=5,
91
- squeeze_ratio=1,
92
- expansion_ratio=1
93
  ):
94
  super().__init__()
95
  if n_blocks != len(decoder_channels):
@@ -99,8 +116,12 @@ class EfficientUnetPlusPlusDecoder(nn.Module):
99
  )
100
  )
101
 
102
- encoder_channels = encoder_channels[1:] # remove first skip with same spatial resolution
103
- encoder_channels = encoder_channels[::-1] # reverse channels to start from head of encoder
 
 
 
 
104
  # computing blocks input and output channels
105
  head_channels = encoder_channels[0]
106
  self.in_channels = [head_channels] + list(decoder_channels[:-1])
@@ -112,37 +133,51 @@ class EfficientUnetPlusPlusDecoder(nn.Module):
112
 
113
  blocks = {}
114
  for layer_idx in range(len(self.in_channels) - 1):
115
- for depth_idx in range(layer_idx+1):
116
  if depth_idx == 0:
117
  in_ch = self.in_channels[layer_idx]
118
- skip_ch = self.skip_channels[layer_idx] * (layer_idx+1)
119
  out_ch = self.out_channels[layer_idx]
120
  else:
121
  out_ch = self.skip_channels[layer_idx]
122
- skip_ch = self.skip_channels[layer_idx] * (layer_idx+1-depth_idx)
 
 
123
  in_ch = self.skip_channels[layer_idx - 1]
124
- blocks[f'x_{depth_idx}_{layer_idx}'] = DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
125
- blocks[f'x_{0}_{len(self.in_channels)-1}'] =\
126
- DecoderBlock(self.in_channels[-1], 0, self.out_channels[-1], **kwargs)
 
 
 
127
  self.blocks = nn.ModuleDict(blocks)
128
  self.depth = len(self.in_channels) - 1
129
 
130
  def forward(self, *features):
131
-
132
- features = features[1:] # remove first skip with same spatial resolution
133
  features = features[::-1] # reverse channels to start from head of encoder
134
  # start building dense connections
135
  dense_x = {}
136
- for layer_idx in range(len(self.in_channels)-1):
137
- for depth_idx in range(self.depth-layer_idx):
138
  if layer_idx == 0:
139
- output = self.blocks[f'x_{depth_idx}_{depth_idx}'](features[depth_idx], features[depth_idx+1])
140
- dense_x[f'x_{depth_idx}_{depth_idx}'] = output
 
 
141
  else:
142
  dense_l_i = depth_idx + layer_idx
143
- cat_features = [dense_x[f'x_{idx}_{dense_l_i}'] for idx in range(depth_idx+1, dense_l_i+1)]
144
- cat_features = torch.cat(cat_features + [features[dense_l_i+1]], dim=1)
145
- dense_x[f'x_{depth_idx}_{dense_l_i}'] =\
146
- self.blocks[f'x_{depth_idx}_{dense_l_i}'](dense_x[f'x_{depth_idx}_{dense_l_i-1}'], cat_features)
147
- dense_x[f'x_{0}_{self.depth}'] = self.blocks[f'x_{0}_{self.depth}'](dense_x[f'x_{0}_{self.depth-1}'])
148
- return dense_x[f'x_{0}_{self.depth}']
 
 
 
 
 
 
 
 
 
1
  import torch
 
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
+ from torch.functional import norm
5
 
6
  from ..base import modules as md
7
 
8
+
9
  class InvertedResidual(nn.Module):
10
  """
11
+ Inverted bottleneck residual block with an scSE block embedded into the residual layer, after the
12
  depthwise convolution. By default, uses batch normalization and Hardswish activation.
13
  """
14
+
15
+ def __init__(
16
+ self,
17
+ in_channels,
18
+ out_channels,
19
+ kernel_size=3,
20
+ stride=1,
21
+ expansion_ratio=1,
22
+ squeeze_ratio=1,
23
+ activation=nn.Hardswish(True),
24
+ normalization=nn.BatchNorm2d,
25
+ ):
26
  super().__init__()
27
  self.same_shape = in_channels == out_channels
28
+ self.mid_channels = expansion_ratio * in_channels
29
  self.block = nn.Sequential(
30
  md.PointWiseConv2d(in_channels, self.mid_channels),
31
  normalization(self.mid_channels),
32
  activation,
33
+ md.DepthWiseConv2d(
34
+ self.mid_channels, kernel_size=kernel_size, stride=stride
35
+ ),
36
  normalization(self.mid_channels),
37
  activation,
38
+ # md.sSEModule(self.mid_channels),
39
+ md.SCSEModule(self.mid_channels, reduction=squeeze_ratio),
40
+ # md.SEModule(self.mid_channels, reduction = squeeze_ratio),
41
  md.PointWiseConv2d(self.mid_channels, out_channels),
42
+ normalization(out_channels),
43
  )
44
+
45
  if not self.same_shape:
46
+ # 1x1 convolution used to match the number of channels in the skip feature maps with that
47
  # of the residual feature maps
48
  self.skip_conv = nn.Sequential(
49
+ nn.Conv2d(
50
+ in_channels=in_channels, out_channels=out_channels, kernel_size=1
51
+ ),
52
+ normalization(out_channels),
53
  )
54
+
55
  def forward(self, x):
56
  residual = self.block(x)
57
+
58
  if not self.same_shape:
59
  x = self.skip_conv(x)
60
  return x + residual
61
+
62
+
63
  class DecoderBlock(nn.Module):
64
  def __init__(
65
+ self,
66
+ in_channels,
67
+ skip_channels,
68
+ out_channels,
69
+ squeeze_ratio=1,
70
+ expansion_ratio=1,
71
  ):
72
  super().__init__()
73
 
74
  # Inverted Residual block convolutions
75
  self.conv1 = InvertedResidual(
76
+ in_channels=in_channels + skip_channels,
77
+ out_channels=out_channels,
78
+ kernel_size=3,
79
+ stride=1,
80
+ expansion_ratio=expansion_ratio,
81
+ squeeze_ratio=squeeze_ratio,
82
  )
83
  self.conv2 = InvertedResidual(
84
+ in_channels=out_channels,
85
+ out_channels=out_channels,
86
+ kernel_size=3,
87
+ stride=1,
88
+ expansion_ratio=expansion_ratio,
89
+ squeeze_ratio=squeeze_ratio,
90
  )
91
 
92
  def forward(self, x, skip=None):
 
98
  x = self.conv2(x)
99
  return x
100
 
101
+
102
  class EfficientUnetPlusPlusDecoder(nn.Module):
103
  def __init__(
104
+ self,
105
+ encoder_channels,
106
+ decoder_channels,
107
+ n_blocks=5,
108
+ squeeze_ratio=1,
109
+ expansion_ratio=1,
110
  ):
111
  super().__init__()
112
  if n_blocks != len(decoder_channels):
 
116
  )
117
  )
118
 
119
+ encoder_channels = encoder_channels[
120
+ 1:
121
+ ] # remove first skip with same spatial resolution
122
+ encoder_channels = encoder_channels[
123
+ ::-1
124
+ ] # reverse channels to start from head of encoder
125
  # computing blocks input and output channels
126
  head_channels = encoder_channels[0]
127
  self.in_channels = [head_channels] + list(decoder_channels[:-1])
 
133
 
134
  blocks = {}
135
  for layer_idx in range(len(self.in_channels) - 1):
136
+ for depth_idx in range(layer_idx + 1):
137
  if depth_idx == 0:
138
  in_ch = self.in_channels[layer_idx]
139
+ skip_ch = self.skip_channels[layer_idx] * (layer_idx + 1)
140
  out_ch = self.out_channels[layer_idx]
141
  else:
142
  out_ch = self.skip_channels[layer_idx]
143
+ skip_ch = self.skip_channels[layer_idx] * (
144
+ layer_idx + 1 - depth_idx
145
+ )
146
  in_ch = self.skip_channels[layer_idx - 1]
147
+ blocks[f"x_{depth_idx}_{layer_idx}"] = DecoderBlock(
148
+ in_ch, skip_ch, out_ch, **kwargs
149
+ )
150
+ blocks[f"x_{0}_{len(self.in_channels)-1}"] = DecoderBlock(
151
+ self.in_channels[-1], 0, self.out_channels[-1], **kwargs
152
+ )
153
  self.blocks = nn.ModuleDict(blocks)
154
  self.depth = len(self.in_channels) - 1
155
 
156
  def forward(self, *features):
157
+ features = features[1:] # remove first skip with same spatial resolution
 
158
  features = features[::-1] # reverse channels to start from head of encoder
159
  # start building dense connections
160
  dense_x = {}
161
+ for layer_idx in range(len(self.in_channels) - 1):
162
+ for depth_idx in range(self.depth - layer_idx):
163
  if layer_idx == 0:
164
+ output = self.blocks[f"x_{depth_idx}_{depth_idx}"](
165
+ features[depth_idx], features[depth_idx + 1]
166
+ )
167
+ dense_x[f"x_{depth_idx}_{depth_idx}"] = output
168
  else:
169
  dense_l_i = depth_idx + layer_idx
170
+ cat_features = [
171
+ dense_x[f"x_{idx}_{dense_l_i}"]
172
+ for idx in range(depth_idx + 1, dense_l_i + 1)
173
+ ]
174
+ cat_features = torch.cat(
175
+ cat_features + [features[dense_l_i + 1]], dim=1
176
+ )
177
+ dense_x[f"x_{depth_idx}_{dense_l_i}"] = self.blocks[
178
+ f"x_{depth_idx}_{dense_l_i}"
179
+ ](dense_x[f"x_{depth_idx}_{dense_l_i-1}"], cat_features)
180
+ dense_x[f"x_{0}_{self.depth}"] = self.blocks[f"x_{0}_{self.depth}"](
181
+ dense_x[f"x_{0}_{self.depth-1}"]
182
+ )
183
+ return dense_x[f"x_{0}_{self.depth}"]
segmentation_models_pytorch/segmentation_models_pytorch/efficientunetplusplus/model.py CHANGED
@@ -1,28 +1,30 @@
1
- from typing import Optional, Union, List
2
- from .decoder import EfficientUnetPlusPlusDecoder
3
- from ..encoders import get_encoder
4
- from ..base import SegmentationModel
5
- from ..base import SegmentationHead, ClassificationHead
6
  from torchvision import transforms
7
 
 
 
 
 
 
8
  class EfficientUnetPlusPlus(SegmentationModel):
9
- """The EfficientUNet++ is a fully convolutional neural network for ordinary and medical image semantic segmentation.
10
- Consists of an *encoder* and a *decoder*, connected by *skip connections*. The encoder extracts features of
11
- different spatial resolutions, which are fed to the decoder through skip connections. The decoder combines its
12
- own feature maps with the ones from skip connections to produce accurate segmentations masks. The EfficientUNet++
13
- decoder architecture is based on the UNet++, a model composed of nested U-Net-like decoder sub-networks. To
14
- increase performance and computational efficiency, the EfficientUNet++ replaces the UNet++'s blocks with
15
  inverted residual blocks with depthwise convolutions and embedded spatial and channel attention mechanisms.
16
  Synergizes well with EfficientNet encoders. Due to their efficient visual representations (i.e., using few channels
17
  to represent extracted features), EfficientNet encoders require few computation from the decoder.
18
 
19
  Args:
20
  encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) to extract features
21
- encoder_depth: Number of stages of the encoder, in range [3 ,5]. Each stage generate features two times smaller,
22
- in spatial dimensions, than the previous one (e.g., for depth=0 features will haves shapes [(N, C, H, W)]),
23
  for depth 1 features will have shapes [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
24
  Default is 5
25
- encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
26
  other pretrained weights (see table with available weights for each encoder_name)
27
  decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in the decoder.
28
  Length of the list should be the same as **encoder_depth**
@@ -31,7 +33,7 @@ class EfficientUnetPlusPlus(SegmentationModel):
31
  activation: An activation function to apply after the final convolution layer.
32
  Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, **callable** and **None**.
33
  Default is **None**
34
- aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is built
35
  on top of encoder if **aux_params** is not **None** (default). Supported params:
36
  - classes (int): A number of classes
37
  - pooling (str): One of "max", "avg". Default is "avg"
@@ -71,7 +73,7 @@ class EfficientUnetPlusPlus(SegmentationModel):
71
  decoder_channels=decoder_channels,
72
  n_blocks=encoder_depth,
73
  squeeze_ratio=squeeze_ratio,
74
- expansion_ratio=expansion_ratio
75
  )
76
 
77
  self.segmentation_head = SegmentationHead(
@@ -117,9 +119,9 @@ class EfficientUnetPlusPlus(SegmentationModel):
117
  [
118
  transforms.ToPILImage(),
119
  transforms.Resize(x.size[1]),
120
- transforms.ToTensor()
121
  ]
122
  )
123
- full_mask = tf(probs.cpu())
124
 
125
- return full_mask
 
1
+ from typing import List, Optional, Union
2
+
 
 
 
3
  from torchvision import transforms
4
 
5
+ from ..base import ClassificationHead, SegmentationHead, SegmentationModel
6
+ from ..encoders import get_encoder
7
+ from .decoder import EfficientUnetPlusPlusDecoder
8
+
9
+
10
  class EfficientUnetPlusPlus(SegmentationModel):
11
+ """The EfficientUNet++ is a fully convolutional neural network for ordinary and medical image semantic segmentation.
12
+ Consists of an *encoder* and a *decoder*, connected by *skip connections*. The encoder extracts features of
13
+ different spatial resolutions, which are fed to the decoder through skip connections. The decoder combines its
14
+ own feature maps with the ones from skip connections to produce accurate segmentations masks. The EfficientUNet++
15
+ decoder architecture is based on the UNet++, a model composed of nested U-Net-like decoder sub-networks. To
16
+ increase performance and computational efficiency, the EfficientUNet++ replaces the UNet++'s blocks with
17
  inverted residual blocks with depthwise convolutions and embedded spatial and channel attention mechanisms.
18
  Synergizes well with EfficientNet encoders. Due to their efficient visual representations (i.e., using few channels
19
  to represent extracted features), EfficientNet encoders require few computation from the decoder.
20
 
21
  Args:
22
  encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) to extract features
23
+ encoder_depth: Number of stages of the encoder, in range [3 ,5]. Each stage generate features two times smaller,
24
+ in spatial dimensions, than the previous one (e.g., for depth=0 features will haves shapes [(N, C, H, W)]),
25
  for depth 1 features will have shapes [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
26
  Default is 5
27
+ encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
28
  other pretrained weights (see table with available weights for each encoder_name)
29
  decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in the decoder.
30
  Length of the list should be the same as **encoder_depth**
 
33
  activation: An activation function to apply after the final convolution layer.
34
  Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, **callable** and **None**.
35
  Default is **None**
36
+ aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is built
37
  on top of encoder if **aux_params** is not **None** (default). Supported params:
38
  - classes (int): A number of classes
39
  - pooling (str): One of "max", "avg". Default is "avg"
 
73
  decoder_channels=decoder_channels,
74
  n_blocks=encoder_depth,
75
  squeeze_ratio=squeeze_ratio,
76
+ expansion_ratio=expansion_ratio,
77
  )
78
 
79
  self.segmentation_head = SegmentationHead(
 
119
  [
120
  transforms.ToPILImage(),
121
  transforms.Resize(x.size[1]),
122
+ transforms.ToTensor(),
123
  ]
124
  )
125
+ full_mask = tf(probs.cpu())
126
 
127
+ return full_mask
segmentation_models_pytorch/segmentation_models_pytorch/encoders/__init__.py CHANGED
@@ -1,22 +1,24 @@
1
  import functools
 
2
  import torch.utils.model_zoo as model_zoo
3
 
4
- from .resnet import resnet_encoders
5
- from .dpn import dpn_encoders
6
- from .vgg import vgg_encoders
7
- from .senet import senet_encoders
8
  from .densenet import densenet_encoders
 
 
9
  from .inceptionresnetv2 import inceptionresnetv2_encoders
10
  from .inceptionv4 import inceptionv4_encoders
11
- from .efficientnet import efficient_net_encoders
12
  from .mobilenet import mobilenet_encoders
13
- from .xception import xception_encoders
 
 
 
 
14
  # from .timm_efficientnet import timm_efficientnet_encoders
15
  from .timm_resnest import timm_resnest_encoders
16
- from .timm_res2net import timm_res2net_encoders
17
- from .timm_regnet import timm_regnet_encoders
18
  from .timm_sknet import timm_sknet_encoders
19
- from ._preprocessing import preprocess_input
 
20
 
21
  encoders = {}
22
  encoders.update(resnet_encoders)
@@ -37,11 +39,14 @@ encoders.update(timm_sknet_encoders)
37
 
38
 
39
  def get_encoder(name, in_channels=3, depth=5, weights=None):
40
-
41
  try:
42
  Encoder = encoders[name]["encoder"]
43
  except KeyError:
44
- raise KeyError("Wrong encoder name `{}`, supported encoders: {}".format(name, list(encoders.keys())))
 
 
 
 
45
 
46
  params = encoders[name]["params"]
47
  params.update(depth=depth)
@@ -51,9 +56,13 @@ def get_encoder(name, in_channels=3, depth=5, weights=None):
51
  try:
52
  settings = encoders[name]["pretrained_settings"][weights]
53
  except KeyError:
54
- raise KeyError("Wrong pretrained weights `{}` for encoder `{}`. Available options are: {}".format(
55
- weights, name, list(encoders[name]["pretrained_settings"].keys()),
56
- ))
 
 
 
 
57
  encoder.load_state_dict(model_zoo.load_url(settings["url"]))
58
 
59
  encoder.set_in_channels(in_channels)
 
1
  import functools
2
+
3
  import torch.utils.model_zoo as model_zoo
4
 
5
+ from ._preprocessing import preprocess_input
 
 
 
6
  from .densenet import densenet_encoders
7
+ from .dpn import dpn_encoders
8
+ from .efficientnet import efficient_net_encoders
9
  from .inceptionresnetv2 import inceptionresnetv2_encoders
10
  from .inceptionv4 import inceptionv4_encoders
 
11
  from .mobilenet import mobilenet_encoders
12
+ from .resnet import resnet_encoders
13
+ from .senet import senet_encoders
14
+ from .timm_regnet import timm_regnet_encoders
15
+ from .timm_res2net import timm_res2net_encoders
16
+
17
  # from .timm_efficientnet import timm_efficientnet_encoders
18
  from .timm_resnest import timm_resnest_encoders
 
 
19
  from .timm_sknet import timm_sknet_encoders
20
+ from .vgg import vgg_encoders
21
+ from .xception import xception_encoders
22
 
23
  encoders = {}
24
  encoders.update(resnet_encoders)
 
39
 
40
 
41
  def get_encoder(name, in_channels=3, depth=5, weights=None):
 
42
  try:
43
  Encoder = encoders[name]["encoder"]
44
  except KeyError:
45
+ raise KeyError(
46
+ "Wrong encoder name `{}`, supported encoders: {}".format(
47
+ name, list(encoders.keys())
48
+ )
49
+ )
50
 
51
  params = encoders[name]["params"]
52
  params.update(depth=depth)
 
56
  try:
57
  settings = encoders[name]["pretrained_settings"][weights]
58
  except KeyError:
59
+ raise KeyError(
60
+ "Wrong pretrained weights `{}` for encoder `{}`. Available options are: {}".format(
61
+ weights,
62
+ name,
63
+ list(encoders[name]["pretrained_settings"].keys()),
64
+ )
65
+ )
66
  encoder.load_state_dict(model_zoo.load_url(settings["url"]))
67
 
68
  encoder.set_in_channels(in_channels)
segmentation_models_pytorch/segmentation_models_pytorch/encoders/_base.py CHANGED
@@ -1,15 +1,16 @@
 
 
 
1
  import torch
2
  import torch.nn as nn
3
- from typing import List
4
- from collections import OrderedDict
5
 
6
  from . import _utils as utils
7
 
8
 
9
  class EncoderMixin:
10
  """Add encoder functionality such as:
11
- - output channels specification of feature tensors (produced by encoder)
12
- - patching first convolution for arbitrary input channels
13
  """
14
 
15
  @property
 
1
+ from collections import OrderedDict
2
+ from typing import List
3
+
4
  import torch
5
  import torch.nn as nn
 
 
6
 
7
  from . import _utils as utils
8
 
9
 
10
  class EncoderMixin:
11
  """Add encoder functionality such as:
12
+ - output channels specification of feature tensors (produced by encoder)
13
+ - patching first convolution for arbitrary input channels
14
  """
15
 
16
  @property
segmentation_models_pytorch/segmentation_models_pytorch/encoders/_preprocessing.py CHANGED
@@ -4,7 +4,6 @@ import numpy as np
4
  def preprocess_input(
5
  x, mean=None, std=None, input_space="RGB", input_range=None, **kwargs
6
  ):
7
-
8
  if input_space == "BGR":
9
  x = x[..., ::-1].copy()
10
 
 
4
  def preprocess_input(
5
  x, mean=None, std=None, input_space="RGB", input_range=None, **kwargs
6
  ):
 
7
  if input_space == "BGR":
8
  x = x[..., ::-1].copy()
9
 
segmentation_models_pytorch/segmentation_models_pytorch/encoders/densenet.py CHANGED
@@ -24,8 +24,8 @@ Methods:
24
  """
25
 
26
  import re
27
- import torch.nn as nn
28
 
 
29
  from pretrainedmodels.models.torchvision_models import pretrained_settings
30
  from torchvision.models.densenet import DenseNet
31
 
@@ -33,7 +33,6 @@ from ._base import EncoderMixin
33
 
34
 
35
  class TransitionWithSkip(nn.Module):
36
-
37
  def __init__(self, module):
38
  super().__init__()
39
  self.module = module
@@ -55,22 +54,32 @@ class DenseNetEncoder(DenseNet, EncoderMixin):
55
  del self.classifier
56
 
57
  def make_dilated(self, stage_list, dilation_list):
58
- raise ValueError("DenseNet encoders do not support dilated mode "
59
- "due to pooling operation for downsampling!")
 
 
60
 
61
  def get_stages(self):
62
  return [
63
  nn.Identity(),
64
- nn.Sequential(self.features.conv0, self.features.norm0, self.features.relu0),
65
- nn.Sequential(self.features.pool0, self.features.denseblock1,
66
- TransitionWithSkip(self.features.transition1)),
67
- nn.Sequential(self.features.denseblock2, TransitionWithSkip(self.features.transition2)),
68
- nn.Sequential(self.features.denseblock3, TransitionWithSkip(self.features.transition3)),
69
- nn.Sequential(self.features.denseblock4, self.features.norm5)
 
 
 
 
 
 
 
 
 
70
  ]
71
 
72
  def forward(self, x):
73
-
74
  stages = self.get_stages()
75
 
76
  features = []
 
24
  """
25
 
26
  import re
 
27
 
28
+ import torch.nn as nn
29
  from pretrainedmodels.models.torchvision_models import pretrained_settings
30
  from torchvision.models.densenet import DenseNet
31
 
 
33
 
34
 
35
  class TransitionWithSkip(nn.Module):
 
36
  def __init__(self, module):
37
  super().__init__()
38
  self.module = module
 
54
  del self.classifier
55
 
56
  def make_dilated(self, stage_list, dilation_list):
57
+ raise ValueError(
58
+ "DenseNet encoders do not support dilated mode "
59
+ "due to pooling operation for downsampling!"
60
+ )
61
 
62
  def get_stages(self):
63
  return [
64
  nn.Identity(),
65
+ nn.Sequential(
66
+ self.features.conv0, self.features.norm0, self.features.relu0
67
+ ),
68
+ nn.Sequential(
69
+ self.features.pool0,
70
+ self.features.denseblock1,
71
+ TransitionWithSkip(self.features.transition1),
72
+ ),
73
+ nn.Sequential(
74
+ self.features.denseblock2, TransitionWithSkip(self.features.transition2)
75
+ ),
76
+ nn.Sequential(
77
+ self.features.denseblock3, TransitionWithSkip(self.features.transition3)
78
+ ),
79
+ nn.Sequential(self.features.denseblock4, self.features.norm5),
80
  ]
81
 
82
  def forward(self, x):
 
83
  stages = self.get_stages()
84
 
85
  features = []
segmentation_models_pytorch/segmentation_models_pytorch/encoders/dpn.py CHANGED
@@ -26,9 +26,7 @@ Methods:
26
  import torch
27
  import torch.nn as nn
28
  import torch.nn.functional as F
29
-
30
- from pretrainedmodels.models.dpn import DPN
31
- from pretrainedmodels.models.dpn import pretrained_settings
32
 
33
  from ._base import EncoderMixin
34
 
@@ -46,15 +44,18 @@ class DPNEncoder(DPN, EncoderMixin):
46
  def get_stages(self):
47
  return [
48
  nn.Identity(),
49
- nn.Sequential(self.features[0].conv, self.features[0].bn, self.features[0].act),
50
- nn.Sequential(self.features[0].pool, self.features[1 : self._stage_idxs[0]]),
 
 
 
 
51
  self.features[self._stage_idxs[0] : self._stage_idxs[1]],
52
  self.features[self._stage_idxs[1] : self._stage_idxs[2]],
53
  self.features[self._stage_idxs[2] : self._stage_idxs[3]],
54
  ]
55
 
56
  def forward(self, x):
57
-
58
  stages = self.get_stages()
59
 
60
  features = []
 
26
  import torch
27
  import torch.nn as nn
28
  import torch.nn.functional as F
29
+ from pretrainedmodels.models.dpn import DPN, pretrained_settings
 
 
30
 
31
  from ._base import EncoderMixin
32
 
 
44
  def get_stages(self):
45
  return [
46
  nn.Identity(),
47
+ nn.Sequential(
48
+ self.features[0].conv, self.features[0].bn, self.features[0].act
49
+ ),
50
+ nn.Sequential(
51
+ self.features[0].pool, self.features[1 : self._stage_idxs[0]]
52
+ ),
53
  self.features[self._stage_idxs[0] : self._stage_idxs[1]],
54
  self.features[self._stage_idxs[1] : self._stage_idxs[2]],
55
  self.features[self._stage_idxs[2] : self._stage_idxs[3]],
56
  ]
57
 
58
  def forward(self, x):
 
59
  stages = self.get_stages()
60
 
61
  features = []
segmentation_models_pytorch/segmentation_models_pytorch/encoders/efficientnet.py CHANGED
@@ -24,14 +24,13 @@ Methods:
24
  """
25
  import torch.nn as nn
26
  from efficientnet_pytorch import EfficientNet
27
- from efficientnet_pytorch.utils import url_map, url_map_advprop, get_model_params
28
 
29
  from ._base import EncoderMixin
30
 
31
 
32
  class EfficientNetEncoder(EfficientNet, EncoderMixin):
33
  def __init__(self, stage_idxs, out_channels, model_name, depth=5):
34
-
35
  blocks_args, global_params = get_model_params(model_name, override_params=None)
36
  super().__init__(blocks_args, global_params)
37
 
@@ -46,21 +45,20 @@ class EfficientNetEncoder(EfficientNet, EncoderMixin):
46
  return [
47
  nn.Identity(),
48
  nn.Sequential(self._conv_stem, self._bn0, self._swish),
49
- self._blocks[:self._stage_idxs[0]],
50
- self._blocks[self._stage_idxs[0]:self._stage_idxs[1]],
51
- self._blocks[self._stage_idxs[1]:self._stage_idxs[2]],
52
- self._blocks[self._stage_idxs[2]:],
53
  ]
54
 
55
  def forward(self, x):
56
  stages = self.get_stages()
57
 
58
- block_number = 0.
59
  drop_connect_rate = self._global_params.drop_connect_rate
60
 
61
  features = []
62
  for i in range(self._depth + 1):
63
-
64
  # Identity and Sequential stages
65
  if i < 2:
66
  x = stages[i](x)
@@ -69,7 +67,7 @@ class EfficientNetEncoder(EfficientNet, EncoderMixin):
69
  else:
70
  for module in stages[i]:
71
  drop_connect = drop_connect_rate * block_number / len(self._blocks)
72
- block_number += 1.
73
  x = module(x, drop_connect)
74
 
75
  features.append(x)
@@ -97,7 +95,7 @@ def _get_pretrained_settings(encoder):
97
  "url": url_map_advprop[encoder],
98
  "input_space": "RGB",
99
  "input_range": [0, 1],
100
- }
101
  }
102
  return pretrained_settings
103
 
 
24
  """
25
  import torch.nn as nn
26
  from efficientnet_pytorch import EfficientNet
27
+ from efficientnet_pytorch.utils import get_model_params, url_map, url_map_advprop
28
 
29
  from ._base import EncoderMixin
30
 
31
 
32
  class EfficientNetEncoder(EfficientNet, EncoderMixin):
33
  def __init__(self, stage_idxs, out_channels, model_name, depth=5):
 
34
  blocks_args, global_params = get_model_params(model_name, override_params=None)
35
  super().__init__(blocks_args, global_params)
36
 
 
45
  return [
46
  nn.Identity(),
47
  nn.Sequential(self._conv_stem, self._bn0, self._swish),
48
+ self._blocks[: self._stage_idxs[0]],
49
+ self._blocks[self._stage_idxs[0] : self._stage_idxs[1]],
50
+ self._blocks[self._stage_idxs[1] : self._stage_idxs[2]],
51
+ self._blocks[self._stage_idxs[2] :],
52
  ]
53
 
54
  def forward(self, x):
55
  stages = self.get_stages()
56
 
57
+ block_number = 0.0
58
  drop_connect_rate = self._global_params.drop_connect_rate
59
 
60
  features = []
61
  for i in range(self._depth + 1):
 
62
  # Identity and Sequential stages
63
  if i < 2:
64
  x = stages[i](x)
 
67
  else:
68
  for module in stages[i]:
69
  drop_connect = drop_connect_rate * block_number / len(self._blocks)
70
+ block_number += 1.0
71
  x = module(x, drop_connect)
72
 
73
  features.append(x)
 
95
  "url": url_map_advprop[encoder],
96
  "input_space": "RGB",
97
  "input_range": [0, 1],
98
+ },
99
  }
100
  return pretrained_settings
101
 
segmentation_models_pytorch/segmentation_models_pytorch/encoders/inceptionresnetv2.py CHANGED
@@ -24,8 +24,10 @@ Methods:
24
  """
25
 
26
  import torch.nn as nn
27
- from pretrainedmodels.models.inceptionresnetv2 import InceptionResNetV2
28
- from pretrainedmodels.models.inceptionresnetv2 import pretrained_settings
 
 
29
 
30
  from ._base import EncoderMixin
31
 
@@ -51,8 +53,10 @@ class InceptionResNetV2Encoder(InceptionResNetV2, EncoderMixin):
51
  del self.last_linear
52
 
53
  def make_dilated(self, stage_list, dilation_list):
54
- raise ValueError("InceptionResnetV2 encoder does not support dilated mode "
55
- "due to pooling operation for downsampling!")
 
 
56
 
57
  def get_stages(self):
58
  return [
@@ -65,7 +69,6 @@ class InceptionResNetV2Encoder(InceptionResNetV2, EncoderMixin):
65
  ]
66
 
67
  def forward(self, x):
68
-
69
  stages = self.get_stages()
70
 
71
  features = []
 
24
  """
25
 
26
  import torch.nn as nn
27
+ from pretrainedmodels.models.inceptionresnetv2 import (
28
+ InceptionResNetV2,
29
+ pretrained_settings,
30
+ )
31
 
32
  from ._base import EncoderMixin
33
 
 
53
  del self.last_linear
54
 
55
  def make_dilated(self, stage_list, dilation_list):
56
+ raise ValueError(
57
+ "InceptionResnetV2 encoder does not support dilated mode "
58
+ "due to pooling operation for downsampling!"
59
+ )
60
 
61
  def get_stages(self):
62
  return [
 
69
  ]
70
 
71
  def forward(self, x):
 
72
  stages = self.get_stages()
73
 
74
  features = []
segmentation_models_pytorch/segmentation_models_pytorch/encoders/inceptionv4.py CHANGED
@@ -24,8 +24,11 @@ Methods:
24
  """
25
 
26
  import torch.nn as nn
27
- from pretrainedmodels.models.inceptionv4 import InceptionV4, BasicConv2d
28
- from pretrainedmodels.models.inceptionv4 import pretrained_settings
 
 
 
29
 
30
  from ._base import EncoderMixin
31
 
@@ -50,21 +53,22 @@ class InceptionV4Encoder(InceptionV4, EncoderMixin):
50
  del self.last_linear
51
 
52
  def make_dilated(self, stage_list, dilation_list):
53
- raise ValueError("InceptionV4 encoder does not support dilated mode "
54
- "due to pooling operation for downsampling!")
 
 
55
 
56
  def get_stages(self):
57
  return [
58
  nn.Identity(),
59
  self.features[: self._stage_idxs[0]],
60
- self.features[self._stage_idxs[0]: self._stage_idxs[1]],
61
- self.features[self._stage_idxs[1]: self._stage_idxs[2]],
62
- self.features[self._stage_idxs[2]: self._stage_idxs[3]],
63
- self.features[self._stage_idxs[3]:],
64
  ]
65
 
66
  def forward(self, x):
67
-
68
  stages = self.get_stages()
69
 
70
  features = []
 
24
  """
25
 
26
  import torch.nn as nn
27
+ from pretrainedmodels.models.inceptionv4 import (
28
+ BasicConv2d,
29
+ InceptionV4,
30
+ pretrained_settings,
31
+ )
32
 
33
  from ._base import EncoderMixin
34
 
 
53
  del self.last_linear
54
 
55
  def make_dilated(self, stage_list, dilation_list):
56
+ raise ValueError(
57
+ "InceptionV4 encoder does not support dilated mode "
58
+ "due to pooling operation for downsampling!"
59
+ )
60
 
61
  def get_stages(self):
62
  return [
63
  nn.Identity(),
64
  self.features[: self._stage_idxs[0]],
65
+ self.features[self._stage_idxs[0] : self._stage_idxs[1]],
66
+ self.features[self._stage_idxs[1] : self._stage_idxs[2]],
67
+ self.features[self._stage_idxs[2] : self._stage_idxs[3]],
68
+ self.features[self._stage_idxs[3] :],
69
  ]
70
 
71
  def forward(self, x):
 
72
  stages = self.get_stages()
73
 
74
  features = []
segmentation_models_pytorch/segmentation_models_pytorch/encoders/mobilenet.py CHANGED
@@ -23,14 +23,13 @@ Methods:
23
  depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled).
24
  """
25
 
26
- import torchvision
27
  import torch.nn as nn
 
28
 
29
  from ._base import EncoderMixin
30
 
31
 
32
  class MobileNetV2Encoder(torchvision.models.MobileNetV2, EncoderMixin):
33
-
34
  def __init__(self, out_channels, depth=5, **kwargs):
35
  super().__init__(**kwargs)
36
  self._depth = depth
 
23
  depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled).
24
  """
25
 
 
26
  import torch.nn as nn
27
+ import torchvision
28
 
29
  from ._base import EncoderMixin
30
 
31
 
32
  class MobileNetV2Encoder(torchvision.models.MobileNetV2, EncoderMixin):
 
33
  def __init__(self, out_channels, depth=5, **kwargs):
34
  super().__init__(**kwargs)
35
  self._depth = depth
segmentation_models_pytorch/segmentation_models_pytorch/encoders/resnet.py CHANGED
@@ -25,11 +25,8 @@ Methods:
25
  from copy import deepcopy
26
 
27
  import torch.nn as nn
28
-
29
- from torchvision.models.resnet import ResNet
30
- from torchvision.models.resnet import BasicBlock
31
- from torchvision.models.resnet import Bottleneck
32
  from pretrainedmodels.models.torchvision_models import pretrained_settings
 
33
 
34
  from ._base import EncoderMixin
35
 
@@ -73,11 +70,11 @@ class ResNetEncoder(ResNet, EncoderMixin):
73
  new_settings = {
74
  "resnet18": {
75
  "ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet18-d92f0530.pth",
76
- "swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet18-118f1556.pth"
77
  },
78
  "resnet50": {
79
  "ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet50-08389792.pth",
80
- "swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet50-16a12f1b.pth"
81
  },
82
  "resnext50_32x4d": {
83
  "imagenet": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
@@ -86,7 +83,7 @@ new_settings = {
86
  },
87
  "resnext101_32x4d": {
88
  "ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x4-dc43570a.pth",
89
- "swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x4-3f87e46b.pth"
90
  },
91
  "resnext101_32x8d": {
92
  "imagenet": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
@@ -104,7 +101,7 @@ new_settings = {
104
  },
105
  "resnext101_32x48d": {
106
  "instagram": "https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth",
107
- }
108
  }
109
 
110
  pretrained_settings = deepcopy(pretrained_settings)
@@ -115,11 +112,11 @@ for model_name, sources in new_settings.items():
115
  for source_name, source_url in sources.items():
116
  pretrained_settings[model_name][source_name] = {
117
  "url": source_url,
118
- 'input_size': [3, 224, 224],
119
- 'input_range': [0, 1],
120
- 'mean': [0.485, 0.456, 0.406],
121
- 'std': [0.229, 0.224, 0.225],
122
- 'num_classes': 1000
123
  }
124
 
125
 
 
25
  from copy import deepcopy
26
 
27
  import torch.nn as nn
 
 
 
 
28
  from pretrainedmodels.models.torchvision_models import pretrained_settings
29
+ from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet
30
 
31
  from ._base import EncoderMixin
32
 
 
70
  new_settings = {
71
  "resnet18": {
72
  "ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet18-d92f0530.pth",
73
+ "swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet18-118f1556.pth",
74
  },
75
  "resnet50": {
76
  "ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet50-08389792.pth",
77
+ "swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet50-16a12f1b.pth",
78
  },
79
  "resnext50_32x4d": {
80
  "imagenet": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
 
83
  },
84
  "resnext101_32x4d": {
85
  "ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x4-dc43570a.pth",
86
+ "swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x4-3f87e46b.pth",
87
  },
88
  "resnext101_32x8d": {
89
  "imagenet": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
 
101
  },
102
  "resnext101_32x48d": {
103
  "instagram": "https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth",
104
+ },
105
  }
106
 
107
  pretrained_settings = deepcopy(pretrained_settings)
 
112
  for source_name, source_url in sources.items():
113
  pretrained_settings[model_name][source_name] = {
114
  "url": source_url,
115
+ "input_size": [3, 224, 224],
116
+ "input_range": [0, 1],
117
+ "mean": [0.485, 0.456, 0.406],
118
+ "std": [0.229, 0.224, 0.225],
119
+ "num_classes": 1000,
120
  }
121
 
122
 
segmentation_models_pytorch/segmentation_models_pytorch/encoders/senet.py CHANGED
@@ -24,14 +24,14 @@ Methods:
24
  """
25
 
26
  import torch.nn as nn
27
-
28
  from pretrainedmodels.models.senet import (
29
- SENet,
30
  SEBottleneck,
 
31
  SEResNetBottleneck,
32
  SEResNeXtBottleneck,
33
  pretrained_settings,
34
  )
 
35
  from ._base import EncoderMixin
36
 
37
 
 
24
  """
25
 
26
  import torch.nn as nn
 
27
  from pretrainedmodels.models.senet import (
 
28
  SEBottleneck,
29
+ SENet,
30
  SEResNetBottleneck,
31
  SEResNeXtBottleneck,
32
  pretrained_settings,
33
  )
34
+
35
  from ._base import EncoderMixin
36
 
37
 
segmentation_models_pytorch/segmentation_models_pytorch/encoders/timm_regnet.py CHANGED
@@ -1,6 +1,7 @@
1
- from ._base import EncoderMixin
2
- from timm.models.regnet import RegNet
3
  import torch.nn as nn
 
 
 
4
 
5
 
6
  class RegNetEncoder(RegNet, EncoderMixin):
@@ -39,78 +40,78 @@ class RegNetEncoder(RegNet, EncoderMixin):
39
 
40
 
41
  regnet_weights = {
42
- 'timm-regnetx_002': {
43
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_002-e7e85e5c.pth',
 
 
 
44
  },
45
- 'timm-regnetx_004': {
46
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_004-7d0e9424.pth',
47
  },
48
- 'timm-regnetx_006': {
49
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_006-85ec1baa.pth',
50
  },
51
- 'timm-regnetx_008': {
52
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_008-d8b470eb.pth',
53
  },
54
- 'timm-regnetx_016': {
55
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_016-65ca972a.pth',
56
  },
57
- 'timm-regnetx_032': {
58
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_032-ed0c7f7e.pth',
59
  },
60
- 'timm-regnetx_040': {
61
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_040-73c2a654.pth',
62
  },
63
- 'timm-regnetx_064': {
64
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_064-29278baa.pth',
65
  },
66
- 'timm-regnetx_080': {
67
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_080-7c7fcab1.pth',
68
  },
69
- 'timm-regnetx_120': {
70
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_120-65d5521e.pth',
71
  },
72
- 'timm-regnetx_160': {
73
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_160-c98c4112.pth',
74
  },
75
- 'timm-regnetx_320': {
76
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_320-8ea38b93.pth',
77
  },
78
- 'timm-regnety_002': {
79
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_002-e68ca334.pth',
80
  },
81
- 'timm-regnety_004': {
82
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_004-0db870e6.pth',
83
  },
84
- 'timm-regnety_006': {
85
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_006-c67e57ec.pth',
86
  },
87
- 'timm-regnety_008': {
88
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_008-dc900dbe.pth',
89
  },
90
- 'timm-regnety_016': {
91
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_016-54367f74.pth',
92
  },
93
- 'timm-regnety_032': {
94
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/regnety_032_ra-7f2439f9.pth'
95
  },
96
- 'timm-regnety_040': {
97
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_040-f0d569f9.pth'
98
  },
99
- 'timm-regnety_064': {
100
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_064-0a48325c.pth'
101
  },
102
- 'timm-regnety_080': {
103
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_080-e7f3eb93.pth',
104
  },
105
- 'timm-regnety_120': {
106
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_120-721ba79a.pth',
107
  },
108
- 'timm-regnety_160': {
109
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_160-d64013cd.pth',
110
  },
111
- 'timm-regnety_320': {
112
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth'
113
- }
114
  }
115
 
116
  pretrained_settings = {}
@@ -119,214 +120,224 @@ for model_name, sources in regnet_weights.items():
119
  for source_name, source_url in sources.items():
120
  pretrained_settings[model_name][source_name] = {
121
  "url": source_url,
122
- 'input_size': [3, 224, 224],
123
- 'input_range': [0, 1],
124
- 'mean': [0.485, 0.456, 0.406],
125
- 'std': [0.229, 0.224, 0.225],
126
- 'num_classes': 1000
127
  }
128
 
129
  # at this point I am too lazy to copy configs, so I just used the same configs from timm's repo
130
 
131
 
132
  def _mcfg(**kwargs):
133
- cfg = dict(se_ratio=0., bottle_ratio=1., stem_width=32)
134
  cfg.update(**kwargs)
135
  return cfg
136
 
137
 
138
  timm_regnet_encoders = {
139
- 'timm-regnetx_002': {
140
- 'encoder': RegNetEncoder,
141
  "pretrained_settings": pretrained_settings["timm-regnetx_002"],
142
- 'params': {
143
- 'out_channels': (3, 32, 24, 56, 152, 368),
144
- 'cfg': _mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13)
145
  },
146
  },
147
- 'timm-regnetx_004': {
148
- 'encoder': RegNetEncoder,
149
  "pretrained_settings": pretrained_settings["timm-regnetx_004"],
150
- 'params': {
151
- 'out_channels': (3, 32, 32, 64, 160, 384),
152
- 'cfg': _mcfg(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22)
153
  },
154
  },
155
- 'timm-regnetx_006': {
156
- 'encoder': RegNetEncoder,
157
  "pretrained_settings": pretrained_settings["timm-regnetx_006"],
158
- 'params': {
159
- 'out_channels': (3, 32, 48, 96, 240, 528),
160
- 'cfg': _mcfg(w0=48, wa=36.97, wm=2.24, group_w=24, depth=16)
161
  },
162
  },
163
- 'timm-regnetx_008': {
164
- 'encoder': RegNetEncoder,
165
  "pretrained_settings": pretrained_settings["timm-regnetx_008"],
166
- 'params': {
167
- 'out_channels': (3, 32, 64, 128, 288, 672),
168
- 'cfg': _mcfg(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16)
169
  },
170
  },
171
- 'timm-regnetx_016': {
172
- 'encoder': RegNetEncoder,
173
  "pretrained_settings": pretrained_settings["timm-regnetx_016"],
174
- 'params': {
175
- 'out_channels': (3, 32, 72, 168, 408, 912),
176
- 'cfg': _mcfg(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18)
177
  },
178
  },
179
- 'timm-regnetx_032': {
180
- 'encoder': RegNetEncoder,
181
  "pretrained_settings": pretrained_settings["timm-regnetx_032"],
182
- 'params': {
183
- 'out_channels': (3, 32, 96, 192, 432, 1008),
184
- 'cfg': _mcfg(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25)
185
  },
186
  },
187
- 'timm-regnetx_040': {
188
- 'encoder': RegNetEncoder,
189
  "pretrained_settings": pretrained_settings["timm-regnetx_040"],
190
- 'params': {
191
- 'out_channels': (3, 32, 80, 240, 560, 1360),
192
- 'cfg': _mcfg(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23)
193
  },
194
  },
195
- 'timm-regnetx_064': {
196
- 'encoder': RegNetEncoder,
197
  "pretrained_settings": pretrained_settings["timm-regnetx_064"],
198
- 'params': {
199
- 'out_channels': (3, 32, 168, 392, 784, 1624),
200
- 'cfg': _mcfg(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17)
201
  },
202
  },
203
- 'timm-regnetx_080': {
204
- 'encoder': RegNetEncoder,
205
  "pretrained_settings": pretrained_settings["timm-regnetx_080"],
206
- 'params': {
207
- 'out_channels': (3, 32, 80, 240, 720, 1920),
208
- 'cfg': _mcfg(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23)
209
  },
210
  },
211
- 'timm-regnetx_120': {
212
- 'encoder': RegNetEncoder,
213
  "pretrained_settings": pretrained_settings["timm-regnetx_120"],
214
- 'params': {
215
- 'out_channels': (3, 32, 224, 448, 896, 2240),
216
- 'cfg': _mcfg(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19)
217
  },
218
  },
219
- 'timm-regnetx_160': {
220
- 'encoder': RegNetEncoder,
221
  "pretrained_settings": pretrained_settings["timm-regnetx_160"],
222
- 'params': {
223
- 'out_channels': (3, 32, 256, 512, 896, 2048),
224
- 'cfg': _mcfg(w0=216, wa=55.59, wm=2.1, group_w=128, depth=22)
225
  },
226
  },
227
- 'timm-regnetx_320': {
228
- 'encoder': RegNetEncoder,
229
  "pretrained_settings": pretrained_settings["timm-regnetx_320"],
230
- 'params': {
231
- 'out_channels': (3, 32, 336, 672, 1344, 2520),
232
- 'cfg': _mcfg(w0=320, wa=69.86, wm=2.0, group_w=168, depth=23)
233
  },
234
  },
235
- #regnety
236
- 'timm-regnety_002': {
237
- 'encoder': RegNetEncoder,
238
  "pretrained_settings": pretrained_settings["timm-regnety_002"],
239
- 'params': {
240
- 'out_channels': (3, 32, 24, 56, 152, 368),
241
- 'cfg': _mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13, se_ratio=0.25)
242
  },
243
  },
244
- 'timm-regnety_004': {
245
- 'encoder': RegNetEncoder,
246
  "pretrained_settings": pretrained_settings["timm-regnety_004"],
247
- 'params': {
248
- 'out_channels': (3, 32, 48, 104, 208, 440),
249
- 'cfg': _mcfg(w0=48, wa=27.89, wm=2.09, group_w=8, depth=16, se_ratio=0.25)
250
  },
251
  },
252
- 'timm-regnety_006': {
253
- 'encoder': RegNetEncoder,
254
  "pretrained_settings": pretrained_settings["timm-regnety_006"],
255
- 'params': {
256
- 'out_channels': (3, 32, 48, 112, 256, 608),
257
- 'cfg': _mcfg(w0=48, wa=32.54, wm=2.32, group_w=16, depth=15, se_ratio=0.25)
258
  },
259
  },
260
- 'timm-regnety_008': {
261
- 'encoder': RegNetEncoder,
262
  "pretrained_settings": pretrained_settings["timm-regnety_008"],
263
- 'params': {
264
- 'out_channels': (3, 32, 64, 128, 320, 768),
265
- 'cfg': _mcfg(w0=56, wa=38.84, wm=2.4, group_w=16, depth=14, se_ratio=0.25)
266
  },
267
  },
268
- 'timm-regnety_016': {
269
- 'encoder': RegNetEncoder,
270
  "pretrained_settings": pretrained_settings["timm-regnety_016"],
271
- 'params': {
272
- 'out_channels': (3, 32, 48, 120, 336, 888),
273
- 'cfg': _mcfg(w0=48, wa=20.71, wm=2.65, group_w=24, depth=27, se_ratio=0.25)
274
  },
275
  },
276
- 'timm-regnety_032': {
277
- 'encoder': RegNetEncoder,
278
  "pretrained_settings": pretrained_settings["timm-regnety_032"],
279
- 'params': {
280
- 'out_channels': (3, 32, 72, 216, 576, 1512),
281
- 'cfg': _mcfg(w0=80, wa=42.63, wm=2.66, group_w=24, depth=21, se_ratio=0.25)
282
  },
283
  },
284
- 'timm-regnety_040': {
285
- 'encoder': RegNetEncoder,
286
  "pretrained_settings": pretrained_settings["timm-regnety_040"],
287
- 'params': {
288
- 'out_channels': (3, 32, 128, 192, 512, 1088),
289
- 'cfg': _mcfg(w0=96, wa=31.41, wm=2.24, group_w=64, depth=22, se_ratio=0.25)
290
  },
291
  },
292
- 'timm-regnety_064': {
293
- 'encoder': RegNetEncoder,
294
  "pretrained_settings": pretrained_settings["timm-regnety_064"],
295
- 'params': {
296
- 'out_channels': (3, 32, 144, 288, 576, 1296),
297
- 'cfg': _mcfg(w0=112, wa=33.22, wm=2.27, group_w=72, depth=25, se_ratio=0.25)
 
 
298
  },
299
  },
300
- 'timm-regnety_080': {
301
- 'encoder': RegNetEncoder,
302
  "pretrained_settings": pretrained_settings["timm-regnety_080"],
303
- 'params': {
304
- 'out_channels': (3, 32, 168, 448, 896, 2016),
305
- 'cfg': _mcfg(w0=192, wa=76.82, wm=2.19, group_w=56, depth=17, se_ratio=0.25)
 
 
306
  },
307
  },
308
- 'timm-regnety_120': {
309
- 'encoder': RegNetEncoder,
310
  "pretrained_settings": pretrained_settings["timm-regnety_120"],
311
- 'params': {
312
- 'out_channels': (3, 32, 224, 448, 896, 2240),
313
- 'cfg': _mcfg(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19, se_ratio=0.25)
 
 
314
  },
315
  },
316
- 'timm-regnety_160': {
317
- 'encoder': RegNetEncoder,
318
  "pretrained_settings": pretrained_settings["timm-regnety_160"],
319
- 'params': {
320
- 'out_channels': (3, 32, 224, 448, 1232, 3024),
321
- 'cfg': _mcfg(w0=200, wa=106.23, wm=2.48, group_w=112, depth=18, se_ratio=0.25)
 
 
322
  },
323
  },
324
- 'timm-regnety_320': {
325
- 'encoder': RegNetEncoder,
326
  "pretrained_settings": pretrained_settings["timm-regnety_320"],
327
- 'params': {
328
- 'out_channels': (3, 32, 232, 696, 1392, 3712),
329
- 'cfg': _mcfg(w0=232, wa=115.89, wm=2.53, group_w=232, depth=20, se_ratio=0.25)
 
 
330
  },
331
  },
332
  }
 
 
 
1
  import torch.nn as nn
2
+ from timm.models.regnet import RegNet
3
+
4
+ from ._base import EncoderMixin
5
 
6
 
7
  class RegNetEncoder(RegNet, EncoderMixin):
 
40
 
41
 
42
  regnet_weights = {
43
+ "timm-regnetx_002": {
44
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_002-e7e85e5c.pth",
45
+ },
46
+ "timm-regnetx_004": {
47
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_004-7d0e9424.pth",
48
  },
49
+ "timm-regnetx_006": {
50
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_006-85ec1baa.pth",
51
  },
52
+ "timm-regnetx_008": {
53
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_008-d8b470eb.pth",
54
  },
55
+ "timm-regnetx_016": {
56
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_016-65ca972a.pth",
57
  },
58
+ "timm-regnetx_032": {
59
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_032-ed0c7f7e.pth",
60
  },
61
+ "timm-regnetx_040": {
62
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_040-73c2a654.pth",
63
  },
64
+ "timm-regnetx_064": {
65
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_064-29278baa.pth",
66
  },
67
+ "timm-regnetx_080": {
68
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_080-7c7fcab1.pth",
69
  },
70
+ "timm-regnetx_120": {
71
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_120-65d5521e.pth",
72
  },
73
+ "timm-regnetx_160": {
74
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_160-c98c4112.pth",
75
  },
76
+ "timm-regnetx_320": {
77
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_320-8ea38b93.pth",
78
  },
79
+ "timm-regnety_002": {
80
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_002-e68ca334.pth",
81
  },
82
+ "timm-regnety_004": {
83
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_004-0db870e6.pth",
84
  },
85
+ "timm-regnety_006": {
86
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_006-c67e57ec.pth",
87
  },
88
+ "timm-regnety_008": {
89
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_008-dc900dbe.pth",
90
  },
91
+ "timm-regnety_016": {
92
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_016-54367f74.pth",
93
  },
94
+ "timm-regnety_032": {
95
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/regnety_032_ra-7f2439f9.pth"
96
  },
97
+ "timm-regnety_040": {
98
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_040-f0d569f9.pth"
99
  },
100
+ "timm-regnety_064": {
101
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_064-0a48325c.pth"
102
  },
103
+ "timm-regnety_080": {
104
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_080-e7f3eb93.pth",
105
  },
106
+ "timm-regnety_120": {
107
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_120-721ba79a.pth",
108
  },
109
+ "timm-regnety_160": {
110
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_160-d64013cd.pth",
111
  },
112
+ "timm-regnety_320": {
113
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth"
114
  },
 
 
 
115
  }
116
 
117
  pretrained_settings = {}
 
120
  for source_name, source_url in sources.items():
121
  pretrained_settings[model_name][source_name] = {
122
  "url": source_url,
123
+ "input_size": [3, 224, 224],
124
+ "input_range": [0, 1],
125
+ "mean": [0.485, 0.456, 0.406],
126
+ "std": [0.229, 0.224, 0.225],
127
+ "num_classes": 1000,
128
  }
129
 
130
  # at this point I am too lazy to copy configs, so I just used the same configs from timm's repo
131
 
132
 
133
  def _mcfg(**kwargs):
134
+ cfg = dict(se_ratio=0.0, bottle_ratio=1.0, stem_width=32)
135
  cfg.update(**kwargs)
136
  return cfg
137
 
138
 
139
  timm_regnet_encoders = {
140
+ "timm-regnetx_002": {
141
+ "encoder": RegNetEncoder,
142
  "pretrained_settings": pretrained_settings["timm-regnetx_002"],
143
+ "params": {
144
+ "out_channels": (3, 32, 24, 56, 152, 368),
145
+ "cfg": _mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13),
146
  },
147
  },
148
+ "timm-regnetx_004": {
149
+ "encoder": RegNetEncoder,
150
  "pretrained_settings": pretrained_settings["timm-regnetx_004"],
151
+ "params": {
152
+ "out_channels": (3, 32, 32, 64, 160, 384),
153
+ "cfg": _mcfg(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22),
154
  },
155
  },
156
+ "timm-regnetx_006": {
157
+ "encoder": RegNetEncoder,
158
  "pretrained_settings": pretrained_settings["timm-regnetx_006"],
159
+ "params": {
160
+ "out_channels": (3, 32, 48, 96, 240, 528),
161
+ "cfg": _mcfg(w0=48, wa=36.97, wm=2.24, group_w=24, depth=16),
162
  },
163
  },
164
+ "timm-regnetx_008": {
165
+ "encoder": RegNetEncoder,
166
  "pretrained_settings": pretrained_settings["timm-regnetx_008"],
167
+ "params": {
168
+ "out_channels": (3, 32, 64, 128, 288, 672),
169
+ "cfg": _mcfg(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16),
170
  },
171
  },
172
+ "timm-regnetx_016": {
173
+ "encoder": RegNetEncoder,
174
  "pretrained_settings": pretrained_settings["timm-regnetx_016"],
175
+ "params": {
176
+ "out_channels": (3, 32, 72, 168, 408, 912),
177
+ "cfg": _mcfg(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18),
178
  },
179
  },
180
+ "timm-regnetx_032": {
181
+ "encoder": RegNetEncoder,
182
  "pretrained_settings": pretrained_settings["timm-regnetx_032"],
183
+ "params": {
184
+ "out_channels": (3, 32, 96, 192, 432, 1008),
185
+ "cfg": _mcfg(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25),
186
  },
187
  },
188
+ "timm-regnetx_040": {
189
+ "encoder": RegNetEncoder,
190
  "pretrained_settings": pretrained_settings["timm-regnetx_040"],
191
+ "params": {
192
+ "out_channels": (3, 32, 80, 240, 560, 1360),
193
+ "cfg": _mcfg(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23),
194
  },
195
  },
196
+ "timm-regnetx_064": {
197
+ "encoder": RegNetEncoder,
198
  "pretrained_settings": pretrained_settings["timm-regnetx_064"],
199
+ "params": {
200
+ "out_channels": (3, 32, 168, 392, 784, 1624),
201
+ "cfg": _mcfg(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17),
202
  },
203
  },
204
+ "timm-regnetx_080": {
205
+ "encoder": RegNetEncoder,
206
  "pretrained_settings": pretrained_settings["timm-regnetx_080"],
207
+ "params": {
208
+ "out_channels": (3, 32, 80, 240, 720, 1920),
209
+ "cfg": _mcfg(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23),
210
  },
211
  },
212
+ "timm-regnetx_120": {
213
+ "encoder": RegNetEncoder,
214
  "pretrained_settings": pretrained_settings["timm-regnetx_120"],
215
+ "params": {
216
+ "out_channels": (3, 32, 224, 448, 896, 2240),
217
+ "cfg": _mcfg(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19),
218
  },
219
  },
220
+ "timm-regnetx_160": {
221
+ "encoder": RegNetEncoder,
222
  "pretrained_settings": pretrained_settings["timm-regnetx_160"],
223
+ "params": {
224
+ "out_channels": (3, 32, 256, 512, 896, 2048),
225
+ "cfg": _mcfg(w0=216, wa=55.59, wm=2.1, group_w=128, depth=22),
226
  },
227
  },
228
+ "timm-regnetx_320": {
229
+ "encoder": RegNetEncoder,
230
  "pretrained_settings": pretrained_settings["timm-regnetx_320"],
231
+ "params": {
232
+ "out_channels": (3, 32, 336, 672, 1344, 2520),
233
+ "cfg": _mcfg(w0=320, wa=69.86, wm=2.0, group_w=168, depth=23),
234
  },
235
  },
236
+ # regnety
237
+ "timm-regnety_002": {
238
+ "encoder": RegNetEncoder,
239
  "pretrained_settings": pretrained_settings["timm-regnety_002"],
240
+ "params": {
241
+ "out_channels": (3, 32, 24, 56, 152, 368),
242
+ "cfg": _mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13, se_ratio=0.25),
243
  },
244
  },
245
+ "timm-regnety_004": {
246
+ "encoder": RegNetEncoder,
247
  "pretrained_settings": pretrained_settings["timm-regnety_004"],
248
+ "params": {
249
+ "out_channels": (3, 32, 48, 104, 208, 440),
250
+ "cfg": _mcfg(w0=48, wa=27.89, wm=2.09, group_w=8, depth=16, se_ratio=0.25),
251
  },
252
  },
253
+ "timm-regnety_006": {
254
+ "encoder": RegNetEncoder,
255
  "pretrained_settings": pretrained_settings["timm-regnety_006"],
256
+ "params": {
257
+ "out_channels": (3, 32, 48, 112, 256, 608),
258
+ "cfg": _mcfg(w0=48, wa=32.54, wm=2.32, group_w=16, depth=15, se_ratio=0.25),
259
  },
260
  },
261
+ "timm-regnety_008": {
262
+ "encoder": RegNetEncoder,
263
  "pretrained_settings": pretrained_settings["timm-regnety_008"],
264
+ "params": {
265
+ "out_channels": (3, 32, 64, 128, 320, 768),
266
+ "cfg": _mcfg(w0=56, wa=38.84, wm=2.4, group_w=16, depth=14, se_ratio=0.25),
267
  },
268
  },
269
+ "timm-regnety_016": {
270
+ "encoder": RegNetEncoder,
271
  "pretrained_settings": pretrained_settings["timm-regnety_016"],
272
+ "params": {
273
+ "out_channels": (3, 32, 48, 120, 336, 888),
274
+ "cfg": _mcfg(w0=48, wa=20.71, wm=2.65, group_w=24, depth=27, se_ratio=0.25),
275
  },
276
  },
277
+ "timm-regnety_032": {
278
+ "encoder": RegNetEncoder,
279
  "pretrained_settings": pretrained_settings["timm-regnety_032"],
280
+ "params": {
281
+ "out_channels": (3, 32, 72, 216, 576, 1512),
282
+ "cfg": _mcfg(w0=80, wa=42.63, wm=2.66, group_w=24, depth=21, se_ratio=0.25),
283
  },
284
  },
285
+ "timm-regnety_040": {
286
+ "encoder": RegNetEncoder,
287
  "pretrained_settings": pretrained_settings["timm-regnety_040"],
288
+ "params": {
289
+ "out_channels": (3, 32, 128, 192, 512, 1088),
290
+ "cfg": _mcfg(w0=96, wa=31.41, wm=2.24, group_w=64, depth=22, se_ratio=0.25),
291
  },
292
  },
293
+ "timm-regnety_064": {
294
+ "encoder": RegNetEncoder,
295
  "pretrained_settings": pretrained_settings["timm-regnety_064"],
296
+ "params": {
297
+ "out_channels": (3, 32, 144, 288, 576, 1296),
298
+ "cfg": _mcfg(
299
+ w0=112, wa=33.22, wm=2.27, group_w=72, depth=25, se_ratio=0.25
300
+ ),
301
  },
302
  },
303
+ "timm-regnety_080": {
304
+ "encoder": RegNetEncoder,
305
  "pretrained_settings": pretrained_settings["timm-regnety_080"],
306
+ "params": {
307
+ "out_channels": (3, 32, 168, 448, 896, 2016),
308
+ "cfg": _mcfg(
309
+ w0=192, wa=76.82, wm=2.19, group_w=56, depth=17, se_ratio=0.25
310
+ ),
311
  },
312
  },
313
+ "timm-regnety_120": {
314
+ "encoder": RegNetEncoder,
315
  "pretrained_settings": pretrained_settings["timm-regnety_120"],
316
+ "params": {
317
+ "out_channels": (3, 32, 224, 448, 896, 2240),
318
+ "cfg": _mcfg(
319
+ w0=168, wa=73.36, wm=2.37, group_w=112, depth=19, se_ratio=0.25
320
+ ),
321
  },
322
  },
323
+ "timm-regnety_160": {
324
+ "encoder": RegNetEncoder,
325
  "pretrained_settings": pretrained_settings["timm-regnety_160"],
326
+ "params": {
327
+ "out_channels": (3, 32, 224, 448, 1232, 3024),
328
+ "cfg": _mcfg(
329
+ w0=200, wa=106.23, wm=2.48, group_w=112, depth=18, se_ratio=0.25
330
+ ),
331
  },
332
  },
333
+ "timm-regnety_320": {
334
+ "encoder": RegNetEncoder,
335
  "pretrained_settings": pretrained_settings["timm-regnety_320"],
336
+ "params": {
337
+ "out_channels": (3, 32, 232, 696, 1392, 3712),
338
+ "cfg": _mcfg(
339
+ w0=232, wa=115.89, wm=2.53, group_w=232, depth=20, se_ratio=0.25
340
+ ),
341
  },
342
  },
343
  }
segmentation_models_pytorch/segmentation_models_pytorch/encoders/timm_res2net.py CHANGED
@@ -1,7 +1,8 @@
1
- from ._base import EncoderMixin
2
- from timm.models.resnet import ResNet
3
- from timm.models.res2net import Bottle2neck
4
  import torch.nn as nn
 
 
 
 
5
 
6
 
7
  class Res2NetEncoder(ResNet, EncoderMixin):
@@ -44,27 +45,27 @@ class Res2NetEncoder(ResNet, EncoderMixin):
44
 
45
 
46
  res2net_weights = {
47
- 'timm-res2net50_26w_4s': {
48
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_4s-06e79181.pth'
49
  },
50
- 'timm-res2net50_48w_2s': {
51
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_48w_2s-afed724a.pth'
52
  },
53
- 'timm-res2net50_14w_8s': {
54
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_14w_8s-6527dddc.pth',
55
  },
56
- 'timm-res2net50_26w_6s': {
57
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_6s-19041792.pth',
58
  },
59
- 'timm-res2net50_26w_8s': {
60
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_8s-2c7c9f12.pth',
61
  },
62
- 'timm-res2net101_26w_4s': {
63
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net101_26w_4s-02a759a1.pth',
 
 
 
64
  },
65
- 'timm-res2next50': {
66
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2next50_4s-6ef7e7bf.pth',
67
- }
68
  }
69
 
70
  pretrained_settings = {}
@@ -73,91 +74,91 @@ for model_name, sources in res2net_weights.items():
73
  for source_name, source_url in sources.items():
74
  pretrained_settings[model_name][source_name] = {
75
  "url": source_url,
76
- 'input_size': [3, 224, 224],
77
- 'input_range': [0, 1],
78
- 'mean': [0.485, 0.456, 0.406],
79
- 'std': [0.229, 0.224, 0.225],
80
- 'num_classes': 1000
81
  }
82
 
83
 
84
  timm_res2net_encoders = {
85
- 'timm-res2net50_26w_4s': {
86
- 'encoder': Res2NetEncoder,
87
  "pretrained_settings": pretrained_settings["timm-res2net50_26w_4s"],
88
- 'params': {
89
- 'out_channels': (3, 64, 256, 512, 1024, 2048),
90
- 'block': Bottle2neck,
91
- 'layers': [3, 4, 6, 3],
92
- 'base_width': 26,
93
- 'block_args': {'scale': 4}
94
  },
95
  },
96
- 'timm-res2net101_26w_4s': {
97
- 'encoder': Res2NetEncoder,
98
  "pretrained_settings": pretrained_settings["timm-res2net101_26w_4s"],
99
- 'params': {
100
- 'out_channels': (3, 64, 256, 512, 1024, 2048),
101
- 'block': Bottle2neck,
102
- 'layers': [3, 4, 23, 3],
103
- 'base_width': 26,
104
- 'block_args': {'scale': 4}
105
  },
106
  },
107
- 'timm-res2net50_26w_6s': {
108
- 'encoder': Res2NetEncoder,
109
  "pretrained_settings": pretrained_settings["timm-res2net50_26w_6s"],
110
- 'params': {
111
- 'out_channels': (3, 64, 256, 512, 1024, 2048),
112
- 'block': Bottle2neck,
113
- 'layers': [3, 4, 6, 3],
114
- 'base_width': 26,
115
- 'block_args': {'scale': 6}
116
  },
117
  },
118
- 'timm-res2net50_26w_8s': {
119
- 'encoder': Res2NetEncoder,
120
  "pretrained_settings": pretrained_settings["timm-res2net50_26w_8s"],
121
- 'params': {
122
- 'out_channels': (3, 64, 256, 512, 1024, 2048),
123
- 'block': Bottle2neck,
124
- 'layers': [3, 4, 6, 3],
125
- 'base_width': 26,
126
- 'block_args': {'scale': 8}
127
  },
128
  },
129
- 'timm-res2net50_48w_2s': {
130
- 'encoder': Res2NetEncoder,
131
  "pretrained_settings": pretrained_settings["timm-res2net50_48w_2s"],
132
- 'params': {
133
- 'out_channels': (3, 64, 256, 512, 1024, 2048),
134
- 'block': Bottle2neck,
135
- 'layers': [3, 4, 6, 3],
136
- 'base_width': 48,
137
- 'block_args': {'scale': 2}
138
  },
139
  },
140
- 'timm-res2net50_14w_8s': {
141
- 'encoder': Res2NetEncoder,
142
  "pretrained_settings": pretrained_settings["timm-res2net50_14w_8s"],
143
- 'params': {
144
- 'out_channels': (3, 64, 256, 512, 1024, 2048),
145
- 'block': Bottle2neck,
146
- 'layers': [3, 4, 6, 3],
147
- 'base_width': 14,
148
- 'block_args': {'scale': 8}
149
  },
150
  },
151
- 'timm-res2next50': {
152
- 'encoder': Res2NetEncoder,
153
  "pretrained_settings": pretrained_settings["timm-res2next50"],
154
- 'params': {
155
- 'out_channels': (3, 64, 256, 512, 1024, 2048),
156
- 'block': Bottle2neck,
157
- 'layers': [3, 4, 6, 3],
158
- 'base_width': 4,
159
- 'cardinality': 8,
160
- 'block_args': {'scale': 4}
161
  },
162
- }
163
  }
 
 
 
 
1
  import torch.nn as nn
2
+ from timm.models.res2net import Bottle2neck
3
+ from timm.models.resnet import ResNet
4
+
5
+ from ._base import EncoderMixin
6
 
7
 
8
  class Res2NetEncoder(ResNet, EncoderMixin):
 
45
 
46
 
47
  res2net_weights = {
48
+ "timm-res2net50_26w_4s": {
49
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_4s-06e79181.pth"
50
  },
51
+ "timm-res2net50_48w_2s": {
52
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_48w_2s-afed724a.pth"
53
  },
54
+ "timm-res2net50_14w_8s": {
55
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_14w_8s-6527dddc.pth",
56
  },
57
+ "timm-res2net50_26w_6s": {
58
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_6s-19041792.pth",
59
  },
60
+ "timm-res2net50_26w_8s": {
61
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_8s-2c7c9f12.pth",
62
  },
63
+ "timm-res2net101_26w_4s": {
64
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net101_26w_4s-02a759a1.pth",
65
+ },
66
+ "timm-res2next50": {
67
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2next50_4s-6ef7e7bf.pth",
68
  },
 
 
 
69
  }
70
 
71
  pretrained_settings = {}
 
74
  for source_name, source_url in sources.items():
75
  pretrained_settings[model_name][source_name] = {
76
  "url": source_url,
77
+ "input_size": [3, 224, 224],
78
+ "input_range": [0, 1],
79
+ "mean": [0.485, 0.456, 0.406],
80
+ "std": [0.229, 0.224, 0.225],
81
+ "num_classes": 1000,
82
  }
83
 
84
 
85
  timm_res2net_encoders = {
86
+ "timm-res2net50_26w_4s": {
87
+ "encoder": Res2NetEncoder,
88
  "pretrained_settings": pretrained_settings["timm-res2net50_26w_4s"],
89
+ "params": {
90
+ "out_channels": (3, 64, 256, 512, 1024, 2048),
91
+ "block": Bottle2neck,
92
+ "layers": [3, 4, 6, 3],
93
+ "base_width": 26,
94
+ "block_args": {"scale": 4},
95
  },
96
  },
97
+ "timm-res2net101_26w_4s": {
98
+ "encoder": Res2NetEncoder,
99
  "pretrained_settings": pretrained_settings["timm-res2net101_26w_4s"],
100
+ "params": {
101
+ "out_channels": (3, 64, 256, 512, 1024, 2048),
102
+ "block": Bottle2neck,
103
+ "layers": [3, 4, 23, 3],
104
+ "base_width": 26,
105
+ "block_args": {"scale": 4},
106
  },
107
  },
108
+ "timm-res2net50_26w_6s": {
109
+ "encoder": Res2NetEncoder,
110
  "pretrained_settings": pretrained_settings["timm-res2net50_26w_6s"],
111
+ "params": {
112
+ "out_channels": (3, 64, 256, 512, 1024, 2048),
113
+ "block": Bottle2neck,
114
+ "layers": [3, 4, 6, 3],
115
+ "base_width": 26,
116
+ "block_args": {"scale": 6},
117
  },
118
  },
119
+ "timm-res2net50_26w_8s": {
120
+ "encoder": Res2NetEncoder,
121
  "pretrained_settings": pretrained_settings["timm-res2net50_26w_8s"],
122
+ "params": {
123
+ "out_channels": (3, 64, 256, 512, 1024, 2048),
124
+ "block": Bottle2neck,
125
+ "layers": [3, 4, 6, 3],
126
+ "base_width": 26,
127
+ "block_args": {"scale": 8},
128
  },
129
  },
130
+ "timm-res2net50_48w_2s": {
131
+ "encoder": Res2NetEncoder,
132
  "pretrained_settings": pretrained_settings["timm-res2net50_48w_2s"],
133
+ "params": {
134
+ "out_channels": (3, 64, 256, 512, 1024, 2048),
135
+ "block": Bottle2neck,
136
+ "layers": [3, 4, 6, 3],
137
+ "base_width": 48,
138
+ "block_args": {"scale": 2},
139
  },
140
  },
141
+ "timm-res2net50_14w_8s": {
142
+ "encoder": Res2NetEncoder,
143
  "pretrained_settings": pretrained_settings["timm-res2net50_14w_8s"],
144
+ "params": {
145
+ "out_channels": (3, 64, 256, 512, 1024, 2048),
146
+ "block": Bottle2neck,
147
+ "layers": [3, 4, 6, 3],
148
+ "base_width": 14,
149
+ "block_args": {"scale": 8},
150
  },
151
  },
152
+ "timm-res2next50": {
153
+ "encoder": Res2NetEncoder,
154
  "pretrained_settings": pretrained_settings["timm-res2next50"],
155
+ "params": {
156
+ "out_channels": (3, 64, 256, 512, 1024, 2048),
157
+ "block": Bottle2neck,
158
+ "layers": [3, 4, 6, 3],
159
+ "base_width": 4,
160
+ "cardinality": 8,
161
+ "block_args": {"scale": 4},
162
  },
163
+ },
164
  }
segmentation_models_pytorch/segmentation_models_pytorch/encoders/timm_resnest.py CHANGED
@@ -1,7 +1,8 @@
1
- from ._base import EncoderMixin
2
- from timm.models.resnet import ResNet
3
- from timm.models.resnest import ResNestBottleneck
4
  import torch.nn as nn
 
 
 
 
5
 
6
 
7
  class ResNestEncoder(ResNet, EncoderMixin):
@@ -44,30 +45,30 @@ class ResNestEncoder(ResNet, EncoderMixin):
44
 
45
 
46
  resnest_weights = {
47
- 'timm-resnest14d': {
48
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest14-9c8fe254.pth'
 
 
 
49
  },
50
- 'timm-resnest26d': {
51
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest26-50eb607c.pth'
52
  },
53
- 'timm-resnest50d': {
54
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50-528c19ca.pth',
55
  },
56
- 'timm-resnest101e': {
57
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest101-22405ba7.pth',
58
  },
59
- 'timm-resnest200e': {
60
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest200-75117900.pth',
61
  },
62
- 'timm-resnest269e': {
63
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest269-0cc87c48.pth',
64
  },
65
- 'timm-resnest50d_4s2x40d': {
66
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50_fast_4s2x40d-41d14ed0.pth',
67
  },
68
- 'timm-resnest50d_1s4x24d': {
69
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50_fast_1s4x24d-d4a4f76f.pth',
70
- }
71
  }
72
 
73
  pretrained_settings = {}
@@ -76,133 +77,133 @@ for model_name, sources in resnest_weights.items():
76
  for source_name, source_url in sources.items():
77
  pretrained_settings[model_name][source_name] = {
78
  "url": source_url,
79
- 'input_size': [3, 224, 224],
80
- 'input_range': [0, 1],
81
- 'mean': [0.485, 0.456, 0.406],
82
- 'std': [0.229, 0.224, 0.225],
83
- 'num_classes': 1000
84
  }
85
 
86
 
87
  timm_resnest_encoders = {
88
- 'timm-resnest14d': {
89
- 'encoder': ResNestEncoder,
90
  "pretrained_settings": pretrained_settings["timm-resnest14d"],
91
- 'params': {
92
- 'out_channels': (3, 64, 256, 512, 1024, 2048),
93
- 'block': ResNestBottleneck,
94
- 'layers': [1, 1, 1, 1],
95
- 'stem_type': 'deep',
96
- 'stem_width': 32,
97
- 'avg_down': True,
98
- 'base_width': 64,
99
- 'cardinality': 1,
100
- 'block_args': {'radix': 2, 'avd': True, 'avd_first': False}
101
- }
102
  },
103
- 'timm-resnest26d': {
104
- 'encoder': ResNestEncoder,
105
  "pretrained_settings": pretrained_settings["timm-resnest26d"],
106
- 'params': {
107
- 'out_channels': (3, 64, 256, 512, 1024, 2048),
108
- 'block': ResNestBottleneck,
109
- 'layers': [2, 2, 2, 2],
110
- 'stem_type': 'deep',
111
- 'stem_width': 32,
112
- 'avg_down': True,
113
- 'base_width': 64,
114
- 'cardinality': 1,
115
- 'block_args': {'radix': 2, 'avd': True, 'avd_first': False}
116
- }
117
  },
118
- 'timm-resnest50d': {
119
- 'encoder': ResNestEncoder,
120
  "pretrained_settings": pretrained_settings["timm-resnest50d"],
121
- 'params': {
122
- 'out_channels': (3, 64, 256, 512, 1024, 2048),
123
- 'block': ResNestBottleneck,
124
- 'layers': [3, 4, 6, 3],
125
- 'stem_type': 'deep',
126
- 'stem_width': 32,
127
- 'avg_down': True,
128
- 'base_width': 64,
129
- 'cardinality': 1,
130
- 'block_args': {'radix': 2, 'avd': True, 'avd_first': False}
131
- }
132
  },
133
- 'timm-resnest101e': {
134
- 'encoder': ResNestEncoder,
135
  "pretrained_settings": pretrained_settings["timm-resnest101e"],
136
- 'params': {
137
- 'out_channels': (3, 128, 256, 512, 1024, 2048),
138
- 'block': ResNestBottleneck,
139
- 'layers': [3, 4, 23, 3],
140
- 'stem_type': 'deep',
141
- 'stem_width': 64,
142
- 'avg_down': True,
143
- 'base_width': 64,
144
- 'cardinality': 1,
145
- 'block_args': {'radix': 2, 'avd': True, 'avd_first': False}
146
- }
147
  },
148
- 'timm-resnest200e': {
149
- 'encoder': ResNestEncoder,
150
  "pretrained_settings": pretrained_settings["timm-resnest200e"],
151
- 'params': {
152
- 'out_channels': (3, 128, 256, 512, 1024, 2048),
153
- 'block': ResNestBottleneck,
154
- 'layers': [3, 24, 36, 3],
155
- 'stem_type': 'deep',
156
- 'stem_width': 64,
157
- 'avg_down': True,
158
- 'base_width': 64,
159
- 'cardinality': 1,
160
- 'block_args': {'radix': 2, 'avd': True, 'avd_first': False}
161
- }
162
  },
163
- 'timm-resnest269e': {
164
- 'encoder': ResNestEncoder,
165
  "pretrained_settings": pretrained_settings["timm-resnest269e"],
166
- 'params': {
167
- 'out_channels': (3, 128, 256, 512, 1024, 2048),
168
- 'block': ResNestBottleneck,
169
- 'layers': [3, 30, 48, 8],
170
- 'stem_type': 'deep',
171
- 'stem_width': 64,
172
- 'avg_down': True,
173
- 'base_width': 64,
174
- 'cardinality': 1,
175
- 'block_args': {'radix': 2, 'avd': True, 'avd_first': False}
176
  },
177
  },
178
- 'timm-resnest50d_4s2x40d': {
179
- 'encoder': ResNestEncoder,
180
  "pretrained_settings": pretrained_settings["timm-resnest50d_4s2x40d"],
181
- 'params': {
182
- 'out_channels': (3, 64, 256, 512, 1024, 2048),
183
- 'block': ResNestBottleneck,
184
- 'layers': [3, 4, 6, 3],
185
- 'stem_type': 'deep',
186
- 'stem_width': 32,
187
- 'avg_down': True,
188
- 'base_width': 40,
189
- 'cardinality': 2,
190
- 'block_args': {'radix': 4, 'avd': True, 'avd_first': True}
191
- }
192
  },
193
- 'timm-resnest50d_1s4x24d': {
194
- 'encoder': ResNestEncoder,
195
  "pretrained_settings": pretrained_settings["timm-resnest50d_1s4x24d"],
196
- 'params': {
197
- 'out_channels': (3, 64, 256, 512, 1024, 2048),
198
- 'block': ResNestBottleneck,
199
- 'layers': [3, 4, 6, 3],
200
- 'stem_type': 'deep',
201
- 'stem_width': 32,
202
- 'avg_down': True,
203
- 'base_width': 24,
204
- 'cardinality': 4,
205
- 'block_args': {'radix': 1, 'avd': True, 'avd_first': True}
206
- }
207
- }
208
  }
 
 
 
 
1
  import torch.nn as nn
2
+ from timm.models.resnest import ResNestBottleneck
3
+ from timm.models.resnet import ResNet
4
+
5
+ from ._base import EncoderMixin
6
 
7
 
8
  class ResNestEncoder(ResNet, EncoderMixin):
 
45
 
46
 
47
  resnest_weights = {
48
+ "timm-resnest14d": {
49
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest14-9c8fe254.pth"
50
+ },
51
+ "timm-resnest26d": {
52
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest26-50eb607c.pth"
53
  },
54
+ "timm-resnest50d": {
55
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50-528c19ca.pth",
56
  },
57
+ "timm-resnest101e": {
58
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest101-22405ba7.pth",
59
  },
60
+ "timm-resnest200e": {
61
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest200-75117900.pth",
62
  },
63
+ "timm-resnest269e": {
64
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest269-0cc87c48.pth",
65
  },
66
+ "timm-resnest50d_4s2x40d": {
67
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50_fast_4s2x40d-41d14ed0.pth",
68
  },
69
+ "timm-resnest50d_1s4x24d": {
70
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50_fast_1s4x24d-d4a4f76f.pth",
71
  },
 
 
 
72
  }
73
 
74
  pretrained_settings = {}
 
77
  for source_name, source_url in sources.items():
78
  pretrained_settings[model_name][source_name] = {
79
  "url": source_url,
80
+ "input_size": [3, 224, 224],
81
+ "input_range": [0, 1],
82
+ "mean": [0.485, 0.456, 0.406],
83
+ "std": [0.229, 0.224, 0.225],
84
+ "num_classes": 1000,
85
  }
86
 
87
 
88
  timm_resnest_encoders = {
89
+ "timm-resnest14d": {
90
+ "encoder": ResNestEncoder,
91
  "pretrained_settings": pretrained_settings["timm-resnest14d"],
92
+ "params": {
93
+ "out_channels": (3, 64, 256, 512, 1024, 2048),
94
+ "block": ResNestBottleneck,
95
+ "layers": [1, 1, 1, 1],
96
+ "stem_type": "deep",
97
+ "stem_width": 32,
98
+ "avg_down": True,
99
+ "base_width": 64,
100
+ "cardinality": 1,
101
+ "block_args": {"radix": 2, "avd": True, "avd_first": False},
102
+ },
103
  },
104
+ "timm-resnest26d": {
105
+ "encoder": ResNestEncoder,
106
  "pretrained_settings": pretrained_settings["timm-resnest26d"],
107
+ "params": {
108
+ "out_channels": (3, 64, 256, 512, 1024, 2048),
109
+ "block": ResNestBottleneck,
110
+ "layers": [2, 2, 2, 2],
111
+ "stem_type": "deep",
112
+ "stem_width": 32,
113
+ "avg_down": True,
114
+ "base_width": 64,
115
+ "cardinality": 1,
116
+ "block_args": {"radix": 2, "avd": True, "avd_first": False},
117
+ },
118
  },
119
+ "timm-resnest50d": {
120
+ "encoder": ResNestEncoder,
121
  "pretrained_settings": pretrained_settings["timm-resnest50d"],
122
+ "params": {
123
+ "out_channels": (3, 64, 256, 512, 1024, 2048),
124
+ "block": ResNestBottleneck,
125
+ "layers": [3, 4, 6, 3],
126
+ "stem_type": "deep",
127
+ "stem_width": 32,
128
+ "avg_down": True,
129
+ "base_width": 64,
130
+ "cardinality": 1,
131
+ "block_args": {"radix": 2, "avd": True, "avd_first": False},
132
+ },
133
  },
134
+ "timm-resnest101e": {
135
+ "encoder": ResNestEncoder,
136
  "pretrained_settings": pretrained_settings["timm-resnest101e"],
137
+ "params": {
138
+ "out_channels": (3, 128, 256, 512, 1024, 2048),
139
+ "block": ResNestBottleneck,
140
+ "layers": [3, 4, 23, 3],
141
+ "stem_type": "deep",
142
+ "stem_width": 64,
143
+ "avg_down": True,
144
+ "base_width": 64,
145
+ "cardinality": 1,
146
+ "block_args": {"radix": 2, "avd": True, "avd_first": False},
147
+ },
148
  },
149
+ "timm-resnest200e": {
150
+ "encoder": ResNestEncoder,
151
  "pretrained_settings": pretrained_settings["timm-resnest200e"],
152
+ "params": {
153
+ "out_channels": (3, 128, 256, 512, 1024, 2048),
154
+ "block": ResNestBottleneck,
155
+ "layers": [3, 24, 36, 3],
156
+ "stem_type": "deep",
157
+ "stem_width": 64,
158
+ "avg_down": True,
159
+ "base_width": 64,
160
+ "cardinality": 1,
161
+ "block_args": {"radix": 2, "avd": True, "avd_first": False},
162
+ },
163
  },
164
+ "timm-resnest269e": {
165
+ "encoder": ResNestEncoder,
166
  "pretrained_settings": pretrained_settings["timm-resnest269e"],
167
+ "params": {
168
+ "out_channels": (3, 128, 256, 512, 1024, 2048),
169
+ "block": ResNestBottleneck,
170
+ "layers": [3, 30, 48, 8],
171
+ "stem_type": "deep",
172
+ "stem_width": 64,
173
+ "avg_down": True,
174
+ "base_width": 64,
175
+ "cardinality": 1,
176
+ "block_args": {"radix": 2, "avd": True, "avd_first": False},
177
  },
178
  },
179
+ "timm-resnest50d_4s2x40d": {
180
+ "encoder": ResNestEncoder,
181
  "pretrained_settings": pretrained_settings["timm-resnest50d_4s2x40d"],
182
+ "params": {
183
+ "out_channels": (3, 64, 256, 512, 1024, 2048),
184
+ "block": ResNestBottleneck,
185
+ "layers": [3, 4, 6, 3],
186
+ "stem_type": "deep",
187
+ "stem_width": 32,
188
+ "avg_down": True,
189
+ "base_width": 40,
190
+ "cardinality": 2,
191
+ "block_args": {"radix": 4, "avd": True, "avd_first": True},
192
+ },
193
  },
194
+ "timm-resnest50d_1s4x24d": {
195
+ "encoder": ResNestEncoder,
196
  "pretrained_settings": pretrained_settings["timm-resnest50d_1s4x24d"],
197
+ "params": {
198
+ "out_channels": (3, 64, 256, 512, 1024, 2048),
199
+ "block": ResNestBottleneck,
200
+ "layers": [3, 4, 6, 3],
201
+ "stem_type": "deep",
202
+ "stem_width": 32,
203
+ "avg_down": True,
204
+ "base_width": 24,
205
+ "cardinality": 4,
206
+ "block_args": {"radix": 1, "avd": True, "avd_first": True},
207
+ },
208
+ },
209
  }
segmentation_models_pytorch/segmentation_models_pytorch/encoders/timm_sknet.py CHANGED
@@ -1,7 +1,8 @@
1
- from ._base import EncoderMixin
2
- from timm.models.resnet import ResNet
3
- from timm.models.sknet import SelectiveKernelBottleneck, SelectiveKernelBasic
4
  import torch.nn as nn
 
 
 
 
5
 
6
 
7
  class SkNetEncoder(ResNet, EncoderMixin):
@@ -41,15 +42,15 @@ class SkNetEncoder(ResNet, EncoderMixin):
41
 
42
 
43
  sknet_weights = {
44
- 'timm-skresnet18': {
45
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet18_ra-4eec2804.pth'
46
  },
47
- 'timm-skresnet34': {
48
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet34_ra-bdc0ccde.pth'
 
 
 
49
  },
50
- 'timm-skresnext50_32x4d': {
51
- 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnext50_ra-f40e40bf.pth',
52
- }
53
  }
54
 
55
  pretrained_settings = {}
@@ -58,46 +59,58 @@ for model_name, sources in sknet_weights.items():
58
  for source_name, source_url in sources.items():
59
  pretrained_settings[model_name][source_name] = {
60
  "url": source_url,
61
- 'input_size': [3, 224, 224],
62
- 'input_range': [0, 1],
63
- 'mean': [0.485, 0.456, 0.406],
64
- 'std': [0.229, 0.224, 0.225],
65
- 'num_classes': 1000
66
  }
67
 
68
  timm_sknet_encoders = {
69
- 'timm-skresnet18': {
70
- 'encoder': SkNetEncoder,
71
  "pretrained_settings": pretrained_settings["timm-skresnet18"],
72
- 'params': {
73
- 'out_channels': (3, 64, 64, 128, 256, 512),
74
- 'block': SelectiveKernelBasic,
75
- 'layers': [2, 2, 2, 2],
76
- 'zero_init_last_bn': False,
77
- 'block_args': {'sk_kwargs': {'min_attn_channels': 16, 'attn_reduction': 8, 'split_input': True}}
78
- }
 
 
 
 
 
 
79
  },
80
- 'timm-skresnet34': {
81
- 'encoder': SkNetEncoder,
82
  "pretrained_settings": pretrained_settings["timm-skresnet34"],
83
- 'params': {
84
- 'out_channels': (3, 64, 64, 128, 256, 512),
85
- 'block': SelectiveKernelBasic,
86
- 'layers': [3, 4, 6, 3],
87
- 'zero_init_last_bn': False,
88
- 'block_args': {'sk_kwargs': {'min_attn_channels': 16, 'attn_reduction': 8, 'split_input': True}}
89
- }
 
 
 
 
 
 
90
  },
91
- 'timm-skresnext50_32x4d': {
92
- 'encoder': SkNetEncoder,
93
  "pretrained_settings": pretrained_settings["timm-skresnext50_32x4d"],
94
- 'params': {
95
- 'out_channels': (3, 64, 256, 512, 1024, 2048),
96
- 'block': SelectiveKernelBottleneck,
97
- 'layers': [3, 4, 6, 3],
98
- 'zero_init_last_bn': False,
99
- 'cardinality': 32,
100
- 'base_width': 4
101
- }
102
- }
103
  }
 
 
 
 
1
  import torch.nn as nn
2
+ from timm.models.resnet import ResNet
3
+ from timm.models.sknet import SelectiveKernelBasic, SelectiveKernelBottleneck
4
+
5
+ from ._base import EncoderMixin
6
 
7
 
8
  class SkNetEncoder(ResNet, EncoderMixin):
 
42
 
43
 
44
  sknet_weights = {
45
+ "timm-skresnet18": {
46
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet18_ra-4eec2804.pth"
47
  },
48
+ "timm-skresnet34": {
49
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet34_ra-bdc0ccde.pth"
50
+ },
51
+ "timm-skresnext50_32x4d": {
52
+ "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnext50_ra-f40e40bf.pth",
53
  },
 
 
 
54
  }
55
 
56
  pretrained_settings = {}
 
59
  for source_name, source_url in sources.items():
60
  pretrained_settings[model_name][source_name] = {
61
  "url": source_url,
62
+ "input_size": [3, 224, 224],
63
+ "input_range": [0, 1],
64
+ "mean": [0.485, 0.456, 0.406],
65
+ "std": [0.229, 0.224, 0.225],
66
+ "num_classes": 1000,
67
  }
68
 
69
  timm_sknet_encoders = {
70
+ "timm-skresnet18": {
71
+ "encoder": SkNetEncoder,
72
  "pretrained_settings": pretrained_settings["timm-skresnet18"],
73
+ "params": {
74
+ "out_channels": (3, 64, 64, 128, 256, 512),
75
+ "block": SelectiveKernelBasic,
76
+ "layers": [2, 2, 2, 2],
77
+ "zero_init_last_bn": False,
78
+ "block_args": {
79
+ "sk_kwargs": {
80
+ "min_attn_channels": 16,
81
+ "attn_reduction": 8,
82
+ "split_input": True,
83
+ }
84
+ },
85
+ },
86
  },
87
+ "timm-skresnet34": {
88
+ "encoder": SkNetEncoder,
89
  "pretrained_settings": pretrained_settings["timm-skresnet34"],
90
+ "params": {
91
+ "out_channels": (3, 64, 64, 128, 256, 512),
92
+ "block": SelectiveKernelBasic,
93
+ "layers": [3, 4, 6, 3],
94
+ "zero_init_last_bn": False,
95
+ "block_args": {
96
+ "sk_kwargs": {
97
+ "min_attn_channels": 16,
98
+ "attn_reduction": 8,
99
+ "split_input": True,
100
+ }
101
+ },
102
+ },
103
  },
104
+ "timm-skresnext50_32x4d": {
105
+ "encoder": SkNetEncoder,
106
  "pretrained_settings": pretrained_settings["timm-skresnext50_32x4d"],
107
+ "params": {
108
+ "out_channels": (3, 64, 256, 512, 1024, 2048),
109
+ "block": SelectiveKernelBottleneck,
110
+ "layers": [3, 4, 6, 3],
111
+ "zero_init_last_bn": False,
112
+ "cardinality": 32,
113
+ "base_width": 4,
114
+ },
115
+ },
116
  }
segmentation_models_pytorch/segmentation_models_pytorch/encoders/vgg.py CHANGED
@@ -24,9 +24,8 @@ Methods:
24
  """
25
 
26
  import torch.nn as nn
27
- from torchvision.models.vgg import VGG
28
- from torchvision.models.vgg import make_layers
29
  from pretrainedmodels.models.torchvision_models import pretrained_settings
 
30
 
31
  from ._base import EncoderMixin
32
 
@@ -49,8 +48,10 @@ class VGGEncoder(VGG, EncoderMixin):
49
  del self.classifier
50
 
51
  def make_dilated(self, stage_list, dilation_list):
52
- raise ValueError("'VGG' models do not support dilated mode due to Max Pooling"
53
- " operations for downsampling!")
 
 
54
 
55
  def get_stages(self):
56
  stages = []
 
24
  """
25
 
26
  import torch.nn as nn
 
 
27
  from pretrainedmodels.models.torchvision_models import pretrained_settings
28
+ from torchvision.models.vgg import VGG, make_layers
29
 
30
  from ._base import EncoderMixin
31
 
 
48
  del self.classifier
49
 
50
  def make_dilated(self, stage_list, dilation_list):
51
+ raise ValueError(
52
+ "'VGG' models do not support dilated mode due to Max Pooling"
53
+ " operations for downsampling!"
54
+ )
55
 
56
  def get_stages(self):
57
  stages = []
segmentation_models_pytorch/segmentation_models_pytorch/encoders/xception.py CHANGED
@@ -1,14 +1,12 @@
1
  import re
2
- import torch.nn as nn
3
 
4
- from pretrainedmodels.models.xception import pretrained_settings
5
- from pretrainedmodels.models.xception import Xception
6
 
7
  from ._base import EncoderMixin
8
 
9
 
10
  class XceptionEncoder(Xception, EncoderMixin):
11
-
12
  def __init__(self, out_channels, *args, depth=5, **kwargs):
13
  super().__init__(*args, **kwargs)
14
 
@@ -23,18 +21,33 @@ class XceptionEncoder(Xception, EncoderMixin):
23
  del self.fc
24
 
25
  def make_dilated(self, stage_list, dilation_list):
26
- raise ValueError("Xception encoder does not support dilated mode "
27
- "due to pooling operation for downsampling!")
 
 
28
 
29
  def get_stages(self):
30
  return [
31
  nn.Identity(),
32
- nn.Sequential(self.conv1, self.bn1, self.relu, self.conv2, self.bn2, self.relu),
 
 
33
  self.block1,
34
  self.block2,
35
- nn.Sequential(self.block3, self.block4, self.block5, self.block6, self.block7,
36
- self.block8, self.block9, self.block10, self.block11),
37
- nn.Sequential(self.block12, self.conv3, self.bn3, self.relu, self.conv4, self.bn4),
 
 
 
 
 
 
 
 
 
 
 
38
  ]
39
 
40
  def forward(self, x):
@@ -49,18 +62,18 @@ class XceptionEncoder(Xception, EncoderMixin):
49
 
50
  def load_state_dict(self, state_dict):
51
  # remove linear
52
- state_dict.pop('fc.bias')
53
- state_dict.pop('fc.weight')
54
 
55
  super().load_state_dict(state_dict)
56
 
57
 
58
  xception_encoders = {
59
- 'xception': {
60
- 'encoder': XceptionEncoder,
61
- 'pretrained_settings': pretrained_settings['xception'],
62
- 'params': {
63
- 'out_channels': (3, 64, 128, 256, 728, 2048),
64
- }
65
  },
66
  }
 
1
  import re
 
2
 
3
+ import torch.nn as nn
4
+ from pretrainedmodels.models.xception import Xception, pretrained_settings
5
 
6
  from ._base import EncoderMixin
7
 
8
 
9
  class XceptionEncoder(Xception, EncoderMixin):
 
10
  def __init__(self, out_channels, *args, depth=5, **kwargs):
11
  super().__init__(*args, **kwargs)
12
 
 
21
  del self.fc
22
 
23
  def make_dilated(self, stage_list, dilation_list):
24
+ raise ValueError(
25
+ "Xception encoder does not support dilated mode "
26
+ "due to pooling operation for downsampling!"
27
+ )
28
 
29
  def get_stages(self):
30
  return [
31
  nn.Identity(),
32
+ nn.Sequential(
33
+ self.conv1, self.bn1, self.relu, self.conv2, self.bn2, self.relu
34
+ ),
35
  self.block1,
36
  self.block2,
37
+ nn.Sequential(
38
+ self.block3,
39
+ self.block4,
40
+ self.block5,
41
+ self.block6,
42
+ self.block7,
43
+ self.block8,
44
+ self.block9,
45
+ self.block10,
46
+ self.block11,
47
+ ),
48
+ nn.Sequential(
49
+ self.block12, self.conv3, self.bn3, self.relu, self.conv4, self.bn4
50
+ ),
51
  ]
52
 
53
  def forward(self, x):
 
62
 
63
  def load_state_dict(self, state_dict):
64
  # remove linear
65
+ state_dict.pop("fc.bias")
66
+ state_dict.pop("fc.weight")
67
 
68
  super().load_state_dict(state_dict)
69
 
70
 
71
  xception_encoders = {
72
+ "xception": {
73
+ "encoder": XceptionEncoder,
74
+ "pretrained_settings": pretrained_settings["xception"],
75
+ "params": {
76
+ "out_channels": (3, 64, 128, 256, 728, 2048),
77
+ },
78
  },
79
  }
segmentation_models_pytorch/segmentation_models_pytorch/fpn/__init__.py CHANGED
@@ -1 +1 @@
1
- from .model import FPN
 
1
+ from .model import FPN
segmentation_models_pytorch/segmentation_models_pytorch/fpn/decoder.py CHANGED
@@ -55,51 +55,63 @@ class MergeBlock(nn.Module):
55
  super().__init__()
56
  if policy not in ["add", "cat"]:
57
  raise ValueError(
58
- "`merge_policy` must be one of: ['add', 'cat'], got {}".format(
59
- policy
60
- )
61
  )
62
  self.policy = policy
63
 
64
  def forward(self, x):
65
- if self.policy == 'add':
66
  return sum(x)
67
- elif self.policy == 'cat':
68
  return torch.cat(x, dim=1)
69
  else:
70
  raise ValueError(
71
- "`merge_policy` must be one of: ['add', 'cat'], got {}".format(self.policy)
 
 
72
  )
73
 
74
 
75
  class FPNDecoder(nn.Module):
76
  def __init__(
77
- self,
78
- encoder_channels,
79
- encoder_depth=5,
80
- pyramid_channels=256,
81
- segmentation_channels=128,
82
- dropout=0.2,
83
- merge_policy="add",
84
  ):
85
  super().__init__()
86
 
87
- self.out_channels = segmentation_channels if merge_policy == "add" else segmentation_channels * 4
 
 
 
 
88
  if encoder_depth < 3:
89
- raise ValueError("Encoder depth for FPN decoder cannot be less than 3, got {}.".format(encoder_depth))
 
 
 
 
90
 
91
  encoder_channels = encoder_channels[::-1]
92
- encoder_channels = encoder_channels[:encoder_depth + 1]
93
 
94
  self.p5 = nn.Conv2d(encoder_channels[0], pyramid_channels, kernel_size=1)
95
  self.p4 = FPNBlock(pyramid_channels, encoder_channels[1])
96
  self.p3 = FPNBlock(pyramid_channels, encoder_channels[2])
97
  self.p2 = FPNBlock(pyramid_channels, encoder_channels[3])
98
 
99
- self.seg_blocks = nn.ModuleList([
100
- SegmentationBlock(pyramid_channels, segmentation_channels, n_upsamples=n_upsamples)
101
- for n_upsamples in [3, 2, 1, 0]
102
- ])
 
 
 
 
103
 
104
  self.merge = MergeBlock(merge_policy)
105
  self.dropout = nn.Dropout2d(p=dropout, inplace=True)
@@ -112,7 +124,9 @@ class FPNDecoder(nn.Module):
112
  p3 = self.p3(p4, c3)
113
  p2 = self.p2(p3, c2)
114
 
115
- feature_pyramid = [seg_block(p) for seg_block, p in zip(self.seg_blocks, [p5, p4, p3, p2])]
 
 
116
  x = self.merge(feature_pyramid)
117
  x = self.dropout(x)
118
 
 
55
  super().__init__()
56
  if policy not in ["add", "cat"]:
57
  raise ValueError(
58
+ "`merge_policy` must be one of: ['add', 'cat'], got {}".format(policy)
 
 
59
  )
60
  self.policy = policy
61
 
62
  def forward(self, x):
63
+ if self.policy == "add":
64
  return sum(x)
65
+ elif self.policy == "cat":
66
  return torch.cat(x, dim=1)
67
  else:
68
  raise ValueError(
69
+ "`merge_policy` must be one of: ['add', 'cat'], got {}".format(
70
+ self.policy
71
+ )
72
  )
73
 
74
 
75
  class FPNDecoder(nn.Module):
76
  def __init__(
77
+ self,
78
+ encoder_channels,
79
+ encoder_depth=5,
80
+ pyramid_channels=256,
81
+ segmentation_channels=128,
82
+ dropout=0.2,
83
+ merge_policy="add",
84
  ):
85
  super().__init__()
86
 
87
+ self.out_channels = (
88
+ segmentation_channels
89
+ if merge_policy == "add"
90
+ else segmentation_channels * 4
91
+ )
92
  if encoder_depth < 3:
93
+ raise ValueError(
94
+ "Encoder depth for FPN decoder cannot be less than 3, got {}.".format(
95
+ encoder_depth
96
+ )
97
+ )
98
 
99
  encoder_channels = encoder_channels[::-1]
100
+ encoder_channels = encoder_channels[: encoder_depth + 1]
101
 
102
  self.p5 = nn.Conv2d(encoder_channels[0], pyramid_channels, kernel_size=1)
103
  self.p4 = FPNBlock(pyramid_channels, encoder_channels[1])
104
  self.p3 = FPNBlock(pyramid_channels, encoder_channels[2])
105
  self.p2 = FPNBlock(pyramid_channels, encoder_channels[3])
106
 
107
+ self.seg_blocks = nn.ModuleList(
108
+ [
109
+ SegmentationBlock(
110
+ pyramid_channels, segmentation_channels, n_upsamples=n_upsamples
111
+ )
112
+ for n_upsamples in [3, 2, 1, 0]
113
+ ]
114
+ )
115
 
116
  self.merge = MergeBlock(merge_policy)
117
  self.dropout = nn.Dropout2d(p=dropout, inplace=True)
 
124
  p3 = self.p3(p4, c3)
125
  p2 = self.p2(p3, c2)
126
 
127
+ feature_pyramid = [
128
+ seg_block(p) for seg_block, p in zip(self.seg_blocks, [p5, p4, p3, p2])
129
+ ]
130
  x = self.merge(feature_pyramid)
131
  x = self.dropout(x)
132
 
segmentation_models_pytorch/segmentation_models_pytorch/fpn/model.py CHANGED
@@ -1,7 +1,8 @@
1
  from typing import Optional, Union
2
- from .decoder import FPNDecoder
3
- from ..base import SegmentationModel, SegmentationHead, ClassificationHead
4
  from ..encoders import get_encoder
 
5
 
6
 
7
  class FPN(SegmentationModel):
@@ -10,11 +11,11 @@ class FPN(SegmentationModel):
10
  Args:
11
  encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
12
  to extract features of different spatial resolution
13
- encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
14
  two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
15
  with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
16
  Default is 5
17
- encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
18
  other pretrained weights (see table with available weights for each encoder_name)
19
  decoder_pyramid_channels: A number of convolution filters in Feature Pyramid of FPN_
20
  decoder_segmentation_channels: A number of convolution filters in segmentation blocks of FPN_
@@ -26,7 +27,7 @@ class FPN(SegmentationModel):
26
  Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, **callable** and **None**.
27
  Default is **None**
28
  upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity
29
- aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
30
  on top of encoder if **aux_params** is not **None** (default). Supported params:
31
  - classes (int): A number of classes
32
  - pooling (str): One of "max", "avg". Default is "avg"
 
1
  from typing import Optional, Union
2
+
3
+ from ..base import ClassificationHead, SegmentationHead, SegmentationModel
4
  from ..encoders import get_encoder
5
+ from .decoder import FPNDecoder
6
 
7
 
8
  class FPN(SegmentationModel):
 
11
  Args:
12
  encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
13
  to extract features of different spatial resolution
14
+ encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
15
  two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
16
  with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
17
  Default is 5
18
+ encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
19
  other pretrained weights (see table with available weights for each encoder_name)
20
  decoder_pyramid_channels: A number of convolution filters in Feature Pyramid of FPN_
21
  decoder_segmentation_channels: A number of convolution filters in segmentation blocks of FPN_
 
27
  Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, **callable** and **None**.
28
  Default is **None**
29
  upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity
30
+ aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
31
  on top of encoder if **aux_params** is not **None** (default). Supported params:
32
  - classes (int): A number of classes
33
  - pooling (str): One of "max", "avg". Default is "avg"