强化学习系列二——应用 AlphaGo Zero 思路优化搜索排序



转载请注明 AIQ - 最专业的机器学习大数据社区  http://www.6aiq.com

AIQ 机器学习大数据 知乎专栏 点击关注

文章作者:杨镒铭 滴滴出行 高级算法工程师
内容来源:记录广告、推荐等方面的模型积累 @知乎专栏

在深度学习大潮之后,搜索推荐等领域模型该如何升级迭代呢?强化学习在游戏等领域大放异彩,那是否可将强化学习应用到搜索推荐领域呢?推荐搜索问题往往也可看作是序列决策的问题,引入强化学习的思想来实现长期回报最大的想法也是很自然的,事实上在工业界已有相关探索。因此后面将会写一个系列来介绍近期强化学习在搜索推荐业务上的应用。 

本次将介绍第二篇,发表在 SIGIR2018 上,适合做搜索推荐等业务的同学精读,链接为http://www.bigdatalab.ac.cn/~junxu/publications/SIGIR2018-M2Div.pdf。作者中徐君老师在 SIGIR2018 上做过搜索推荐中 Matching 相关的分享,很赞。

一、Introduction

这篇文章解决的问题是提高搜索多样性,搜索多样性是指搜索结果覆盖更多的主题,问题可定义成从候选集合中选出一个最小的子集来尽可能覆盖更多的主题。传统的解决方法是贪心选择,每次从候选集合里选出一个 marginal relevance 的文档。贪心选择的问题在于每次选择局部最优的文档会导致最终生成的文档列表很难是全局最优的,因为每个位置选择的文档会影响后续文档的选择。如果想得到全局最优的文档列表,如果候选文档的数目为 N,暴力搜索的时间复杂度为 N 的阶乘,显然在实际的应用中不现实。

模型的整体框架

文章则借鉴了 AlphaGo 和 AlphaGo Zero 的思路,结合强化学习和蒙特卡洛搜索进行多样性排序,整体框架如上图。训练过程中,在每个时间步 (对应每个排序的位置),基于用户 query 和已产生的文档列表,使用 RNN 得到当前状态,然后基于当前状态得到指导文档选择的策略函数(raw policy) 以及评估当前列表的值函数。为了缓解次优解的问题,模型使用 MCTS 进行搜索,输出一个更优的策略(search policy)。模型的损失函数由两部分构成,一个是预测价值和文档排序指标的均方误差,另一个是 raw policy 和 search policy 的交叉熵。后面讲详细介绍整个训练过程和预测过程。

二、模型

  • 基于 MDP 的多样性排序

文档列表生成过程中每个位置的文档选择可以看做一个时间步,整体是一个序列决策的过程,适合用 MDP 来建模。下面介绍下相关的状态、动作、转换、值函数和策略函数的定义。

状态:每个时间步的状态可看做三元组,  ,q 是用户查询,  是 t 时间之前已经产生的文档列表,而  是候选文档。

动作:  中的动作对应可选择的候选文档。时间步 t 时选择的动作  选择了标号为 的文档  。

** 状态转换:** 状态转换很显然,设 表示将  加入到  中,而  表示将  从  中删除,定义为  。

值函数:根据输入状态估计文档排序列表的质量,通过近似一个预先设定的指标来学习。不同时间步的状态可以一个输入给 LSTM 模型,值函数以 LSTM 输出作为输入并经过一个非线性层,即为  ,其中  。

策略函数:策略函数也是以状态作为输入,输出则是动作上的概率分布。形式化如下,其中  是待学习参数。

 , 

  • 基于 MCTS 的更优策略

通过 p(s) 贪心地选择文档的方式仅基于历史信息去做决策,并未考虑动作  对后续决策的影响,会导致次优解。因此如下图,文章提出了利用 MCTS 在排序空间进行 lookahead 搜索。在每个排序位置,利用当前的策略函数 p 和值函数 V 进行 MCTS 搜索,返回一个更优的搜索策略  。蒙特卡洛树中节点表示状态,边表示状态转换,每条边保存三个属性,分别是动作值函数 Q(s,a)、访问次数 N(s,a) 和先验概率 p(s,a)。整体最多迭代 K 次,每次迭代中会在第四步 Back-propagation 时更新 Q 和 N 两个变量,迭代结束之后输出更优策略。

MCTS 搜索过程 - 四个步骤

Selection

从根节点  开始,在每个时间步 t 迭代选择使下式最大的动作  。其中, 控制 exploitation,倾向于选择价值更高的边;而  控制 exploration, 正比于先验概率,但随着访问次数衰减,倾向于选择探索次数少的边。

其中,  。

Evaluation and Expansion

上一步一直迭代到一个叶子节点  ,如果节点是一个 episode 的结尾则使用预设的评价指标评估,其他情况使用值函数  评估。这和 AlphaGo Zero 处理是一样的,而 AlphaGo 中价值评估是通过快速走子模拟实现的。接着叶子节点  有可能会拓展,拓展出的新边  会初始化  、  和  这三个属性。

Back-propagation and Update

利用  沿着 Selection 经过的路径反向一路更新。针对路径上的每条边 e(s,a),  不变化,  加一,  做累积平均的更新:

输出更优搜索策略

上述步骤迭代 K 次后可得到根节点  处的更优搜索策略。由于边的访问次数说明了每条边的价值高低,新策略由每个边的访问次数决定,形式化为 

  • 基于强化学习的训练方式

模型参数主要是 LSTM 和策略函数 p 中的参数。针对一个 query,在当前的参数下,每个位置执行一次 MCTS 搜索,直到产生一个排序列表  ,然后得到评估  ,其中 R 是任意的多样性评价指标,比如文中使用到的  ,而 J 是该 query 下的真值。

损失函数组成

使用每个时间步积累的训练数据  和 r 作为监督信号来训练网络。如上图,损失函数包含两部分,一个是值函数的预测值  和指标  的均方误差,另一个是 raw policy  和 search policy 的交叉熵,形式化如下: 

  • 在线预测

针对线上预测需要两种方式:一个是采用上面提到的 MCTS,但是这种方式非常耗时;另一个是放弃树搜索直接使用 raw policy 进行排序。在实验中发现即使采用第二种方法的效果依然超过了基准模型。这是因为训练时由于使用 MCTS 搜索产生了质量较好的序列来训练参数,使得策略函数 p 更加准确。这点也使得文章提出的 idea 在工业界真实上线具备可能。

三、总结

  • 对 AlphaGo 和 AlphaGo Zero 了解的同学读这篇文章会很顺畅,思路和 AlphaGo Zero 类似,借鉴了比如使用同一个网络得到策略和值函数、MCTS 搜索、同时优化值函数和策略函数的损失等成功经验。
  • 在我们的业务场景中,action 偏少,一般只有 2-3 个,引入 mcts 也已经很耗时了。在文章中,由于候选文档也就是动作可能偏多,所以想要使得 mcts 输出的策略尽可能平滑,mcts 需要很多次的迭代,文中的 k 值设置的是 5000,可想而知训练过程的耗时程度,其实就是拿时间换效果的 tradeoff。
  • 存在一个疑问的地方是在损失函数里,因为开始时可能只搜出来一篇或者两篇文档,其状态的 Q 值明显和最终整个列表生成后的状态 Q 值应该差别很大,但损失函数里要求使所有状态的 Q 值都要和最终列表的指标 r 都接近。AlphaGo Zero 中这里采取类似做法,但多了一个采样。这个细节的处理,欢迎提出不同看法。


更多高质资源 尽在AIQ 机器学习大数据 知乎专栏 点击关注

转载请注明 AIQ - 最专业的机器学习大数据社区  http://www.6aiq.com