《nnU-Net 0基础入门(10):修改 loss 与 deep supervision,从 Dice+CE 到自定义损失》

503611908 发布于 3 小时前 2 次阅读


本篇学习目标

这是《nnU-Net 0基础入门》系列的第 10 篇。上一篇我们学会了继承 nnUNetTrainer 写自定义 Trainer。本文继续往里走:修改 loss,并理解 deep supervision。

读完本文,你应该能够:

  1. 理解 nnU-Net v2 默认 loss 的大致组成。
  2. 知道普通类别分割和 region-based training 为什么会用不同 loss 组合。
  3. 理解 deep supervision 为什么会让 loss 接收多尺度输出。
  4. 在自定义 Trainer 中覆盖 _build_loss,实现一个可维护的 loss 变体。

1. 先理解 loss 在训练流程中的位置

loss 是训练时优化的目标函数。模型输出预测,loss 比较预测和标签之间的差异,再通过反向传播更新网络参数。

在 nnU-Net v2 中,loss 不是孤立存在的。它和标签格式、ignore label、region-based training、deep supervision、DDP 都有关。

flowchart TD
    A[network output] --> D[loss]
    B[segmentation target] --> D
    C[deep supervision outputs] --> D
    E
          
         --> D
    D --> F[backpropagation]
    F --> G[optimizer step]

这也是为什么我们不建议随便把网上某个 loss 函数复制进来就用。你必须确认它能处理 nnU-Net 当前训练输出和标签格式。

2. nnU-Net v2 默认 loss 做了什么

根据当前官方 nnUNetTrainer.py_build_loss 实现,默认逻辑可以概括为:

  • 如果当前任务是 region-based training,使用 DC_and_BCE_loss
  • 否则使用 DC_and_CE_loss
  • Dice 部分默认使用 MemoryEfficientSoftDiceLoss
  • 如果启用 deep supervision,则用 DeepSupervisionWrapper 包装 loss。
场景 默认组合 直观理解
普通多类别分割 DC_and_CE_loss Dice 关注区域重叠,CE 关注逐像素/体素分类
region-based training DC_and_BCE_loss region 可能不是互斥类别,更适合 BCE 形式

CE 是 cross entropy,交叉熵,用于分类任务。BCE 是 binary cross entropy,二元交叉熵。Dice loss 来自 Dice 指标,强调预测区域和标签区域的重叠程度。

3. deep supervision 是什么

deep supervision 可以理解为:网络不只在最终最高分辨率输出上计算 loss,也在中间多个低分辨率输出上计算 loss。这样可以给网络更深层的特征提供训练信号。

在 nnU-Net v2 中,如果启用 deep supervision,网络输出通常不是一个 tensor,而是一组多尺度输出。官方 DeepSupervisionWrapper 会把同一个基础 loss 应用到多个输出上,并按权重求和。

total_loss =
  w0 * loss(output_0, target_0)
+ w1 * loss(output_1, target_1)
+ w2 * loss(output_2, target_2)
+ ...

官方 Trainer 中,权重按分辨率降低而逐步减小,最后一个最低分辨率输出通常权重为 0 或非常小。这样做的直觉是:最高分辨率输出最重要,低分辨率输出主要提供辅助监督。

4. 一个最小 loss 修改:调整 Dice 和 CE 权重

下面示例演示如何在自定义 Trainer 中覆盖 _build_loss,把普通多类别分割中的 Dice 和 CE 权重改成 1.5 和 0.5。这个例子只用于说明结构,不能保证提高精度。

import numpy as np
import torch

from nnunetv2.training.loss.compound_losses import DC_and_BCE_loss, DC_and_CE_loss
from nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper
from nnunetv2.training.loss.dice import MemoryEfficientSoftDiceLoss
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer


class nnUNetTrainerDiceHeavy(nnUNetTrainer):
    def _build_loss(self):
        if self.label_manager.has_regions:
            loss = DC_and_BCE_loss(
                {},
                {
                    "batch_dice": self.configuration_manager.batch_dice,
                    "do_bg": True,
                    "smooth": 1e-5,
                    "ddp": self.is_ddp,
                },
                weight_ce=0.5,
                weight_dice=1.5,
                use_ignore_label=self.label_manager.ignore_label is not None,
                dice_class=MemoryEfficientSoftDiceLoss,
            )
        else:
            loss = DC_and_CE_loss(
                {
                    "batch_dice": self.configuration_manager.batch_dice,
                    "smooth": 1e-5,
                    "do_bg": False,
                    "ddp": self.is_ddp,
                },
                {},
                weight_ce=0.5,
                weight_dice=1.5,
                ignore_label=self.label_manager.ignore_label,
                dice_class=MemoryEfficientSoftDiceLoss,
            )

        if self.enable_deep_supervision:
            deep_supervision_scales = self._get_deep_supervision_scales()
            weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))])
            if self.is_ddp and not self._do_i_compile():
                weights[-1] = 1e-6
            else:
                weights[-1] = 0
            weights = weights / weights.sum()
            loss = DeepSupervisionWrapper(loss, weights)

        return loss

训练时使用:

nnUNetv2_train 1 3d_fullres 0 -tr nnUNetTrainerDiceHeavy --npz

注意,这个示例保留了官方处理 ignore label、region training、DDP 和 deep supervision 的基本逻辑,只改 Dice 与 CE/BCE 的权重。这是比较稳妥的修改方式。

5. 为什么不要忽略 ignore label 和 region

ignore label 指某些像素或体素在训练中不参与 loss 计算。例如标注不确定区域、裁剪边界或无效区域。官方 loss 实现会把 ignore label 对应区域从梯度中排除。

region-based training 指标签不是简单的互斥类别,而是由多个类别组合成 region。比如一个 region 可以表示“肿瘤整体”,另一个 region 表示“增强肿瘤”。这种情况下输出头和标签组织方式会不同。

如果你的自定义 loss 没有处理 ignore label 或 region,可能会出现:

  • 无效区域参与训练,导致模型学到错误信号。
  • region 标签和网络输出 shape 对不上。
  • 训练能跑但指标异常。

6. 自定义 loss 的检查表

检查项 为什么重要
输入是否是 logits nnU-Net 网络通常不在最后手动加 softmax/sigmoid,loss 内部处理非线性
target shape 是否符合预期 普通 CE target 和 one-hot/region target 要求不同
是否支持 deep supervision 启用 deep supervision 时输出和 target 是多尺度列表
是否处理 ignore label 避免不确定区域影响训练
是否兼容 DDP 多 GPU 训练中某些参数未被使用会导致错误

7. 实验设计建议

修改 loss 后,不要只看训练 loss 下降。你至少应该比较:

  • 默认 Trainer 与自定义 Trainer 的同一 fold 验证 Dice。
  • 多个 fold 上是否一致改善。
  • 小目标类别是否真的变好,还是只提升了大器官类别。
  • loss 曲线是否稳定,是否出现 NaN。

如果一个 loss 只在 fold 0 提升,其他 fold 没有提升,不能轻易说它优于默认 loss。医学图像数据集通常样本少,单次划分结果波动很常见。

8. 常见错误

现象 可能原因 建议
loss 直接 NaN 除零、log 输入不合法、学习率过高 先用默认 loss 对照,检查自定义 loss 数值稳定性
shape mismatch 没有处理 deep supervision 或 target 维度 打印 output 和 target 的类型与 shape
DDP 报 unused parameters deep supervision 某些输出权重处理不当 参考官方最低分辨率权重处理方式
训练 loss 下降但 Dice 不升 loss 优化目标和评价指标不一致,或数据/标签问题 检查验证指标和可视化结果,不只看 loss

9. 官方资料入口

本文主要参考:

本篇总结

nnU-Net v2 默认 loss 不是单一 Dice,而是根据任务类型组合 Dice、CE 或 BCE,并在启用 deep supervision 时用 DeepSupervisionWrapper 包装。修改 loss 时,推荐继承 nnUNetTrainer 并覆盖 _build_loss,尽量保留官方对 ignore label、region、DDP 和 deep supervision 的处理,只修改你真正要实验的部分。

下一篇预告

下一篇我们会修改 data augmentation:理解 nnU-Net v2 默认训练增强流程,哪些增强对医学图像有风险,以及如何在自定义 Trainer 中调整 transform。

此作者没有提供个人介绍。
最后更新于 2026-05-14