一种更加高效的卷积计算策略:Im2Col+GEMM的改进方法MEC
极市导读
本文介绍了一种Im2Col+GEMM的改进版本——MEC,并记录了对它进行复现尝试后的测试结果。>>加入极市CV技术交流群,走在计算机视觉的最前沿
1. 前言
2. 背景介绍
所以,MEC改进了Im2Col+GEMM的策略,目的是减少它的内存消耗同时提升一点速度。
3. MEC算法原理
下面的Algorithm1展示了这个算法的流程:
3.2 MEC算法高级版
然后下面的Figure3是它的示例图:
从伪代码里可以看到这里有2种计算方法:
Solution 1:Algorithm2中的第9-19行和Algorithm1中的方法完全一致,然后14-19行是对临时结果对做排列变化,即Figure3中的上半部分。 Solution 2:Algorithm2中的第21-25行。每次循环处理一个样本,不需要做额外的排列变化,即Figure3中的下半部分。
4. 实验对比
5. 复现尝试(暂时只针对X86 CPU)
// 原始的Im2Col
void im2col_cpu(float** src, const int &inHeight, const int &intWidth, const int &kHeight,
const int &kWidth, float* srcIm2col){
const int outHeight = inHeight - kHeight + 1;
const int outWidth = intWidth - kWidth + 1;
int cnt = 0;
for(int i = 0; i < kHeight; i++){
for(int j = 0; j < kWidth; j++){
int id = i * kWidth + j;
int ii = i;
for(int x = 0; x < outHeight; x++){
int jj = j;
for(int y = 0; y < outWidth; y++){
srcIm2col[cnt] = src[ii][jj];
jj += 1;
cnt++;
}
ii += 1;
}
}
}
}
cblas_sgemm
接口,关于OpenBlas的介绍以及计算方式,函数接口可以查看参考中的资料2,这里就不过多介绍了。// 构造输入矩阵
float **src = new float*[inHeight];
for(int i = 0; i < inHeight; i++){
src[i] = new float[inWidth];
for(int j = 0; j < inWidth; j++){
src[i][j] = 0.1;
}
}
// 构造kernel矩阵
float **kernel[kernel_num];
for(int i = 0; i < kernel_num; i++){
kernel[i] = new float*[kernel_h];
for(int j = 0; j < kernel_h; j++){
kernel[i][j] = new float[kernel_w];
for(int k = 0; k < kernel_w; k++){
kernel[i][j][k] = 0.2;
}
}
}
// 开始计时
struct timeval tstart, tend;
gettimeofday(&tstart, NULL);
// 对kernel进行Im2col
float* kernel2col = new float[kernel_num*kernel_h*kernel_w];
int cnt = 0;
for(int i = 0; i < kernel_num; i++){
for(int j = 0; j < kernel_h; j++){
for(int k = 0; k < kernel_w; k++){
kernel2col[cnt++] = kernel[i][j][k];
}
}
}
// 对输入矩阵Im2col
int outHeight = inHeight - kernel_h + 1;
int outWidth = inWidth - kernel_w + 1;
float *srcIm2col = new float[kernel_w * kernel_h * outWidth * outHeight];
im2col_cpu(src, inHeight, inWidth, kernel_h, kernel_w, srcIm2col);
cblas_sgemm
函数接口即可完成卷积层的计算,这个地方加入了计时函数,统计Im2Col+gemm的运行时间:// 使用Blas库实现矩阵乘法
float *output = new float[kernel_num * outHeight * outWidth];
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,kernel_num,
outHeight*outWidth, kernel_w*kernel_h, 1,
kernel2col, kernel_h*kernel_w,
srcIm2col,outHeight * outWidth, 0, output, outHeight * outWidth);
// 结束计时
gettimeofday(&tend, NULL);
cout<<"im2colOrigin Total time cost: "<<(tend.tv_sec-tstart.tv_sec)*1000 + (tend.tv_usec-tstart.tv_usec)/1000<<" ms"<
// MEC
void im2col_mec(float** src, const int &inHeight, const int &intWidth, const int &kHeight,
const int &kWidth, float* srcIm2col){
const int outHeight = inHeight - kHeight + 1;
const int outWidth = intWidth - kWidth + 1;
#pragma omp parallel for num_threads(THREAD_NUM)
for(int i = 0; i < outWidth; i++){
int outrow = 0;
for(int j = 0; j < inHeight; j++){
for(int k = i; k < i + kWidth; k++){
srcIm2col[outrow * outWidth + i] = src[j][k];
outrow++;
}
}
}
}
// 构造输入矩阵
float **src = new float*[inHeight];
for(int i = 0; i < inHeight; i++){
src[i] = new float[inWidth];
for(int j = 0; j < inWidth; j++){
src[i][j] = 0.1;
}
}
// 构造kernel矩阵
float **kernel[kernel_num];
for(int i = 0; i < kernel_num; i++){
kernel[i] = new float*[kernel_h];
for(int j = 0; j < kernel_h; j++){
kernel[i][j] = new float[kernel_w];
for(int k = 0; k < kernel_w; k++){
kernel[i][j][k] = 0.2;
}
}
}
// 开始计时
struct timeval tstart, tend;
gettimeofday(&tstart, NULL);
// 对kernel进行Im2col
float* kernel2col = new float[kernel_num*kernel_h*kernel_w];
int cnt = 0;
for(int i = 0; i < kernel_num; i++){
for(int j = 0; j < kernel_h; j++){
for(int k = 0; k < kernel_w; k++){
kernel2col[cnt++] = kernel[i][j][k];
}
}
}
// 对输入矩阵Im2col
int outHeight = inHeight - kernel_h + 1;
int outWidth = inWidth - kernel_w + 1;
float *srcIm2col = new float[outWidth * inHeight * kernel_w];
im2col_mec(src, inHeight, inWidth, kernel_h, kernel_w, srcIm2col);
// 使用Blas库实现矩阵乘法
float **output = new float*[outHeight];
#pragma omp parallel for num_threads(THREAD_NUM)
for(int i = 0; i < outHeight; i++){
output[i] = new float [kernel_num * outWidth];
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,kernel_num,
outWidth, kernel_w * kernel_h,1,
kernel2col, kernel_h * kernel_w,
srcIm2col + i * outWidth, outWidth, 0, output[i], outWidth);
}
// 结束计时
gettimeofday(&tend, NULL);
cout<<"MEC Total time cost: "<<(tend.tv_sec-tstart.tv_sec)*1000 + (tend.tv_usec-tstart.tv_usec)/1000<<" ms"<
https://github.com/BBuf/Memory-efficient-Convolution-for-Deep-Neural-Network
6. 效果
参考资料
推荐阅读
ACCV 2020国际细粒度网络图像识别竞赛正式开赛!
评论