人人都能看懂的DPO数学原理
共 22203字,需浏览 45分钟
·
2024-10-16 07:00
一、DPO在做一件什么事
在文章的开始,我们来思考一个问题:如果想让你训练一个能听得懂人类问题,并给出人类满意答案的模型,你会怎么设计大致的训练步骤?
一口吃成一个大胖子是困难的,所以不如让我们循序渐进地来设计这个训练过程:
-
首先,你的模型必须要有充足的知识储备,以应对你可能提出的任何问题
-
其次,你的模型需要明白“你在提出问题”或者“你在下达指令”这个动作,它在理解这一点的基础上,尝试按照你的指令给出相应的回答
-
最后,你希望模型不仅能对你的问题给出答案,还需要给出令你满意的回答,也就是你希望模型对齐你的偏好。
我们以chatGPT的训练为例,看看它是如何贴合这个训练步骤的:
-
首先,它使用大量的数据(文本、代码、数学等),先训练出一个base模型。这个训练过程赋予模型对文本上下文的理解能力和通用的知识,你也可以理解成它训练出了一个能做好词语接龙的base模型
-
然后,我们在这个base模型上做微调,使得模型能够听懂人类的指令,这个过程得到的模型我们称为
-
最后,我们采用rlhf的方式(奖励模型 + ppo),使得模型不仅能够听懂人类指令,还能产出符合人类偏好的回答。
(关于chatGPT原理和rlhf原理的部分,可以参考这篇和这篇文章)。
现在,我们把焦点放在第三步上:如何训练模型对齐人类的偏好。在以chatGPT为代表的训练方法中:
-
训练奖励模型(Reward Model, RM):
-
首先,我们需要有一个【标准】,这个标准在告诉待训练的模型,什么回答才是人类喜欢的。这个标准就是奖励模型,它将对各个回答进行打分。
-
然后,在训练奖励模型时,我们可以采用【偏好排序】来标注数据集。即对于一个prompt,我们可以产出若干个回答。然后让人工对这若干个回答进行偏好排序,我们就用这些数据来训练模型
-
训练对齐模型:
-
在这篇文章中,我们称经过偏好对齐训练后的模型为【对齐模型】,这个模型也是我们训练的最终目的
-
设计对齐模型的优化目标:这个优化目标不仅考虑到奖励模型的得分,也尽量让对齐模型参数更新后输出的分布不要偏移 太远,防止模型越训越差。
-
使用强化学习的方法,采用PPO手段来训练这个优化目标,因为用到了强化学习,所以这种方法又被称为rlhf-ppo。在这个过程中,我们让对齐模型根据prompt自生成回答,并采用训练好的奖励模型对回答进行打分,对齐模型会根据评分结果不断调整自己的输出分布。更多细节在上面提到的rlhf相关文章中,这里不赘述。
当你仔细端详【对齐人类偏好】这个训练步骤时,你可能会感觉有些疑惑:
-
看起来,在训练奖励模型的过程中,我们就已经在考虑“什么回答是好的,什么回答是不好的”这个问题了。而对齐模型依然是在考虑这个问题。所以,我们能不能避开奖励模型的训练,直接一步到位训练对齐模型呢?
-
在实际rlhf-ppo的训练中,存在【显存占据大】、【超参多】、【模型训练不稳定】等一系列问题。所以,在考虑“一步到位训练对齐模型”的过程中,我们是不是也能顺手做到绕过强化学习,采用一个更简单的方式(比如类似于sft)来使用偏好数据训练对齐模型呢?
基于这些考虑,DPO(Direct Preference Optimization)应运而生,正如它名字中Direct蕴含的含义一样,比起传统基于强化学习PPO的方式,它改进了以下两点:
-
不再训练奖励模型,直接使用人类标注的偏好数据,一步到位训练对齐模型。
-
不再使用强化学习的方法,通过数学推理,将原始的偏好对齐目标步步简化,最后通过类似于sft的方式,用更简单的步骤训练出对齐模型。
我们借用DPO论文中的配图,来直观比较RLHF-PPO和DPO之间的差异:
这篇文章将从数学原理上详细解释,DPO是如何从最原始的偏好对齐优化目标开始,一步步做简化的(不涉及实操代码,这个后续有时间再单开一篇文章)。本文从更符合大家逻辑思考顺序的角度,重构了DPO的推导过程,并对每一步推导过程都给出了详细的注解,希望能帮大家解决一些数学上的困惑,也更好理解DPO。
二、偏好对齐模型的优化目标
不管你是ppo还是dpo,在偏好对齐这一步中,总的优化目标是不变的,如上式所示,其中:
-
:是我们正在训练的、目的是为了对齐人类偏好的模型 -
:是训练好的奖励模型 -
:参考模型,一般是sft步骤的模型初始化而来
下面开始循序渐进解释dpo loss函数是如何从这个总体优化目标中推导而出的,大家在这个过程中依然牢记两件事:
-
绕过奖励模型 -
最大可能简化优化目标
三、步骤1:从优化目标中求解最优对齐模型
3.1 推导细节
式(11)依然是总优化目标,符号稍作了改写。现在我们要找到能最大化这个优化目标的对齐模型 。
现在我们开始对它进行改进:
第1行~第2行:
所以可以从第1行改写成第2行。
第2行~第3行:
除以 ,并取反(因此max改成min)
第3行~第4行:
-
首先,人为定义一个partition function:
-
表示在给定某个prompt x的前提下,ref模型可能生成的所有y,因此我们有 -
由 的定义我们知道,它是关于x的函数,且它和我们准备优化的模型 没有关系。
我们把 带入第3行,就可以得到第4行的结果。
观察式(12)中括号里的左半部分,我们发现它非常像一个KL散度的形式(即衡量了两个分布之间的相似性),鉴于分子 已经是个显式的分布表示了,我们干脆把分母也写成一个显示的分布表示,我们定义一个分布 :
好,现在再把这个人为定义的分布表达带回到式(12)中,我们得到:
观察式(14),前面我们说过 和我们准备优化的模型 没有关系,所以可以把它忽略掉。那么现在我们只用关心KL散度这一项。我们知道KL散度在两个分布完全相等时达到最小,由此我们可以写出模型 的显式解:
我们对式(15)再做一个简单的改写:因为以上推导都是在假设我们有一个固定的奖励函数 的基础上进行的,所以我们可以加一个下标 来强调这一点,则式(15)可进一步被改写成:
可是,在正常的对齐训练中,这个奖励函数 可不是任意的,它是我们先用数据训练出来的最优奖励模型,然后在这个最优奖励模型的基础上,我们再通过训练去找到最优对齐模型 。最优的奖励模型 和基于它训练出的最优的对齐模型 依然满足式(4)的关系,我们分别设它们为 ,则有:
后面这些推导步骤没有什么难度,无非是做了些公式和符号上的变化。
3.2 步骤1总结
到此为止,经过了长长的推导,你可能已经有点懵了,没关系,我们总结一下步骤1中我们做的事情:
-
首先,我们有一个对齐人类偏好阶段的总优化目标函数,它是在假设我们已经有一个奖励函数 的基础上设计的,我们的目标是找到能使这个目标值最大化的对齐模型 :
-
然后,我们从这个目标函数出发,找到 的显式解(也就是在任意固定的奖励函数r的基础上最优的 ):
其中, 是人为定义的partition function,它形式为
-
最后,由于在实际训练中,我们肯定是在最优的奖励函数上去训练最优的对齐模型 ,所以我们对上式的符号稍加更改,令星号代表最优,则有:
四、步骤2:跳过奖励模型
虽然我们现在得到了对齐模型 的显式解 ,但是我们却很难直接利用起这个显式解形式,原因如下:
-
的值很难估计。根据 的形式可知,想要估计它,需要对一个prompt x采样足够多的回答y。这个代价是十分昂贵的。
-
同时回顾最开始我们的目标:省略训练奖励模型这个步骤,一步到位来训练对齐模型。而目前我们得到的 的显式解仍然需要一个确定的奖励函数 ,没有达到我们的目标。
所以现在我们继续来迭代。基于上述第2个原因,我们可以先从 的显式解中推出奖励函数 的形式:
好,现在既然我们能用最优对齐模型 表示出最优奖励模型 了,那么我们直接把 代入到奖励模型的训练优化目标中去,不就意味着我可以明面上训练奖励模型,实际上却一步到位训练出了对齐模型吗?这不就能实现我们最开始的目标了么?
现在,问题回到“奖励模型的训练上”来了。我们通常使用“偏好排序”这种数据标注方式来对奖励模型进行训练,一般有2种偏好排序方法:
-
只生成2个回答,
<prompt x, chosen y1, reject y2>
,即对于一个prompt,我们只生成2个回答,让人工对这两个回答的偏好做排序,我们希望奖励模型对chosen回答的给分尽量高于对reject回答的给分。 -
生成K个(K > 2)回答,
<prompt x, y1, ..., yK>
,假设人工标注后的偏好排序组合为 (比如人工人为偏好从大到小应该为y2 > y3 > y1 >... > yK,则这个排列就为 ),那么我们希望奖励模型对 这个排序的总得分要大于其余任何可能的偏好排序。
在某些框架(比如chatGPT)的训练中,当生成的回答>2个时,它会把回答拆成两两pair对,这样就可以和只生成2个回答时的目标函数做统一。但在更一般的场景中,对于>2个回答的场景,我们是把每一种可能的回答偏好排序当成一个整体数据进行处理的,然后希望真值排序 的得分最高。DPO的推导是基于后者进行的,所以接下来,我们也对K=2和K>2这两种情况分别下DPO最终的目标函数形式推导进行详细说明。
4.1 BT模型:只生成2个回答
<prompt x, chosen y1, reject y2>
,对于一个prompt,我们只生成两个答案,然后在这两个答案间进行偏好排序。那么在这种偏好标注数据下,我们该怎么设计奖励模型的训练目标呢?
首先,我们需要明确,我们到底希望一个好的奖励模型能做什么事?我们当然是希望“chosen y1打败reject y2的概率尽量大”。基于此,我们可以引入统计模型Bradley-Terry(BT模型)进行建模,该模型在1952年被首次提出,用于分析成对数据间的相对优势或者偏好,被广泛应用于体育比赛分析、市场研究等场景。在该模型下,我们假设有一个成对数据 ,则“y1打败y2的概率”可以被表示成:
其中, 分别表示 的强度参数。
什么是强度参数呢?假设我们现在想预测球队1在本场能打败球队2的概率,那么强度参数就可以是这两只球队过往的胜率。那么同理,如果现在y1和y2分别表示chosen和reject回答,那么强度参数就可以是奖励模型对这两个回答打出的分数,则根据BT模型,我们有:
我们希望y1打败y2的概率尽量大,也就是我们希望对于整个标注数据集,chosen打败reject的期望概率尽量大(其中,w=chosen,l=reject),所以奖励函数的总体优化目标可以设计成:
我们把P的具体形式代入这个函数,则有:
诶,你看,这最后一行公式你眼熟不?这不就是chatGPT中构造的奖励模型优化目标的最终形式么?在chatGPT的论文中,直接给出了这个优化目标,并做了直觉上的解读。而这里我们则更进一步:从最经典的BT模型开始,一步步推导出成对偏好数据下的奖励模型优化目标应该如何设计。
好,现在我们假设,最优奖励模型为 ,则我们将其带入上面的优化目标中,就有:
而同时,根据前文的推导,最优的奖励模型又可以用最优的对齐模型 来显式表示,即:
我们把这个显式表示带入上面的优化目标中,则有:
到这里,你是否惊奇地发现:我们已经把训练奖励模型的目标函数转化成只和对齐模型 相关了!也就是说,我们可以一步到位,绕开训练奖励模型的过程,直接用标注好的【成对】偏好数据,以类似于sft的过程直接训练对齐模型了!因此我们对上述式子再稍加改动,我们设等待训练的对齐模型为 ,则有:
看,这就是【成对】偏好数据下DPO的优化目标。
总结一下,在这一节中:
-
我们通过经典的BT模型,先一步步推导出【成对偏好数据】下训练奖励模型的优化目标
-
然后,我们再使用前问推导出的奖励函数和对齐模型的关系,把奖励模型优化目标中和奖励函数相关的部分替换成对齐模型,构造出了【成对偏好数据】下的DPO优化目标,以此达到绕过奖励模型的偏好对齐训练。
4.2 PT模型:生成K(K>2)个回答
现在,如果我不想使用【成对偏好数据】,而是对于一个prompt,我标出K(K>2)个回答,然后对这些回答进行人工偏好排序,在这种方式下,我要怎么设计奖励模型优化目标呢?
类比于BT,我们同样有一个基于统计的PT模型(Plackett-Luce)可以对多数据的偏好进行排序,假设 为人工标注出的真值排序,则我们当然希望 能够打败其余任何一种可能的偏好排序,我们将“最优排序 打败其余任何一种排序”的概率表示成:
其中, 表示人类标注的偏好序列 中的第k个数据,序列 中的K个回答已经按照偏好从高到低进行排序
上面这个公式从直观上理解的话:
-
对于真值 中的第一个回答 ,它是人工标注的偏好最高的数据,我们当然希望它的得分在 中占大头
-
对于真值 中的第一个回答 ,我们当然希望它的得分在 中占大头
-
对于真值 中的第一个回答 ,我们当然希望它的得分在 中占大头
-
以此类推,则不难理解上述在PT模型下概率P的表达方式。
同样,我们把最优奖励函数 代入上面的P中,则有:
然后我们再用 去表示 ,则有(这里我们可以把 省略掉,因为正如前文所说,它和对齐模型 没有关系):
那么对于整个数据集来说,我们希望最优序列打败其余任何一个可能序列的期望概率尽量大,则多回答下DPO的目标函数可以写成:
五、(必看)DPO优化目标推导过程总结
我们把DPO的推导过程完整总结一次:
-
首先,我们有一个对齐人类偏好阶段的总优化目标函数,它是在假设我们已经有一个奖励函数 的基础上设计的,我们的目标是找到能使这个目标值最大化的对齐模型
-
然后,我们从这个目标函数出发,找到 的显式解(也就是在任意固定的奖励函数r的基础上最优的 ):
其中, 是人为定义的partition function,它形式为
-
接着,由于在实际训练中,我们肯定是在最优的奖励函数上去训练最优的对齐模型 ,所以我们对上式的符号稍加更改,令星号代表最优:
-
然后,由于我们的训练目标是尽量绕过奖励模型,直接使用偏好数据,通过类似sft的方式一步到位训练对齐模型,所以我们将上式先改写成如下形式。接下来,我们只需要关注如何构造奖励函数的训练目标,然后把这个式代入训练目标中,就可以一步到位训练对齐模型。
-
在奖励模型的训练过程中,有2种数据标注的方式,分别是【成对回答偏好标注】和【多回答偏好标注】。
-
成对回答偏好标注:在这类数据标注方式上,我们使用BT模型,先推导出奖励模型的优化目标,然后使用 替换掉优化目标中的 ,得到最终的dpo优化目标: