Image Segmentation
lzeeorno666's picture
Upload 24 files
587425f verified
import numpy as np
from tqdm import tqdm
from torch.cuda.amp import autocast as autocast
import torch
import torch.nn.functional as F
from sklearn.metrics import confusion_matrix
from scipy.ndimage.morphology import binary_fill_holes, binary_opening
from utils import test_single_volume
import time
def calculate_dice_per_class(pred, target, num_classes):
"""计算每个类别的dice系数"""
dice_scores = []
pred_argmax = torch.argmax(pred, dim=1)
for cls in range(1, num_classes): # 跳过背景类
pred_cls = (pred_argmax == cls).float()
target_cls = (target == cls).float()
intersection = (pred_cls * target_cls).sum()
total = pred_cls.sum() + target_cls.sum()
if total > 0:
dice = (2.0 * intersection) / total
dice_scores.append(dice.item()) # 这里需要.item()因为dice是tensor
else:
dice = 1.0 if intersection == 0 else 0.0
dice_scores.append(dice) # 这里不需要.item()因为dice已经是float
return dice_scores
def calculate_miou(pred, target, num_classes):
"""计算mIoU"""
pred_argmax = torch.argmax(pred, dim=1)
iou_scores = []
for cls in range(1, num_classes): # 跳过背景类
pred_cls = (pred_argmax == cls).float()
target_cls = (target == cls).float()
intersection = (pred_cls * target_cls).sum()
union = pred_cls.sum() + target_cls.sum() - intersection
if union > 0:
iou = intersection / union
iou_scores.append(iou.item()) # 这里需要.item()因为iou是tensor
else:
iou = 1.0 if intersection == 0 else 0.0
iou_scores.append(iou) # 这里不需要.item()因为iou已经是float
return np.mean(iou_scores)
def train_one_epoch(train_loader,
model,
criterion,
optimizer,
scheduler,
epoch,
logger,
config,
scaler=None):
'''
train model for one epoch
'''
stime = time.time()
model.train()
loss_list = []
dice_scores_all = []
miou_scores_all = []
for iter, data in enumerate(train_loader):
optimizer.zero_grad()
images, targets = data['image'], data['label']
images, targets = images.cuda(non_blocking=True).float(), targets.cuda(non_blocking=True).long()
if config.amp:
with autocast():
model_output = model(images)
# 解包输出 - 处理多种模型输出格式
if isinstance(model_output, tuple):
if len(model_output) == 3:
# SMAFormerV2 with deep supervision: (seg_out, boundary_out, ds_outs)
out, boundary_out, ds_outs = model_output
# TODO: 可以添加boundary和deep supervision的损失
loss = criterion(out, targets)
elif len(model_output) == 2:
# SMAFormerV2 with boundary_head: (seg_out, boundary_out)
# 或 其他模型: (out, intermediate_results)
out, second_output = model_output
# 检查是否支持中间监督
if hasattr(criterion, 'forward') and 'intermediate_preds' in criterion.forward.__code__.co_varnames:
loss = criterion(out, targets, intermediate_preds=second_output)
else:
loss = criterion(out, targets)
else:
out = model_output[0]
loss = criterion(out, targets)
else:
out = model_output
loss = criterion(out, targets)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
model_output = model(images)
# 解包输出 - 处理多种模型输出格式
if isinstance(model_output, tuple):
if len(model_output) == 3:
# SMAFormerV2 with deep supervision: (seg_out, boundary_out, ds_outs)
out, boundary_out, ds_outs = model_output
# TODO: 可以添加boundary和deep supervision的损失
loss = criterion(out, targets)
elif len(model_output) == 2:
# SMAFormerV2 with boundary_head: (seg_out, boundary_out)
# 或 其他模型: (out, intermediate_results)
out, second_output = model_output
# 检查是否支持中间监督
if hasattr(criterion, 'forward') and 'intermediate_preds' in criterion.forward.__code__.co_varnames:
loss = criterion(out, targets, intermediate_preds=second_output)
else:
loss = criterion(out, targets)
else:
out = model_output[0]
loss = criterion(out, targets)
else:
out = model_output
loss = criterion(out, targets)
loss.backward()
optimizer.step()
loss_list.append(loss.item())
# 计算训练时的dice和mIoU
with torch.no_grad():
dice_scores = calculate_dice_per_class(out, targets, config.num_classes)
miou_score = calculate_miou(out, targets, config.num_classes)
dice_scores_all.append(dice_scores)
miou_scores_all.append(miou_score)
now_lr = optimizer.state_dict()['param_groups'][0]['lr']
mean_loss = np.mean(loss_list)
if iter % config.print_interval == 0 and iter != 0:
log_info = f'train: epoch {epoch}, iter:{iter}, loss: {loss.item():.4f}, lr: {now_lr}'
print(log_info)
logger.info(log_info)
scheduler.step()
# 计算平均指标
mean_loss = np.mean(loss_list)
mean_dice_per_class = np.mean(dice_scores_all, axis=0)
mean_dice_avg = np.mean(mean_dice_per_class)
mean_miou = np.mean(miou_scores_all)
etime = time.time()
log_info = f'Finish one epoch train: epoch {epoch}, loss: {mean_loss:.4f}, avg_dice: {mean_dice_avg:.4f}, mIoU: {mean_miou:.4f}, time(s): {etime-stime:.2f}'
print(log_info)
logger.info(log_info)
return {
'loss': mean_loss,
'avg_dice': mean_dice_avg,
'dice_per_class': mean_dice_per_class,
'miou': mean_miou
}
def val_one_epoch(test_datasets,
test_loader,
model,
epoch,
logger,
config,
test_save_path,
val_or_test=False):
# switch to evaluate mode
stime = time.time()
model.eval()
with torch.no_grad():
metric_list = 0.0
i_batch = 0
all_dice_scores = []
all_hd95_scores = []
for data in tqdm(test_loader):
img, msk, case_name = data['image'], data['label'], data['case_name'][0]
metric_i = test_single_volume(img, msk, model, classes=config.num_classes, patch_size=[config.input_size_h, config.input_size_w],
test_save_path=test_save_path, case=case_name, z_spacing=config.z_spacing, val_or_test=val_or_test)
metric_list += np.array(metric_i)
# 收集每个case的dice和hd95分数
case_dice_scores = [metric[0] for metric in metric_i] # 每个类别的dice
case_hd95_scores = [metric[1] for metric in metric_i] # 每个类别的hd95
all_dice_scores.append(case_dice_scores)
all_hd95_scores.append(case_hd95_scores)
logger.info('idx %d case %s mean_dice %f mean_hd95 %f' % (i_batch, case_name,
np.mean(metric_i, axis=0)[0], np.mean(metric_i, axis=0)[1]))
i_batch += 1
metric_list = metric_list / len(test_datasets)
# 计算平均指标
avg_dice_per_class = np.mean(all_dice_scores, axis=0) # 每个类别的平均dice
avg_hd95_per_class = np.mean(all_hd95_scores, axis=0) # 每个类别的平均hd95
performance = np.mean(avg_dice_per_class) # 所有类别的平均dice
mean_hd95 = np.mean(avg_hd95_per_class) # 所有类别的平均hd95
# 计算mIoU (基于dice分数的近似)
mean_miou = np.mean([dice / (2 - dice) if dice < 1.0 else 1.0 for dice in avg_dice_per_class])
for i in range(len(avg_dice_per_class)):
logger.info('Mean class %d mean_dice %f mean_hd95 %f' % (i+1, avg_dice_per_class[i], avg_hd95_per_class[i]))
etime = time.time()
log_info = f'val epoch: {epoch}, mean_dice: {performance:.4f}, mean_hd95: {mean_hd95:.4f}, mIoU: {mean_miou:.4f}, time(s): {etime-stime:.2f}'
print(log_info)
logger.info(log_info)
return {
'avg_dice': performance,
'dice_per_class': avg_dice_per_class,
'avg_hd95': mean_hd95,
'hd95_per_class': avg_hd95_per_class,
'miou': mean_miou
}
def val_one_epoch_slice(val_dataset,
val_loader,
model,
epoch,
logger,
config,
test_save_path,
val_or_test=False):
"""
使用slice-by-slice方式进行验证,与训练时的数据处理保持一致
"""
stime = time.time()
model.eval()
all_dice_scores = []
all_miou_scores = []
with torch.no_grad():
for data in tqdm(val_loader, desc=f"Validation Epoch {epoch}"):
images, targets = data['image'], data['label']
images = images.cuda(non_blocking=True).float()
targets = targets.cuda(non_blocking=True).long()
# 前向传播
model_output = model(images)
# 解包输出 - AFFSegNet返回两个值
if isinstance(model_output, tuple) and len(model_output) == 2:
out, intermediate_results = model_output
else:
out = model_output
# 计算dice和mIoU
dice_scores = calculate_dice_per_class(out, targets, config.num_classes)
miou_score = calculate_miou(out, targets, config.num_classes)
all_dice_scores.append(dice_scores)
all_miou_scores.append(miou_score)
# 计算平均指标
avg_dice_per_class = np.mean(all_dice_scores, axis=0) # 每个类别的平均dice
avg_dice = np.mean(avg_dice_per_class) # 总体平均dice
avg_miou = np.mean(all_miou_scores) # 平均mIoU
# HD95暂时设为0,因为slice-by-slice计算HD95不太合理
avg_hd95_per_class = np.zeros_like(avg_dice_per_class)
avg_hd95 = 0.0
# 打印每个类别的性能
organ_names = ['spleen', 'right_kidney', 'left_kidney', 'gallbladder',
'esophagus', 'liver', 'stomach', 'aorta']
for i, organ in enumerate(organ_names):
if i < len(avg_dice_per_class):
logger.info(f'Validation class {i+1} ({organ}): dice={avg_dice_per_class[i]:.4f}')
etime = time.time()
log_info = f'val_slice epoch: {epoch}, mean_dice: {avg_dice:.4f}, mIoU: {avg_miou:.4f}, slices: {len(all_dice_scores)}, time(s): {etime-stime:.2f}'
print(log_info)
logger.info(log_info)
return {
'avg_dice': avg_dice,
'dice_per_class': avg_dice_per_class,
'avg_hd95': avg_hd95,
'hd95_per_class': avg_hd95_per_class,
'miou': avg_miou
}
def val_one_epoch_with_visualization(test_datasets,
test_loader,
model,
epoch,
logger,
config,
test_save_path,
prediction_vis_dir=None,
attention_vis_dir=None,
activation_vis_dir=None,
val_or_test=False,
save_vis_every_n=5,
save_prediction_comparison_func=None,
extract_attention_maps_func=None,
save_attention_heatmaps_func=None,
extract_activation_maps_func=None,
save_activation_heatmaps_func=None):
"""
验证函数,支持生成预测对比图、注意力热图和激活热图
Args:
save_vis_every_n: 每n个案例保存一次可视化图像
save_prediction_comparison_func: 保存预测对比图的函数
extract_attention_maps_func: 提取注意力图的函数
save_attention_heatmaps_func: 保存注意力热图的函数
extract_activation_maps_func: 提取激活图的函数
save_activation_heatmaps_func: 保存激活热图的函数
activation_vis_dir: 激活热图保存目录
"""
stime = time.time()
model.eval()
with torch.no_grad():
metric_list = 0.0
i_batch = 0
for data in tqdm(test_loader):
img, msk, case_name = data['image'], data['label'], data['case_name'][0]
# 执行预测并计算指标
metric_i = test_single_volume(img, msk, model, classes=config.num_classes,
patch_size=[config.input_size_h, config.input_size_w],
test_save_path=test_save_path, case=case_name,
z_spacing=config.z_spacing, val_or_test=val_or_test)
metric_list += np.array(metric_i)
# 每隔save_vis_every_n个案例生成可视化图像
if i_batch % save_vis_every_n == 0 and (prediction_vis_dir or attention_vis_dir or activation_vis_dir):
try:
# 生成预测结果用于可视化
img_np = img.squeeze(0).cpu().detach().numpy()
msk_np = msk.squeeze(0).cpu().detach().numpy()
# 创建模型预测
# 根据配置推断期望输入通道数
desired_in_ch = getattr(config, 'input_channels', None)
if hasattr(config, 'model_config') and isinstance(config.model_config, dict):
desired_in_ch = config.model_config.get('input_channels', desired_in_ch)
if desired_in_ch is None:
desired_in_ch = 1
def to_input_tensor(slice_img_np):
# (H, W) -> 构造模型期望的 (1, C, H, W)
t = torch.from_numpy(slice_img_np).unsqueeze(0).float().cuda() # (1, H, W)
if desired_in_ch == 3:
t = t.repeat(3, 1, 1).unsqueeze(0) # (1,H,W)->(3,H,W)->(1,3,H,W)
else:
t = t.unsqueeze(0) # (1,H,W)->(1,1,H,W)
return t
if len(img_np.shape) == 3: # 3D数据
prediction_np = np.zeros_like(msk_np)
# 为3D数据生成注意力热图和激活热图 - 完整的3D数据处理
if attention_vis_dir and extract_attention_maps_func and save_attention_heatmaps_func:
try:
# 使用完整的3D数据生成注意力热图
mid_slice = img_np[img_np.shape[0] // 2]
x, y = mid_slice.shape[0], mid_slice.shape[1]
if x != config.input_size_h or y != config.input_size_w:
from scipy.ndimage import zoom
mid_slice = zoom(mid_slice, (config.input_size_h / x, config.input_size_w / y), order=3)
input_tensor = to_input_tensor(mid_slice)
attention_maps = extract_attention_maps_func(model, input_tensor)
if attention_maps:
save_attention_heatmaps_func(
img_np, attention_maps, case_name, attention_vis_dir
)
except Exception as e:
logger.warning(f"生成注意力热图失败 {case_name}: {e}")
# 为3D数据生成激活热图
if activation_vis_dir and extract_activation_maps_func and save_activation_heatmaps_func:
try:
# 使用完整的3D数据生成激活热图
mid_slice = img_np[img_np.shape[0] // 2]
x, y = mid_slice.shape[0], mid_slice.shape[1]
if x != config.input_size_h or y != config.input_size_w:
from scipy.ndimage import zoom
mid_slice = zoom(mid_slice, (config.input_size_h / x, config.input_size_w / y), order=3)
input_tensor = to_input_tensor(mid_slice)
activation_maps = extract_activation_maps_func(model, input_tensor)
if activation_maps:
save_activation_heatmaps_func(
img_np, activation_maps, case_name, activation_vis_dir
)
except Exception as e:
logger.warning(f"生成激活热图失败 {case_name}: {e}")
# 生成每个slice的预测
for ind in range(img_np.shape[0]):
slice_img = img_np[ind, :, :]
x, y = slice_img.shape[0], slice_img.shape[1]
# 调整尺寸
if x != config.input_size_h or y != config.input_size_w:
from scipy.ndimage import zoom
slice_img = zoom(slice_img, (config.input_size_h / x, config.input_size_w / y), order=3)
# 根据期望通道数构造输入
input_tensor = to_input_tensor(slice_img)
# 获取预测结果
outputs = model(input_tensor)
pred = torch.argmax(torch.softmax(outputs, dim=1), dim=1).squeeze(0).cpu().detach().numpy()
# 调整回原尺寸
if x != config.input_size_h or y != config.input_size_w:
pred = zoom(pred, (x / config.input_size_h, y / config.input_size_w), order=0)
prediction_np[ind] = pred
else:
# 2D数据处理
input_tensor = to_input_tensor(img_np)
outputs = model(input_tensor)
prediction_np = torch.argmax(torch.softmax(outputs, dim=1), dim=1).squeeze(0).cpu().detach().numpy()
# 生成注意力热图
if attention_vis_dir and extract_attention_maps_func and save_attention_heatmaps_func:
try:
attention_maps = extract_attention_maps_func(model, input_tensor)
if attention_maps:
save_attention_heatmaps_func(
img_np, attention_maps, case_name, attention_vis_dir
)
except Exception as e:
logger.warning(f"生成注意力热图失败 {case_name}: {e}")
# 生成激活热图
if extract_activation_maps_func and save_activation_heatmaps_func:
try:
activation_maps = extract_activation_maps_func(model, input_tensor)
if activation_maps:
save_activation_heatmaps_func(
img_np, activation_maps, case_name,
activation_vis_dir
)
except Exception as e:
logger.warning(f"生成激活热图失败 {case_name}: {e}")
# 保存预测对比图
if prediction_vis_dir and save_prediction_comparison_func:
try:
save_prediction_comparison_func(
img_np, msk_np, prediction_np, case_name, prediction_vis_dir
)
logger.info(f"已保存预测对比图: {case_name}")
except Exception as e:
logger.warning(f"保存预测对比图失败 {case_name}: {e}")
except Exception as e:
logger.warning(f"可视化处理失败 {case_name}: {e}")
logger.info('idx %d case %s mean_dice %f mean_hd95 %f' % (i_batch, case_name,
np.mean(metric_i, axis=0)[0], np.mean(metric_i, axis=0)[1]))
i_batch += 1
metric_list = metric_list / len(test_datasets)
performance = np.mean(metric_list, axis=0)[0]
mean_hd95 = np.mean(metric_list, axis=0)[1]
# 提取每个器官的dice和hd95
dice_per_organ = []
hd95_per_organ = []
for i in range(1, config.num_classes):
dice_score = metric_list[i-1][0]
hd95_score = metric_list[i-1][1]
dice_per_organ.append(dice_score)
hd95_per_organ.append(hd95_score)
logger.info('Mean class %d mean_dice %f mean_hd95 %f' % (i, dice_score, hd95_score))
etime = time.time()
log_info = f'val epoch: {epoch}, mean_dice: {performance}, mean_hd95: {mean_hd95}, time(s): {etime-stime:.2f}'
print(log_info)
logger.info(log_info)
return performance, mean_hd95, dice_per_organ, hd95_per_organ