Hello guys, the TorchScript interpretation series has been updated. In the last article, we gave you an introduction to TorchScript.
TorchScript is a model serialization and deployment solution provided by PyTorch, which can compensate for PyTorch’s difficulty in deployment and can easily implement graph optimization or back-end docking. TorchScript supports the generation of data streams by trace; Script mode for parsing AST to generate graphs directly is also supported.
Today we’ll show how TorchScript generates streams of data through trace, and we’ll share the ONNX export process implemented using this mechanism. Next, let’s get to the topic of today
The basic concept
Let’s start by looking at three different representations of the same model. To make it easier to show the various JIT components, we create the diagram using script:
code
def forward(self, x):
x = x * 2
x.add_(0)
x = x.view(-1)
if x[0] > 1:
return x[0]
else:
return x[-1]
Copy the code
TorchScript Graph
graph(%self : __torch__.TestModel,
%x.1 : Tensor):
%12 : int = prim::Constant[value=-1]() # graph_example.py:12:19
%3 : int = prim::Constant[value=2]() # graph_example.py:10:16
%6 : int = prim::Constant[value=0]() # graph_example.py:11:15
%10 : int = prim::Constant[value=1]() # graph_example.py:12:20
%x.3 : Tensor = aten::mul(%x.1, %3) # graph_example.py:10:12
%8 : Tensor = aten::add_(%x.3, %6, %10) # graph_example.py:11:8
%13 : int[] = prim::ListConstruct(%12)
%x.6 : Tensor = aten::view(%x.3, %13) # graph_example.py:12:12
%17 : Tensor = aten::select(%x.6, %6, %6) # graph_example.py:13:11
%18 : Tensor = aten::gt(%17, %10) # graph_example.py:13:11
%20 : bool = aten::Bool(%18) # graph_example.py:13:11
%41 : Tensor = prim::If(%20) # graph_example.py:13:8
block0():
%23 : Tensor = aten::select(%x.6, %6, %6) # graph_example.py:14:19
-> (%23)
block1():
%32 : Tensor = aten::select(%x.6, %6, %12) # graph_example.py:16:19
-> (%32)
return (%41)
Copy the code
The middle section of the image above is a visualization of the TorchScript model, which contains the following elements:
Graph
The Graph column in the table as a whole is used to represent a Graph, which has the following properties
- Graph is used to represent a “function”, and different functions in a Module (such as forward, etc.) are converted to different graphs.
- The Graph has a number of nodes managed by a Block. All nodes are organized into a bidirectional linked list for easy insertion and deletion, where the Return Node “Return Node” serves as the “sentinel” of the bidirectional linked list. Bidirectional lists are usually sorted topologically to ensure correct execution.
Node
In the Graph column of the table, rows 3 to 14, 16 and 19 represent each Node, and each Node corresponds to an operation. The input for the operation is Value and, in rare cases, static attributes. Node contains a lot of information, including:
- Kind () represents the type of operation for Node, as shown above
aten::mul
andprim::ListConstruct
Etc are all kind corresponding to Node. Note that it is just a string, so modifying the string also means modifying the operation. - FunctionSchema is the description of the interface to this function, and it looks like the ops function declaration, and then you can add some tokens to indicate whether one Tensor is another Tensor’s Alias and so on. It can be used as a search basis for Peelhole-optimize. In order to
Tensor.add_
Function as an example:
// Add_ is an inplace operation, so the output shares the same memory space as self // FunctionSchema is annotated with this alias relation to ensure that the output is correct // Netron's visualization does not seem to do alias analysis? So in the visualization on the right above, the add_ part is wrong "add_.Scalar(Tensor(A!)) self, Scalar other, Scalar alpha=1) -> Tensor(a!) "Copy the code
- The schema of commonly used functions can be found in the
aten/src/ATen/native/native_functions.yaml
In the view.
Block
A Block represents an ordered list of nodes, representing the kind=Param of the input Node and the kind=Return of the output Node.
The Graph itself actually contains a root Block object that manages all nodes. Some nodes may also have sub blocks. For example, the Graph in the table has three blocks, one is the root Block implied by Graph, and the other two are the sub blocks of prim::If Node.
The concept of blocks probably stems from the basic Block in compilation principles. The so-called basic block is a series of instruction sequences that do not contain any jump instructions. Since the contents in the basic block can be guaranteed to be executed in sequence, many optimizations will take the basic block as the premise. In fact, much of the optimization for intermediate representation (IR) in PyTorch is block-level.
Value
Value is the input and output of a Node, it could be a Tensor it could be a container, it could be something else, you can tell by type().
The Value object maintains a use_list, which is added to a Node’s USe_list whenever the Value becomes its input. Using this use_list, you can easily resolve the input-output relationship between the newly added Node and other nodes.
Note: Value is used to describe the structure of the Graph, not the Runtime. What you really use for reasoning is the IValue object, which has real data at run time.
Pass
This is not strictly part of the Graph, but pass is a concept derived from compilation principles that takes an intermediate representation (IR), iterates over it and performs some transformation to generate a new IR that meets certain conditions.
TorchScript defines many passes to optimize the Graph. Such as is common for conventional compiler DeadCodeElimination (DCE), CommonSubgraphElimination (CSE), etc. There are also some fusion optimization for deep learning, such as FuseConvBN, etc. There are also passes for special tasks, such as ONNX exports.
JIT Trace
Jit trace in python side of the interface for the torch. The Jit. Trace, input parameters will be through layer upon layer, eventually entering the torch/Jit/frontend/trace. The trace function of CPP. This function is at the heart of Jit Trace and roughly performs the following steps:
- Create a new
TracingState
Object that maintains the Graph of the trace and any necessary environment parameters. - Generate Graph input nodes based on the model input parameters when tracing.
- Do model reasoning while generating the elements in the Graph.
- Generate the output node of Graph.
- Do some simple optimizations.
These steps are described in detail below:
1. CreateTracingState
object
The TracingState object contains the pointer to the Graph, function name mapping, stack frame information, etc. The process of tracing is the process of constantly updating the TracingState.
struct TORCH_API TracingState : Public STD ::enable_shared_from_this<TracingState> {// Part of the interface that can help build STD ::shared_ptr<Graph> Graph; void enterFrame(); void leaveFrame(); void setValue(const IValue& v, Value* value); void delValue(const IValue& var); Value* getValue(const IValue& var); Value* getOutput(const IValue& var, size_t i); bool hasValue(const IValue& var) const; Node* createNode(c10::Symbol op_name, size_t num_outputs); void insertNode(Node* node); };Copy the code
2. Generate Graph input
This step inserts a new input Value into graph based on the type of IValue entered. Remember the difference between IValue and Value we mentioned in the basic Concepts section?
for (IValue& input : Inputs: this function unpacks ivalues of container types, Create Node input = addInput(state, input, input.type(), state->graph->addInput()); }Copy the code
(3) to carry on the Tracing
Tracing is the process of inference using sample data. However, there is no code about how to update TracingState during inference in github source code.
So how exactly does PyTorch manage to update TracingState during reasoning? Let’s start with a few minor details about compiling the source code for PyTorch.
PyTorch has to adapt to a wide variety of hardware and environments, and customizing code for all of these situations is an awful lot of work and inconvenient for subsequent maintenance updates. So much of the code in PyTorch is generated based on build-time parameters, and updating TracingState is one of them. The script to generate the Tracing code is as follows:
python -m tools.autograd.gen_autograd \ aten/src/ATen/native/native_functions.yaml \ ${OUTPUT_DIR} \ tools/autograd # Yaml and native_functions.yaml contain # a lot of FunctionSchema and information needed to generate codeCopy the code
And you can run around and see what’s generated. The generated code tracetypeeverything. CPP contains a lot of information about updating TracingState, so let’s use the add operator as an example:
yaml
- func: scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor structured_delegate: scatter_add.out variants: function, method - func: scatter_add_(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!) structured_delegate: scatter_add.out variants: method - func: scatter_add.out(Tensor self, int dim, Tensor index, Tensor src, *, Tensor(a!) out) -> Tensor(a!) The contents of the structured scatter_add # func is a FunctionSchema that defines the input and output of functions, alias information, and so on.Copy the code
cpp
at::Tensor scatter_add(c10::DispatchKeySet ks, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src) { torch::jit::Node* node = nullptr; std::shared_ptr<jit::tracer::TracingState> tracer_state; If (JIT ::tracer::isTracing()) {// Step 1: If use TracingState create ops tracing, the corresponding Node and insert Graph tracer_state = jit: : tracer: : getTracingState (); at::Symbol op_name; op_name = c10::Symbol::fromQualString("aten::scatter_add"); node = tracer_state->createNode(op_name, /*num_outputs=*/0); jit::tracer::recordSourceLocation(node); jit::tracer::addInputs(node, "self", self); jit::tracer::addInputs(node, "dim", dim); jit::tracer::addInputs(node, "index", index); jit::tracer::addInputs(node, "src", src); tracer_state->insertNode(node); jit::tracer::setTracingState(nullptr); } // Step 2: Ops, Regardless of whether Tracing is performed, auto result =at::_ops:: SCATter_add ::redispatch(ks &) is executed c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::Tracer), self, dim, index, src); If (tracer_state) {/ / step 3: set in the TracingState ops output jit: : tracer: : setTracingState (STD: : move (tracer_state)); jit::tracer::addOutput(node, result); } return result; }Copy the code
Above is the Functional Schema and below is the generated code. The code selects whether to log Graph structure information based on whether acing is.
In fact, when Tracing, a function similar to the one generated above is called every time an OPS is passed, and the following steps are performed:
- Generate nodes and each input Value according to the functional schema parsing before reasoning;
- Then do the normal calculation of OPS;
- Finally, the output Value of Node is generated based on the output of OPS.
4. Register Graph output
There’s not much to say about this section, except to register the output of the inference as Graph’s output Value one by one. Since the outputs are on a stack, they are numbered in reverse order.
size_t i = 0;
for (auto& output : out_stack) {
// NB: The stack is in "reverse" order, so when we pass the diagnostic
// number we need to flip it based on size.
state->graph->registerOutput(
state->getOutput(output, out_stack.size() - i));
i++;
}
Copy the code
5. Graph optimization
After Tracing, I made some simple optimizations to Graph, including the following passes:
- Inline(Optional) : Network definitions often contain many nested structures, such as
Resnet
Will be a lot ofBottleNeck
Composition. This involves a call to the Sub Module, which generatesprim::CallMethod
Such as the Node. Inline optimization inlines the sub Module Graph into the current Graph, eliminating CallMethod, CallFunction, and so on. - FixupTraceScopeBlock: handles some scope-related nodes, such as
prim::TracedAttr[scope="__module.f.param"]()
These nodes are split into severalprim::GetAttr
The combination of. - NormalizeOps: Some nodes with different names may have the same functionality, for example
aten::absolute
andaten::abs
N ormalizeOps will unify the type names of these nodes (usually the shorter one).
A more detailed analysis of PASS will be covered in a subsequent share.
After the above steps, you get the result that has been traced.
ONNX Export
The Onnx model export also needs jit trace process, the general steps are as follows:
- The ops symbolic function is loaded, mainly the predefined symbolic in Torch.
- Set the environment, including opset_version, whether to fold constants, and so on.
- Use jit trace to generate Graph.
- Map nodes in the Graph to nodes in ONNX and make the necessary optimizations.
- Export the model to the ONNX serialization format.
Next, we will introduce the above steps in order:
1. Load —
Strictly speaking, this step has been completed before export. In symbolic_registry. Py, a _symbolic_versions object is maintained, The pre-defined symbolic (torch. Onnx.symbolic_opset) is loaded into the module using importlib.
_symbolic_versions: Dict[Union[int, str], Any] = {}
from torch.onnx.symbolic_helper import _onnx_stable_opsets, _onnx_main_opset
for opset_version in _onnx_stable_opsets + [_onnx_main_opset]:
module = importlib.import_module("torch.onnx.symbolic_opset{}".format(opset_version))
_symbolic_versions[opset_version] = module
Copy the code
In _symbolic_VERSIONS, key is opset_version and value is the corresponding symbolic set. Symbolic is a mapping function that maps the aten/ Prim Node to the Node of onNX. Read torch/onnx/symbolic_opset.py for more details.
2. Set the environment
Adjust environment information based on export Input parameters, such as the version of opset, whether init is exported as Input, whether constant folding is performed, and so on. Subsequent optimizations run specific passes based on these environments.
3.Graph Tracing
This step actually performs the Jit Tracing process introduced above, which can be reviewed if forgotten.
4.ToONNX
Graph passes many passes before it is actually used, and each pass makes some transformations to the Graph. See implementation details in Torch/CSRC/JIT/Passes. Many of the functions of these passes are similar to those found in common compilers, so I won’t cover them here. For Torchscript ->ONNX, the most important pass is ToONNX.
ToONNX’s Python interface is torch. _c._jit_pass_onnx, which is implemented as onnx.cpp. It traverses all nodes in the Graph, generates the corresponding ONNX Node, and inserts a new Graph:
auto k = old_node->kind(); If (k.iss_caffe2 ()) {if (k.iss_caffe2 ()) {if (k.iss_caffe2 ()) {if (k.iss_caffe2 ()) {cloneNode(old_node); } else if (k == prim::PythonOp) {// If Python is a custom function, For example, the Function // inherited from torch.autograd.Function looks for and calls the corresponding symbolic Function for the conversion callPySymbolicMethod(static_cast<ConcretePythonOp*>(old_node)); } else {// If it is something else (usually the aten operator) call the symbolic loaded in Step 1 to convert callPySymbolicFunction(old_node); }Copy the code
The cloneNode does what it’s called, simply copying old_node and inserting it into a new Graph.
callPySymbolicFunction
This function is called when Node type is PyTorch’s built-in type.
This function calls the Torch. Onnx.utils. _run_symbolic_function on the Python side to convert Node and insert a new Graph. We can try python code like this:
Graph = graph._c.graph () # create graph [graph.addInput() for _ in range(2)] # Insert two inputs node = graph.create('aten::add', List (graph.inputs())) # create node node = graph.insertnode (node) # insertNode graph.output (node.output()) # print(f'old Graph :\n {graph}') new_graph = torch._c.graph () # Create a new graph for ONNX [new_graph.addInput() for _ in range(2)] # insert two inputs _run_symbolic_function(new_graph, node, inputs=list(new_graph.inputs()), env={}) Torch >1.8 block print(f'new Graph :\n {new_graph}')Copy the code
Then look at the results of the visualization:
Old graph
graph(%0 : Tensor,
%1 : Tensor):
%2 : Tensor = aten::add(%0, %1)
return (%2)
Copy the code
New graph
graph(%0 : Tensor,
%1 : Tensor):
%2 : Tensor = onnx::Add(%0, %1)
return ()
Copy the code
As you can see, the aten:: Add node has been replaced with onnx::Add. So how does this mapping work? Remember the _symbolic_versions recorded in the first step? _run_symbolic_function calls the _find_symbolic_in_registry function in torch. Onnx. symbolic_registry to find if there is a mapping in _symbolic_versions that meets the criteria. If so, the transformation as shown in the figure above is performed.
Note: the new Graph for the transformation does not output Value, because this part is implemented in the c++ code of ToONNX, and _run_symbolic_function is only responsible for Node mapping.
callPySymbolicMethod
Some calculations that are not native to PyTorch will be marked PythonOp. There are three possible ways to handle nodes like this:
- If the PythonOp has a property named symbolic, then an attempt is made to use this symbolic as a mapping function to generate an ONNX node
- If there is no symbolic attribute, but the prim::PythonOp symbolic function was registered at Step 1, then this function will be used to generate the node.
- If none exists, clone the PythonOp node directly to the new Graph.
The symbolic function is simple to write, basically calling Python Bind’s Graph interface to create a new node, such as:
class CustomAdd(torch.autograd.Function): @staticmethod def forward(ctx, x, val): return x + val @staticmethod def symbolic(g, x, val): Node_name can have a number of attributes. These attribute names must have an _<type> suffix. For example, if val is a float, return g.op("custom_domain::add", x, val_f=val) must have _f suffix.Copy the code
When the above function is actually used, the custom_domain::add Node is generated. Of course, whether it can be used for reasoning depends on the support of the reasoning engine.
With the callPySymbolicFunction and callPySymbolicMethod, you can generate a new Graph consisting of ONNX (or nodes under your custom domain). After that, there are also some passes to optimize the ONNX Graph, which I won’t elaborate on here.
5. The serialization
The diagram is complete at this point, but to be used by other backends, the Grap needs to be serialized and exported. The serialization process is relatively simple, basically just calling ONNX’s Proto interface and mapping each element in the Graph to ONNX’s GraphProto. There’s not much to expand on, but read the EncodeGraph, EncodeBlock, EncodeNode functions in export-cpp for more details.
The serialized proTO is then written to the file based on the specific export_type.
At this point, ONNX Export is complete and you can begin to enjoy the speed increase brought by various inference engines.
From the above sharing, we should have an idea of how jit models are generated using trace and how trace models affect ONNX exports. In order to make the model work better for deployment, we can consider optimizing the model, and a common optimization paradigm will be introduced in the following shares. Stay tuned.
MMDeploy has added support for TorchScript models, which also use trace to build JIT models. Please visit MMDeploy GitHub to experience