arXiv 2410.12557  ·  cs.LG / cs.CV  ·  2024

一步扩散生成:Shortcut Models(捷径模型)

One Step Diffusion via Shortcut Models
Kevin Frans  ·  Danijar Hafner  ·  Sergey Levine  ·  Pieter Abbeel

扩散模型需要数十至数百次神经网络前向传播才能生成样本,推理缓慢且计算代价高昂。 Shortcut Models 通过让网络同时感知噪声水平和目标步长,实现了在单个或少数步骤内生成高质量样本, 且无需单独的蒸馏阶段,仅需约 16% 的额外训练计算开销。

图像生成 · CelebA-HQ · ImageNet-256 机器人控制 · Push-T · Transport 训练额外开销仅 +16% 计算量 📄 arXiv:2410.12557 PDF 全文
关键词one-step diffusionflow matchingshortcut models一步生成自一致性蒸馏扩散策略consistency distillation快速采样图像生成diffusion policy

01 动机

扩散模型和流匹配模型在图像生成中取得了显著成果,但其迭代去噪过程需要数十到数百次神经网络前向传播, 导致推理极为缓慢。现有的加速方案(如蒸馏、一致性模型)往往需要复杂的两阶段训练流程,或在少步生成时质量大幅下降。

"Sampling from these models involves iterative denoising over many neural network passes, making generation slow and expensive." ——论文引言
Shortcut Models 与 Flow Matching 在不同推理步数下的生成对比
图1:在不同推理步数预算下,Flow Matching(上)与 Shortcut Models(下)的生成质量对比。 每列使用相同的初始噪声。数据集:CelebA-HQ(左)和 ImageNet-256(右)。 可以看出,Flow Matching 在 4 步或 1 步时质量急剧下降,而 Shortcut Models 在所有步数预算下都保持了较高质量。
1步仅需单次前向传播即可生成高质量样本
FID 20.5CelebA-HQ 1步生成(vs. Flow Matching 280.5)
FID 3.8DiT-XL 128步最优结果
+16%相比基础扩散模型的额外训练计算开销
朴素扩散模型在少步生成时失败的原因
图2:朴素扩散/流匹配模型在少步生成时失败的根本原因——训练路径在中间时间步交叉重叠, 导致模型输出趋向数据集均值("模糊"效应)。Shortcut Models 通过学习能"跳跃"到未来状态的方向向量来解决这一问题。

现有方法的痛点

  • 流匹配 / 扩散:在 4 步或 1 步推理时,FID 急剧退化(CelebA-HQ 1步:FID 280.5)
  • 一致性蒸馏 (Consistency Distillation):需要两阶段训练,先训基础模型再做蒸馏
  • Reflow:需要构建(噪声, 数据)配对数据集,额外 pipeline 复杂
  • 渐进式蒸馏 (Progressive Distillation):多轮迭代蒸馏,工程复杂度高

Shortcut Models 的核心主张

  • 单一网络、单阶段训练,无需蒸馏
  • 网络在任意步数预算下均可高质量采样
  • "自蒸馏":训练时自动从自身较优的中间结果中学习
  • 无需特殊训练时间表或预热策略(warmup)
  • 支持多步 FID 持续优化,可随模型规模扩展

02 方法

Shortcut Models 在标准流匹配框架基础上,对网络输入增加了"目标步长" d。 网络不仅预测当前时刻的去噪方向(速度场),还学习如何"跨越"多个小步, 直接预测从当前状态跳跃到 t − d 位置所需的归一化方向。 训练结合两个目标:流匹配损失(精细步骤监督)和自一致性目标(粗粒度步骤监督)。

Shortcut Models 训练流程总览
图3:训练流程总览。 左侧:标准流匹配,以小步长 d≈0 作为监督信号(绿色箭头); 右侧:自一致性目标,将一个大步长 d 的预测与两个连续的 d/2 步(通过停止梯度的"教师"网络计算)对齐(蓝色箭头)。 两路损失在同一网络、同一训练阶段中联合优化。

核心思想:Shortcut 函数

定义 shortcut 函数 v(x, t, d):给定当前状态 x、噪声时间步 t 和目标步长 d, 输出一个方向向量,使得沿该方向走一步即可到达 t − d 时刻对应的状态。

x_{t-d} = x_t + d · v(x_t, t, d)

当 d → 0 时,v 退化为标准流匹配的速度场;当 d 较大时,v 学习如何跳过中间曲折路径, 预测考虑了"未来曲率"的综合方向。

自一致性训练目标 (Self-Consistency Target)

训练的核心约束来自自一致性:一个步长为 d 的大步等价于两个连续的步长 d/2 的小步。 即:

v(x_t, t, d) ≈ 平均[v(x_t, t, d/2) 和 v(x_{t−d/2}, t−d/2, d/2)]

训练时,用当前网络的停止梯度(stop-gradient)版本计算两个 d/2 步的目标, 作为对大步长预测的监督信号。这使得模型能够从自身的中间预测中"自举"(bootstrapping), 无需外部教师网络,实现了训练时的自蒸馏。

推理策略

推理时,步数 N 可以任意选择:

  • 1步:d=1,单次前向传播直接生成
  • 4步:d=0.25,四次前向传播
  • 128步:d≈0,接近标准流匹配质量

同一个训练好的模型可以以任意步数运行,无需针对特定步数重新训练。

实现细节

  • 基础架构:DiT-B(图像)/ MLP(机器人控制)
  • 大模型实验:DiT-XL on ImageNet-256
  • Classifier-Free Guidance (CFG):在 d=0 时使用;大步长时不使用(线性近似不适用)
  • 训练额外开销约 16%(自一致性目标的额外前向传播)

03 实验

实验在两类任务上验证:(1)图像生成——CelebA-HQ-256(无条件)和 ImageNet-256(类别条件), 使用 FID-50k 指标(越低越好);(2)机器人控制——Push-T 和 Transport 任务,使用成功率评估。 基线包括:Progressive Distillation、Consistency Distillation、Reflow、Consistency Training、Live Reflow、标准 Flow Matching 等。

主要结果:图像生成 FID(FID-50k,越低越好)

括号内为降质严重、质量不具竞争力的结果。Shortcut Models 各列均标注。

方法 CelebA-HQ-256(无条件) ImageNet-256(类别条件)
128步4步1步 128步4步1步
Progressive Distillation(302.9)(251.3)14.8(201.9)(142.5)35.6
Consistency Distillation59.539.638.2132.898.01136.5
Reflow16.118.423.216.932.844.8
Flow Matching (DiT-B)7.3(63.3)(280.5)17.3(108.2)(324.8)
Consistency Training53.719.033.242.843.069.7
Live Reflow6.327.243.346.395.858.1
Shortcut Models (本文) 6.913.820.5 15.528.340.3

Shortcut Models 在 4 步和 1 步设置下优于所有单阶段端到端方法(Consistency Training、Live Reflow 等), 且在多步设置下与两阶段蒸馏方法相当甚至更优。

大模型扩展:DiT-XL on ImageNet-256

模型128步 FID4步 FID1步 FID
Shortcut Models (DiT-XL)3.87.810.6
不同步数下的 FID 对比
图4:不同推理步数下各方法的 FID 对比。Flow Matching 在步数减少时性能急剧退化, 而 Shortcut Models 在少步至多步全范围内均保持稳定的生成质量分布。
模型规模与生成质量的扩展关系
图5:随着模型参数量增加,Shortcut Models 的 1 步生成质量持续提升, 表明该方法具备良好的模型规模可扩展性(scaling)。这与部分基于 bootstrapping 的强化学习方法不同。

生成样本质量展示

CelebA-HQ 无条件生成样本
图8:CelebA-HQ 256×256 无条件生成样本,从上至下分别为 128 步、4 步、1 步。
ImageNet-256 类别条件生成样本
图9:ImageNet-256 类别条件生成样本,从上至下分别为 128 步、4 步、1 步。

机器人控制任务

将 Shortcut Models 应用于扩散策略(Diffusion Policy)的机器人控制任务, 在 Push-T 和 Transport 两个连续控制基准上与标准 100 步扩散策略对比:

方法步数Push-T 成功率Transport 成功率
Diffusion Policy(基线)100步~0.85~0.80
Diffusion Policy(基线)1步0.120.00
Shortcut Models(本文)1步 0.870.80

Shortcut 策略在仅用 1 步推理时,达到与 100 步基线相当的成功率, 而标准扩散策略在 1 步时几乎完全失效(Transport 成功率为 0.00)。

机器人控制任务结果
图7:Push-T 和 Transport 机器人任务的成功率曲线。Shortcut 策略(橙色)在所有步数下均优于或持平于 Diffusion Policy(蓝色), 在 1 步时差距最为显著。

潜空间插值

噪声空间插值
图6:在两个采样噪声点之间进行插值,生成的图像呈现出"qualitatively smooth transitions"(论文原文), 语义平滑过渡。这表明 Shortcut Models 学到了结构化的生成式潜空间。

消融分析

论文通过消融实验验证了自一致性训练目标的必要性:去掉自一致性损失后,模型仅能在多步模式下正常工作, 在 1 步时质量退化到 Flow Matching 同等水平。 此外,实验表明不需要特别设计的训练时间表(schedule)或预热(warmup),训练过程稳定。

04 局限性

说明:以下局限性中,前两条为论文第 6 节(讨论)中作者明确陈述的;第三条及后续为根据论文设计推断(推断 / inferred)。
噪声-数据映射依赖于数据集期望(作者明确陈述)

"The mapping between noise and data is entirely dependent on an expectation over the dataset."(论文原文) 与 GAN 或 VAE 不同,Shortcut Models 无法对噪声到数据的映射做独立的干预或调整, 生成的多样性上限受限于训练数据的分布。

多步与单步生成质量仍存在差距(作者明确陈述)

"In our shortcut model implementation there remains a gap between many-step generation quality and one-step generation quality."(论文原文) 尽管 Shortcut Models 的差距远小于标准扩散模型,但 1 步的 FID 仍明显高于 128 步, 完全消除该差距仍是开放问题。

Classifier-Free Guidance (CFG) 仅适用于小步长(作者明确陈述)

论文指出,CFG 在大步长时不能直接使用,因为"linear approximation is not appropriate"(论文原文)。 实现中只在 d=0(极小步长)时使用 CFG,大步长时必须放弃 CFG 加成, 这在一定程度上限制了 1 步生成的可控性。 此外,CFG 的比例(scale)需要在训练前指定,不能在推理时灵活调整。

自一致性 bootstrapping 存在累积误差(推断 / inferred)

训练中的自一致性目标使用当前网络的停止梯度版本作为"教师"。若当前网络在某些状态下预测质量差, 则自举目标本身也可能含有噪声,形成累积误差——这是所有 bootstrapping 方法共有的理论风险, 论文中对此未作详细分析。

仅在较小规模数据集上充分验证(推断 / inferred)

主要实验集中在 CelebA-HQ-256 和 ImageNet-256 两个数据集。 在更大规模(如 LAION-级别)或更复杂的条件生成(如文本到图像)场景下的性能表现, 论文未作系统性验证。