Fork me on GitHub

强化学习系列二——应用 AlphaGo Zero 思路优化搜索排序

文章作者:杨镒铭 滴滴出行 高级算法工程师
内容来源:记录广告、推荐等方面的模型积累@知乎专栏

在深度学习大潮之后,搜索推荐等领域模型该如何升级迭代呢?强化学习在游戏等领域大放异彩,那是否可将强化学习应用到搜索推荐领域呢?推荐搜索问题往往也可看作是序列决策的问题,引入强化学习的思想来实现长期回报最大的想法也是很自然的,事实上在工业界已有相关探索。因此后面将会写一个系列来介绍近期强化学习在搜索推荐业务上的应用。 

本次将介绍第二篇,发表在SIGIR2018上,适合做搜索推荐等业务的同学精读,链接为http://www.bigdatalab.ac.cn/~junxu/publications/SIGIR2018-M2Div.pdf。作者中徐君老师在SIGIR2018上做过搜索推荐中Matching相关的分享,很赞。

一、Introduction

这篇文章解决的问题是提高搜索多样性,搜索多样性是指搜索结果覆盖更多的主题,问题可定义成从候选集合中选出一个最小的子集来尽可能覆盖更多的主题。传统的解决方法是贪心选择,每次从候选集合里选出一个marginal relevance的文档。贪心选择的问题在于每次选择局部最优的文档会导致最终生成的文档列表很难是全局最优的,因为每个位置选择的文档会影响后续文档的选择。如果想得到全局最优的文档列表,如果候选文档的数目为N,暴力搜索的时间复杂度为N的阶乘,显然在实际的应用中不现实。

模型的整体框架

文章则借鉴了AlphaGo和AlphaGo Zero的思路,结合强化学习和蒙特卡洛搜索进行多样性排序,整体框架如上图。训练过程中,在每个时间步(对应每个排序的位置),基于用户query和已产生的文档列表,使用RNN得到当前状态,然后基于当前状态得到指导文档选择的策略函数(raw policy)以及评估当前列表的值函数。为了缓解次优解的问题,模型使用MCTS进行搜索,输出一个更优的策略(search policy)。模型的损失函数由两部分构成,一个是预测价值和文档排序指标的均方误差,另一个是raw policy和search policy的交叉熵。后面讲详细介绍整个训练过程和预测过程。

二、模型

  • 基于MDP的多样性排序

文档列表生成过程中每个位置的文档选择可以看做一个时间步,整体是一个序列决策的过程,适合用MDP来建模。下面介绍下相关的状态、动作、转换、值函数和策略函数的定义。

状态:每个时间步的状态可看做三元组,  ,q是用户查询,  是t时间之前已经产生的文档列表,而  是候选文档。

动作:  中的动作对应可选择的候选文档。时间步t时选择的动作  选择了标号为 的文档  。

**状态转换:**状态转换很显然,设 表示将  加入到  中,而  表示将  从  中删除,定义为  。

值函数:根据输入状态估计文档排序列表的质量,通过近似一个预先设定的指标来学习。不同时间步的状态可以一个输入给LSTM模型,值函数以LSTM输出作为输入并经过一个非线性层,即为  ,其中  。

策略函数:策略函数也是以状态作为输入,输出则是动作上的概率分布。形式化如下,其中  是待学习参数。

 , 

  • 基于MCTS的更优策略

通过p(s)贪心地选择文档的方式仅基于历史信息去做决策,并未考虑动作  对后续决策的影响,会导致次优解。因此如下图,文章提出了利用MCTS在排序空间进行lookahead搜索。在每个排序位置,利用当前的策略函数p和值函数V进行MCTS搜索,返回一个更优的搜索策略  。蒙特卡洛树中节点表示状态,边表示状态转换,每条边保存三个属性,分别是动作值函数Q(s,a)、访问次数N(s,a)和先验概率p(s,a)。整体最多迭代K次,每次迭代中会在第四步Back-propagation时更新Q和N两个变量,迭代结束之后输出更优策略。

MCTS搜索过程-四个步骤

Selection

从根节点  开始,在每个时间步t迭代选择使下式最大的动作  。其中, 控制exploitation,倾向于选择价值更高的边;而  控制exploration, 正比于先验概率,但随着访问次数衰减,倾向于选择探索次数少的边。

其中,  。

Evaluation and Expansion

上一步一直迭代到一个叶子节点  ,如果节点是一个episode的结尾则使用预设的评价指标评估,其他情况使用值函数  评估。这和AlphaGo Zero处理是一样的,而AlphaGo中价值评估是通过快速走子模拟实现的。接着叶子节点  有可能会拓展,拓展出的新边  会初始化  、  和  这三个属性。

Back-propagation and Update

利用  沿着Selection经过的路径反向一路更新。针对路径上的每条边e(s,a),  不变化,  加一,  做累积平均的更新:

输出更优搜索策略

上述步骤迭代K次后可得到根节点  处的更优搜索策略。由于边的访问次数说明了每条边的价值高低,新策略由每个边的访问次数决定,形式化为 

  • 基于强化学习的训练方式

模型参数主要是LSTM和策略函数p中的参数。针对一个query,在当前的参数下,每个位置执行一次MCTS搜索,直到产生一个排序列表  ,然后得到评估  ,其中R是任意的多样性评价指标,比如文中使用到的  ,而J是该query下的真值。

损失函数组成

使用每个时间步积累的训练数据  和r作为监督信号来训练网络。如上图,损失函数包含两部分,一个是值函数的预测值  和指标  的均方误差,另一个是raw policy  和search policy 的交叉熵,形式化如下: 

  • 在线预测

针对线上预测需要两种方式:一个是采用上面提到的MCTS,但是这种方式非常耗时;另一个是放弃树搜索直接使用raw policy进行排序。在实验中发现即使采用第二种方法的效果依然超过了基准模型。这是因为训练时由于使用MCTS搜索产生了质量较好的序列来训练参数,使得策略函数p更加准确。这点也使得文章提出的idea在工业界真实上线具备可能。

三、总结

  • 对AlphaGo和AlphaGo Zero了解的同学读这篇文章会很顺畅,思路和AlphaGo Zero类似,借鉴了比如使用同一个网络得到策略和值函数、MCTS搜索、同时优化值函数和策略函数的损失等成功经验。
  • 在我们的业务场景中,action偏少,一般只有2-3个,引入mcts也已经很耗时了。在文章中,由于候选文档也就是动作可能偏多,所以想要使得mcts输出的策略尽可能平滑,mcts需要很多次的迭代,文中的k值设置的是5000,可想而知训练过程的耗时程度,其实就是拿时间换效果的tradeoff。
  • 存在一个疑问的地方是在损失函数里,因为开始时可能只搜出来一篇或者两篇文档,其状态的Q值明显和最终整个列表生成后的状态Q值应该差别很大,但损失函数里要求使所有状态的Q值都要和最终列表的指标r都接近。AlphaGo Zero中这里采取类似做法,但多了一个采样。这个细节的处理,欢迎提出不同看法。


本文地址:https://www.6aiq.com/article/1553609835324
本文版权归作者和AIQ共有,欢迎转载,但未经作者同意必须保留此段声明,且在文章页面明显位置给出