NIPS 2015 · 强化学习 · 视频预测

Action-Conditional Video Prediction using Deep Networks in Atari Games

首个在高维 Atari 游戏画面上实现超 100 步动作条件长期预测的深度网络框架
Junhyuk Oh · Xiaoxiao Guo · Honglak Lee · Richard Lewis · Satinder Singh  |  University of Michigan, Ann Arbor

本文针对视觉强化学习问题提出两种深度神经网络架构——基于 feedforward(前馈)编码和 recurrent(循环)编码——用于预测依赖动作变量的未来游戏帧。实验表明,模型能在 Atari 游戏中生成视觉逼真、可用于控制的超 100 步预测帧,并可通过"知情探索(informed exploration)"策略改善 DQN 的训练效率。

NIPS 2015 5 个 Atari 游戏 100+ 步预测 📄 arXiv:1507.08750 Project Page
action-conditional prediction video prediction Atari games DQN reinforcement learning LSTM multiplicative interaction informed exploration 动作条件预测 model-based RL

01 动机

在视觉强化学习(RL)中,智能体观察到的画面不仅取决于历史帧,还受当前动作的支配。如何在高维图像空间中建立精确的动作条件转移模型,是 model-based RL 的核心挑战之一。

"To the best of our knowledge, this paper is the first to make and evaluate long-term predictions on high-dimensional video conditioned by control inputs."

Atari 游戏提供了极具挑战性的场景:画面分辨率达 210×160 像素 RGB,场景中可能存在数十个对象,部分对象由动作直接控制,其余受间接影响;对象可随时进入或离开画面,且存在深度局部可观测性。之前的视频预测研究(如 RTRBM、LSTM 等)仅处理弹跳球或小图像块等简单情形,且忽略动作变量。本文要解决的核心问题是:

Freeway 游戏 250 步预测示例
图 1:Freeway 游戏中超过 250 步的动作条件预测示例。"Step"列为预测步数,"Action"列为动作。白框标出智能体控制的对象。预测步 256→257 时,被控对象越过上边界后从底部重现——本文两种架构均能正确预测这一非线性跳变,而 MLP 和 naFf 基线则无法预测。
100+可用于控制的预测步数
5Atari 游戏域(Seaquest / S.Invaders / Freeway / QBert / Ms Pacman)
500K每个游戏的训练帧数
提出架构:Feedforward 编码 + Recurrent 编码

02 方法

两种架构均由三个模块组成:编码层(Encoding)从输入帧提取时空特征,动作条件变换层(Action-Conditional Transformation)将特征依动作变换为下一帧的预测特征,解码层(Decoding)将高层特征反卷积回像素空间。

Feedforward Encoding 架构
图 2a:Feedforward encoding 架构。将最近 4 帧拼接为多通道输入,经 4 层卷积 + 全连接层提取特征向量 henc,再通过乘性变换层和反卷积解码器输出预测帧。
Recurrent Encoding 架构
图 2b:Recurrent encoding 架构。每次只输入一帧,由 CNN + LSTM 联合编码时序特征;LSTM 的记忆细胞保留长历史依赖,再经相同的变换层和解码器输出预测帧。

Feedforward Encoding(前馈编码)

将最近 m 帧(本文取 m=4)按通道拼接后输入 CNN,直接从像素提取时空特征:

henct = CNN(xt−m+1:t)

卷积层使用 64×(8×8)、128×(6×6)、128×(6×6)、128×(4×4) 滤波器,步幅均为 2;后接全连接层输出 2048 维特征向量。该方案类似"early-fusion",可精确建模局部像素级时空关系,但难以捕获超出输入窗口的长期依赖。

Recurrent Encoding(循环编码)

每步输入单帧,先用 CNN 提取空间特征,再送入 LSTM(2048 个隐藏单元)捕获任意长度时序依赖:

[henct, ct] = LSTM(CNN(xt), henct−1, ct−1)

LSTM 的记忆细胞 ct 保留来自深层历史的信息,能处理 9 步以上的长期事件(如 Space Invaders 敌人的周期性移动)。

Multiplicative Action-Conditional Transformation(乘性动作条件变换)

动作变量通过乘性交互注入特征空间,避免简单拼接导致的加法独立性问题。为降低参数量,对三阶张量进行因式分解(因子数 f=2048):

hdect = Wdec(Wenc·henct ⊙ Wa·at) + b

当动作为 one-hot 向量时,该乘性变换等价于为每个动作使用独立的权重矩阵,能够建模不同动作对应的不同变换。

Convolutional Decoding(卷积解码)

变换后的特征向量经全连接层 reshape 为 128×11×8 的特征图,再经 4 层反卷积(滤波器:128×(4×4)、128×(6×6)、128×(6×6)、3×(8×8),步幅均为 2)还原为全分辨率 RGB 图像(210×160)。反卷积比先上采样再卷积的方案更高效。

Curriculum Learning with Multi-Step Prediction(课程式多步预测训练)

为避免 1 步预测误差在长期滚动推理中累积,采用课程学习策略,依次以 K=1、3、5 步为目标函数训练,学习率分别为 10−4、10−5、10−5

LK(θ) = (1/2K) Σi Σt Σk=1..K ‖x̂(i)t+k − x(i)t+k‖²

每阶段迭代次数分别为 1.5×106、106、106;RMSProp(momentum=0.9)优化,batch size feedforward 网络为 32/8/8,recurrent 网络为 4/4/4。

03 实验

在 5 个 Atari 游戏(Seaquest、Space Invaders、Freeway、QBert、Ms Pacman)上进行评估,训练集约 500,000 帧,测试集约 50,000 帧,使用 DQN ε-greedy 策略(ε=0.3)采集数据。图像为全分辨率 210×160 RGB,每 4 帧执行一次动作(60fps → 15fps)。对比基线:MLP(4 隐层,约同等参数量)和 naFf(无动作输入的前馈网络)。

量化评估:100 步均方误差

本文两种架构在所有游戏域的 100 步 MSE 上均优于两个基线。其中与 naFf 的差距在 Seaquest 最为明显(被控对象占图像面积更大);Space Invaders 等游戏中差距较小,原因是被控对象仅占图像的一小部分。

使用预测帧替代真实帧的游戏得分
图 3:将预测模型替代 game emulator 后的游戏得分。x 轴为每次重置前连续预测的步数(1–100),y 轴为 30 局平均得分。"Emulator"为使用真实帧的 DQN 基准,"Rand"为随机策略。本文两种架构(蓝/绿曲线)显著优于 MLP 和 naFf 基线,在 3/5 个游戏中即使预测 100 步后得分仍远高于随机策略。

知情探索(Informed Exploration)改善 DQN 训练

利用预测模型为 ε-greedy 策略中的探索动作提供信息:在随机动作选择时,改为选择预测帧与最近 d 帧轨迹记忆中最少出现相似帧的动作(用 Gaussian kernel 估计访问频率)。训练了专用的灰度/下采样(84×84)前馈网络提升计算效率。

模型SeaquestS. InvadersFreewayQBertMs Pacman
DQN - Random exploration 13119 (538)698 (20)30.9 (0.2)3876 (106)2281 (53)
DQN - Informed exploration 13265 (577)681 (23)32.2 (0.2)8238 (498)2522 (57)

知情探索在 3/5 个游戏中提升了 DQN 性能,QBert 提升最显著(3876→8238)。括号内为标准误差,数据来自 100 局游戏。

表征分析:动作相似性与受控/非受控对象分离

动作因子余弦相似度矩阵
图 4:Seaquest 训练后动作因子(Wa·a)的两两余弦相似度。方向相同的动作(如"up"与"up+fire")呈正相关,方向相反的动作(如"up+right"与"down+left")呈负相关,体现了模型自动学习到的合理动作语义结构。"N"="no-operation","F"="fire"。
受控/非受控对象分离可视化
图 5:分离受控与非受控对象的可视化。给定同一输入帧,仅激活高方差因子(Action 图)时,模型准确预测被动作控制的对象的移动,背景为均值像素;仅激活低方差因子(Non-Action 图)时,模型预测其他对象和背景的运动,被控对象停留在原位。

04 局限性

Note: 以下局限性由作者在论文正文中明确陈述(stated),最后一条为从设计中推断(inferred)。
小对象预测困难(stated)

"Both of our models have difficulty in accurately predicting small objects, such as bullets in Space Invaders. The reason is that the squared error signal is small when the model fails to predict small objects during training."——以 MSE 为损失函数时,小对象的像素误差对总损失贡献极小,导致模型缺乏充足的学习信号。

随机性处理困难(stated)

"In Seaquest, e.g., new objects appear from the left side or right side randomly, and so are hard to predict. Although our models do generate new objects with reasonable shapes and movements…the generated frames do not necessarily match the ground-truth."——随机出现的对象(如敌方潜水艇)无法被确定性模型精确预测,生成结果在形状和运动方向合理,但位置不一定与真实帧吻合。

Feedforward 编码无法捕获长期依赖(stated)

"[Feedforward architecture] is not suitable for long-term dependencies because it requires more memory and parameters as more frames are concatenated into the input."——在 Freeway 等游戏中,当智能体进入新阶段后动作被忽略 9 步,前馈架构因输入窗口仅为最近 4 帧而无法建模这一长期状态,导致在该场景下发生错误预测(diverge)。

尚未延伸至 reward 预测与完整 model-based RL(stated)

论文结论中指出:"In future work we will learn models that predict future reward in addition to predicting future frames and evaluate the performance of our architectures in model-based RL."——当前架构仅预测像素帧,尚未整合奖励信号预测,因此不能直接支撑完整的 model-based RL 规划。

未扩展到更大自然场景视频(inferred)

所有实验均限于 Atari 游戏(非自然场景,分辨率 210×160)。更高分辨率、更复杂动态(如自动驾驶第一视角)中的泛化能力未经验证。