FatemehT's picture
style: run pre-commit
8e6512c
raw
history blame
4.79 kB
import os
import sys
import mock
import pytest
import torch
# mock detection module
sys.modules["torchvision._C"] = mock.Mock()
import segmentation_models_pytorch as smp
IS_TRAVIS = os.environ.get("TRAVIS", False)
def get_encoders():
travis_exclude_encoders = [
"senet154",
"resnext101_32x16d",
"resnext101_32x32d",
"resnext101_32x48d",
]
encoders = smp.encoders.get_encoder_names()
if IS_TRAVIS:
encoders = [e for e in encoders if e not in travis_exclude_encoders]
return encoders
ENCODERS = get_encoders()
DEFAULT_ENCODER = "resnet18"
def get_sample(model_class):
if model_class in [
smp.Unet,
smp.Linknet,
smp.FPN,
smp.PSPNet,
smp.UnetPlusPlus,
smp.MAnet,
]:
sample = torch.ones([1, 3, 64, 64])
elif model_class == smp.PAN:
sample = torch.ones([2, 3, 256, 256])
elif model_class == smp.DeepLabV3:
sample = torch.ones([2, 3, 128, 128])
else:
raise ValueError("Not supported model class {}".format(model_class))
return sample
def _test_forward(model, sample, test_shape=False):
with torch.no_grad():
out = model(sample)
if test_shape:
assert out.shape[2:] == sample.shape[2:]
def _test_forward_backward(model, sample, test_shape=False):
out = model(sample)
out.mean().backward()
if test_shape:
assert out.shape[2:] == sample.shape[2:]
@pytest.mark.parametrize("encoder_name", ENCODERS)
@pytest.mark.parametrize("encoder_depth", [3, 5])
@pytest.mark.parametrize(
"model_class", [smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus]
)
def test_forward(model_class, encoder_name, encoder_depth, **kwargs):
if (
model_class is smp.Unet
or model_class is smp.UnetPlusPlus
or model_class is smp.MAnet
):
kwargs["decoder_channels"] = (16, 16, 16, 16, 16)[-encoder_depth:]
model = model_class(
encoder_name, encoder_depth=encoder_depth, encoder_weights=None, **kwargs
)
sample = get_sample(model_class)
model.eval()
if encoder_depth == 5 and model_class != smp.PSPNet:
test_shape = True
else:
test_shape = False
_test_forward(model, sample, test_shape)
@pytest.mark.parametrize(
"model_class",
[
smp.PAN,
smp.FPN,
smp.PSPNet,
smp.Linknet,
smp.Unet,
smp.UnetPlusPlus,
smp.MAnet,
smp.DeepLabV3,
],
)
def test_forward_backward(model_class):
sample = get_sample(model_class)
model = model_class(DEFAULT_ENCODER, encoder_weights=None)
_test_forward_backward(model, sample)
@pytest.mark.parametrize(
"model_class",
[smp.PAN, smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus, smp.MAnet],
)
def test_aux_output(model_class):
model = model_class(
DEFAULT_ENCODER, encoder_weights=None, aux_params=dict(classes=2)
)
sample = get_sample(model_class)
label_size = (sample.shape[0], 2)
mask, label = model(sample)
assert label.size() == label_size
@pytest.mark.parametrize("upsampling", [2, 4, 8])
@pytest.mark.parametrize("model_class", [smp.FPN, smp.PSPNet])
def test_upsample(model_class, upsampling):
default_upsampling = 4 if model_class is smp.FPN else 8
model = model_class(DEFAULT_ENCODER, encoder_weights=None, upsampling=upsampling)
sample = get_sample(model_class)
mask = model(sample)
assert mask.size()[-1] / 64 == upsampling / default_upsampling
@pytest.mark.parametrize("model_class", [smp.FPN])
@pytest.mark.parametrize("encoder_name", ENCODERS)
@pytest.mark.parametrize("in_channels", [1, 2, 4])
def test_in_channels(model_class, encoder_name, in_channels):
sample = torch.ones([1, in_channels, 64, 64])
model = model_class(DEFAULT_ENCODER, encoder_weights=None, in_channels=in_channels)
model.eval()
with torch.no_grad():
model(sample)
assert model.encoder._in_channels == in_channels
@pytest.mark.parametrize("encoder_name", ENCODERS)
def test_dilation(encoder_name):
if (
encoder_name in ["inceptionresnetv2", "xception", "inceptionv4"]
or encoder_name.startswith("vgg")
or encoder_name.startswith("densenet")
or encoder_name.startswith("timm-res")
):
return
encoder = smp.encoders.get_encoder(encoder_name)
encoder.make_dilated(
stage_list=[5],
dilation_list=[2],
)
encoder.eval()
with torch.no_grad():
sample = torch.ones([1, 3, 64, 64])
output = encoder(sample)
shapes = [out.shape[-1] for out in output]
assert shapes == [64, 32, 16, 8, 4, 4] # last downsampling replaced with dilation
if __name__ == "__main__":
pytest.main([__file__])