Fork me on GitHub

大模型中的人工反馈强化学习详解

以下文章来源于 https://zhuanlan.zhihu.com/p/651780908



一、前言

我们从instruct gpt论文中了解到chatgpt的训练方式,除了基底模型的预训练之外还有sft(监督学习微调)和rlhf(基于人工反馈的强化学习)过程。现在大多的开源模型绝大多数都只经过了sft过程,缺少rlhf过程,导致模型的输出可能存在输出有毒有害的内容和有明显偏见的内容。所以rlhf过程是很关键的一步,也是很麻烦的一步,涉及到的模型很多,收敛会比较困难。下面我会来讲解下大模型中的人工反馈强化学习。


二、Actor-Critic 强化学习算法

在基于策略的强化学习中,最优策略是通过直接操纵策略来计算的,而基于值的函数通过寻找最优值函数隐式地找到最优策略。基于策略的强化学习在高维和随机连续行动空间以及学习随机策略中非常有效。同时,基于价值的强化学习在样本效率和稳定性方面表现出色。

策略梯度强化学习的主要挑战是高梯度方差。减少梯度估计方差的标准方法是使用基线函数 b(st)。
Actor-Critic结构

简单来说,Actor-Critic 是策略梯度的时间差异(TD)。它有两个网络:Actor 和 Critic。Actor(演员)决定应该采取哪个动作,Critic(评论家)告诉Actor(演员这个动作有多好以及应该如何调整。Actor(演员)的学习基于策略梯度方法。相比之下,Critic(评论家)通过计算价值函数来评估Actor(演员)产生的动作。

这种结构很像生成对抗网络(GAN),其中判别器和生成器都参与训练。生成器生成假图像,鉴别器评估生成的假图像及其对真实图像的表示有多好。随着时间的推移,生成器可以创建鉴别器无法区分的假图像。同样,Actor 和 Critic 都参与训练,但它们都随着时间的推移而不断改进,这与 GAN 不同。

Actor-critic 类似于带有基线的称为 REINFORCE 的策略梯度算法。强化是蒙特卡罗学习,表明总回报是从完整轨迹中采样的。但在Actor-Critic中,我们使用引导程序。所以主要变化在advantage函数上。

REINFORCE算法的策略梯度表达式如下所示:
REINFORCE策略梯度表达式的期望形式

轨迹奖励 R(τ):
轨迹奖励 R(τ)表达式

添加基线函数修改策略梯度表达式如下:
添加基线函数的策略梯度公式

奖励和基线项称为advantage函数:
advantage函数

b(st)变为当前状态的值函数:
b(st)的值函数

actor-critic 的advantage函数可以转化为:
actor-critic 的advantage函数

advantage函数被称为 TD 误差,如 Actor-Critic 框架中所示。如上所述,Actor的学习是基于策略梯度的。Actor的策略梯度表达式如下所示:
Actor的策略梯度


Actor-Critic算法的训练过程如下:

1、使用来自参与者网络的策略 πθ 采样 {s_t, a_t}

2、评估advantage函数 A_t。它可以称为TD误差δt。在Actor-critic算法中,优势函数是由critic-network产生的。
advantage函数

3、 计算Actor的策略梯度:
Actor的策略梯度

4、更新策略参数 θ:
策略参数表达式

5、更新基于价值的强化学习(Q-learning)的权重。δt 相当于优势函数:
基于价值的权重

6、重复1到5,直到找到最优策略πθ。


三、RLHF中的采样方法

PPO算法在训练的时候会进行多次采样,prompt有限但是样本采样是无限的,PPO算法的采样也主要是用的蒙特卡洛采样,对某一种概率分布p(x)进行蒙特卡洛采样的方法主要分为直接采样、拒绝采样与重要性采样三种。

1、直接采样

直接采样的方法是根据概率分布进行采样。对一个已知概率密度函数与累积概率密度函数的概率分布,我们可以直接从累积分布函数(cdf)进行采样。

这种采样方式存在比较大的弊端就是现实中很多分布我们是找不到率密度函数与累积概率密度函数,所以使用直接采样就不太合适。

2、拒绝采样

对于累积分布函数未知的分布,可以使用拒绝采样。p(z)是希望采样的分布,q(z)是提议的分布(proposal distribution),令kq(z)>p(z),我们首先在kq(z)中按照直接采样的方法采样粒子,接下来判断这个粒子落在途中什么区域,对于落在灰色区域的粒子予以拒绝,落在红线下的粒子接受,最终得到符合p(z)的N个粒子。
拒绝采样

3、重要性采样

拒绝采样使用的情况主要是累积分布函数未知的情况下使用的,而且非常依赖提议的分布(proposal distribution),如果没选好的话,拒绝采样的效果就很差。而重要性采样就解决了这一问题

直接采样与接受拒绝采样都是假设每个粒子的权重相等,而重要性采样则是给予每个粒子不同的权重,使用加权平均的方法来计算期望

E_{p(x)}[f(x)]=\int _a^bf(x)\frac{p(x)}{q(x)}q(x)dx=E_{q(x)}[f(x)\frac{p(x)}{q(x)}]

我们从提议分布q(x)中采样大量粒子 x_1,x_2...x_n ,每个粒子的权重是 \frac{p(x_i)}{q(x_i)} ,通过加权平均的方式可以计算出期望

E_{p(x)}[f(x)]=\frac{1}{N}\sum f(x_i)\frac{p(x_i)}{q(x_i)}


四、基于人工反馈的Actor-Critic强化学习

ColossalChat RLFH过程

ColossalChat RLFH过程也是非常接近ChatGPT的RLFH过程,RLFH过程主要涉及四个模型分别是Actor、Critic、RM、STF,损失函数也是由三个损失函数组成分别是策略损失、价值损失和 PTX 损失。

策略损失函数计算:
策略损失计算过程

通过instruction dataset数据训练STF模型,通过计算sft model的logits和actor model(没有经过sft的model)的logits计算kl散度,然后加上reward model的打分变成 reward R奖励值,避免太过偏向reward model加入和sft model的kl散度,同时也避免强化学习将actor模型训歪。

R_{reward} = r(x,y) - \beta log(\pi_{\phi}^{RL}(y | x)/\pi^{SFT}(y | x))
策略损失

\iota_{ppo} = min(rA,clip(r,1-\varepsilon,1+\varepsilon)A)

式中, \epsilon 是一个超参数,一般\epsilon=0.2 。 min中的第一项是 L^{CPI} ;第二项是 clip(r_{t}(\theta),1-\epsilon,1+\epsilon)\bar{A}_{t} ,这一项通过clip概率比例 r_{t}(\theta) 从而使 r_{t}(\theta) 的取值在 [1-\epsilon,1+\epsilon] 之间,即 r_{t}(\theta) 不会远离1。 L^{CLIP} 的意义如下:

  • \bar{A}_{t} 可理解为一个权重,PPO和TRPO要解决的是如何合理地最大化 \frac{\pi_{\theta}(a_{t}|s_{t})}{\pi_{\theta_{old}}(a_{t}|s_{t})}
  • policy在变好的话,即 \bar{A}_{t}>0 (adavantage说明当前时刻的动作比平均动作好多少),那么理应让 \pi_{\theta} 变得更好,但是不能变化太大,因为这样可能使 r_{t}(\theta)\gg1 ,导致两个概率分布差别太大。因此,最终的 r_{t}(\theta)\leq1+\varepsilon
  • policy在变差的话,即 \bar{A}_{t}<0 ,那么采样好坏已经不重要,继续优化πθ,设定一个下限。因此,最终的 r_{t}(\theta)\geq1-\varepsilon

这样做的目的就是避免模型训飞,让模型更新保持在一个小范围内。


价值损失函数计算:
价值损失计算过程

\iota_{value} = MSE(R,V(s))

上式R是reward model和sft model计算出来的反馈分数,V(s)是Critic Model输出的价值分数。主要是衡量reward分数和价值函数分数的均方误差。


ptx的损失计算:
ptx的损失计算过程

计算Actor输出response和输入语料的回答部分的交叉熵损失函数,用来在PPO梯度中加入预训练梯度,以保持语言模型原有性能防止遗忘。这个就是instruct gpt论文中在强化学习中加入预训练梯度以防过度拟合ppo数据带来nlp通用任务能力的下降操作。

\iota_{ptx}=\gamma E[log(\pi_{\phi}^{RL}(x))]


总的强化学习损失计算:

\iota=\iota_{ppo}+\iota_{value}+\iota_{ptx}


五、总结

通过ChatGPT的成功,体现了RLHF的重要性,也来探索下强化学习在大模型训练的原理和细节。

具体的细节代码:
GitHub - binmakeswell/ColossalChat: Fork from https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat

就不过多介绍代码了,有兴趣的可以进入链接看看源码。


本文地址:https://www.6aiq.com/article/1692846951585
本文版权归作者和AIQ共有,欢迎转载,但未经作者同意必须保留此段声明,且在文章页面明显位置给出