ICLR 2021 · Model-Based RL

Mastering Atari with Discrete World Models

DreamerV2:首个在独立世界模型中实现 Atari 人类水平的 RL 智能体
Danijar Hafner (Google Research)  ·  Timothy Lillicrap (DeepMind)  ·  Mohammad Norouzi (Google Research)  ·  Jimmy Ba (University of Toronto)

DreamerV2 通过在离散世界模型的紧凑潜空间内纯粹预测来学习行为,首次以单 GPU 在 Atari 55 款游戏基准上实现人类水平性能,同时超越顶级单 GPU model-free 算法 IQN 和 Rainbow。其核心创新在于以 categorical 离散隐变量替代 Gaussian 隐变量,并引入 KL balancing 技术稳定世界模型训练。

ICLR 2021 单 NVIDIA V100 · 10 天 55 Atari Games · 200M Frames 📄 arXiv:2010.02193 Project Page / Code
DreamerV2 world model discrete representations RSSM KL balancing latent imagination model-based RL categorical latents Atari benchmark 强化学习

01 动机 Motivation

世界模型理论上能使智能体从想象中的经历学习行为,从而大幅提升样本效率。然而,多年来在 Atari 这一最具竞争力的基准上,没有任何世界模型方法能与主流 model-free 算法竞争。核心问题是:如何让世界模型足够精确,从而完全在其预测中学习出成功的游戏策略?

"We introduce DreamerV2, the first reinforcement learning agent that achieves human-level performance on the Atari benchmark of 55 tasks by learning behaviors inside a separately trained world model."
DreamerV2 Atari benchmark results teaser
Figure 1(论文原图):DreamerV2 在 Atari 55 款游戏基准(sticky actions,200M steps)上的 gamer normalized median score 达到 2.15,超越 IQN(1.29)、Rainbow(1.47)及所有单 GPU model-free 算法,是首个以 model-based 方式达到人类水平的智能体。
2.15×Gamer Median Score(DreamerV2,200M frames)
0.28Clipped Record Mean(推荐的鲁棒评估指标)
10 天单 V100 GPU 完成 200M 环境步
468B世界模型内部想象的 compact states 数量

02 方法 Method

DreamerV2 由三部分组成:世界模型(World Model)从过去经验中学习环境的紧凑表征;Actor-Critic 在世界模型的想象轨迹中学习策略;环境交互不断扩充经验数据集。世界模型与策略分开训练,使策略可以充分利用世界模型的表征而不相互干扰。

DreamerV2 World Model Learning
Figure 2(论文原图):世界模型结构。图像序列经 CNN 编码,RSSM 维护确定性 recurrent state ht;在每个时间步同时计算后验随机状态 zt(融入当前图像 xt)和先验随机状态 ẑt(不访问当前图像,用于后续想象)。DreamerV2 的关键改进是将随机状态由 Gaussian 改为 32 × 32 classes 的 categorical 变量(32 个 categorical,每个 32 类),并用 straight-through gradients 优化。

离散潜变量(Categorical Latents)

DreamerV2 用一组 32 个 categorical 变量(每个含 32 个类别)替代 DreamerV1 中的 Gaussian 潜变量。Flatten 后得到长度 1024、仅 32 位为 1 的稀疏二值向量。作者列出四条为何 categorical 变量优于 Gaussian 变量的假设:

KL Balancing

世界模型的损失函数包含 KL 项,同时训练先验(transition predictor)和正则化后验(representation model)。标准 KL 的问题是:先验在早期训练不充分,若过度正则化后验会损害表征质量。DreamerV2 引入 KL balancing:以不同学习率优化先验和后验,具体地用混合系数 α = 0.8:

kl_loss = alpha * compute_kl(stop_grad(approx_posterior), prior) + (1 - alpha) * compute_kl(approx_posterior, stop_grad(prior))

这使先验被更快地拉向后验(促进精确先验动力学),而不是让后验熵增大以减小 KL(避免表征退化)。KL balancing 与 beta-VAE 方法正交,可配合使用。

DreamerV2 Actor Critic Learning
Figure 3(论文原图):Actor-Critic 行为学习。从世界模型训练中的后验状态出发,利用 transition predictor 向前展开 H = 15 步的想象轨迹(无需生成图像),actor 采样动作,critic 用 λ-return(λ = 0.95)估计价值。Actor 同时使用 Reinforce 梯度和 straight-through 梯度的组合(Atari 上 ρ = 1 倾向 Reinforce),critic 每 100 梯度步更新一次 target network。单 GPU 可并行模拟 2500 条潜空间轨迹。

行为学习(Behavior Learning)

Actor 和 Critic 都是 MLP(各约 1M 参数),在世界模型固定后的想象 MDP 中训练。Critic 用 λ-return 的均方误差损失,Actor 同时最大化 λ-return(通过 Reinforce 梯度和 straight-through 梯度的加权和)并添加熵正则化(Atari 上 η = 10⁻³)。世界模型总参数约 20M。

03 实验 Experiments

在 55 款 Atari 游戏上评估(sticky actions,200M steps,action repeat 4,单 V100 GPU,单环境实例),对比 IQN、Rainbow、C51、DQN 四个 model-free 基准(分数来自 Dopamine 框架)。作者同时提出四种聚合方式并推荐 Clipped Record Mean 作为最鲁棒评估指标。

DreamerV2 Atari Performance Curves
Figure 4(论文原图):200M steps 内四种聚合方式下的性能曲线。DreamerV2 在所有指标上均优于 model-free 基准,在 Record Mean 上领先最为显著(0.44 vs IQN 0.21、Rainbow 0.17)。
算法Gamer MedianGamer MeanRecord MeanClipped Record Mean
DreamerV22.1511.330.440.28
DreamerV2 (schedules)2.6410.450.430.28
IQN1.298.850.210.21
Rainbow1.479.120.170.17
C511.097.700.150.15
DQN0.652.840.120.12

Table 1(论文原表):200M steps 时各算法在 55 款游戏上的汇总分数。DreamerV2 在全部四项指标上超越所有单 GPU 基准。值得注意的是,Rainbow 在 Gamer Median 上优于 IQN(1.47 > 1.29),但在其余三项指标上 IQN 均优于 Rainbow,说明聚合方式的选择对排名有显著影响。

DreamerV2 Ablation Study
Figure 5 & Table 2(论文原图/表):消融实验结果(使用略早版本的 DreamerV2)。移除 Discrete Latents 导致 Clipped Record Mean 从 0.25 下降至 0.19;移除 KL Balancing 降至 0.16;移除 Image Gradients 导致几乎完全失效(0.01)。停止 Reward Gradients 在部分任务上反而略有提升,说明 DreamerV2 的表征来自图像重建而非奖励预测。

消融实验(Ablations)

论文通过逐一移除组件进行消融(Table 2):

04 局限性 Limitations

Note: 论文未设专门的 Limitations 小节。以下标注为来源:作者在正文中明确指出(stated),或从方法设计推断(inferred from the design)。
Video Pinball 等游戏的失败案例(stated)

作者明确指出:"We hypothesize that the reconstruction loss of the world model does not encourage learning a meaningful latent representation because the most important object in the game, the ball, occupies only a single pixel." 即对于关键物体极小的游戏,图像重建损失无法提供足够的学习信号,导致表征质量差。

计算成本高于 Rainbow(stated / inferred)

DreamerV2 在 10 天内完成 200M 步(单 V100),与 Rainbow 相当;但世界模型需要额外维护 20M 参数,并在每步同时更新世界模型和策略,内存与计算开销更高。MuZero 虽更强大,但需 2 个月以上的 GPU 计算(stated:论文明确指出 MuZero "would require over 2 months of computation to train even one agent on a GPU")。

Categorical 潜变量优势机制仍不明确(stated)

作者诚实地承认:"While we do not know the reason why the categorical variables are beneficial, we state several hypotheses that can be investigated in future work." 该设计在实验上有效,但理论解释尚不充分。

连续控制场景下表现有限(inferred)

论文展示了 DreamerV2 在 humanoid stand-up 和 walking(连续动作)上的初步结果,但主要评估集中在 Atari(离散动作)。连续控制下使用 dynamics backpropagation(ρ = 0)而非 Reinforce,超参数需分别调整,跨任务泛化能力未经系统验证。