NeurIPS 2024 · Spotlight

Diffusion for World Modeling: Visual Details Matter in Atari

用扩散模型构建世界模型,视觉细节决定智能体表现
Eloi Alonso, Adam Jelley, Vincent Micheli, Anssi Kanervisto, Amos Storkey, Tim Pearce, François Fleuret

DIAMOND(DIffusion As a Model Of eNvironment Dreams)是首个在扩散世界模型中训练的强化学习智能体。与离散隐变量压缩方案不同,DIAMOND 在像素连续空间中直接建模环境动态,保留了对策略决策至关重要的视觉细节,在 Atari 100k 基准上取得了 1.46 的平均人类归一化分数,创下世界模型内训练智能体的最优记录。

Atari 100k · 26 games Mean HNS: 1.46 (SOTA) NeurIPS 2024 Spotlight 📄 arXiv:2405.12399 Project Page
diffusion world model reinforcement learning Atari 100k EDM visual fidelity 世界模型 扩散模型 在线强化学习

01 动机(Motivation)

现有世界模型普遍依赖离散隐变量进行压缩,这种有损压缩会丢弃对策略决策至关重要的视觉信息。作者认为,小的视觉细节——例如远处的行人或交通灯——可能直接改变智能体的行为,而离散编码无法忠实还原这些细节。

"Discrete latent representations involve lossy compression that may discard visual information crucial for learning. Small details in the visual input, such as a traffic light or a pedestrian in the distance, may change the policy of an agent."
DIAMOND imagination procedure
Figure 1. DIAMOND 的想象展开过程。顶行展示策略 πϕ 在学习到的扩散世界模型 Dθ 的"想象"中执行动作序列。横轴为环境时间 t,纵轴为扩散的去噪时间 τ(从 T 到 0 逆向流动)。每一帧的生成都经历一个完整的去噪过程,保留丰富的像素级细节。
1.46Mean HNS(Atari 100k)
世界模型训练智能体最优
0.64IQM 得分
超越 STORM (0.636)
11/26超越人类水平的游戏数量
3 NFE每帧去噪步数
vs. IRIS 的 16 NFE

扩散模型天然具备在连续空间建模、灵活捕捉多模态分布、无模式崩塌等优势。核心问题在于:能否将扩散模型高效、稳定地用于在线强化学习的世界模型,从而让智能体在高质量的想象轨迹中学习?

02 方法(Method)

DIAMOND 将扩散模型作为环境动态模型,在像素空间直接预测下一帧图像。智能体在世界模型生成的"梦境轨迹"中通过 REINFORCE 算法学习策略,无需访问真实环境(除数据收集阶段外)。

EDM 扩散框架选择

关键设计决策是采用 EDM(Elucidating the Design Space of Diffusion-Based Generative Models)而非 DDPM 框架。两者的核心差异在于预训练目标:

DDPM vs EDM trajectory stability
Figure 3. 基于 DDPM(左)与 EDM(右)的扩散世界模型在 Breakout 游戏上的想象轨迹对比,从同一初始帧出发,每行对应不同的去噪步数 n。DDPM 出现严重的误差累积,世界模型迅速漂移出分布;EDM 即使在单步去噪下也表现出长时序稳定性。

网络架构与条件化

动态模型采用标准 U-Net 2D 架构,将历史 L 帧图像观测与动作拼接为条件输入:

多步去噪与多模态分布

采用 n=3 步 Euler 方法去噪。单步采样在环境状态多模态(如 Boxing 中黑色选手移动方向不可预测)时产生模糊预测;多步采样将生成"驱向特定模式",产生清晰图像。

Single vs multi-step sampling in Boxing
Figure 4. Boxing 游戏中单步(顶行)与多步(底行)采样对比。黑色选手的移动方向不可预测,单步去噪在多种可能结果间插值,导致模糊预测;多步采样则生成清晰图像,"driving the generation towards a particular mode"。

RL 智能体架构

奖励与终止模型采用独立的 CNN-LSTM 网络;Actor-Critic 共用一个 CNN-LSTM 主干,分别接策略头与价值头。策略学习使用带价值基线的 REINFORCE 算法,价值更新使用 Bellman 误差与 λ-returns。

03 实验(Experiments)

在 Atari 100k 基准(26 款游戏,每款 5 个随机种子,每种子 100k 帧真实交互)上与 SimPLe、IRIS、DreamerV3、TWM、STORM 等世界模型基线对比。计算资源:每款游戏约 2.9 天(单张 Nvidia RTX 4090),共约 1.03 GPU 年。

Mean and IQM human normalized scores
Figure 2. 各智能体在 Atari 100k 上的 Mean HNS 与 IQM 得分(含置信区间)。DIAMOND(蓝色)获得 Mean HNS 1.46、IQM 0.64,超越所有世界模型内训练的智能体。
智能体 Mean HNS (↑) IQM (↑) 超人类游戏数 (↑)
SimPLe0.3320.1301
TWM0.9560.4598
IRIS1.0460.50110
DreamerV31.0970.4979
STORM1.2660.63610
DIAMOND(本文)1.4590.64111

IRIS vs. DIAMOND 视觉一致性对比

定性分析展示了 IRIS 在连续帧中出现的视觉不一致现象,而 DIAMOND 无此问题:

IRIS inconsistencies
IRIS 生成的连续帧:白框标注处可见敌人与奖励显示混淆、砖块/分数渲染不一致等错误。
DIAMOND consistency
DIAMOND 生成的连续帧:无上述不一致现象,视觉细节保持忠实连贯。

CS:GO 扩展实验

为验证方法的可扩展性,作者在 Counter-Strike: Global Offensive 的 Dust II 地图上训练了大规模扩散世界模型:87 小时(5M 帧)训练数据,动态模型参数量从 4M 扩展至 381M(含 51M 上采样器),训练耗时 12 天(RTX 4090),推理运行于 10Hz(RTX 3090)。

CS:GO world model gameplay
Figure 6. 玩家用键鼠在 DIAMOND 扩散世界模型内部游玩 CS:GO Dust II 地图的截图。该模型仅基于静态游戏录像训练,可生成"stable trajectories over hundreds of timesteps",展现了扩散世界模型在复杂真实场景中的可扩展性。

消融实验

核心消融结论:

04 局限性(Limitations)

说明:以下局限性均为论文 Section 8 中作者明确陈述(stated),非推断。
仅评估离散动作控制,未扩展至连续控制域

本文实验集中在 Atari 等离散动作游戏,尚未在连续控制任务(如 MuJoCo、机器人操控)上验证扩散世界模型的有效性。

帧堆叠提供的记忆机制十分有限

当前架构通过拼接过去 L 帧(frame stacking)提供历史信息,作者承认这"provides a minimal mechanism for memory"。基于 Transformer 的方法(如 STORM)在长时依赖建模上具有潜在优势,未来引入 Transformer 历史编码器可进一步提升性能。

奖励与终止预测独立于扩散模型,架构未完全统一

奖励和终止信号由独立的 CNN-LSTM 网络预测,而非由扩散动态模型统一建模。作者认为将奖励/终止集成进扩散模型"would make our world model unnecessarily complex",因此保持了模块化设计,但这也意味着整体架构并非端到端统一。

计算成本较高,不适合资源受限场景

每款 Atari 游戏的完整训练需要约 2.9 天(RTX 4090),26 款游戏合计约 1.03 GPU 年。扩散模型的逐步去噪推理比离散编码方案计算密集,尽管 3 NFE/帧已是较大优化,CS:GO 大规模版本(381M 参数)更需 12 天训练。