论文标题
掩盖的对比度表示学习
Masked Contrastive Representation Learning for Reinforcement Learning
论文作者
论文摘要
提高样品效率是增强学习(RL)的关键研究问题,并且使用对比度学习来从单个视频帧的原始像素中提取高级特征,这是一种有效的算法〜\ citep {srinivas20202020curl}。我们观察到,游戏中连续的视频帧高度相关,但卷曲与它们独立打交道。为了进一步提高数据效率,我们提出了一种新的算法,掩盖了RL的对比度表示学习,该学习将连续输入之间的相关性考虑在内。除了CNN编码器和Curl中的策略网络外,我们的方法还引入了辅助变压器模块,以利用视频框架之间的相关性。在训练过程中,我们随机掩盖了几个帧的功能,并使用CNN编码器和变压器根据上下文帧重建它们。 CNN编码器和变压器是通过对比度学习共同训练的,在该学习中,重建的特征应与地面真相相似,而与其他特征相似。在推断期间,CNN编码器和策略网络用于采取行动,并丢弃了变压器模块。我们的方法可以从DMControl Suite中的$ 16 $环境中获得比Curl的一致改进,而来自Atari 2600游戏的$ 26 $环境中的$ 21 $。该代码可在https://github.com/teslacool/m-curl上找到。
Improving sample efficiency is a key research problem in reinforcement learning (RL), and CURL, which uses contrastive learning to extract high-level features from raw pixels of individual video frames, is an efficient algorithm~\citep{srinivas2020curl}. We observe that consecutive video frames in a game are highly correlated but CURL deals with them independently. To further improve data efficiency, we propose a new algorithm, masked contrastive representation learning for RL, that takes the correlation among consecutive inputs into consideration. In addition to the CNN encoder and the policy network in CURL, our method introduces an auxiliary Transformer module to leverage the correlations among video frames. During training, we randomly mask the features of several frames, and use the CNN encoder and Transformer to reconstruct them based on the context frames. The CNN encoder and Transformer are jointly trained via contrastive learning where the reconstructed features should be similar to the ground-truth ones while dissimilar to others. During inference, the CNN encoder and the policy network are used to take actions, and the Transformer module is discarded. Our method achieves consistent improvements over CURL on $14$ out of $16$ environments from DMControl suite and $21$ out of $26$ environments from Atari 2600 Games. The code is available at https://github.com/teslacool/m-curl.