TorchScript 快速入门之初识一场

共 6010字,需浏览 13分钟

 ·

2022-04-18 13:43




今天,我们又将开启新的 TorchScript 解读系列教程,带领大家玩转 PyTorch 模型部署。感兴趣的小伙伴一起往下看吧~


什么是 TorchScript


PyTorch 无疑是现在最成功的深度学习训练框架之一,是各种顶会顶刊论文实验的大热门。比起其他的框架,PyTorch 最大的卖点是它对动态网络的支持,比其他需要构建静态网络的框架拥有更低的学习成本。PyTorch 源码 Readme 中还专门为此做了一张动态图:



对研究员而言,PyTorch 能极大地提高想 idea、做实验、发论文的效率,是训练框架中的豪杰,但是它不适合部署。动态建图带来的优势对于性能要求更高的应用场景而言更像是缺点,非固定的网络结构给网络结构分析并进行优化带来了困难,多数参数都能以 Tensor 形式传输也让资源分配变成一件闹心的事。另外由于图是由 python 代码来构建的,一方面部署要依赖 python 环境,另一方面模型也毫无保密性可言。


而 TorchScript 就是为了解决这个问题而诞生的工具。包括代码的追踪及解析、中间表示的生成、模型优化、序列化等各种功能,可以说是覆盖了模型部署的方方面面。今天我们先简要地介绍一些 TorchScript 的功能,让大家有一个初步的认识,进阶的解读会陆续推出~


模型转换



作为模型部署的一个范式,通常我们都需要生成一个模型的中间表示(IR),这个 IR 拥有相对固定的图结构,所以更容易优化,让我们看一个例子:


import torchfrom torchvision.models import resnet18
# 使用PyTorch model zoo中的resnet18作为例子model = resnet18()model.eval()
# 通过trace的方法生成IR需要一个输入样例dummy_input = torch.rand(1, 3, 224, 224)
# IR生成with torch.no_grad():    jit_model = torch.jit.trace(model, dummy_input)


到这里就将 PyTorch 的模型转换成了 TorchScript 的 IR。这里我们使用了 trace 模式来生成 IR,所谓 trace 指的是进行一次模型推理,在推理的过程中记录所有经过的计算,将这些记录整合成计算图。关于 trace 的过程我们会在未来的分享中进行解读。


那么这个 IR 中到底都有些什么呢?我们可以可视化一下其中的 layer1 看看:


jit_layer1 = jit_model.layer1print(jit_layer1.graph)
# graph(%self.6 : __torch__.torch.nn.modules.container.Sequential,#       %4 : Float(1, 64, 56, 56, strides=[200704, 3136, 56, 1], requires_grad=0, device=cpu)):#   %1 : __torch__.torchvision.models.resnet.___torch_mangle_10.BasicBlock = prim::GetAttr[name="1"](%self.6)#   %2 : __torch__.torchvision.models.resnet.BasicBlock = prim::GetAttr[name="0"](%self.6)#   %6 : Tensor = prim::CallMethod[name="forward"](%2, %4)#   %7 : Tensor = prim::CallMethod[name="forward"](%1, %6)#   return (%7)


是不是有点摸不着头脑?TorchScript 有它自己对于 Graph 以及其中元素的定义,对于第一次接触的人来说可能比较陌生,但是没关系,我们还有另一种可视化方式:


print(jit_layer1.code)
# def forward(self,#     argument_1: Tensor) -> Tensor:#   _0 = getattr(self, "1")#   _1 = (getattr(self, "0")).forward(argument_1, )#   return (_0).forward(_1, )


没错,就是代码!TorchScript 的 IR 是可以还原成 python 代码的,如果你生成了一个 TorchScript 模型并且想知道它的内容对不对,那么可以通过这样的方式来做一些简单的检查。


刚才的例子中我们使用 trace 的方法生成 IR。除了 trace 之外,PyTorch 还提供了另一种生成 TorchScript 模型的方法:script。这种方式会直接解析网络定义的 python 代码,生成抽象语法树 AST,因此这种方法可以解决一些 trace 无法解决的问题,比如对 branch/loop 等数据流控制语句的建图。script 方式的建图有很多有趣的特性,会在未来的分享中做专题分析,敬请期待。


模型优化



聪明的同学可能发现了,上面的可视化中只有 resnet18 里 forward 的部分,其中的子模块信息是不是丢失了呢?如果没有丢失,那么怎么样才能确定子模块的内容是否正确呢?别担心,还记得我们说过 TorchScript 支持对网络的优化吗,这里我们就可以用一个 pass 解决这个问题:


# 调用inline pass,对graph做变换torch._C._jit_pass_inline(jit_layer1.graph)print(jit_layer1.code)
# def forward(self,#     argument_1: Tensor) -> Tensor:#   _0 = getattr(self, "1")#   _1 = getattr(self, "0")#   _2 = _1.bn2#   _3 = _1.conv2#   _4 = _1.bn1#   input = torch._convolution(argument_1, _1.conv1.weight, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1, False, False, True, True)#   _5 = _4.running_var#   _6 = _4.running_mean#   _7 = _4.bias#   input0 = torch.batch_norm(input, _4.weight, _7, _6, _5, False, 0.10000000000000001, 1.0000000000000001e-05, True)#   input1 = torch.relu_(input0)#   input2 = torch._convolution(input1, _3.weight, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1, False, False, True, True)#   _8 = _2.running_var#   _9 = _2.running_mean#   _10 = _2.bias#   out = torch.batch_norm(input2, _2.weight, _10, _9, _8, False, 0.10000000000000001, 1.0000000000000001e-05, True)#   input3 = torch.add_(out, argument_1, alpha=1)#   input4 = torch.relu_(input3)#   _11 = _0.bn2#   _12 = _0.conv2#   _13 = _0.bn1#   input5 = torch._convolution(input4, _0.conv1.weight, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1, False, False, True, True)#   _14 = _13.running_var#   _15 = _13.running_mean#   _16 = _13.bias#   input6 = torch.batch_norm(input5, _13.weight, _16, _15, _14, False, 0.10000000000000001, 1.0000000000000001e-05, True)#   input7 = torch.relu_(input6)#   input8 = torch._convolution(input7, _12.weight, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1, False, False, True, True)#   _17 = _11.running_var#   _18 = _11.running_mean#   _19 = _11.bias#   out0 = torch.batch_norm(input8, _11.weight, _19, _18, _17, False, 0.10000000000000001, 1.0000000000000001e-05, True)#   input9 = torch.add_(out0, input4, alpha=1)#   return torch.relu_(input9)


这里我们就能看到卷积、batch_norm、relu 等熟悉的算子了。


上面代码中我们使用了一个名为 inline 的 pass,将所有子模块进行内联,这样我们就能看见更完整的推理代码。pass 是一个来源于编译原理的概念,一个 TorchScript 的 pass 会接收一个图,遍历图中所有元素进行某种变换,生成一个新的图。我们这里用到的 inline 起到的作用就是将模块调用展开,尽管这样做并不能直接影响执行效率,但是它其实是很多其他 pass 的基础。PyTorch 中定义了非常多的 pass 来解决各种优化任务,未来我们会做一些更详细的介绍。


序列化



不管是哪种方法创建的 TorchScript 都可以进行序列化,比如:


# 将模型序列化jit_model.save('jit_model.pth')# 加载序列化后的模型jit_model = torch.jit.load('jit_model.pth')


序列化后的模型不再与 python 相关,可以被部署到各种平台上。


PyTorch 提供了可以用于 TorchScript 模型推理的 c++ API,序列化后的模型终于可以不依赖 python 进行推理了:


// 加载生成的torchscript模型auto module = torch::jit::load('jit_model.pth');// 根据任务需求读取数据std::vector inputs = ...;// 计算推理结果auto output = module.forward(inputs).toTensor();


与其他组件的关系


与 torch.onnx 的关系




ONNX 是业界广泛使用的一种神经网络中间表示,PyTorch 自然也对 ONNX 提供了支持。torch.onnx.export 函数可以帮助我们把 PyTorch 模型转换成 ONNX 模型,这个函数会使用 trace 的方式记录 PyTorch 的推理过程。聪明的同学可能已经想到了,没错,ONNX 的导出,使用的正是 TorchScript 的 trace 工具。具体步骤如下:


1. 使用 trace 的方式先生成一个 TorchScipt 模型,如果你转换的本身就是 TorchScript 模型,则可以跳过这一步。

2. 使用许多 pass 对 1 中生成的模型进行变换,其中对 ONNX 导出最重要的一个 pass 就是ToONNX,这个 pass 会进行一个映射,将 TorchScript 中 prim、aten 空间下的算子映射到onnx空间下的算子。

3. 使用 ONNX 的 proto 格式对模型进行序列化,完成 ONNX 的导出。


关于 ONNX 导出的实现以及算子映射的方式将会在未来的分享中详细展开。


与 torch.fx 的关系



PyTorch1.9 开始添加了 torch.fx 工具,根据官方的介绍,它由符号追踪器 (symbolic tracer),中间表示(IR), Python 代码生成 (Python code generation) 等组件组成,实现了 python->python 的翻译。是不是和 TorchScript 看起来有点像?


其实他们之间联系不大,可以算是互相垂直的两个工具,为解决两个不同的任务而诞生。


TorchScript 的主要用途是进行模型部署,需要记录生成一个便于推理优化的 IR,对计算图的编辑通常都是面向性能提升等等,不会给模型本身添加新的功能。


FX 的主要用途是进行 python->python 的翻译,它的 IR 中节点类型更简单,比如函数调用、属性提取等等,这样的 IR 学习成本更低更容易编辑。使用 FX 来编辑图通常是为了实现某种特定功能,比如给模型插入量化节点等,避免手动编辑网络造成的重复劳动。


这两个工具可以同时使用,比如使用 FX 工具编辑模型来让训练更便利、功能更强大;然后用 TorchScript 将模型加速部署到特定平台。



推荐阅读

深入理解生成模型VAE

DropBlock的原理和实现

SOTA模型Swin Transformer是如何炼成的!

有码有颜!你要的生成模型VQ-VAE来了!

集成YYDS!让你的模型更快更准!

辅助模块加速收敛,精度大幅提升!移动端实时的NanoDet-Plus来了!

SimMIM:一种更简单的MIM方法

SSD的torchvision版本实现详解


机器学习算法工程师


                                    一个用心的公众号


浏览 64
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报