| import sys | |
| import torch | |
| import torch.nn as nn | |
| from transformers import PreTrainedModel | |
| from .ProbUNet_model import InjectionConvEncoder2D, InjectionUNet2D, InjectionConvEncoder3D, InjectionUNet3D, ProbabilisticSegmentationNet | |
| from .PULASkiConfigs import ProbUNetConfig | |
| class ProbUNet(PreTrainedModel): | |
| config_class = ProbUNetConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| if config.dim == 2: | |
| task_op = InjectionUNet2D | |
| prior_op = InjectionConvEncoder2D | |
| posterior_op = InjectionConvEncoder2D | |
| elif config.dim == 3: | |
| task_op = InjectionUNet3D | |
| prior_op = InjectionConvEncoder3D | |
| posterior_op = InjectionConvEncoder3D | |
| else: | |
| sys.exit("Invalid dim! Only configured for dim 2 and 3.") | |
| if config.latent_distribution == "normal": | |
| latent_distribution = torch.distributions.Normal | |
| else: | |
| sys.exit("Invalid latent_distribution. Only normal has been implemented.") | |
| self.model = ProbabilisticSegmentationNet(in_channels=config.in_channels, | |
| out_channels=config.out_channels, | |
| num_feature_maps=config.num_feature_maps, | |
| latent_size=config.latent_size, | |
| depth=config.depth, | |
| latent_distribution=latent_distribution, | |
| task_op=task_op, | |
| task_kwargs={"output_activation_op": nn.Identity if config.no_outact_op else nn.Sigmoid, | |
| "activation_kwargs": {"inplace": True}, "injection_at": config.prob_injection_at}, | |
| prior_op=prior_op, | |
| prior_kwargs={"activation_kwargs": {"inplace": True}, "norm_depth": 2}, | |
| posterior_op=posterior_op, | |
| posterior_kwargs={"activation_kwargs": {"inplace": True}, "norm_depth": 2}, | |
| ) | |
| def forward(self, x): | |
| return self.model(x) |