《nnU-Net 0基础入门(9):修改 nnU-Net Trainer,从继承 nnUNetTrainer 开始》

503611908 发布于 3 小时前 4 次阅读


本篇学习目标

这是《nnU-Net 0基础入门》系列的第 9 篇。前一篇我们建立了 nnU-Net v2 的内部框架地图。本文开始真正修改 nnU-Net:从自定义 Trainer 入手。

读完本文,你应该能够:

  1. 理解为什么修改训练流程时推荐继承 nnUNetTrainer
  2. 写出一个最小自定义 Trainer。
  3. -tr 参数让 nnUNetv2_train 调用你的 Trainer。
  4. 知道哪些改动适合放在 Trainer,哪些不应该直接硬改核心源码。

1. 为什么从 Trainer 开始改

官方扩展文档明确建议:如果你想修改 training procedure,也就是训练流程,例如 loss、sampling、data augmentation、lr scheduler 等,应实现自己的 trainer class。最佳实践是创建一个继承 nnUNetTrainer 的类,并只覆盖你需要修改的方法。

这样做有三个好处:

  • 保留 nnU-Net v2 已经实现好的数据加载、日志、checkpoint、验证和推理兼容逻辑。
  • 你的实验改动集中在一个新类里,便于回滚和复现。
  • 训练命令中可以通过 -tr 明确记录使用了哪个 Trainer。

2. 不推荐直接改核心源码

很多初学者会直接打开 nnUNetTrainer.py 修改几行代码。这种方式短期看最快,但问题很大:

  • 以后更新 nnU-Net 时容易被覆盖。
  • 无法清楚区分“官方默认行为”和“你自己的实验行为”。
  • 复现实验时很难说明到底改了哪里。
  • 和别人共享 checkpoint 时,对方可能无法推理或继续训练。

更稳妥的方式是:新建一个 Trainer 类,继承官方 nnUNetTrainer,只覆盖你需要的部分。

3. Trainer 修改入口图

flowchart TD
    A[nnUNetv2_train] --> B[-tr 指定 Trainer 名称]
    B --> C[查找 Trainer 类]
    C --> D[实例化自定义 Trainer]
    D --> E[initialize]
    E --> F[build_network_architecture]
    E --> G[_build_loss]
    E --> H[configure_optimizers]
    E --> I[get_training_transforms]
    D --> J[run_training]

这张图说明:Trainer 不是只负责一个函数,而是训练流程的组织者。你可以覆盖 loss、optimizer、augmentation、network architecture 等不同入口,但每次改动都应尽量小。

4. 一个最小自定义 Trainer:把优化器换成 AdamW

下面示例演示如何写一个自定义 Trainer,把默认优化器替换为 AdamW。这个例子只是教学用,不能保证比默认 SGD 更好。医学图像分割实验中,任何优化器改动都必须用验证集结果证明。

import torch

from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer


class nnUNetTrainerMyAdamW(nnUNetTrainer):
    def __init__(
        self,
        plans: dict,
        configuration: str,
        fold: int,
        dataset_json: dict,
        device: torch.device = torch.device("cuda"),
    ):
        super().__init__(plans, configuration, fold, dataset_json, device)
        self.initial_lr = 3e-4
        self.weight_decay = 1e-4

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.network.parameters(),
            lr=self.initial_lr,
            weight_decay=self.weight_decay,
        )
        lr_scheduler = PolyLRScheduler(
            optimizer,
            self.initial_lr,
            self.num_epochs,
        )
        return optimizer, lr_scheduler

这里覆盖的是 configure_optimizers。官方当前 nnUNetTrainer 默认使用 SGD、momentum、Nesterov 和 poly learning rate scheduler。我们没有改数据加载、loss、augmentation 和 checkpoint 行为。

5. 把 Trainer 放在哪里

官方训练入口会根据 -tr 参数查找 Trainer 类。当前源码中,查找顺序包括:

  1. nnunetv2/training/nnUNetTrainer 下的 Python 文件。
  2. 如果设置了 nnUNet_extTrainer 环境变量,也会搜索该外部路径。

如果你使用 editable install,可以把文件放到源码目录中,例如:

nnUNet/
└── nnunetv2/
    └── training/
        └── nnUNetTrainer/
            └── variants/
                └── optimizer/
                    └── nnUNetTrainerMyAdamW.py

也可以放到外部目录,然后设置环境变量:

export nnUNet_extTrainer="/home/you/nnunet_custom_trainers"

外部目录中放入:

/home/you/nnunet_custom_trainers/
└── nnUNetTrainerMyAdamW.py

对于初学者,我建议先使用 editable install 并把 trainer 放在源码树的 variants 目录下。等你要长期维护多个私有 Trainer,再考虑 nnUNet_extTrainer

6. 用 -tr 调用自定义 Trainer

训练时使用 -tr 指定类名:

nnUNetv2_train 1 3d_fullres 0 -tr nnUNetTrainerMyAdamW --npz

如果你还使用自定义 plans identifier,可以同时用 -p 指定:

nnUNetv2_train 1 3d_fullres 0 -tr nnUNetTrainerMyAdamW -p nnUNetPlans --npz

训练输出目录会包含 Trainer 名称,例如:

nnUNet_results/
└── Dataset001_Liver/
    └── nnUNetTrainerMyAdamW__nnUNetPlans__3d_fullres/
        └── fold_0/

这非常重要。目录名本身就记录了你使用的 Trainer、plans 和 configuration,有利于实验管理。

7. 调试自定义 Trainer 的推荐顺序

不要一上来就训练完整 1000 epochs。建议先用官方已有短训练 Trainer 或你自己的 debug Trainer 跑最小验证。例如官方仓库里有 nnUNetTrainer_5epochs 这类训练长度变体,适合快速检查流程。

对自己的 Trainer,可以先这样做:

# 先确认类能被找到
nnUNetv2_train 1 3d_fullres 0 -tr nnUNetTrainerMyAdamW --disable_checkpointing

# 正式训练时再加 --npz
nnUNetv2_train 1 3d_fullres 0 -tr nnUNetTrainerMyAdamW --npz

--disable_checkpointing 适合测试,避免调试时产生大量 checkpoint。正式实验不要随意禁用 checkpoint,否则中断后很难恢复。

8. 常见错误

错误现象 常见原因 解决方式
找不到 Trainer 类名和 -tr 不一致,或文件不在搜索路径中 检查类名、文件名、editable install 和 nnUNet_extTrainer
Trainer 不是子类 没有继承 nnUNetTrainer 确认 class MyTrainer(nnUNetTrainer)
训练能跑,推理报错 推理环境找不到同一个自定义 Trainer 保证推理机器也能 import 该 Trainer
改了很多地方后不知道哪里坏了 一次改动范围太大 每次只改一个入口,例如先只改 optimizer

9. 哪些内容适合在 Trainer 中改

目标 推荐入口
换 optimizer 或 lr scheduler configure_optimizers
换 loss _build_loss
改 augmentation get_training_transforms 或相关 transform 构建逻辑
改网络结构 build_network_architecture,并处理 deep supervision 兼容
改训练 epoch、学习率等超参数 __init__ 中修改对应属性

10. 官方资料入口

本文主要参考:

本篇总结

修改 nnU-Net v2 训练流程的推荐方式是继承 nnUNetTrainer,创建自己的 Trainer 类,然后在训练命令中用 -tr 指定。这样既能复用官方完整训练框架,又能让实验改动清晰、可复现、可回滚。本文用 AdamW 示例演示了最小修改方式,后续我们会继续用 Trainer 修改 loss 和 augmentation。

下一篇预告

下一篇我们会专门修改 loss:理解 nnU-Net 默认的 Dice + CE / Dice + BCE 组合、deep supervision wrapper,以及怎样在自定义 Trainer 中替换或组合自己的 loss。

此作者没有提供个人介绍。
最后更新于 2026-05-14