ICLR 2020 · Model-Based RL

Model-Based Reinforcement Learning for Atari

SimPLe:用 video prediction world model 以极低样本量学会 Atari 游戏
Łukasz Kaiser · Mohammad Babaeizadeh · Piotr Miłos · Błażej Osiński · Roy H. Campbell · Konrad Czechowski · Dumitru Erhan · Chelsea Finn · Piotr Kozakowski · Sergey Levine · Afroz Mohiuddin · Ryan Sepassi · George Tucker · Henryk Michalewski
Google Brain · deepsense.ai · University of Warsaw · UIUC · Stanford

SimPLe(Simulated Policy Learning)是首个在 Atari Learning Environment (ALE) 上通过 video prediction 模型实现竞争力的 model-based deep RL 系统。仅用 100K 次环境交互(约 2 小时真实游戏时长),在多数游戏上超越 state-of-the-art model-free 算法 Rainbow,部分游戏样本效率提升超过 10 倍

ICLR 2020 26 Atari games 100K 交互预算 📄 arXiv:1903.00374 Project Page / Videos
model-based RL video prediction world model sample efficiency Atari 离散潜变量 SimPLe policy learning

01 动机

人类玩家可以在几分钟内学会 Atari 游戏,而最优秀的 model-free RL 算法需要数千万乃至数亿次交互才能达到相近水平——相当于数周的实时训练。这种巨大的样本效率差距是 model-based RL 研究的核心动机。

"So far, there has been no clear demonstration of successful planning with a learned model in the ALE." — Machado et al. (2018) 对 Atari 基准上 model-based 控制的挑战性评述

论文的核心假设:人类之所以能快速学会游戏,部分原因在于拥有对物理过程的直觉理解,能够预测动作的结果。SimPLe 通过学习视频预测模型来实现类似的能力,从而大幅降低与真实环境的交互次数。

SimPLe 主循环
Figure 1(论文原图):SimPLe 的主循环。① agent 按当前 policy 与真实环境交互,收集观测。② 用收集的数据更新 world model(对观测帧做自监督预测,对奖励做监督训练)。③ agent 在 world model 内部用 RL 更新 policy。新 policy 再次被送回真实环境评估和采样数据(回到①),循环迭代 15 次。
100K真实环境交互上限
>10×Freeway 游戏上的样本效率提升(对比 Rainbow)
26评测的 Atari 游戏数
15SimPLe 迭代训练轮次

02 方法

SimPLe 的核心是将 world model 训练与 policy 训练交替进行(类似 Dyna-Q 框架)。world model 是一个视频预测网络,接收 4 帧 stacked 输入和 action,预测下一帧画面与奖励。policy 则在 world model 内部通过 PPO 训练,避免大量真实环境交互。

世界模型架构
Figure 2(论文原图):提出的 stochastic model with discrete latent 的架构图。输入为 4 帧 stacked frames 及 agent 的 action,输出为下一帧预测及奖励预测。网络主体为 skip-connected 卷积 encoder-decoder;action 通过 embedding 与 decoder 各层输出相乘来条件化生成。推理网络(inference network)在训练时估计后验分布并将 latent 离散化为 bits;推理时由辅助 LSTM 自回归预测这些 bits(取代从先验采样)。整个模型约 74M 参数。

World Model:随机离散 latent 模型

论文提出三种 world model 架构,最优为 stochastic discrete (SD) model

  • Deterministic model:基础 skip-connected conv encoder-decoder,action embedding 通过乘法注入 decoder 各层。
  • Stochastic VAE model:引入 variational autoencoder 建模环境随机性,但 KL 权重调参困难,推理时 latent 分布外问题严重。
  • Stochastic Discrete (SD) model(最优):将 latent 值离散化为 bits(0/1),训练辅助 LSTM 自回归预测这些 bits,规避 VAE 的 OOD latent 问题。为提升鲁棒性,在离散化前加 uniform noise,离散化后加 dropout。

Loss 设计:视觉输出采用 per-pixel softmax(256 色空间)或 L2,并使用 clipped loss max(Loss, C),其中对 L2 取 C=10,对 softmax 取 C=0.03(意味着置信度 >97% 时不产生梯度,避免大面积背景主导优化)。

Scheduled Sampling:为降低 compounding error,训练时按线性增长的概率将输入帧替换为模型自身的预测帧,到第一次迭代中期将混合率提升至 100%。

Policy Training:PPO in World Model

采用 PPO(Proximal Policy Optimization,γ=0.95)在 world model 内训练 policy:

  • N=50 步从真实数据 buffer D 中均匀采样起始状态重启模拟环境,避免长 rollout 带来的 compounding error。
  • rollout 末尾加入 value function 估值以弥补短 rollout 的远见不足。
  • 每次 PPO epoch 使用 16 个并行 agent,收集 25/50/100 步(默认 50)。
  • SimPLe 在模拟环境内共执行约 15.2M 次交互,以 100K 真实步为代价换取大量模拟经验。

Algorithm 1(SimPLe 伪代码)

  1. 初始化 policy π 和 world model 参数 θ,初始化空数据集 D。
  2. 循环(15 轮):①用 π 从真实环境 env 收集数据加入 D;②有监督地训练 world model:θ ← TRAIN_SUPERVISED(env', D);③在 world model 内更新 policy:π ← TRAIN_RL(π, env')。

03 实验

在 Atari Learning Environment (ALE) 的 26 款游戏上评测,训练预算限定为 100K 次真实环境交互(= 400K 帧 = 约 114 分钟 @ 60 FPS)。基线方法为高度调优的 Rainbow(Q-learning SOTA)和 PPO(model-free policy gradient)。评测采用 5 次运行取平均,使用 softmax(logits(π)/T)(T=0.5)的确定性 policy 评测。

与 Rainbow 和 PPO 的比较
Figure 3(论文原图):与 Rainbow 和 PPO 的样本效率比较。每根柱子表示 Rainbow(左)或 PPO(右)需要多少次与环境的交互才能达到 SimPLe 在 100K 步时的得分。红线标注 SimPLe 使用的 100K 阈值——柱高超过红线意味着该 model-free 方法需要更多样本。SimPLe 在几乎所有游戏上都更高效,Freeway 上超过 10 倍。

主要结果(选取部分游戏的 SD 模型均值得分,来自论文 Table 2)

游戏Ours, SD(均值)Ours, det. recurrent(均值)Ours, deterministic(均值)
Freeway20.323.75.9
Pong12.8-11.6-17.4
Boxing9.1-3.1-9.3
Breakout12.710.26.1
CrazyClimber39827.854700.319380.0
KungFuMaster17257.24086.610340.9
RoadRunner5169.41228.85724.4
Seaquest370.9289.6419.5

注:数值来自论文 Table 2,为 5 次实验均值。SD = stochastic discrete(提出的最优模型)。

整体比较(论文摘要引述)

"In most games SimPLe outperforms state-of-the-art model-free algorithms, in some games by over an order of magnitude."

更新至 v5 版本的结论:经 van Hasselt et al. (2019) 和 Kielak (2020) 改进后的 Rainbow 在低数据量下与 SimPLe 持平——两种 model-free 方法各在 13 款游戏上胜出,SimPLe 在另外 13 款游戏上胜出(共 26 款)。

样本量 vs 得分分数图
Figure 4 & 5(论文原图):上方:以公式 (SimPLe_score@100K - random_score)/(baseline_score - random_score) 计算的得分比例,从左到右分别对比 Rainbow@100K、Rainbow@200K、PPO@100K、PPO@200K——SimPLe 在对手获得双倍交互时仍胜出。下方:随样本量增长,PPO 追上 SimPLe 所需帧数的变化曲线(a)及 SimPLe 预热后继续做 PPO 微调的效果(b)。

消融实验

论文对 7 种配置各跑 5 次,汇总于 Table 1:

模型配置最优(26 游戏中取最高的数量)至少达中位数
deterministic07
det. recurrent313
SD(默认)1021
SD γ=0.9114
SD 100 steps014
SD 25 steps419

04 局限性

Note:以下限制均为论文作者在 "Conclusions and Future Work" 及正文中明确陈述(stated),非推断。
渐近性能低于 model-free SOTA(stated)

"The final scores are on the whole lower than the best state-of-the-art model-free methods." 论文作者指出这在 model-based RL 中普遍存在,需要更好的 dynamics model 来弥补。在极高样本量时,PPO 等方法的最终得分仍超过 SimPLe。

训练方差大、稳定性差(stated)

"The performance of our method generally varied substantially between different runs on the same game." 多轮训练的互相影响(policy 训练、world model 训练和数据收集之间的耦合)以及 world model 与真实环境之间的 domain shift 是高方差的主因。作者建议未来使用 Bayesian 参数后验或 ensemble 方法提升鲁棒性。

计算代价高昂(stated)

"The computational and time requirement of training inside world model are substantial." World model(约 74M 参数)的推理约 32ms/帧(batch size=16,NVIDIA Tesla P100),而真实 ALE 模拟器仅约 0.4ms/步——相差约 80 倍。这使得开发更轻量的 world model 成为重要研究方向。

对小物体和全局变化场景预测失败(stated)

在 Atlantis 和 Battle Zone 等游戏中,子弹等极小但重要的物体容易从模型预测帧中消失。在 Private Eye 等需要场景切换(大范围全局变化)的游戏中,模型同样难以捕捉。这类游戏上的最终得分接近随机。

低数据量时优势显著,高数据量时消失(stated)

"This demonstrates that SimPLe excels in a low data regime, but its advantage disappears with a bigger amount of data." 在约 500K 样本时 SimPLe 与 PPO 打平,之后 model-free 方法反超。作者将此归因于 SimPLe policy 熵过低导致的探索不足,制约了后续 PPO 微调的潜力。