Pytorch转ONNX-实战篇1(tracing机制)
昨天的文章简单描述了在Pytorch转ONNX中面临的问题和需要注意的事情,今天的文章会重点结合OpenMMlab系列中用到的Pytorch转ONNX的小技巧来介绍实战部分。
(1)tracing的机制
上文提到过,Pytorch转ONNX的方式是基于tracing(追踪),通俗来说,就是ONNX的相关代码在一旁看着Pytorch跑一遍,运行了什么内容就把什么记录下来。但是在这里并不是所有Python的运行内容都会被记录。举个例子,下面的代码中,
c = torch.matmul(a, b)
print("Blabla")
e = torch.matmul(c, d)
其中只有第1,3行相关的内容会被记录,因为只有他们是和Pytorch相关的,而第二行只是普通的python语句。
具体来说,只有ATen操作会被记录下来。ATen可以被理解为一个Pytorch的基本操作库,一切的Pytorch函数都是基于这些零部件构造出来的(比如ATen就是加减乘除,所有Pytorch的其他操作,比如平方,算sigmoid,都可以根据加减乘除构造出来)
*之前说的ONNX无法记录if语句的问题也是因为if并不是Aten中的操作
虽然ONNX可以记录所有Pytorch的执行(即记录所有ATen操作),但是在输出的时候会做一个剪枝,把没用的操作剪掉
举个例子,下面的程序,显而易见第一句话是没有用的。
t1 = torch.matmul(a, b)
t2 = torch.matmul(c, d)
return t2
ONNX会在得到全部的操作以及他们之间的输入输出关系后(以DAG作为表示),根据DAG的输出往前推,做遍历,所有可以被遍历到的节点被保留,其他节点直接扔掉。
在MMDetection(https://github.com/open-mmlab/mmdetection)中,在NMS(non-Maximumnon maximum suppression)中有如下代码:
if bboxes.numel() == 0:
bboxes = multibboxes.newzeros((0, 5))
labels = multibboxes.newzeros((0, ), dtype=torch.long)
if torch.onnx.isinonnxexport():
raise RuntimeError('[ONNX Error] Can not record NMS '
'as it has not been executed this time')
return bboxes, labels
dets, keep = batchednms(bboxes, scores, labels, nmscfg)
代码逻辑很简单,如果之前的网络根本没有输出任何合法的bbox(第一行的分支判断),那么显然nms的结果就是一堆0,所以没必要运行nms直接返回0就可以。
如果我们想将这段代码转换到ONNX,之前我们提到过ONNX不能处理分支逻辑,因此只能选择一条路去走,记录那条路转换得到的模型。很显然,正常情况下我们自然期待会有较多的bbox,并且将这些bbox作为参数调用nms。
所以如果我们发现模型执行的路径触发了if分支,我们必须要进行一个判断,看看是不是在转ONNX,如果是的话我们就需要直接报错,因为显然转出来的ONNX不是我们想要的。
假设什么都不做,在这种情况下我们转出来的模型是什么样呢?思考一下不难发现,假设函数的返回值就是网络的最终输出,那么我们只会得到一个2个节点的DAG,即第2,3行的两个操作。之前说过ONNX拿到所有的DAG之后会做剪枝,在这里ONNX拿到返回值(bboxes, labels)做回溯,发现最头上就是第2,3行的两个操作,就直接停掉了。所有其他的操作,比如backbone,rpn,fpn,都会被扔掉。
因此,在进行MMDet模型的转换的时候,必须用真实的数据和训练好的参数来做转换,否则基本不会得到有效的bbox,于是就会触发第6行的error
(2)利用tracing机制做优化
在MMSeg中有一个很巧妙的利用tracing机制做优化的例子。
在slide inference时,我们需要计算一个count mat矩阵,这个矩阵在h, w以及对应的stride都固定的情况下会是一个常量。
不过在训练时,往往这些都是我们要调的参数,所有MMSeg没有选择把这些常数保存下来,而是每次都算一遍
countmat = img.newzeros((batchsize, 1, himg, wimg))
for hidx in range(hgrids):
for widx in range(wgrids):
y1 = hidx * hstride
x1 = widx * wstride
y2 = min(y1 + hcrop, himg)
x2 = min(x1 + wcrop, wimg)
y1 = max(y2 - hcrop, 0)
x1 = max(x2 - wcrop, 0)
cropimg = img[:, :, y1:y2, x1:x2]
cropseglogit = self.encodedecode(cropimg, imgmeta)
preds += F.pad(cropseglogit,
(int(x1), int(preds.shape[3] - x2), int(y1),
int(preds.shape[2] - y2)))
countmat[:, :, y1:y2, x1:x2] += 1
assert (countmat == 0).sum() == 0
if torch.onnx.isinonnxexport():
# cast countmat to constant while exporting to ONNX
countmat = torch.fromnumpy(
countmat.cpu().detach().numpy()).to(device=img.device)
不过在部署时,这些参数往往是固定的,因此我们没必要把它算一遍。因此在倒数第4行的if分支里,我们做了一件看似很没用的事
countmat = torch.fromnumpy(countmat.cpu().detach().numpy()).to(device=img.device)
即我们把算出来的countmat从tensor转换成numpy,再转回tensor。
其实我们的目的是切断tracing。
之前提到过,ONNX只能记录ATen相关的操作,但是很显然,tensor和numpy的互转肯定不是ATen操作。因此在回溯的时候,当访问到count mat,ONNX并不能发现它是被谁运算出来的,所以countmat就会被看作一个常数被保存下来,之前计算countmat的部分都会被扔掉
- The End -
长按二维码关注我们
本公众号专注:
1. 技术分享;
2. 学术交流;
3. 资料共享。
欢迎关注我们,一起成长!