Pytorch转ONNX-实战篇2(实战踩坑总结)
各位读者好,这里是BBuf,由于近期公众号遭受恶意举报将会在很长一段时间内失去原创功能并且面临封号,所以我们作者团队商量着重新申请了一个新的公众号,名字是 『PandaCV』 ,长按下方二维码关注和转发,谢谢!这个公众号近期会通过无来源转载的方式将GiantPandaCV公众号的所有高质量文章(也会舍弃一部分不好的)逐渐搬运过去(会耗时半个月到1个月),然后在PandaCV这个公众号上继续发表高质量原创文章。维护和谐健康的知识原创环境是每个人的责任,希望我们能一起努力,打造一个学习和分享的双赢平台。
作者丨立交桥跳水冠军
前两篇文章分别从理论和ONNX的核心机制描述了Pytorch转ONNX需要注意的事情。接下来这篇文章没有什么核心主旨,只是纯粹记录我当时做项目的时候踩的坑以及应对方案
(1)Pytorch2ONNX不支持对slice对象赋值
下面这段代码是不被Pytorch原生的onnx转换接口支持的,即不能对slice对象赋值
preds[:, :, y1:y2, x1:x2] += crop_seg_logit
仔细想想其实也比较合理,因为上面的操作也很难在DAG上被表示,因为并不仅仅是把preds中的那个区域取出来弄个新的变量,然后在上面+1,而是直接把preds的一部分改掉了。当时我负责MMSeg的slide inference转换的时候遇到了这个问题,解决方案如下:
preds += F.pad(crop_seg_logit,
(int(x1), int(preds.shape[3] - x2), int(y1),
int(preds.shape[2] - y2)))
即我对crop_seg_logit做了一个padding,把它变成了和preds一样的大小,这样我就直接变成了矩阵相加,没必要变成slice的操作了
这个方法自然很丑,而且会引出一个新的问题,那就是Pytorch生成的onnx padding的格式,onnx runtime接收的格式以及TensorRT需要的格式都不一样。这个就是之后的问题了(超纲了,不讲了)
这里具体的例子我懒得查了,以二维矩阵的填充为例。只记得一个转出来的是(begin0, begin1, end0, end1),另一个是(begin0, end0, begin1, end1)
这里面begin0代表第0维左边的填充数量,end0代表右边的填充数量
(2)resize
当时做segmentation模型的时候,最重要的就是resize操作。ONNX里面的resize要求output shape必须为常量(即tuple of int),因此不可以用tensor.Size作为输入,因为人家并不是tuple of int
if isinstance(size, torch.Size):
size = tuple(int(x) for x in size)
所以我们必须手动粗暴的把torch.Size变成tuple of int
当时有reviewer吐槽我这个方法丑,要我改成tuple(size),说Pytorch重载了tuple,直接可以把torch.Size变成tuple of int。但是很诡异的是在正常情况下的确可以,但如果一旦进入了ONNX tracining模式,这个方法就失效了。我简单看了看,推测是因为对tuple的重载是在C++层面做的,而ONNX tracing也会涉及到一些C++层面的事情,也就是说ONNX tracing会重载一些C++的部分,可能正好就把tuple给抹掉了
(3) 应对kwargs的约束
pytorch自带的onnx转换api: torch.onnx.export,只支持args参数。一般来说调用这个api只需要提供model(喜闻乐见的nn.Module),调用model的参数args(也就是调用model.forwrd()的参数)以及导出的文件名f。然后这个函数就会内部执行一遍: model(*args),执行的时候做tracining
但是我们知道一般来说除了args,还需要kwargs,比如model(input, getloss=False),其中input就是args,False就是kwargs。OpenMMLab里面几乎所有的model都需要kwargs
为了绕开这个约束,我们需要利用python的partial函数,将model做个封装:
model.forward = partial(model.forward, return_loss=False)
这样我们可以给model提供需要的kwargs,同时又可以原封不动的调用torch.onnx.export
注意,kwargs不能包括网络的输入,比如如果你想把input image放进args,那么得到的onnx就会是一个没有输入的图(它会把kwargs里面的input image当成一个常量)
(4)Pytorch和ONNX Runtime结果对齐
OpenMMLab系列提供了一个很有用的功能,就是自动比对Pytorch和ONNXRuntime的精度。这个功能可以帮助用户确定转出来的ONNX有没有问题。
然而之前也提到过,ONNXRuntime和Pytorch需要的ONNX格式不一样,而且有些计算也不一样,因此就算结果对不上,也不能代表什么
在某些操作上,ONNXRuntime和Pytorch的行为不一致。比如对一个一维tensor:[0,0,0]调用argmax,那么ONNXRuntime返回的是0,而Pytorch是1(举个例子,具体的差异我记不清了)
当时我在做Detection模型的自动比对的时候就遇到了问题,在经历了nms操作之后,bbox会根据score的大小做排序,但score相同的情况下,ONNXRuntime和Pytorch的结果就会有差异。因此我们最后只选择比对score,而不管bbox的dx,dy这些信息了
- The End -
长按二维码关注我们
本公众号专注:
1. 技术分享;
2. 学术交流;
3. 资料共享。
欢迎关注我们,一起成长!