Transformer Decoder-Only 模型批量生成 Trick
极市导读
本文给出了一个用单Transformer decoder( GPT)模型进行批量生成时的解决方法。 >>加入极市CV技术交流群,走在计算机视觉的最前沿
发现用单 Transformer decoder (Aka GPT)模型进行生成时,因为位置对齐等问题,进行批量生成时十分麻烦。
训练时,context 和 target 可以直接拼一起,然后一个 Batch 内通过裁剪或 Padding 到相同长度来进行批量训练。但生成时,只有 context,每个长度还不同,如果 Padding 到相同长度,直接进行生成的话,会让生成阶段和训练阶段有巨大 gap,导致生成不了好的结果。
解决问题的最好方法就是——不解决问题。直接一条条输出吧。
但如果不批量生成,模型小数据少时还好,站起来喝杯水撒泡尿时间就差不多了。但模型一大且数据量一大,花的时间就太大了。
手动开几个进程同时跑多个模型也不是不行,但太美了。
所以只能想办法解决了。
训练阶段解决
通过 Padding 来解决的最主要问题是,生成和训练阶段的差别太大,那是不是在训练时就给 Padding 直接放在 Context 后,再直接拼 target 就行。
可行,但成本太大了,还得重训模型。
所以还是不行。
利用 Transformer 特性
于是就想,如何通过处理让生成时模拟训练时状况,让模型以为 target 位置是直接在 context 后,且只参考 context。
需要明确一点,Transformer 里因为位置信息主要通过位置编码来表示的,所以只要对应的位置编码不变,即使输入向量顺序再怎么变,对 Transformer 来说还是差不多,这也是一些技巧如 PLM(Permutation Language Model)) 得以实现的原理。
直接这样说太抽象了,举个栗子。
假设一个 batch 长度不一样样本训练时如下
input_ids:
1 3 2 6 2 0 0
1 3 6 2 5 4 2
2 为分割和终止符,可看到训练时,通过给第一句 padding,算 loss 时 padding 位置都不算上来进行训练。
而 inference 时,只有 context,即使 padding 也会是下面这样
1 3 2 0
1 3 6 2
这种情况下如果直接用默认的 pos_ids 和 atten_mask (不了解的看The Annotated Transformer),第一句就会出现问题。
对比一下,训练时用到的三个参数
1 3 2 6 2 0 0 (input_ids)
0 1 2 3 4 0 0 (pos_ids)
1 1 1 1 1 0 0 (atten_mask)
训练时当生成 6 的时候看到的是
1 3 2
0 1 2
1 1 1
再来看看生成时的情况,生成 6 的时候直接看到的是
1 3 2 0
0 1 2 3
1 1 1 0
首先拿的是最后 padding 位置的向量来预测下一个,同时还有个问题就是,当预测完成一个时,之后拿到的位置 id 是不对的,这里假设预测成功为 6
1 3 2 0 6
0 1 2 3 4
1 1 1 0 1
会发现用 6 来预测下一个词时已经和训练时不一样了,因为训练时 6 对应的位置 id 是 3
实际这样用时,我也发现生成结果总是错开几个字,像是给刀直接切开了一样。
于是改进,最简单方法是直接给 padding 的位置向量都设成 padding 前的位置,这样预测时位置向量就对了。
1 3 2 0 6
0 1 2 2 3
1 1 1 0 1
但这只解决了一个问题,即生成过程中的问题,第一个位置拿的还是 padding 位置进行的输出。这里有个解决方法,就是生成时,第一次预测取到 padding 前 token,之后就依次取最后一个进行预测了。
这样基本上就算是解决问题了,但生成时第一次和之后还得区分开,说实话还是有点 ugly.
还可以进一步优化。
Left Padding
解决方法很简单,思维掉转下就好了,因为并行生成时都是从最后一位开始取,那能不能直接给 padding 放到前面去呢。
于是生成时一个 batch 会变成这样
input_ids:
0 0 1 3 2
1 3 6 2 5
那么对于第一条进行预测时,也只需要这样设置一下 pos_id 和 atten_mask 就行
0 1 3 2
0 0 1 2
0 1 1 1
这样子生成 6 时,位置向量就能自然而然衔接上,同时 atten_mask 也给前面的 padding 完美 mask 掉了。
完美解决!速度一下提高了好几倍。
如果觉得有用,就请分享到朋友圈吧!
公众号后台回复“目标检测竞赛”获取目标检测竞赛经验资源~
# CV技术社群邀请函 #
备注:姓名-学校/公司-研究方向-城市(如:小极-北大-目标检测-深圳)
即可申请加入极市目标检测/图像分割/工业检测/人脸/医学影像/3D/SLAM/自动驾驶/超分辨率/姿态估计/ReID/GAN/图像增强/OCR/视频理解等技术交流群
每月大咖直播分享、真实项目需求对接、求职内推、算法竞赛、干货资讯汇总、与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企视觉开发者互动交流~