本篇学习目标
这是《nnU-Net 0基础入门》系列的第 9 篇。前一篇我们建立了 nnU-Net v2 的内部框架地图。本文开始真正修改 nnU-Net:从自定义 Trainer 入手。
读完本文,你应该能够:
- 理解为什么修改训练流程时推荐继承
nnUNetTrainer。 - 写出一个最小自定义 Trainer。
- 用
-tr参数让nnUNetv2_train调用你的 Trainer。 - 知道哪些改动适合放在 Trainer,哪些不应该直接硬改核心源码。
1. 为什么从 Trainer 开始改
官方扩展文档明确建议:如果你想修改 training procedure,也就是训练流程,例如 loss、sampling、data augmentation、lr scheduler 等,应实现自己的 trainer class。最佳实践是创建一个继承 nnUNetTrainer 的类,并只覆盖你需要修改的方法。
这样做有三个好处:
- 保留 nnU-Net v2 已经实现好的数据加载、日志、checkpoint、验证和推理兼容逻辑。
- 你的实验改动集中在一个新类里,便于回滚和复现。
- 训练命令中可以通过
-tr明确记录使用了哪个 Trainer。
2. 不推荐直接改核心源码
很多初学者会直接打开 nnUNetTrainer.py 修改几行代码。这种方式短期看最快,但问题很大:
- 以后更新 nnU-Net 时容易被覆盖。
- 无法清楚区分“官方默认行为”和“你自己的实验行为”。
- 复现实验时很难说明到底改了哪里。
- 和别人共享 checkpoint 时,对方可能无法推理或继续训练。
更稳妥的方式是:新建一个 Trainer 类,继承官方 nnUNetTrainer,只覆盖你需要的部分。
3. Trainer 修改入口图
flowchart TD
A[nnUNetv2_train] --> B[-tr 指定 Trainer 名称]
B --> C[查找 Trainer 类]
C --> D[实例化自定义 Trainer]
D --> E[initialize]
E --> F[build_network_architecture]
E --> G[_build_loss]
E --> H[configure_optimizers]
E --> I[get_training_transforms]
D --> J[run_training]
这张图说明:Trainer 不是只负责一个函数,而是训练流程的组织者。你可以覆盖 loss、optimizer、augmentation、network architecture 等不同入口,但每次改动都应尽量小。
4. 一个最小自定义 Trainer:把优化器换成 AdamW
下面示例演示如何写一个自定义 Trainer,把默认优化器替换为 AdamW。这个例子只是教学用,不能保证比默认 SGD 更好。医学图像分割实验中,任何优化器改动都必须用验证集结果证明。
import torch
from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
class nnUNetTrainerMyAdamW(nnUNetTrainer):
def __init__(
self,
plans: dict,
configuration: str,
fold: int,
dataset_json: dict,
device: torch.device = torch.device("cuda"),
):
super().__init__(plans, configuration, fold, dataset_json, device)
self.initial_lr = 3e-4
self.weight_decay = 1e-4
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.network.parameters(),
lr=self.initial_lr,
weight_decay=self.weight_decay,
)
lr_scheduler = PolyLRScheduler(
optimizer,
self.initial_lr,
self.num_epochs,
)
return optimizer, lr_scheduler
这里覆盖的是 configure_optimizers。官方当前 nnUNetTrainer 默认使用 SGD、momentum、Nesterov 和 poly learning rate scheduler。我们没有改数据加载、loss、augmentation 和 checkpoint 行为。
5. 把 Trainer 放在哪里
官方训练入口会根据 -tr 参数查找 Trainer 类。当前源码中,查找顺序包括:
nnunetv2/training/nnUNetTrainer下的 Python 文件。- 如果设置了
nnUNet_extTrainer环境变量,也会搜索该外部路径。
如果你使用 editable install,可以把文件放到源码目录中,例如:
nnUNet/
└── nnunetv2/
└── training/
└── nnUNetTrainer/
└── variants/
└── optimizer/
└── nnUNetTrainerMyAdamW.py
也可以放到外部目录,然后设置环境变量:
export nnUNet_extTrainer="/home/you/nnunet_custom_trainers"
外部目录中放入:
/home/you/nnunet_custom_trainers/
└── nnUNetTrainerMyAdamW.py
对于初学者,我建议先使用 editable install 并把 trainer 放在源码树的 variants 目录下。等你要长期维护多个私有 Trainer,再考虑 nnUNet_extTrainer。
6. 用 -tr 调用自定义 Trainer
训练时使用 -tr 指定类名:
nnUNetv2_train 1 3d_fullres 0 -tr nnUNetTrainerMyAdamW --npz
如果你还使用自定义 plans identifier,可以同时用 -p 指定:
nnUNetv2_train 1 3d_fullres 0 -tr nnUNetTrainerMyAdamW -p nnUNetPlans --npz
训练输出目录会包含 Trainer 名称,例如:
nnUNet_results/
└── Dataset001_Liver/
└── nnUNetTrainerMyAdamW__nnUNetPlans__3d_fullres/
└── fold_0/
这非常重要。目录名本身就记录了你使用的 Trainer、plans 和 configuration,有利于实验管理。
7. 调试自定义 Trainer 的推荐顺序
不要一上来就训练完整 1000 epochs。建议先用官方已有短训练 Trainer 或你自己的 debug Trainer 跑最小验证。例如官方仓库里有 nnUNetTrainer_5epochs 这类训练长度变体,适合快速检查流程。
对自己的 Trainer,可以先这样做:
# 先确认类能被找到
nnUNetv2_train 1 3d_fullres 0 -tr nnUNetTrainerMyAdamW --disable_checkpointing
# 正式训练时再加 --npz
nnUNetv2_train 1 3d_fullres 0 -tr nnUNetTrainerMyAdamW --npz
--disable_checkpointing 适合测试,避免调试时产生大量 checkpoint。正式实验不要随意禁用 checkpoint,否则中断后很难恢复。
8. 常见错误
| 错误现象 | 常见原因 | 解决方式 |
|---|---|---|
| 找不到 Trainer | 类名和 -tr 不一致,或文件不在搜索路径中 |
检查类名、文件名、editable install 和 nnUNet_extTrainer |
| Trainer 不是子类 | 没有继承 nnUNetTrainer |
确认 class MyTrainer(nnUNetTrainer) |
| 训练能跑,推理报错 | 推理环境找不到同一个自定义 Trainer | 保证推理机器也能 import 该 Trainer |
| 改了很多地方后不知道哪里坏了 | 一次改动范围太大 | 每次只改一个入口,例如先只改 optimizer |
9. 哪些内容适合在 Trainer 中改
| 目标 | 推荐入口 |
|---|---|
| 换 optimizer 或 lr scheduler | configure_optimizers |
| 换 loss | _build_loss |
| 改 augmentation | get_training_transforms 或相关 transform 构建逻辑 |
| 改网络结构 | build_network_architecture,并处理 deep supervision 兼容 |
| 改训练 epoch、学习率等超参数 | __init__ 中修改对应属性 |
10. 官方资料入口
本文主要参考:
- Extending nnU-Net
- Train Models
- run_training.py 官方源码
- find_objects.py 官方源码
- nnUNetTrainer.py 官方源码
- 官方 Trainer 变体示例
本篇总结
修改 nnU-Net v2 训练流程的推荐方式是继承 nnUNetTrainer,创建自己的 Trainer 类,然后在训练命令中用 -tr 指定。这样既能复用官方完整训练框架,又能让实验改动清晰、可复现、可回滚。本文用 AdamW 示例演示了最小修改方式,后续我们会继续用 Trainer 修改 loss 和 augmentation。
下一篇预告
下一篇我们会专门修改 loss:理解 nnU-Net 默认的 Dice + CE / Dice + BCE 组合、deep supervision wrapper,以及怎样在自定义 Trainer 中替换或组合自己的 loss。
Comments NOTHING