Pytorch转ONNX-实战篇1(tracing机制)

共 3460字,需浏览 7分钟

 ·

2021-01-17 10:57

作者丨立交桥跳水冠军
来源丨https://zhuanlan.zhihu.com/p/273566106
编辑丨GiantPandaCV

昨天的文章简单描述了在Pytorch转ONNX中面临的问题和需要注意的事情,今天的文章会重点结合OpenMMlab系列中用到的Pytorch转ONNX的小技巧来介绍实战部分。

(1)tracing的机制

上文提到过,Pytorch转ONNX的方式是基于tracing(追踪),通俗来说,就是ONNX的相关代码在一旁看着Pytorch跑一遍,运行了什么内容就把什么记录下来。但是在这里并不是所有Python的运行内容都会被记录。举个例子,下面的代码中,

c = torch.matmul(a, b)
print("Blabla")
e = torch.matmul(c, d)

其中只有第1,3行相关的内容会被记录,因为只有他们是和Pytorch相关的,而第二行只是普通的python语句。

具体来说,只有ATen操作会被记录下来。ATen可以被理解为一个Pytorch的基本操作库,一切的Pytorch函数都是基于这些零部件构造出来的(比如ATen就是加减乘除,所有Pytorch的其他操作,比如平方,算sigmoid,都可以根据加减乘除构造出来)

*之前说的ONNX无法记录if语句的问题也是因为if并不是Aten中的操作

虽然ONNX可以记录所有Pytorch的执行(即记录所有ATen操作),但是在输出的时候会做一个剪枝,把没用的操作剪掉

举个例子,下面的程序,显而易见第一句话是没有用的。


t1 = torch.matmul(a, b)
t2 = torch.matmul(c, d)
return t2

ONNX会在得到全部的操作以及他们之间的输入输出关系后(以DAG作为表示),根据DAG的输出往前推,做遍历,所有可以被遍历到的节点被保留,其他节点直接扔掉。

在MMDetection(https://github.com/open-mmlab/mmdetection)中,在NMS(non-Maximumnon maximum suppression)中有如下代码:

if bboxes.numel() == 0:
    bboxes = multibboxes.newzeros((05))
    labels = multibboxes.newzeros((0, ), dtype=torch.long)

    if torch.onnx.isinonnxexport():
        raise RuntimeError('[ONNX Error] Can not record NMS '
                           'as it has not been executed this time')
    return bboxes, labels

dets, keep = batchednms(bboxes, scores, labels, nmscfg)

代码逻辑很简单,如果之前的网络根本没有输出任何合法的bbox(第一行的分支判断),那么显然nms的结果就是一堆0,所以没必要运行nms直接返回0就可以。

如果我们想将这段代码转换到ONNX,之前我们提到过ONNX不能处理分支逻辑,因此只能选择一条路去走,记录那条路转换得到的模型。很显然,正常情况下我们自然期待会有较多的bbox,并且将这些bbox作为参数调用nms。

所以如果我们发现模型执行的路径触发了if分支,我们必须要进行一个判断,看看是不是在转ONNX,如果是的话我们就需要直接报错,因为显然转出来的ONNX不是我们想要的。

假设什么都不做,在这种情况下我们转出来的模型是什么样呢?思考一下不难发现,假设函数的返回值就是网络的最终输出,那么我们只会得到一个2个节点的DAG,即第2,3行的两个操作。之前说过ONNX拿到所有的DAG之后会做剪枝,在这里ONNX拿到返回值(bboxes, labels)做回溯,发现最头上就是第2,3行的两个操作,就直接停掉了。所有其他的操作,比如backbone,rpn,fpn,都会被扔掉。

因此,在进行MMDet模型的转换的时候,必须用真实的数据和训练好的参数来做转换,否则基本不会得到有效的bbox,于是就会触发第6行的error

(2)利用tracing机制做优化

在MMSeg中有一个很巧妙的利用tracing机制做优化的例子。

在slide inference时,我们需要计算一个count mat矩阵,这个矩阵在h, w以及对应的stride都固定的情况下会是一个常量。

不过在训练时,往往这些都是我们要调的参数,所有MMSeg没有选择把这些常数保存下来,而是每次都算一遍

        countmat = img.newzeros((batchsize, 1, himg, wimg))
        for hidx in range(hgrids):
            for widx in range(wgrids):
                y1 = hidx * hstride
                x1 = widx * wstride
                y2 = min(y1 + hcrop, himg)
                x2 = min(x1 + wcrop, wimg)
                y1 = max(y2 - hcrop, 0)
                x1 = max(x2 - wcrop, 0)
                cropimg = img[:, :, y1:y2, x1:x2]
                cropseglogit = self.encodedecode(cropimg, imgmeta)
                preds += F.pad(cropseglogit,
                               (int(x1), int(preds.shape[3] - x2), int(y1),
                                int(preds.shape[2] - y2)))

                countmat[:, :, y1:y2, x1:x2] += 1
        assert (countmat == 0).sum() == 0
        if torch.onnx.isinonnxexport():
            # cast countmat to constant while exporting to ONNX
            countmat = torch.fromnumpy(
                countmat.cpu().detach().numpy()).to(device=img.device)

不过在部署时,这些参数往往是固定的,因此我们没必要把它算一遍。因此在倒数第4行的if分支里,我们做了一件看似很没用的事

countmat = torch.fromnumpy(countmat.cpu().detach().numpy()).to(device=img.device)

即我们把算出来的countmat从tensor转换成numpy,再转回tensor。

其实我们的目的是切断tracing。

之前提到过,ONNX只能记录ATen相关的操作,但是很显然,tensor和numpy的互转肯定不是ATen操作。因此在回溯的时候,当访问到count mat,ONNX并不能发现它是被谁运算出来的,所以countmat就会被看作一个常数被保存下来,之前计算countmat的部分都会被扔掉


- The End -


GiantPandaCV

长按二维码关注我们

本公众号专注:

1. 技术分享;

2. 学术交流

3. 资料共享

欢迎关注我们,一起成长!

浏览 27
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报