本篇学习目标
这是《nnU-Net 0基础入门》系列的第 10 篇。上一篇我们学会了继承 nnUNetTrainer 写自定义 Trainer。本文继续往里走:修改 loss,并理解 deep supervision。
读完本文,你应该能够:
- 理解 nnU-Net v2 默认 loss 的大致组成。
- 知道普通类别分割和 region-based training 为什么会用不同 loss 组合。
- 理解 deep supervision 为什么会让 loss 接收多尺度输出。
- 在自定义 Trainer 中覆盖
_build_loss,实现一个可维护的 loss 变体。
1. 先理解 loss 在训练流程中的位置
loss 是训练时优化的目标函数。模型输出预测,loss 比较预测和标签之间的差异,再通过反向传播更新网络参数。
在 nnU-Net v2 中,loss 不是孤立存在的。它和标签格式、ignore label、region-based training、deep supervision、DDP 都有关。
flowchart TD
A[network output] --> D[loss]
B[segmentation target] --> D
C[deep supervision outputs] --> D
E
--> D
D --> F[backpropagation]
F --> G[optimizer step]
这也是为什么我们不建议随便把网上某个 loss 函数复制进来就用。你必须确认它能处理 nnU-Net 当前训练输出和标签格式。
2. nnU-Net v2 默认 loss 做了什么
根据当前官方 nnUNetTrainer.py 的 _build_loss 实现,默认逻辑可以概括为:
- 如果当前任务是 region-based training,使用
DC_and_BCE_loss。 - 否则使用
DC_and_CE_loss。 - Dice 部分默认使用
MemoryEfficientSoftDiceLoss。 - 如果启用 deep supervision,则用
DeepSupervisionWrapper包装 loss。
| 场景 | 默认组合 | 直观理解 |
|---|---|---|
| 普通多类别分割 | DC_and_CE_loss |
Dice 关注区域重叠,CE 关注逐像素/体素分类 |
| region-based training | DC_and_BCE_loss |
region 可能不是互斥类别,更适合 BCE 形式 |
CE 是 cross entropy,交叉熵,用于分类任务。BCE 是 binary cross entropy,二元交叉熵。Dice loss 来自 Dice 指标,强调预测区域和标签区域的重叠程度。
3. deep supervision 是什么
deep supervision 可以理解为:网络不只在最终最高分辨率输出上计算 loss,也在中间多个低分辨率输出上计算 loss。这样可以给网络更深层的特征提供训练信号。
在 nnU-Net v2 中,如果启用 deep supervision,网络输出通常不是一个 tensor,而是一组多尺度输出。官方 DeepSupervisionWrapper 会把同一个基础 loss 应用到多个输出上,并按权重求和。
total_loss =
w0 * loss(output_0, target_0)
+ w1 * loss(output_1, target_1)
+ w2 * loss(output_2, target_2)
+ ...
官方 Trainer 中,权重按分辨率降低而逐步减小,最后一个最低分辨率输出通常权重为 0 或非常小。这样做的直觉是:最高分辨率输出最重要,低分辨率输出主要提供辅助监督。
4. 一个最小 loss 修改:调整 Dice 和 CE 权重
下面示例演示如何在自定义 Trainer 中覆盖 _build_loss,把普通多类别分割中的 Dice 和 CE 权重改成 1.5 和 0.5。这个例子只用于说明结构,不能保证提高精度。
import numpy as np
import torch
from nnunetv2.training.loss.compound_losses import DC_and_BCE_loss, DC_and_CE_loss
from nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper
from nnunetv2.training.loss.dice import MemoryEfficientSoftDiceLoss
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
class nnUNetTrainerDiceHeavy(nnUNetTrainer):
def _build_loss(self):
if self.label_manager.has_regions:
loss = DC_and_BCE_loss(
{},
{
"batch_dice": self.configuration_manager.batch_dice,
"do_bg": True,
"smooth": 1e-5,
"ddp": self.is_ddp,
},
weight_ce=0.5,
weight_dice=1.5,
use_ignore_label=self.label_manager.ignore_label is not None,
dice_class=MemoryEfficientSoftDiceLoss,
)
else:
loss = DC_and_CE_loss(
{
"batch_dice": self.configuration_manager.batch_dice,
"smooth": 1e-5,
"do_bg": False,
"ddp": self.is_ddp,
},
{},
weight_ce=0.5,
weight_dice=1.5,
ignore_label=self.label_manager.ignore_label,
dice_class=MemoryEfficientSoftDiceLoss,
)
if self.enable_deep_supervision:
deep_supervision_scales = self._get_deep_supervision_scales()
weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))])
if self.is_ddp and not self._do_i_compile():
weights[-1] = 1e-6
else:
weights[-1] = 0
weights = weights / weights.sum()
loss = DeepSupervisionWrapper(loss, weights)
return loss
训练时使用:
nnUNetv2_train 1 3d_fullres 0 -tr nnUNetTrainerDiceHeavy --npz
注意,这个示例保留了官方处理 ignore label、region training、DDP 和 deep supervision 的基本逻辑,只改 Dice 与 CE/BCE 的权重。这是比较稳妥的修改方式。
5. 为什么不要忽略 ignore label 和 region
ignore label 指某些像素或体素在训练中不参与 loss 计算。例如标注不确定区域、裁剪边界或无效区域。官方 loss 实现会把 ignore label 对应区域从梯度中排除。
region-based training 指标签不是简单的互斥类别,而是由多个类别组合成 region。比如一个 region 可以表示“肿瘤整体”,另一个 region 表示“增强肿瘤”。这种情况下输出头和标签组织方式会不同。
如果你的自定义 loss 没有处理 ignore label 或 region,可能会出现:
- 无效区域参与训练,导致模型学到错误信号。
- region 标签和网络输出 shape 对不上。
- 训练能跑但指标异常。
6. 自定义 loss 的检查表
| 检查项 | 为什么重要 |
|---|---|
| 输入是否是 logits | nnU-Net 网络通常不在最后手动加 softmax/sigmoid,loss 内部处理非线性 |
| target shape 是否符合预期 | 普通 CE target 和 one-hot/region target 要求不同 |
| 是否支持 deep supervision | 启用 deep supervision 时输出和 target 是多尺度列表 |
| 是否处理 ignore label | 避免不确定区域影响训练 |
| 是否兼容 DDP | 多 GPU 训练中某些参数未被使用会导致错误 |
7. 实验设计建议
修改 loss 后,不要只看训练 loss 下降。你至少应该比较:
- 默认 Trainer 与自定义 Trainer 的同一 fold 验证 Dice。
- 多个 fold 上是否一致改善。
- 小目标类别是否真的变好,还是只提升了大器官类别。
- loss 曲线是否稳定,是否出现 NaN。
如果一个 loss 只在 fold 0 提升,其他 fold 没有提升,不能轻易说它优于默认 loss。医学图像数据集通常样本少,单次划分结果波动很常见。
8. 常见错误
| 现象 | 可能原因 | 建议 |
|---|---|---|
| loss 直接 NaN | 除零、log 输入不合法、学习率过高 | 先用默认 loss 对照,检查自定义 loss 数值稳定性 |
| shape mismatch | 没有处理 deep supervision 或 target 维度 | 打印 output 和 target 的类型与 shape |
| DDP 报 unused parameters | deep supervision 某些输出权重处理不当 | 参考官方最低分辨率权重处理方式 |
| 训练 loss 下降但 Dice 不升 | loss 优化目标和评价指标不一致,或数据/标签问题 | 检查验证指标和可视化结果,不只看 loss |
9. 官方资料入口
本文主要参考:
- nnUNetTrainer.py 官方源码
- compound_losses.py 官方源码
- deep_supervision.py 官方源码
- Extending nnU-Net
- Region-based Training
- Ignore Label
本篇总结
nnU-Net v2 默认 loss 不是单一 Dice,而是根据任务类型组合 Dice、CE 或 BCE,并在启用 deep supervision 时用 DeepSupervisionWrapper 包装。修改 loss 时,推荐继承 nnUNetTrainer 并覆盖 _build_loss,尽量保留官方对 ignore label、region、DDP 和 deep supervision 的处理,只修改你真正要实验的部分。
下一篇预告
下一篇我们会修改 data augmentation:理解 nnU-Net v2 默认训练增强流程,哪些增强对医学图像有风险,以及如何在自定义 Trainer 中调整 transform。
Comments NOTHING