从慢速到SIMG: 一个Go优化的故事
SourceGraph 的工程师 Camden Cheek 提供的一个利用 SIMD 进行 Go 性能优化的故事: From slow to SIMD: A Go optimization story [1]。
这是一个关于某函数的故事,这个函数被大量调用,而且这些调用都在关键路径上。让我们来看看如何让它变快。
剧透一下,这个函数是一个点积函数。
“点积(Dot Product),也称为内积或数量积,是一种数学运算,通常用于计算两个向量之间的乘积。点积的结果是一个标量(即一个实数),而不是一个向量。
假设有两个向量:
>那么,这两个向量的点积为:
一些背景
在 Sourcegraph,我们正在开发一个名为 Cody 的 Code AI 工具。为了让 Cody 能够很好地回答问题,我们需要给它足够的上下文。我们做的一种方式是利用嵌入[2](embedding
)。
为了我们的目的,嵌入是文本块的向量表示。它们用某种方式构建,以便语义上相似的文本块具有更相似的向量。当 Cody 需要更多信息来回答查询时,我们在嵌入上运行相似性搜索,以获取一组相关的代码块,并将这些结果提供给 Cody,以提高结果的相关性。
和这篇文章相关的部分是相似度度量
,它是一个函数,用于判断两个向量有多相似。对于相似性搜索,常见的度量是余弦相似度[3]。然而,对于归一化向量(单位幅度的向量),点积产生的排名与余弦相似度是等价的。为了运行一次搜索,我们计算数据集中每个嵌入的点积,并保留前几个结果。由于我们在得到必要的上下文之前无法开始执行 LLM,因此优化这一步至关重要。
你可能会想:为什么不使用索引向量数据库?除了添加我们需要管理的另一个基础设施外,索引的构建会增加延迟并增加资源需求。此外,标准的最近邻索引只提供近似检索,这与更易于解释的穷举搜索相比,增加了另一层模糊性。鉴于这一点,我们决定在我们的手工解决方案中投入一点精力,看看我们能走多远。
目标
下面的代码是一个计算两个向量点积的简单的 Go 函数实现。我的目标是刻画出我为优化这个函数所采取的方法,并分享我在这个过程中学到的一些工具。
func DotNaive(a, b []float32) float32 {
sum := float32(0)
for i := 0; i < len(a) && i < len(b); i++ {
sum += a[i] * b[i]
}
return sum
}
除非另有说明,否则所有基准都在 Intel Xeon Platinum 8481C 2.70GHz CPU
上运行。这是一个 c3-highcpu-44
GCE VM。本博客文章中的代码都可以在这里[4]找到。
循环展开 (Loop unrolling)
现代的 CPU 都有一个叫做指令流水线[5]的东西,它可以同时运行多条指令,如果它们之间没有数据依赖的话。数据依赖只是意味着一个指令的输入取决于另一个指令的输出。
在我们的简单实现中,我们的循环迭代之间有数据依赖。实际上,每个迭代都有一个读/写对,这意味着一个迭代不能开始执行,直到前一个迭代完成。
一个常见的方法是在循环中展开一些迭代,这样我们就可以在没有数据依赖的情况下执行更多的指令。此外,它将固定的循环开销(增量和比较)分摊到多个操作中。
func DotUnroll4(a, b []float32) float32 {
sum := float32(0)
for i := 0; i < len(a); i += 4 {
s0 := a[i] * b[i]
s1 := a[i+1] * b[i+1]
s2 := a[i+2] * b[i+2]
s3 := a[i+3] * b[i+3]
sum += s0 + s1 + s2 + s3
}
return sum
}
在我们的展开代码中,乘法指令的依赖关系被移除了,这使得 CPU 可以更好地利用流水线。这使我们的吞吐量比我们的简单实现提高了 37%。
注意,我们实际上可以通过调整我们展开的迭代次数来进一步提高性能。在基准机器上,8 似乎是最佳的,但在我的笔记本电脑上,4 的性能最好。然而,改进是与平台相关的,而且改进相当微小,所以在本文的其余部分,我将使用 4 个展开深度来提高可读性。
边界检查消除 (Bounds-checking elimination)
为了防止越界的切片访问成为安全漏洞(如著名的 Heartbleed 漏洞[6]),go 编译器在每次读取之前插入检查。你可以在生成的汇编中查看[7]它(查找 runtime.panic)。
编译的代码看起来像我们写了这样的东西:
func DotUnroll4(a, b []float32) float32 {
sum := float32(0)
for i := 0; i < len(a); i += 4 {
if i >= cap(b) {
panic("out of bounds")
}
s0 := a[i] * b[i]
if i+1 >= cap(a) || i+1 >= cap(b) {
panic("out of bounds")
}
s1 := a[i+1] * b[i+1]
if i+2 >= cap(a) || i+2 >= cap(b) {
panic("out of bounds")
}
s2 := a[i+2] * b[i+2]
if i+3 >= cap(a) || i+3 >= cap(b) {
panic("out of bounds")
}
s3 := a[i+3] * b[i+3]
sum += s0 + s1 + s2 + s3
}
return sum
}
在像这样的频繁调用循环(hot loop)中,即使是现代的分支预测,每次迭代的额外分支也会增加相当大的性能损失。这在我们的例子中尤其明显,因为插入的跳转限制了我们可以利用流水线的程度。
如果我们可以告诉编译器这些读取永远不会越界,它就不会插入这些运行时检查。这种技术被称为“边界检查消除”,相同的模式也适用于 Go 之外的语言。
理论上,我们应该能够在循环之外做所有的检查,编译器就能够确定所有的切片索引都是安全的。然而,我找不到正确的检查组合来说服编译器我所做的是安全的。我最终选择了断言长度相等的组合,并将所有的边界检查移到循环的顶部。这足以接近无边界检查版本的速度。
func DotBCE(a, b []float32) float32 {
if len(a) != len(b) {
panic("slices must have equal lengths")
}
if len(a)%4 != 0 {
panic("slice length must be multiple of 4")
}
sum := float32(0)
for i := 0; i < len(a); i += 4 {
aTmp := a[i : i+4 : i+4]
bTmp := b[i : i+4 : i+4]
s0 := aTmp[0] * bTmp[0]
s1 := aTmp[1] * bTmp[1]
s2 := aTmp[2] * bTmp[2]
s3 := aTmp[3] * bTmp[3]
sum += s0 + s1 + s2 + s3
}
return sum
}
这个边界检查的最小化使我们的性能提高了 9%。但是始终未将检查降到零,没有什么值得一提的。
这个技术对于内存安全的编程语言来说是非常有用的,比如 Rust。
一个问题抛给读者:为什么我们要像a[i:i+4:i+4]
这样切片,而不是只是a[i:i+4]
?
量化 (Quantization)
目前我们已经提高了单核的搜索的吞吐率 50%以上,但现在我们遇到了一个新的瓶颈:内存使用。我们的向量是1536维的。用 4 字节的元素,这就是每个向量6KiB,我们每 GiB 代码生成大约一百万个向量。这很快就积累起来了。我们有一些客户带着一些大型的monorepo
来找我们,我们想减少我们的内存使用,这样我们就可以更便宜地支持这些大型代码库。
一个可能的缓解措施是将向量移动到磁盘上,但是在搜索时从磁盘加载它们可能会增加显著的延迟,特别是在慢速磁盘上。相反,我们选择用int8量化我们的向量。
有很多方式可以压缩向量,但我们将讨论整数量化,这是相对简单但有效的。这个想法是通过将4 字节的float32
向量元素转换为1 字节的int8
来减少精度。
我不会深入讨论我们如何在float32
和int8
之间进行转换,因为这是一个相当深奥的话题[8],但可以说我们的函数现在看起来像下面这样:
func DotInt8BCE(a, b []int8) int32 {
if len(a) != len(b) {
panic("slices must have equal lengths")
}
sum := int32(0)
for i := 0; i < len(a); i += 4 {
aTmp := a[i : i+4 : i+4]
bTmp := b[i : i+4 : i+4]
s0 := int32(aTmp[0]) * int32(bTmp[0])
s1 := int32(aTmp[1]) * int32(bTmp[1])
s2 := int32(aTmp[2]) * int32(bTmp[2])
s3 := int32(aTmp[3]) * int32(bTmp[3])
sum += s0 + s1 + s2 + s3
}
return sum
}
这个改变导致内存使用量减少了 4 倍,但牺牲了一些准确性(我们进行了仔细的测量,但这与本博客文章无关)。
不幸的是,这个改变导致我们的性能下降了。查看产生的汇编代码(go tool compile -S
),我们可以看到一些int8
到int32
转换的指令,这可能解释了差异。我没有深入研究,因为我们在下一节中的所有性能改进都变得无关紧要了。
SIMD
到目前为止,速度提升还不错,但对于我们最大的客户来说,还不够。所以我们开始尝试一些更激进的方法。
我总是喜欢找借口来玩 SIMD。而这个问题似乎正好对症下药。
对于还不熟悉 SIMD 的同学来说,SIMD 代表“单指令多数据”(Single Instruction Multiple Data
)。就像它说的那样,它允许你用一条指令在一堆数据上运行一个操作。举个例子,要对两个int32
向量逐元素相加,我们可以用ADD
指令一个一个地加起来,或者我们可以用VPADDD
指令一次加上 64 对,延迟相同(取决于架构)。
但是我们还是有点问题。Go 不像 C 或 Rust 那样暴露 SIMD 内部函数。我们有两个选择:用 C 写,然后用 Cgo,或者用 Go 的汇编器手写。我尽量避免使用 Cgo,因为有很多原因,这些原因都不是根本原因,但其中一个原因是 Cgo 会带来性能损失,而这个片段的性能是至关重要的。此外,用汇编写一些东西听起来很有趣,所以我就这么做了。
我想要这个这个算法可以输出到其他编程语言,所以我限制自己只使用 AVX2 指令,这些指令在大多数 x86_64 服务器 CPU 上都支持。我们可以使用运行时进行检测[9],在纯 Go 中回退到一个更慢的选项。
#include "textflag.h"
TEXT ·DotAVX2(SB), NOSPLIT, $0-52
// Offsets based on slice header offsets.
// To check, use `GOARCH=amd64 go vet`
MOVQ a_base+0(FP), AX
MOVQ b_base+24(FP), BX
MOVQ a_len+8(FP), DX
XORQ R8, R8 // return sum
// Zero Y0, which will store 8 packed 32-bit sums
VPXOR Y0, Y0, Y0
// In blockloop, we calculate the dot product 16 at a time
blockloop:
CMPQ DX, $16
JB reduce
// Sign-extend 16 bytes into 16 int16s
VPMOVSXBW (AX), Y1
VPMOVSXBW (BX), Y2
// Multiply words vertically to form doubleword intermediates,
// then add adjacent doublewords.
VPMADDWD Y1, Y2, Y1
// Add results to the running sum
VPADDD Y0, Y1, Y0
ADDQ $16, AX
ADDQ $16, BX
SUBQ $16, DX
JMP blockloop
reduce:
// X0 is the low bits of Y0.
// Extract the high bits into X1, fold in half, add, repeat.
VEXTRACTI128 $1, Y0, X1
VPADDD X0, X1, X0
VPSRLDQ $8, X0, X1
VPADDD X0, X1, X0
VPSRLDQ $4, X0, X1
VPADDD X0, X1, X0
// Store the reduced sum
VMOVD X0, R8
end:
MOVL R8, ret+48(FP)
VZEROALL
RET
这个实现的核心循环依赖于三条主要指令:
- VPMOVSXBW:将一个
int8
加载到一个int16
向量中 - VPMADDWD:将两个
int16
向量逐个元素相乘,然后将相邻的两对模糊堆叠相加,生成一个int32
向量。 - VPADDD:这将生成的 int32 向量累积到我们的运行总和
VPMADDWD
在这里是真正的主力军。通过将乘法和加法步骤合并为一个步骤,它不仅节省了指令,还帮助我们避免了溢出问题,同时将结果扩展为 int32
。
让我们看看这给我们带来了什么。
哇,这是我们之前最好表现的 530% 的增加!SIMD 胜利了 🚀。
现在,情况并非一帆风顺。在 Go 中手写汇编是有点奇怪的。它使用自定义的汇编器,这意味着它的汇编语言看起来与您通常在网上找到的汇编片段相比,会有略微不同而令人困惑。它有一些奇怪的怪癖,比如改变指令操作数的顺序或者使用不同的指令名称。在 Go 汇编器中,有些指令甚至没有名称,只能通过它们的二进制编码来使用。不得不说一句:我发现 sourcegraph.com 对于查找 Go 汇编示例非常有价值,可以供参考。
话虽如此,与 Cgo 相比,还是有一些不错的好处。调试仍然很好用,汇编可以逐步执行,并且可以使用 delve 检查寄存器。没有额外的构建步骤(不需要设置 C 工具链)。很容易设置一个纯 Go 的备用方案,所以跨编译仍然有效。常见问题被 go vet 捕捉到。
SIMD ... 更大
以前,我们限制自己只使用 AVX2
,但如果不这样呢?AVX-512
的 VNNI
扩展添加了 VPDPBUSD
指令,该指令计算 int8
向量而不是 int16
的点积。这意味着我们可以在单个指令中处理四倍的元素,因为我们不必先转换为 int16,并且我们的向量宽度在 AVX-512 中加倍!
唯一的问题是该指令要求一个向量是有符号字节,另一个向量是无符号字节。而我们的两个向量都是有符号的。我们可以借鉴英特尔开发者指南中的技巧来解决这个问题。给定两个 int8 元素 an
和 bn
,我们进行逐元素计算如下:an * (bn + 128) - an * 128
。an * 128
项是将 128
加到 bn
以将其提升到 u8
范围的超出部分。我们单独跟踪这部分并在最后进行减法。该表达式中的每个操作都可以进行向量化处理。
#include "textflag.h"
// DotVNNI calculates the dot product of two slices using AVX512 VNNI
// instructions The slices must be of equal length and that length must be a
// multiple of 64.
TEXT ·DotVNNI(SB), NOSPLIT, $0-52
// Offsets based on slice header offsets.
// To check, use `GOARCH=amd64 go vet`
MOVQ a_base+0(FP), AX
MOVQ b_base+24(FP), BX
MOVQ a_len+8(FP), DX
ADDQ AX, DX // end pointer
// Zero our accumulators
VPXORQ Z0, Z0, Z0 // positive
VPXORQ Z1, Z1, Z1 // negative
// Fill Z2 with 128
MOVD $0x80808080, R9
VPBROADCASTD R9, Z2
blockloop:
CMPQ AX, DX
JE reduce
VMOVDQU8 (AX), Z3
VMOVDQU8 (BX), Z4
// The VPDPBUSD instruction calculates of the dot product 4 columns at a
// time, accumulating into an i32 vector. The problem is it expects one
// vector to be unsigned bytes and one to be signed bytes. To make this
// work, we make one of our vectors unsigned by adding 128 to each element.
// This causes us to overshoot, so we keep track of the amount we need
// to compensate by so we can subtract it from the sum at the end.
//
// Effectively, we are calculating SUM((Z3 + 128) · Z4) - 128 * SUM(Z4).
VPADDB Z3, Z2, Z3 // add 128 to Z3, making it unsigned
VPDPBUSD Z4, Z3, Z0 // Z0 += Z3 dot Z4
VPDPBUSD Z4, Z2, Z1 // Z1 += broadcast(128) dot Z4
ADDQ $64, AX
ADDQ $64, BX
JMP blockloop
reduce:
// Subtract the overshoot from our calculated dot product
VPSUBD Z1, Z0, Z0 // Z0 -= Z1
// Sum Z0 horizontally. There is no horizontal sum instruction, so instead
// we sum the upper and lower halves of Z0, fold it in half again, and
// repeat until we are down to 1 element that contains the final sum.
VEXTRACTI64X4 $1, Z0, Y1
VPADDD Y0, Y1, Y0
VEXTRACTI128 $1, Y0, X1
VPADDD X0, X1, X0
VPSRLDQ $8, X0, X1
VPADDD X0, X1, X0
VPSRLDQ $4, X0, X1
VPADDD X0, X1, X0
// Store the reduced sum
VMOVD X0, R8
end:
MOVL R8, ret+48(FP)
VZEROALL
RET
这种实现又带来了另外 21% 的改进。真不赖!
下一步
好吧,我对吞吐量增加 9.3 倍和内存使用量减少 4 倍感到非常满意,所以我可能会适可而止了。
现实生活中的答案可能是“使用索引”。有大量优秀的工作致力于使最近邻居搜索更快,并且有许多内置向量 DB 使其部署相当简单。
然而,如果你想要一些有趣的思考,我的一位同事在 GPU 实现的点积[10]。
一些有价值的资料
- 如果你还没有使用过 benchstat[11],你应该使用。太棒了。基准测试结果超级简单统计比较。
- 不要错过compiler explorer[12],这是一个非常有用的挖掘生成的汇编代码工具。
- 还有一次,我被技术上的挑战吸引,实现了ARM NEON 的版本[13],这带来了一些有趣的对比。
- 如果您还没有遇到过它,Agner Fog 说明表[14]会让您大吃一惊,很多底层优化的参考资料。在优化点积函数的工作中,我使用它们来理解指令延迟的差异,以及为什么某些流水线优于其他流水线。
From slow to SIMD: A Go optimization story: https://sourcegraph.com/blog/slow-to-simd
[2]嵌入: https://platform.openai.com/docs/guides/embeddings
[3]余弦相似度: https://en.wikipedia.org/wiki/Cosine_similarity
[4]这里: https://github.com/camdencheek/simd_blog
[5]指令流水线: https://chadaustin.me/2009/02/latency-vs-throughput/
[6]Heartbleed 漏洞: https://en.wikipedia.org/wiki/Heartbleed
[7]查看: https://go.godbolt.org/z/qT3M7nPGf
[8]话题: https://huggingface.co/docs/optimum/concept_guides/quantization
[9]检测: https://sourcegraph.com/github.com/sourcegraph/sourcegraph@3ac2170c6523dd074835919a1804f197cf86e451/-/blob/internal/embeddings/dot_amd64.go?L17-21
[10]GPU 实现的点积: https://github.com/sourcegraph/sourcegraph/compare/simd-post-gpu-embeddings~3...simd-post-gpu-embeddings
[11]benchstat: https://pkg.go.dev/golang.org/x/perf/cmd/benchstat
[12]compiler explorer: https://go.godbolt.org/z/qT3M7nPGf
[13]ARM NEON 的版本: https://github.com/camdencheek/simd_blog/blob/main/dot_arm64.s
[14]Agner Fog 说明表: https://www.agner.org/optimize/
推荐阅读:
想要了解Go更多内容,欢迎扫描下方👇关注公众号, 回复关键词 [实战群] ,就有机会进群和我们进行交流
分享、在看与点赞Go