This tutorial is an introduction to TorchScript, an intermediate representation of the PyTorch model (a subclass of nn.module) that can be run in high performance environments such as C ++.
In this tutorial, we will introduce:
- Model authoring basics in PyTorch include:
- Module,
- Define forward functionality
- Modules are grouped into a hierarchy of modules
- A specific method for converting the PyTorch module to TorchScript, our high-performance deployment runtime
- Trace existing modules
- Compile modules directly using scripts
- How do you combine these two approaches
- Save and load the TorchScript module
We hope that after you finish this tutorial, you will continue to read the subsequent tutorial, which will lead you through an example of actually calling the TorchScript model from C ++.
Import Torch # This is all the imports needed to use Both PyTorch and TorchScript! print(torch.__version__)Copy the code
- The output
1.3.0Copy the code
1.PyTorch model basics
Let’s start by defining a simple module. Modules are the basic units composed in PyTorch. It includes:
- Constructor that prepares the module for the call
- A set of parameters and submodules. These are initialized by the constructor and can be used by the module during the call.
- Forward function. This is the code that runs when the module is called. Let’s take a quick example:
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
def forward(self, x, h):
new_h = torch.tanh(x + h)
return new_h, new_h
my_cell = MyCell()
x = torch.rand(3, 4)
h = torch.rand(3, 4)
print(my_cell(x, h))Copy the code
- The output
(tensor ([[0.5139, 0.6451, 0.3697, 0.7738], [0.7936, 0.5864, 0.8063, 0.9324], [0.6479, 0.8408, 0.8062, 0.7263]]). Tensor ([[0.5139, 0.6451, 0.3697, 0.7738], [0.7936, 0.5864, 0.8063, 0.9324], [0.6479, 0.8408, 0.8062, 0.7263]]))Copy the code
Therefore, we have:
- A subclass is created
torch.nn.Module
In the class. - Define a constructor. The constructor doesn’t do much other than call the constructor super.
- Defines a forward function that takes two inputs and returns two outputs. The actual content of the forward function is not very important, but it is a bogus RNN unit – that is, the function is applied to the loop.
We instantiated the module and made x and y, which are just random value matrices for 3×4. We then call the cell with my_cell (x, h). This again invokes our forwarding function.
Let’s do something more interesting:
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
new_h = torch.tanh(self.linear(x) + h)
return new_h, new_h
my_cell = MyCell()
print(my_cell)
print(my_cell(x, h))Copy the code
- The output
MyCell( (linear): Linear(IN_features =4, out_features=4, Bias =True)) (tensor([[0.3941, 0.4160, -0.1086, 0.8432], [0.5604, 0.4003, 0.5009, 0.6842], [0.7084, 0.7147, 0.1818, 0.8296]], grad_fn=<TanhBackward>), tensor([[0.3941, 0.4160, -0.1086, 0.8432], [0.5604, 0.4003, 0.5009, 0.6842], [0.7084, 0.7147, 0.1818, 0.8296], grad_FN =<TanhBackward>)Copy the code
We have redefined the module MyCell, but this time we have added the self.Linear property and called self.Linear in the forward function.
What the hell is going on here? Torch. Nn.Linear is a module in the PyTorch standard library. Just like MyCell, it can be called using the call syntax. We are building a hierarchy of modules.
Printing on a module visually represents the subclass hierarchy of that module. In our example, we can see our linear subclass and its parameters.
By combining modules in this way, we can write models with reusable components succinct and readable.
You may have noticed grad_fn in the output. This is the details of PyTorch’s automatic discrimination method, called Autograd. In short, the system allows us to compute derivatives through potentially complex procedures. This design provides great flexibility for model creation.
Now, let’s examine its flexibility:
class MyDecisionGate(torch.nn.Module):
def forward(self, x):
if x.sum() > 0:
return x
else:
return -x
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.dg = MyDecisionGate()
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
new_h = torch.tanh(self.dg(self.linear(x)) + h)
return new_h, new_h
my_cell = MyCell()
print(my_cell)
print(my_cell(x, h))Copy the code
- The output
MyCell( (dg): MyDecisionGate() (linear): Linear(IN_features =4, out_features=4, Bias =True)) (tensor([[0.0850, 0.2812, 0.5188, 0.8523], [0.1233, 0.3948, 0.6615, [0.7072, 0.6103, 0.6953, 0.7047]], grad_fn=<TanhBackward>), tensor([[0.0850, 0.2812, 0.5188, 0.8523], 0.3948, 0.6615, 0.7466], [0.7072, 0.6103, 0.6953, 0.7047], grad_fn=<TanhBackward>))Copy the code
We’ve redefined the MyCell class again, but here we’ve defined MyDecisionGate. This module utilizes the control flow. Control flow includes things like loops and if statements.
Given a complete program representation, many frameworks take the computed symbol derivation approach. However, in PyTorch, we use gradient bands. We record the operation as it happened and play it back when we calculate the derivative. In this way, the framework does not have to explicitly define derived classes for all constructs in the language.
2. TorchScript foundation
Now, let’s use a running example to see how TorchScript can be applied.
In short, even with PyTorch’s flexible and dynamic nature, TorchScript provides tools to capture model definitions. Let’s start with something called stalking.
2.1 Tracing module
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
new_h = torch.tanh(self.linear(x) + h)
return new_h, new_h
my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell)
traced_cell(x, h)Copy the code
- The output
TracedModule[MyCell](
original_name=MyCell
(linear): TracedModule[Linear](original_name=Linear)
)Copy the code
We backtracked a bit and chose the second version of the MyCell class. As before, we instantiate it, but this time we call torch.jit. Trace, passing the example in the Module and passing in the input that the network might see.
What exactly does this do? It has called the module, logged what happened while the module was running, and created an instance of Torch.jit.ScriptModule (TracedModule is an example)
TorchScript records its definition in an intermediate representation (or IR), often called a graph in deep learning. We can check the graph with the.graph attribute:
print(traced_cell.graph)Copy the code
- The output
graph(%self : ClassType<MyCell>, %input : Float(3, 4), %h : Float(3, 4)): %1 : ClassType<Linear> = prim::GetAttr[name="linear"](%self) %weight : Tensor = prim::GetAttr[name="weight"](%1) %bias : Tensor = prim::GetAttr[name="bias"](%1) %6 : Float(4, 4) = aten::t(%weight), scope: MyCell/Linear (Linear) # / opt/conda/lib/python3.6 / site - packages/torch/nn/functional. Py: 1370:0% 7: int = prim::Constant[value=1](), scope: MyCell/Linear (Linear) # / opt/conda/lib/python3.6 / site - packages/torch/nn/functional. Py: 1370:8:0% int = prim::Constant[value=1](), scope: MyCell/Linear (Linear) # / opt/conda/lib/python3.6 / site - packages/torch/nn/functional. Py: 1370:0% 9: Float(3, 4) = aten::addmm(%bias, %input, %6, %7, %8), scope: MyCell/Linear (Linear) # / opt/conda/lib/python3.6 / site - packages/torch/nn/functional. Py: 1370:0% 10: int = prim::Constant[value=1](), scope: MyCell # /var/lib/jenkins/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:188:0 %11 : Float(3, 4) = aten::add(%9, %h, %10), scope: MyCell # /var/lib/jenkins/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:188:0 %12 : Float(3, 4) = aten::tanh(%11), scope: MyCell # /var/lib/jenkins/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:188:0 %13 : (Float(3, 4), Float(3, 4)) = prim::TupleConstruct(%12, %12) return (%13)Copy the code
However, this is a very low-level representation, and most of the information contained in the diagram is not useful to the end user. Instead, we can use the.code attribute to give a Python syntax explanation of the code:
print(traced_cell.code)Copy the code
- The output
import __torch__
import __torch__.torch.nn.modules.linear
def forward(self,
input: Tensor,
h: Tensor) -> Tuple[Tensor, Tensor]:
_0 = self.linear
weight = _0.weight
bias = _0.bias
_1 = torch.addmm(bias, input, torch.t(weight), beta=1, alpha=1)
_2 = torch.tanh(torch.add(_1, h, alpha=1))
return (_2, _2)Copy the code
So why are we doing all this? There are several reasons for this:
- TorchScript code can be called in its own interpreter, which is basically a restricted Python interpreter. The interpreter is not locked by the global interpreter, so many requests can be processed simultaneously on the same instance.
- This format allows us to save the entire model to disk and load it into another environment, such as a server written in a language other than Python
- TorchScript gives us a representation in which we can compiler optimise our code to provide more efficient execution
- TorchScript allows us to interface with many back-end/device runtimes that require a broader view of the program than a single operator.
As we can see, calling traced_cell yields the same result as the Python module:
print(my_cell(x, h))
print(traced_cell(x, h))Copy the code
- The output
(tensor ([[0.3983, 0.5954, 0.2587, 0.3748], [0.5033, 0.4471, 0.8264, 0.2135], [0.3430, 0.5561, 0.6794, 0.2273]]. Grad_fn =<TanhBackward>), tensor([[-0.3983, 0.5954, 0.2587, -0.3748], [-0.5033, 0.4471, 0.8264, 0.2135], [0.3430, 0.5561, 0.6794, 0.2273]], grad_fn = < TanhBackward >)) (tensor ([[0.3983, 0.5954, 0.2587, 0.3748], [0.5033, 0.4471, 0.8264, 0.2135], [0.3430, 0.5561, 0.6794, 0.2273]], grad_fn = < DifferentiableGraphBackward >), tensor ([[0.3983, 0.5954, 0.2587, 0.3748], [0.5033, 0.4471, 0.8264, 0.2135], [0.3430, 0.5561, 0.6794, 0.2273]]. grad_fn=<DifferentiableGraphBackward>))Copy the code
3. Use the script conversion module
One reason is that we used version 2 of the module, rather than a submodule with a lot of control flow. Now let’s check:
class MyDecisionGate(torch.nn.Module):
def forward(self, x):
if x.sum() > 0:
return x
else:
return -x
class MyCell(torch.nn.Module):
def __init__(self, dg):
super(MyCell, self).__init__()
self.dg = dg
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
new_h = torch.tanh(self.dg(self.linear(x)) + h)
return new_h, new_h
my_cell = MyCell(MyDecisionGate())
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell.code)Copy the code
- The output
import __torch__.___torch_mangle_0
import __torch__
import __torch__.torch.nn.modules.linear.___torch_mangle_1
def forward(self,
input: Tensor,
h: Tensor) -> Tuple[Tensor, Tensor]:
_0 = self.linear
weight = _0.weight
bias = _0.bias
x = torch.addmm(bias, input, torch.t(weight), beta=1, alpha=1)
_1 = torch.tanh(torch.add(x, h, alpha=1))
return (_1, _1)Copy the code
Looking at the.code output, we can see where we can’t find the if-else branch! Why is that? Tracing does exactly what we say: run the code, record what happens, and build a ScriptModule that does it. Unfortunately, things like control flow were erased.
How do we faithfully represent this module in TorchScript? We provide a script compiler that directly parses your Python source code to convert it to TorchScript. Let’s use the script compiler to convert MyDecisionGate:
scripted_gate = torch.jit.script(MyDecisionGate())
my_cell = MyCell(scripted_gate)
traced_cell = torch.jit.script(my_cell)
print(traced_cell.code)Copy the code
- The output
import __torch__.___torch_mangle_3
import __torch__.___torch_mangle_2
import __torch__.torch.nn.modules.linear.___torch_mangle_4
def forward(self,
x: Tensor,
h: Tensor) -> Tuple[Tensor, Tensor]:
_0 = self.linear
_1 = _0.weight
_2 = _0.bias
if torch.eq(torch.dim(x), 2):
_3 = torch.__isnot__(_2, None)
else:
_3 = False
if _3:
bias = ops.prim.unchecked_unwrap_optional(_2)
ret = torch.addmm(bias, x, torch.t(_1), beta=1, alpha=1)
else:
output = torch.matmul(x, torch.t(_1))
if torch.__isnot__(_2, None):
bias0 = ops.prim.unchecked_unwrap_optional(_2)
output0 = torch.add_(output, bias0, alpha=1)
else:
output0 = output
ret = output0
_4 = torch.gt(torch.sum(ret, dtype=None), 0)
if bool(_4):
_5 = ret
else:
_5 = torch.neg(ret)
new_h = torch.tanh(torch.add(_5, h, alpha=1))
return (new_h, new_h)Copy the code
We have now faithfully captured the behavior of our program in TorchScript. Now, let’s try running the program:
# New inputs
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell(x, h)Copy the code
3.1 Mixed Scripting and Tracing
In some cases, tracing is needed instead of scripting (for example, modules have many architectural decisions based on constant Python values that we hope will not appear in TorchScript). In this case, scripts can be written by tracing: torch.jit. Script will inline the code of the traced module, while tracing will inline the code of the script module.
- An example of the first case:
class MyRNNLoop(torch.nn.Module):
def __init__(self):
super(MyRNNLoop, self).__init__()
self.cell = torch.jit.trace(MyCell(scripted_gate), (x, h))
def forward(self, xs):
h, y = torch.zeros(3, 4), torch.zeros(3, 4)
for i in range(xs.size(0)):
y, h = self.cell(xs[i], h)
return y, h
rnn_loop = torch.jit.script(MyRNNLoop())
print(rnn_loop.code)Copy the code
- The output
import __torch__
import __torch__.___torch_mangle_5
import __torch__.___torch_mangle_2
import __torch__.torch.nn.modules.linear.___torch_mangle_6
def forward(self,
xs: Tensor) -> Tuple[Tensor, Tensor]:
h = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
y = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
y0 = y
h0 = h
for i in range(torch.size(xs, 0)):
_0 = self.cell
_1 = torch.select(xs, 0, i)
_2 = _0.linear
weight = _2.weight
bias = _2.bias
_3 = torch.addmm(bias, _1, torch.t(weight), beta=1, alpha=1)
_4 = torch.gt(torch.sum(_3, dtype=None), 0)
if bool(_4):
_5 = _3
else:
_5 = torch.neg(_3)
_6 = torch.tanh(torch.add(_5, h0, alpha=1))
y0, h0 = _6, _6
return (y0, h0)Copy the code
- An example of the second case:
class WrapRNN(torch.nn.Module):
def __init__(self):
super(WrapRNN, self).__init__()
self.loop = torch.jit.script(MyRNNLoop())
def forward(self, xs):
y, h = self.loop(xs)
return torch.relu(y)
traced = torch.jit.trace(WrapRNN(), (torch.rand(10, 3, 4)))
print(traced.code)Copy the code
- The output
import __torch__
import __torch__.___torch_mangle_9
import __torch__.___torch_mangle_7
import __torch__.___torch_mangle_2
import __torch__.torch.nn.modules.linear.___torch_mangle_8
def forward(self,
argument_1: Tensor) -> Tensor:
_0 = self.loop
h = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
h0 = h
for i in range(torch.size(argument_1, 0)):
_1 = _0.cell
_2 = torch.select(argument_1, 0, i)
_3 = _1.linear
weight = _3.weight
bias = _3.bias
_4 = torch.addmm(bias, _2, torch.t(weight), beta=1, alpha=1)
_5 = torch.gt(torch.sum(_4, dtype=None), 0)
if bool(_5):
_6 = _4
else:
_6 = torch.neg(_4)
h0 = torch.tanh(torch.add(_6, h0, alpha=1))
return torch.relu(h0)Copy the code
This way, you can use scripts and traces and use them together when the situation calls for them.
4. Save and load the model
We provide apis to save TorchScript modules to or load TorchScript modules from disk in archive format. This format includes code, parameters, properties, and debugging information, meaning that the archive is a separate representation of the model that can be loaded in a completely separate process. Let’s save and load the wrapped RNN module:
traced.save('wrapped_rnn.zip')
loaded = torch.jit.load('wrapped_rnn.zip')
print(loaded)
print(loaded.code)Copy the code
- The output
ScriptModule(
original_name=WrapRNN
(loop): ScriptModule(
original_name=MyRNNLoop
(cell): ScriptModule(
original_name=MyCell
(dg): ScriptModule(original_name=MyDecisionGate)
(linear): ScriptModule(original_name=Linear)
)
)
)
import __torch__
import __torch__.___torch_mangle_9
import __torch__.___torch_mangle_7
import __torch__.___torch_mangle_2
import __torch__.torch.nn.modules.linear.___torch_mangle_8
def forward(self,
argument_1: Tensor) -> Tensor:
_0 = self.loop
h = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
h0 = h
for i in range(torch.size(argument_1, 0)):
_1 = _0.cell
_2 = torch.select(argument_1, 0, i)
_3 = _1.linear
weight = _3.weight
bias = _3.bias
_4 = torch.addmm(bias, _2, torch.t(weight), beta=1, alpha=1)
_5 = torch.gt(torch.sum(_4, dtype=None), 0)
if bool(_5):
_6 = _4
else:
_6 = torch.neg(_4)
h0 = torch.tanh(torch.add(_6, h0, alpha=1))
return torch.relu(h0)Copy the code
As you can see, serialization preserves the module hierarchy and the code we’ve been working on. For example, it is also possible to load the model into C ++ for python-independent execution.
Further reading
We’ve finished the tutorial! For more involved in the presentation, please check NeurIPS demo, to use TorchScript conversion machine translation models: https://colab.research.google.com/drive/1HiICg6jRkBnr5hvK2-VnMi88Vi9pUzEJ
Total script run time :(0 min 0.247 SEC)
Rock and the AI technology blog resources summary station: http://docs.panchuang.net/PyTorch, the official Chinese tutorial station: Chinese official document: http://pytorch.panchuang.net/OpenCV http://woshicver.com/