PyTorch 源码解读之即时编译篇
共 84968字,需浏览 170分钟
·
2021-06-13 21:57
前言
torch 从 1.0 开始支持了 jit 模块,其大概包括以下几个部分:
一种新的计算图中间表示 (Intermediate Representation),之后简称为 IR. 从 Python 代码导出IR的两种方法,即 trace 与 script. IR 优化以及 IR 的解释器(翻译为具体的运算 op).
这篇解读会分为以下几个部分:
jit 的简单介绍以及两种导出方式的使用例子 jit 中 IR 的形式 导出 IR 的两种方式,trace 与 script 的源码解读 IR 优化的简单介绍
1 jit 的简单介绍以及使用例子
JIT 简介
如前言,这篇解读虽然标题是 JIT,但是真正称得上即时编译器的部分是在导出 IR 后,即优化 IR 计算图,并且解释为对应 operation 的过程,即PyTorch jit 相关 code 带来的优化一般是计算图级别优化,比如部分运算的融合,但是对具体算子(如卷积)是没有特定优化的,其依旧调用 torch的基础算子库.
大家也可以在导出 IR 也就是 torchscript 后,使用其他的编译优化或者解释器,如现在也有script to a TensorRT engine,TRTtorch转 tensorRT 的方案。
trace
给大家一个简单例子。
import torchvision.models as models
resnet = torch.jit.trace(models.resnet18(),torch.rand(1,3,224,224))
output=resnet(torch.ones(1,3,224,224))
print(output)
output=resnet(torch.ones(1,3,224,224))
resnet.save('resnet.pt')
output 便是我们导出的中间表示,其可以 save 下来,在其他框架使用
我们可以看下 output 中的 IR,即 torchscript 表征的计算图是什么样子的。
graph(%self.1 : __torch__.torchvision.models.resnet.___torch_mangle_194.ResNet,
%input.1 : Float(1:150528, 3:50176, 224:224, 224:1, requires_grad=0, device=cpu)):
%1472 : __torch__.torch.nn.modules.linear.___torch_mangle_193.Linear = prim::GetAttr[name="fc"](%self.1)
%1469 : __torch__.torch.nn.modules.pooling.___torch_mangle_192.AdaptiveAvgPool2d = prim::GetAttr[name="avgpool"](%self.1)
%1468 : __torch__.torch.nn.modulesjieshao.container.___torch_mangle_191.Sequential = prim::GetAttr[name="layer4"](%self.1)
%1422 : __torch__.torch.nn.modules.container.___torch_mangle_175.Sequential = prim::GetAttr[name="layer3"](%self.1)
....
%1556 : Tensor = prim::CallMethod[name="forward"](%1469, %1555)
%1202 : int = prim::Constant[value=1]()
%1203 : int = prim::Constant[value=-1]()
%input : Float(1:512, 512:1, requires_grad=1, device=cpu) = aten::flatten(%1556, %1202, %1203)
%1557 : Tensor = prim::CallMethod[name="forward"](%1472, %input)
return (%1557)
这便是 trace 方法的使用,其核心实现的入口便是torch.jit.trace
,参数为你需要导出的 model,以及合法输入input,其大概原理恰如其名,便是跟踪模型 inference 过程,将模型对输入进行的操作逐一记录下来,并对应到 IR 的操作,从而得到原本模型forward 的 IR。
ote :但是这种实现方式有很明显的缺陷,PyTorch 作为动态图网络,会有很多的 input dependent的控制流语句,根据输入的不同可能会执行情况会不同(if 或者 变长的 loop),这样就无法 trace 到完整的计算图。如下就是一个 trace
失败的 case:
if x > 2.0:
r = torch.tensor(1.0)
else:
r = torch.tensor(2.0)
return r
ftrace = torch.jit.trace(test, (torch.ones(1)))
y = torch.ones(1) * 5
print(ftrace(y))
# results: tensor(2.)
# 因为输入只走了的分支else
script
@torch.jit.script
def foo(x, y):
if x.max() > y.max():
r = x
else:
r = y
return r
print(foo.graph)
print(foo(torch.Tensor([0]), torch.Tensor([1])))
print(foo(torch.Tensor([1]), torch.Tensor([0])))
graph(%x.1 : Tensor,
%y.1 : Tensor):
%3 : Tensor = aten::max(%x.1)
%5 : Tensor = aten::max(%y.1)
# 可以看到确实捕捉到了控制语句,
%6 : Tensor = aten::gt(%3, %5)
%7 : bool = aten::Bool(%6)
%r : Tensor = prim::If(%7)
block0():
-> (%x.1)
block1():
-> (%y.1)
return (%r)
tensor([1.])
tensor([1.])
script 使用是在你需要的地方 (fuction or nn.Module (默认追踪 forward函数))挂载装饰器torch.jit.script
,其转换方式跟 trace 是完全不同的思路,script 直接解析你的 PyTorch代码,通过语法分析解析你的逻辑为一棵语法树,然后转换为中间表示 IR。
Note: 虽然其可以解决 trace 存在无法追踪动态逻辑的问题,但是 Python 作为灵活度极高的语法, 想完整支持解析各种 Python 操作几乎是不可能的,因此我们需要额外的时间熟悉哪些写法是可以被解析的,让我们写代码的体验大打折扣。
两者结合
两者各有优势,支持灵活集合。
import torch
import torch.nn as nn
import torch.nn.functional as F
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
# torch.jit.trace produces a ScriptModule's conv1 and conv2
self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))
def forward(self, input):
input = F.relu(self.conv1(input))
input = F.relu(self.conv2(input))
return input
scripted_module = torch.jit.script(MyModule())
因此实际使用时候,可以有如下准则:
1 大部分情况 model 只有 tensor operation,就直接无脑 tracing
2 带 control-flow (if-else, for-loop) 的,上 scripting
3 碰上 scripting 不能 handle 的语法,要么重写,要么把 tracing 和 scripting 合起来用(比如说只在有 control-
flow 的代码用 scripting,其他用 tracing)
如何扩展
trace 与 script 都不能转换第三方 Python 库中的函数,尽量所有代码都使用 PyTorch 实现, 自定义 op 需要注册成 jit
操作( torch 的 op 其实也注册了),最后转成 torchscript。
TORCH_LIBRARY(my_ops, m) {
m.def("warp_perspective", warp_perspective);
}
更多可以参考官方教程
1 EXTENDING TORCHSCRIPT WITH CUSTOM C++ OPERATORS
2 IR (torchscript)的基本表示
PyTorch 中的各种设计(parameter,计算节点等)在 torchscript 中是如何对应的呢?
这便是转换出的 IR 结果,torchscrip 以下结构组合。
名称 | source code | 简介 |
---|---|---|
Modules | module.h | 对标 nn.Module |
Parameters | module.h | 对标 PyTorch 的 parameter |
Method | Method.h | 包括 FunctionSchema 方法描述,Graph 实际计算图,GraphExecutor do the optimization and execution |
FunctionSchema | function_schema.h | 描述参数与返回类型 |
Graph | ir.h | 定义 function 的具体实现,包括 Nodes,Blocks,Values |
Nodes | ir.h | 一个指令,如一次卷积运算,一次矩阵运算 |
Block | ir.h | 控制语句 if,loop + list of nodes |
还有with
,Value
,Type
等
# %x.1 value
graph(%x.1 : Tensor,
%y.1 : Tensor):
# aten::max 就是一个Node
# Tensor: Type-TensorType
%3 : Tensor = aten::max(%x.1)
%5 : Tensor = aten::max(%y.1)
%6 : Tensor = aten::gt(%3, %5)
%7 : bool = aten::Bool(%6)
%r : Tensor = prim::If(%7)
# Blocks
block0():
-> (%x.1)
block1():
-> (%y.1)
return (%r)
3 导出 IR 的两种方式,trace 与 script
因为其具体实现颇为复杂,粘贴的源码也仅仅保留了简单 case 跑过的分支,并且省去了绝大部分细节,读者如有需要更多细节可以自行去源码查阅。
trace 实现
func,
example_inputs,
optimize=None,
check_trace=True,
check_inputs=None,
check_tolerance=1e-5,
strict=True,
_force_outplace=False,
_module_class=None,
_compilation_unit=_python_cu,
):
# 发现是nn.Module instacene forward, 追踪forward
if isinstance(func, torch.nn.Module):
return trace_module(
func,
{"forward": example_inputs},
None,
check_trace,
wrap_check_inputs(check_inputs),
check_tolerance,
strict,
_force_outplace,
_module_class,
)
# 传进来的是某个module instance的forward
if (
hasattr(func, "__self__")
and isinstance(func.__self__, torch.nn.Module)
and func.__name__ == "forward"
):
return trace_module(
func.__self__,
{"forward": example_inputs},
None,
check_trace,
wrap_check_inputs(check_inputs),
check_tolerance,
strict,
_force_outplace,
_module_class,
)
# 一个查找变量名的接口
var_lookup_fn = _create_interpreter_name_lookup_fn(0)
# C++ 入口
traced = torch._C._create_function_from_trace(
name, func, example_inputs, var_lookup_fn, strict,_force_outplace
)
# 检查traced 与 原func是否有差异
if check_trace:
if check_inputs is not None:
_check_trace(
check_inputs,
func,
traced,
check_tolerance,
strict,
_force_outplace,
False,
_module_class,
)
else:
_check_trace(
[example_inputs],
func,
traced,
check_tolerance,
strict,
_force_outplace,
False,
_module_class,
)
return traced
我们发现经过简单的判断,代码便进入了 C++ 相关函数
traced = torch._C._create_function_from_trace(
name, func, example_inputs, var_lookup_fn, strict, _force_outplace
)
我们去 C++ 中看下发生了什么
std::pair<std::shared_ptr<TracingState>, Stack> trace(
Stack inputs,
const std::function<Stack(Stack)>& traced_fn,
std::function<std::string(const Variable&)> var_name_lookup_fn,
bool strict,
bool force_outplace,
Module* self) {
try {
auto state = std::make_shared<TracingState>();
# setTracingState 将state 这个实例set下来,在之后计算节点get出来insert计算过程
setTracingState(state);
#state这个数据结构会在forward过程中存储trace到的计算过程
if (self) {
Value* self_value = state->graph->insertInput(0, "self")->setType(
self->_ivalue()->type());
gatherParametersAndBuffers(state, self_value, *self, {"__module"});
}
for (IValue& input : inputs) {
input = addInput(state, input, input.type(), state->graph->addInput());
}
auto graph = state->graph;
# 将python中的变量名解析函数绑定下来
getTracingState()->lookup_var_name_fn = std::move(var_name_lookup_fn);
getTracingState()->strict = strict;
getTracingState()->force_outplace = force_outplace;
# 开始forward,在计算发生时,会把计算记录到state中
auto out_stack = traced_fn(inputs);
// Exit a trace, treating 'out_stack' as the outputs of the trace. These
// are the variables whose values will be computed upon subsequent
// invocations of the trace.
size_t i = 0;
for (auto& output : out_stack) {
// NB: The stack is in "reverse" order, so when we pass the diagnostic
// number we need to flip it based on size.
state->graph->registerOutput(
state->getOutput(output, out_stack.size() - i));
i++;
}
setTracingState(nullptr);
if (getInlineEverythingMode()) {
Inline(*graph);
}
FixupTraceScopeBlocks(graph, self);
NormalizeOps(graph);
return {state, out_stack};
} catch (...) {
tracer::abandon();
throw;
}
}
那么具体记录 operation 的过程发生在哪里呢?
pytorch/torch/csrc/jit/runtime/register_c10_ops.cpp
Operator createOperatorFromC10_withTracingHandledHere(
const c10::OperatorHandle& op) {
return Operator(op, [op](Stack& stack) {
const auto input_size = op.schema().arguments().size();
const auto output_size = op.schema().returns().size();
Node* node = nullptr;
std::shared_ptr<jit::tracer::TracingState> tracer_state;
// trace the input before unwrapping, otherwise we may lose
// the input information
if (jit::tracer::isTracing()) {
# 获取 tracer_state
tracer_state = jit::tracer::getTracingState();
auto symbol = Symbol::fromQualString(op.schema().name());
const auto& graph = tracer::getTracingState()->graph;
node = graph->create(symbol, 0);
tracer::recordSourceLocation(node);
const auto& args = op.schema().arguments();
int i = 0;
# 记录args
for (auto iter = stack.end() - input_size; iter != stack.end();
++iter, ++i) {
// TODO we need to refactor graph APIs (e.g., addInputs)
// appropriately; after that, we can get rid of the giant if-else
// block we will clean this tech debt together in the following PRs
auto type = args[i].type();
if (type->kind() == TypeKind::OptionalType) {
if (iter->isNone()) {
Value* none = graph->insertNode(graph->createNone())->output();
node->addInput(none);
continue;
} else {
type = type->expect<OptionalType>()->getElementType();
}
}
if (type->isSubtypeOf(TensorType::get())) {
AT_ASSERT(iter->isTensor());
tracer::addInputs(node, args[i].name().c_str(), iter->toTensor());
} else if (type->kind() == TypeKind::FloatType) {
AT_ASSERT(iter->isDouble());
tracer::addInputs(node, args[i].name().c_str(), iter->toDouble());
} else if (type->kind() == TypeKind::IntType) {
AT_ASSERT(iter->isInt());
tracer::addInputs(node, args[i].name().c_str(), iter->toInt());
} else if (type->kind() == TypeKind::BoolType) {
AT_ASSERT(iter->isBool());
tracer::addInputs(node, args[i].name().c_str(), iter->toBool());
} else if (type->kind() == TypeKind::StringType) {
AT_ASSERT(iter->isString());
tracer::addInputs(node, args[i].name().c_str(), iter->toStringRef());
} else if (type->kind() == TypeKind::NumberType) {
tracer::addInputs(node, args[i].name().c_str(), iter->toScalar());
} else if (type->kind() == TypeKind::ListType) {
const auto& elem_type = type->expect<ListType>()->getElementType();
if (elem_type->isSubtypeOf(TensorType::get())) {
AT_ASSERT(iter->isTensorList());
auto list = iter->toTensorVector();
tracer::addInputs(node, args[i].name().c_str(), list);
} else if (elem_type->kind() == TypeKind::FloatType) {
AT_ASSERT(iter->isDoubleList());
// NB: now, tracer doesn't support tracing double list. We add
// special handling here, since in our case, we assume that all the
// doubles in the list are constants
auto value = iter->toDoubleVector();
std::vector<Value*> info(value.size());
for (size_t value_index = 0; value_index < value.size();
++value_index) {
info[value_index] = graph->insertConstant(value[value_index]);
tracer::recordSourceLocation(info[value_index]->node());
}
node->addInput(
graph
->insertNode(graph->createList(jit::FloatType::get(), info))
->output());
} else if (elem_type->kind() == TypeKind::IntType) {
AT_ASSERT(iter->isIntList());
tracer::addInputs(
node, args[i].name().c_str(), iter->toIntVector());
} else if (elem_type->kind() == TypeKind::BoolType) {
AT_ASSERT(iter->isBoolList());
tracer::addInputs(
node, args[i].name().c_str(), iter->toBoolList().vec());
} else {
throw std::runtime_error(
"unsupported input list type: " + elem_type->str());
}
} else if (iter->isObject()) {
tracer::addInputs(node, args[i].name().c_str(), iter->toObject());
} else {
throw std::runtime_error("unsupported input type: " + type->str());
}
}
# node嵌入graph
graph->insertNode(node);
jit::tracer::setTracingState(nullptr);
}
可以看到,在具体运算发生时,会使用 getTracingState() 得到 forward 开始去创建的 state,然后看到根据op.schema().name() 得到计算类型(比如相加),根据计算类型通过 createNone 方法创建一个计算节点,然后创建计算输入,最后把计算node insert 到 graph 中,完成一次对计算的记录。
script
因为 script 得到 IR 的方式是解析源码,因此对于不同的代码形式会略有不同(函数,class,nn.Module的instance):1 Python 函数 简化后 code
def script(obj, optimize=None, _frames_up=0, _rcb=None):
# fucntion 分支
if hasattr(obj, "__script_if_tracing_wrapper"):
obj = obj.__original_fn
_rcb = _jit_internal.createResolutionCallbackFromClosure(obj)
# 检查重载
_check_directly_compile_overloaded(obj)
# 是否之前被script过了
maybe_already_compiled_fn = _try_get_jit_cached_function(obj)
if maybe_already_compiled_fn:
return maybe_already_compiled_fn
# 得到ast语法树
ast = get_jit_def(obj, obj.__name__)
if _rcb is None:
_rcb = _jit_internal.createResolutionCallbackFromClosure(obj)
#c++ 入口,根据ast得到ir
fn = torch._C._jit_script_compile(
qualified_name, ast, _rcb, get_default_args(obj)
)
# Forward docstrings
fn.__doc__ = obj.__doc__
# cache起来
_set_jit_function_cache(obj, fn)
return fn
我们看下get_jit_def是如何得到 jit 规定的 ast 语法树的
仅保留逻辑代码,细节删掉
def get_jit_def(fn, def_name, self_name=None):
# 得到源代码的一些信息
sourcelines, file_lineno, filename = get_source_lines_and_file(fn, torch._C.ErrorReport.call_stack())
sourcelines = normalize_source_lines(sourcelines)
source = dedent_src ''.join(sourcelines)
# dedent_src 为包含了要script函数的字符串
dedent_src = dedent(source)
# 调用python ast包将字符串解析为Python的ast
py_ast = ast.parse(dedent_src)
# 得到python类型注释
type_line = torch.jit.annotations.get_type_line(source)
#ctx中包含了函数所有原信息
ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len, True)
fn_def = py_ast.body[0]
# build_def将python 的ast 转化为torchjit 使用的ast格式
return build_def(ctx, fn_def, type_line, def_name, self_name=self_name)
用一个简单的例子给大家解释下 py_ast.body[0] 是什么
import ast
... func_def= \
... """def test(a):
... a = a + 2
... return a + 1"""
... results = ast.parse(func_def)
Python 解析出的 AST
可见,ast.body 是一个 list,其长度等于解析的 string 中包含的函数的个数,我们看第一个元素,其中 value 是一个
Binop
具体为一个Add
,left 是Name
类型,id
为``a,right是
Num,也就是2,这个
Binop即解析的
a = a + 2`。
因为我们 get_source_lines_and_file 返回的一定是一个 single top-level function, 因此我们直接取用第 0个元素,即 py_ast.body[0] 就可以了。
接下来看build_def
是如何将 Python 的 ast 转化为自己需要的 ast 的。
进入buid_def
def build_def(ctx, py_def, type_line, def_name, self_name=None):
....
return Def(Ident(r, def_name),
decl,
build_stmts(ctx, body))
因为ctx
包含 source code 所有信息, body 是 Python ast 解析结果,那么build_stmts
中应该包含我们想要的答案。
我们用例子中a+2
为例看会怎么转换,这部分可见frontend.py
关于StmtBuilder
from torch._C._jit_tree_views import (
ClassDef, Ident, Stmt, Decl, Def, Var,
EmptyTypeAnnotation, Param, ExprStmt, Assign,
Delete, Return, Raise, Assert, AugAssign, While,
For, If, Pass, Break, Continue, Apply, Dots, Select,
TrueLiteral, FalseLiteral, NoneLiteral, Starred,
ListLiteral, TupleLiteral, DictLiteral, Const,
StringLiteral, ListComp, Attribute, BinOp, UnaryOp,
SliceExpr, Subscript, TernaryIf, With, WithItem, Property,
DictComp,
)
# jit中定义的ast基本结构
def build_stmts(ctx, stmts):
#发现其调用了`build_stmt`
stmts = [build_stmt(ctx, s) for s in stmts]
return list(filter(None, stmts))
#`build_stmt` 是一个StmtBuilder()的instance
build_stmt = StmtBuilder()
build_expr = ExprBuilder()
class Builder(object):
def __call__(self, ctx, node):
# 可见会根据解析出的ast的类型返回相应的build方法,从截图可以看到`a+2`是一个`Assign`类型
# 因此会调用build_Assign
method = getattr(self, 'build_' + node.__class__.__name__, None)
if method is None:
raise UnsupportedNodeError(ctx, node)
return method(ctx, node)
class StmtBuilder(Builder):
@staticmethod
def build_Assign(ctx, stmt):
# 截图可以看到stmt.value是一个Binop
# build_expr是ExprBuilder的INSTANCE,其会调用`build_BinOp`
rhs = build_expr(ctx, stmt.value)
lhs = [build_expr(ctx, x) for x in stmt.targets]
return Assign(lhs, rhs)
@staticmethod
def build_Expr(ctx, stmt):
# Binop
value = stmt.value
if value.__class__.__name__ == 'Str':
# If a statement is a string literal expression,
# then it is a docstring. Just ignore it.
return None
else:
return ExprStmt(build_expr(ctx, value))
class ExprBuilder(Builder):
binop_map = {
ast.Add: '+',
ast.Sub: '-',
ast.Mult: '*',
ast.Div: '/',
ast.Pow: '**',
ast.Mod: '%',
ast.FloorDiv: '//',
ast.BitAnd: '&',
ast.BitXor: '^',
ast.BitOr: '|',
ast.LShift: '<<',
ast.RShift: '>>',
}
@staticmethod
def build_BinOp(ctx, expr):
#expr.left是个`Name`调用build_Name
lhs = build_expr(ctx, expr.left)
rhs = build_expr(ctx, expr.right)
op = type(expr.op)
# 转化为约定的代表运算类型的string 符号
op_token = ExprBuilder.binop_map.get(op)
return BinOp(op_token, lhs, rhs)
最终转化为的格式,类似于S-expression.
(def
(ident test)
(decl
(list
(param
(ident a)
(option)
(option)
(False)))
(option))
(list
(assign
(list (variable (ident a)))
(option
(+
(variable (ident a))
(const 2)))
(option))
(return
(+
(variable (ident a))
(const 1)))))
好的,我们已经得到得到jit约定的 AST 树了,接下来我们要进入 torch._C._jit_script_compile查看如何将这样的 ast 树转化为 IR.
C++ 入口为 script_compile_function
static StrongFunctionPtr script_compile_function(
const c10::QualifiedName& name,
const Def& def,
const FunctionDefaults& defaults,
const ResolutionCallback& rcb) {
# def 中包含ast,跟着它就能找到答案
auto cu = get_python_cu();
#看来是get_python_cu这个类中的define函数完成的
auto defined_functions = cu->define(
QualifiedName(name.prefix()),
/*properties=*/{},
/*propResolvers=*/{},
{def},
{pythonResolver(rcb)},
nullptr,
true);
TORCH_INTERNAL_ASSERT(defined_functions.size() == 1);
auto& defined = defined_functions[0];
defined->setSchema(getSchemaWithNameAndDefaults(
def.range(), defined->getSchema(), def.name().name(), defaults));
StrongFunctionPtr ret(std::move(cu), defined);
didFinishEmitFunction(ret);
return ret;
}
# 发现只是wapper了下CompilationUnit
inline std::shared_ptr<CompilationUnit> get_python_cu() {
return py::module::import("torch.jit._state")
.attr("_python_cu")
.cast<std::shared_ptr<CompilationUnit>>();
}
#关于compilation_unit
#/torch/csrc/jit/api/compilation_unit.h
// for historic reasons, these are defined in ir_emitter.cpp
// Returns the list of Functions just defined.
std::vector<Function*> define(
const c10::optional<c10::QualifiedName>& prefix,
const std::vector<Property>& properties,
const std::vector<ResolverPtr>& propResolvers,
const std::vector<Def>& definitions,
const std::vector<ResolverPtr>&
defResolvers, /* determines how we handle free
variables in each definition*/
// if non-null, the first argument to each def, is bound to this value
const Self* self,
// see [name mangling]
bool shouldMangle = false);
#实现在torch/csrc/jit/frontend/ir_emitter.cpp
std::unique_ptr<Function> CompilationUnit::define(
const c10::optional<QualifiedName>& prefix,
const Def& def,
const ResolverPtr& resolver,
const Self* self,
const std::unordered_map<std::string, Function*>& function_table,
bool shouldMangle) const {
auto _resolver = resolver;
.....
auto creator = [def, _resolver, self](Function& method) {
....
##核心代码to_ir
to_ir(def, _resolver, self, method);
};
auto fn = torch::make_unique<GraphFunction>(
std::move(name), std::make_shared<Graph>(), creator);
return fn;
}
我们跟随 def,找到了一个转化为 IR 的关键的struct
to_ir,其输入中有 def,也就是 ast,_resolver 是 Python 中传过来的解析名字的函数,我们可以在内部找到关键部分
to_ir(
const Def& def,
ResolverPtr resolver_,
const Self* self,
Function& method) // method being constructed
: method(method),
graph(method.graph()),
resolver(std::move(resolver_)),
typeParser_(resolver),
environment_stack(nullptr) {
AT_ASSERT(resolver);
pushFrame(graph->block(), /*starts_def=*/true);
#emitDef 中会调用emitStatements
method.setSchema(emitDef(def, self, graph->block()));
ConvertToSSA(graph);
CanonicalizeModifiedLoops(graph);
NormalizeOps(graph);
runCleanupPasses(graph);
}
private:
#在to_ir 的private中我们可以看到Graph Function这些我们之前介绍的IR的组成部分
Function& method;
std::shared_ptr<Graph> graph;
ResolverPtr resolver;
std::unordered_map<int64_t, Value*> integral_constants;
#emitDef 中会调用emitStatements
FunctionSchema emitDef(const Def& def, const Self* self, Block* block) {
......
// body
auto stmts_list = def.statements();
emitStatements(stmts_list.begin(), stmts_list.end());
........
}
void emitStatements(
List<Stmt>::const_iterator begin,
List<Stmt>::const_iterator end) {
for (; begin != end; ++begin) {
auto stmt = *begin;
ErrorReport::CallStack::update_pending_range(stmt.range());
switch (stmt.kind()) {
case TK_IF:
emitIf(If(stmt));
break;
case TK_WHILE:
emitWhile(While(stmt));
break;
case TK_FOR:
emitFor(For(stmt));
break;
case TK_ASSIGN:
emitAssignment(Assign(stmt));
.................
break;
default:
throw ErrorReport(stmt)
<< "Unrecognized statement kind " << kindToString(stmt.kind());
}
// Found an exit statement in this block. The remaining statements aren't
// reachable so we don't emit them.
if (exit_blocks.count(environment_stack->block()))
return;
}
}
我们可以看到根据stmt.kind(),会进入而各种emit里面,其中一定可以找到
graph->insertNode(graph->create(.....));
类似的操作,对应我们建立IR graph
以上是我们以一个 function 为例子,接下来我们以 script 一个 module为例,其有一些独有的挑战,因为有一些变量的指代,是需要初始化后才知道的,同时,我们希望 script 完的 module 对外还能保持一样的接口,即可以正常访问原有 module 的属性,那么应该怎么做呢?
在 module 原有的 init 结束后随即开始完整的 script forward 函数,替换涉及到的所有函数为 script 后的函数 如何正常访问原有的属性
如何在一个类的 init 函数后面绑定行为呢,我们想到 metaclass,torch.jit 实现了 ScriptMeta这个 metaclass。
class MyModule(torch.jit.ScriptModule):
@torch.jit.script_method
def f(self.x):
return x * x
@torch.jit.script_method
def forward(self, x):
return x + self.f(x)
关于script_method
def script_method(fn):
_rcb = _jit_internal.createResolutionCallbackFromFrame(frames_up=2)
ast = get_jit_def(fn, fn.__name__, self_name="ScriptModule")
#暂时没有script,只是返回包含ast的nametuple
return ScriptMethodStub(_rcb, ast, fn)
ScriptMethodStub = collections.namedtuple('ScriptMethodStub', ('resolution_callback', 'def_', 'original_method'))
1. 移除所有script_method属性被(@script_method修饰的方法),确保访问到的是script function
2. 修改module的_init_,确保module的self.param或者self.module初始化后立即编译所有的script_method,从而生成的instance的forward已经被替换
class ScriptMeta(type):
def __init__(cls, name, bases, attrs): # noqa: B902
# cls ScriptMeta的instance,是一个类如ScriptModule
cls._methods: Dict[str, Any] = {}
cls._constants_set = set(getattr(cls, "__constants__", ()))
for base in reversed(bases):
# 还记得吗trace的module也是有一个_methods的属性
for k, v in getattr(base, "_methods", {}).items():
cls._methods[k] = v
base_constants = getattr(base, "_constants_set", set())
cls._constants_set = cls._constants_set.union(base_constants)
# 找到现在所有被@script_method修饰的方法,放到_method,并删除原有attr
# init后之后统一script
for k, v in sorted(attrs.items()):
if isinstance(v, ScriptMethodStub):
delattr(cls, k)
cls._methods[v.original_method.__name__] = v
original_init = getattr(cls, "__init__", lambda self: None)
# 此处实现了init结束后,调用create_script_module进行script
@functools.wraps(original_init)
def init_then_script(self, *args, **kwargs):
# 此处的self为instance
num_methods = len(cls._methods)
original_init(self, *args, **kwargs)
added_methods_in_init = len(cls._methods) > num_methods
if type(self) == cls:
# 选取需要script的method
def make_stubs(module):
cls = type(module)
if hasattr(cls, "_methods"):
return [v for k, v in sorted(cls._methods.items())]
else:
# infer_methods_to_compile 是一个选取要script函数的函数
return infer_methods_to_compile(module)
# 讲所有script_method一块编译为_actual_script_module属性
self.__dict__[
"_actual_script_module"
] = torch.jit._recursive.create_script_module(self, make_stubs, share_types=not added_methods_in_init)
# Delete the Python attributes that now shadow the ScriptModule
# ones, so that __getattr__ and __setattr__ will properly find
# the scripted versions.
concrete_type = self._actual_script_module._concrete_type
for name in concrete_type.get_attributes():
delattr(self, name)
for name, _ in concrete_type.get_modules():
delattr(self, name)
for name in ("_parameters", "_buffers", "_modules"):
delattr(self, name)
cls.__init__ = init_then_script # type: ignore
return super(ScriptMeta, cls).__init__(name, bases, attrs)
class _CachedForward(object):
def __get__(self, obj, cls):
return self.__getattr__("forward") # type: ignore
class ScriptModule(with_metaclass(ScriptMeta, Module)): # type: ignore
def __init__(self):
super(ScriptModule, self).__init__()
forward = _CachedForward()
# 想访问module的attr,返回_actual_script_module的attr
def __getattr__(self, attr):
if "_actual_script_module" not in self.__dict__:
return super(ScriptModule, self).__getattr__(attr)
return getattr(self._actual_script_module, attr)
def __setattr__(self, attr, value):
if "_actual_script_module" not in self.__dict__:
# Unwrap torch.jit.Attribute into a regular setattr + recording
# the provided type in __annotations__.
#
# This ensures that if we use the attr again in `__init__`, it
# will look like the actual value, not an instance of Attribute.
if isinstance(value, Attribute):
if "__annotations__" not in self.__class__.__dict__:
self.__class__.__annotations__ = {}
self.__annotations__[attr] = value.type
value = value.value
return super(ScriptModule, self).__setattr__(attr, value)
setattr(self._actual_script_module, attr, value)
关于 create_script_module 函数会 script method 然后返回一个RecursiveScriptModule,但是其逻辑较为复杂,在此不再展开。
关于 getattribute vs getattr
当访问某个实例属性时,getattribute 会被无条件调用,当这个属性不存在,则会调用 getattr,如未实现自己的 getattr 方法,会抛出AttributeError 提示找不到这个属性,如果自定义了自己 getattr 方法的话方法会在这种找不到属性的情况下被调用。
4 IR优化的简单介绍
jit 一般涉及如下优化: loop unrolling peephole optimization constant propagation DCE fusion inlining... 我们看如下例子:
def test(x):
# Dead code Elimination
for i in range(1000):
y = x + 1
for i in range(100):
#peephole optimization
x = x.t()
x = x.t()
return x.sum()
opt_test = torch.jit.script(test)
s = time()
inputs = torch.ones(4,4).cuda()
s = time()
for i in range(10000):
test(inputs)
print(time()-s)
# 95s
s = time()
for i in range(10000):
opt_test(inputs)
print(time()-s)
# 0.13s
print(opt_test.graph)
print(opt_test.graph_for(inputs))
95.13823795318604
0.13010907173156738
graph(%x.1 : Tensor):
%22 : None = prim::Constant()
%13 : bool = prim::Constant[value=1]() # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:10:4
%10 : int = prim::Constant[value=100]() # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:10:19
%x : Tensor = prim::Loop(%10, %13, %x.1) # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:10:4
block0(%i : int, %x.10 : Tensor):
%x.4 : Tensor = aten::t(%x.10) # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:11:12
%x.7 : Tensor = aten::t(%x.4) # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:12:12
-> (%13, %x.7)
%23 : Tensor = aten::sum(%x, %22) # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:13:11
return (%23)
graph(%x.1 : Tensor):
%1 : None = prim::Constant()
%2 : Tensor = aten::sum(%x.1, %1) # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:13:11
return (%2)
关于 IR 计算图优化
IR 的 Method 中内置 GraphExecutor object,创建于第一次执行的时候,负责优化。
文件 pytorch-master/torch/csrc/jit/api/method.h scritp_method 的 C++ 原型里
GraphExecutor& get_executor() {
return function_->get_executor();
}
GraphExecutor 的定义在/torch/csrc/jit/runtime/graph_executor.cpp,可见其由 graph 产生,定义了 run 方法执行
GraphExecutor::GraphExecutor(
const std::shared_ptr<Graph>& graph,
std::string function_name)
: pImpl(
IsNewExecutorEnabled()
? dynamic_cast<GraphExecutorImplBase*>(
new ProfilingGraphExecutorImpl(
graph,
std::move(function_name)))
: dynamic_cast<GraphExecutorImplBase*>(
new GraphExecutorImpl(graph, std::move(function_name)))) {}
std::shared_ptr<Graph> GraphExecutor::graph() const {
return pImpl->graph;
}
const ExecutionPlan& GraphExecutor::getPlanFor(
Stack& inputs,
size_t remaining_bailout_depth) {
return pImpl->getPlanFor(inputs, remaining_bailout_depth);
}
std::shared_ptr<GraphExecutorImplBase> pImpl;
.....
关于 GraphExecutorImplBase,/torch/csrc/jit/runtime/graph_executor.cpp
const ExecutionPlan& getOrCompile(const Stack& stack) {
.....
auto plan = compileSpec(spec);
}
}
# compileSpec 会返回一个plan
ExecutionPlan compileSpec(const ArgumentSpec& spec) {
auto opt_graph = graph->copy();
GRAPH_DUMP("Optimizing the following function:", opt_graph);
arg_spec_creator_.specializeTypes(*opt_graph, spec);
// Phase 0. Inline functions, then clean up any artifacts that the inliner
// left in that may inhibit optimization
.....
runRequiredPasses(opt_graph);
GRAPH_DEBUG(
"After runRequiredPasses, before ConstantPropagation\n", *opt_graph);
// Phase 2. Propagate detailed information about the spec through the
// graph (enabled more specializations in later passes).
// Shape propagation sometimes depends on certain arguments being
// constants, and constant propagation doesn't need shape
// information anyway, so it's better to run it first.
ConstantPropagation(opt_graph);
GRAPH_DEBUG(
"After ConstantPropagation, before PropagateInputShapes\n", *opt_graph);
PropagateInputShapes(opt_graph);
GRAPH_DEBUG(
"After PropagateInputShapes, before PropagateRequiresGrad\n",
*opt_graph);
PropagateRequiresGrad(opt_graph);
GRAPH_DEBUG(
"After PropagateRequiresGrad, before runOptimization\n", *opt_graph);
// Phase 3. Run differentiable optimizations (i.e. simple graph rewrites
// that we can still execute using autograd).
runOptimization(opt_graph);
.....各种优化
return ExecutionPlan(opt_graph, function_name_);
}
这些优化在 torch/csrc/jit/passes/ 文件夹
torch/csrc/jit/passes/dead_code_elimination.cpp
/torch/csrc/jit/passes/fuse_linear.cpp
torch/csrc/jit/passes/remove_dropout.cpp
torch/csrc/jit/passes/fold_conv_bn.cpp
参考
1. INTRODUCTION TO TORCHSCRIPT
2. PyTorch 部署_TorchScript
3.pytorch_wiki
4. PyTorch-JIT-Source-Code-Read-Note
5. Abstract_syntax_tree
- The End -
长按二维码关注我们
本公众号专注:
1. 技术分享;
2. 学术交流;
3. 资料共享。
欢迎关注我们,一起成长!