《nnU-Net 0基础入门(12):修改 network architecture 与 plans,从 ResEnc preset 到自定义网络》

503611908 发布于 5 小时前 10 次阅读


本篇学习目标

这是《nnU-Net 0基础入门》系列的第 12 篇,也是本系列最后一篇。前面我们已经学会安装、数据准备、训练、推理、模型选择、Trainer、loss 和 augmentation。本文进入最容易“改坏”的部分:network architecture 和 plans。

读完本文,你应该能够:

  1. 理解为什么网络结构不能脱离 plans 单独修改。
  2. 知道 quick-and-dirty Trainer 覆盖路线和 proper planner 路线的区别。
  3. 了解官方 ResEnc presets 的用法和显存预算。
  4. 知道替换网络时必须检查 deep supervision、patch size、输入输出通道和推理兼容。

1. 为什么 network architecture 不能随便换

network architecture 指网络结构,例如 U-Net、ResEnc U-Net、Transformer/Mamba 风格网络等。很多人接触 nnU-Net 后第一反应是:“我能不能把里面的网络换成自己的模型?”答案是可以,但要非常谨慎。

原因是 nnU-Net v2 的网络不是孤立模块。它和 plans 中的很多配置强相关:

  • patch_size 决定网络输入空间大小。
  • num_input_channels 来自数据集通道数。
  • num_output_channels 来自 label manager,而不是简单等于类别数。
  • pool_op_kernel_sizesconv_kernel_sizes 影响 U-Net 下采样和上采样结构。
  • enable_deep_supervision 决定训练时是否需要多尺度输出。
flowchart TD
    A[nnUNetPlans.json] --> B[configuration_manager]
    B --> C[patch_size / kernels / architecture kwargs]
    C --> D[build_network_architecture]
    E[label_manager] --> F[num_segmentation_heads]
    F --> D
    G[num_input_channels] --> D
    D --> H[network]
    H --> I[training and inference]

所以,替换网络不是“改一行模型类名”这么简单。你必须保证训练、验证、推理都能用同一个构建逻辑重建网络。

2. 当前 build_network_architecture 签名

当前官方 nnUNetTrainer 中,推荐的 build_network_architecture 签名是:

@staticmethod
def build_network_architecture(
    plans_manager,
    configuration_manager,
    num_input_channels: int,
    num_output_channels: int,
    enable_deep_supervision: bool = True,
):
    ...

官方源码仍兼容旧签名,但会给出弃用警告。写新 Trainer 时应使用新签名。这个函数不仅训练时会调用,推理加载模型时也会调用。因此,如果你通过自定义 Trainer 改了网络结构,推理环境也必须能找到同一个 Trainer 类。

3. quick-and-dirty 路线:在 Trainer 里覆盖网络构建

官方扩展文档提供了一条 quick-and-dirty 路线:继承 Trainer 并覆盖 build_network_architecture。这种方式适合快速验证一个网络想法。

下面是一个只用于教学的最小 3D 网络示例。为了避免多尺度输出和 deep supervision 的复杂性,这个 Trainer 显式关闭 deep supervision。真实研究不要使用这个玩具网络作为强模型,它只是帮助你理解接口。

import torch
from torch import nn

from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer


class TinyExampleNet(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv3d(in_channels, 32, kernel_size=3, padding=1),
            nn.InstanceNorm3d(32),
            nn.LeakyReLU(inplace=True),
            nn.Conv3d(32, out_channels, kernel_size=1),
        )

    def forward(self, x):
        return self.net(x)


class nnUNetTrainerTinyNet(nnUNetTrainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.enable_deep_supervision = False

    @staticmethod
    def build_network_architecture(
        plans_manager,
        configuration_manager,
        num_input_channels: int,
        num_output_channels: int,
        enable_deep_supervision: bool = True,
    ):
        return TinyExampleNet(
            num_input_channels,
            num_output_channels,
        )

这个示例只是为了说明接口:真实网络不能这么简单。它缺少 U-Net 的多尺度上下文建模,也没有 deep supervision 多尺度输出。因此它适合帮助你理解接线方式,不适合当成强模型使用。

训练命令:

nnUNetv2_train 1 3d_fullres 0 -tr nnUNetTrainerTinyNet --npz

4. proper 路线:通过 planner 和 plans 集成网络

quick-and-dirty 路线适合快速实验,但如果你希望网络结构长期可维护、可复现、可和 nnU-Net 自动规划结合,应该考虑 proper 路线:通过自定义 planner 或 plans,让 architecture 信息进入 nnUNetPlans.json

proper 路线的核心思想是:

  • 让 plans 明确记录网络类名、初始化参数和需要 import 的参数。
  • 训练和推理都从 plans 中读取架构配置。
  • 不同架构使用不同 plans identifier,避免覆盖默认 nnU-Net 结果。
路线 适合场景 优点 风险
Trainer 覆盖 快速验证网络想法 上手快,改动集中 容易和 plans 脱节,长期维护差
Planner / plans 集成 正式方法、论文实验、长期维护 配置可追踪,训练推理一致 需要理解 plans 和 planner

5. 官方 ResEnc presets:先用官方强基线

在自己造网络之前,建议先了解官方 Residual Encoder UNet presets。官方文档给出三种 preset:

Preset 官方目标显存 适合场景
ResEnc M 约 9-11GB VRAM 显存接近标准 U-Net 配置
ResEnc L 约 24GB VRAM 官方推荐的新默认配置
ResEnc XL 约 40GB VRAM 更大显存和算力预算

官方使用方式是先在 plan/preprocess 阶段指定 planner:

nnUNetv2_plan_and_preprocess -d 1 -pl nnUNetPlannerResEncM
nnUNetv2_plan_and_preprocess -d 1 -pl nnUNetPlannerResEncL
nnUNetv2_plan_and_preprocess -d 1 -pl nnUNetPlannerResEncXL

训练和推理时指定对应 plans:

nnUNetv2_train 1 3d_fullres 0 -p nnUNetResEncUNetMPlans --npz
nnUNetv2_predict -i ./input_images -o ./predictions -d 1 -c 3d_fullres -p nnUNetResEncUNetMPlans

如果你已经有标准 2D 或 3D fullres 预处理数据,官方文档说明可用 nnUNetv2_plan_experiment 避免重复 preprocessing;但初学者第一次使用时,直接按官方 preset 命令完整跑更不容易混淆。

6. scaling VRAM target 时不要覆盖默认 plans

官方 ResEnc preset 文档还说明,可以通过 -gpu_memory_target 调整目标显存预算。但必须用 -overwrite_plans_name 指定新的 plans 名称,避免覆盖 preset plans。

nnUNetv2_plan_experiment \
  -d 3 \
  -pl nnUNetPlannerResEncM \
  -gpu_memory_target 80 \
  -overwrite_plans_name nnUNetResEncUNetPlans_80G

后续训练用:

nnUNetv2_train 3 3d_fullres 0 -p nnUNetResEncUNetPlans_80G --npz

如果多 GPU 训练,不要简单把多张 GPU 的显存相加传给 planner。官方文档提醒,patch size 必须能被单张 GPU 处理。多 GPU 通常应先按单 GPU 显存规划,再通过 batch size 扩展。

7. 替换网络前的硬性检查

检查项 必须满足什么
输入通道 网络接受 num_input_channels
输出通道 网络输出 num_output_channels,不要自己猜类别数
空间维度 能处理当前 patch_size
deep supervision 训练启用时输出格式要与 loss wrapper 兼容
推理兼容 推理环境能 import 同一个 Trainer 或 architecture 类
显存 单张 GPU 能处理 patch 和 batch

8. benchmarking:别只和弱基线比

官方 ResEnc preset 文档最后特别提醒:如果你提出一种新的 segmentation method,应使用合适的 nnU-Net baseline 做公平比较,并鼓励和 residual encoder variants 比较。换句话说,不要只和一个很弱或未调好的 U-Net 比,然后宣称新方法有效。

对研究实验来说,至少要做到:

  • 比较默认 nnU-Net 与 ResEnc preset。
  • 显存和训练时间预算尽量公平。
  • 使用相同数据划分和相同评价指标。
  • 报告多个 fold,而不是只挑一个最好结果。

9. 官方资料入口

本文主要参考:

本篇总结

修改 network architecture 是 nnU-Net v2 二次开发中风险最高的部分。快速实验可以通过自定义 Trainer 覆盖 build_network_architecture;长期维护和正式研究更应通过 planner 与 plans 集成。正式提出新网络前,建议先尝试官方 ResEnc presets,并用公平的显存、训练时间和 5-fold 验证结果做比较。

系列总结

到这里,《nnU-Net 0基础入门》系列完成了从 0 到进阶修改的完整路径:你已经学习了 nnU-Net v2 的定位、安装、数据格式、训练、推理、模型选择、内部框架,以及如何修改 Trainer、loss、augmentation 和 network architecture。下一步不再是继续堆教程,而是选择一个真实数据集,建立默认 nnU-Net 强基线,然后只改一个变量,做可复现的对照实验。

此作者没有提供个人介绍。
最后更新于 2026-05-14