This tutorial is an introduction to TorchScript, an intermediate representation of the PyTorch model (a subclass of nn.module) that can be run in high performance environments such as C ++.

In this tutorial, we will introduce:

  1. Model authoring basics in PyTorch include:
    • Module,
    • Define forward functionality
    • Modules are grouped into a hierarchy of modules
  2. A specific method for converting the PyTorch module to TorchScript, our high-performance deployment runtime
    • Trace existing modules
    • Compile modules directly using scripts
    • How do you combine these two approaches
    • Save and load the TorchScript module

We hope that after you finish this tutorial, you will continue to read the subsequent tutorial, which will lead you through an example of actually calling the TorchScript model from C ++.

Import Torch # This is all the imports needed to use Both PyTorch and TorchScript! print(torch.__version__)Copy the code

  • The output
1.3.0Copy the code

1.PyTorch model basics

Let’s start by defining a simple module. Modules are the basic units composed in PyTorch. It includes:

  • Constructor that prepares the module for the call
  • A set of parameters and submodules. These are initialized by the constructor and can be used by the module during the call.
  • Forward function. This is the code that runs when the module is called. Let’s take a quick example:
class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()

    def forward(self, x, h):
        new_h = torch.tanh(x + h)
        return new_h, new_h

my_cell = MyCell()
x = torch.rand(3, 4)
h = torch.rand(3, 4)
print(my_cell(x, h))Copy the code

  • The output
(tensor ([[0.5139, 0.6451, 0.3697, 0.7738], [0.7936, 0.5864, 0.8063, 0.9324], [0.6479, 0.8408, 0.8062, 0.7263]]). Tensor ([[0.5139, 0.6451, 0.3697, 0.7738], [0.7936, 0.5864, 0.8063, 0.9324], [0.6479, 0.8408, 0.8062, 0.7263]]))Copy the code

Therefore, we have:

  1. A subclass is createdtorch.nn.ModuleIn the class.
  2. Define a constructor. The constructor doesn’t do much other than call the constructor super.
  3. Defines a forward function that takes two inputs and returns two outputs. The actual content of the forward function is not very important, but it is a bogus RNN unit – that is, the function is applied to the loop.

We instantiated the module and made x and y, which are just random value matrices for 3×4. We then call the cell with my_cell (x, h). This again invokes our forwarding function.

Let’s do something more interesting:

class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        return new_h, new_h

my_cell = MyCell()
print(my_cell)
print(my_cell(x, h))Copy the code

  • The output
MyCell( (linear): Linear(IN_features =4, out_features=4, Bias =True)) (tensor([[0.3941, 0.4160, -0.1086, 0.8432], [0.5604, 0.4003, 0.5009, 0.6842], [0.7084, 0.7147, 0.1818, 0.8296]], grad_fn=<TanhBackward>), tensor([[0.3941, 0.4160, -0.1086, 0.8432], [0.5604, 0.4003, 0.5009, 0.6842], [0.7084, 0.7147, 0.1818, 0.8296], grad_FN =<TanhBackward>)Copy the code

We have redefined the module MyCell, but this time we have added the self.Linear property and called self.Linear in the forward function.

What the hell is going on here? Torch. Nn.Linear is a module in the PyTorch standard library. Just like MyCell, it can be called using the call syntax. We are building a hierarchy of modules.

Printing on a module visually represents the subclass hierarchy of that module. In our example, we can see our linear subclass and its parameters.

By combining modules in this way, we can write models with reusable components succinct and readable.

You may have noticed grad_fn in the output. This is the details of PyTorch’s automatic discrimination method, called Autograd. In short, the system allows us to compute derivatives through potentially complex procedures. This design provides great flexibility for model creation.

Now, let’s examine its flexibility:

class MyDecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x

class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.dg = MyDecisionGate()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h

my_cell = MyCell()
print(my_cell)
print(my_cell(x, h))Copy the code

  • The output
MyCell( (dg): MyDecisionGate() (linear): Linear(IN_features =4, out_features=4, Bias =True)) (tensor([[0.0850, 0.2812, 0.5188, 0.8523], [0.1233, 0.3948, 0.6615, [0.7072, 0.6103, 0.6953, 0.7047]], grad_fn=<TanhBackward>), tensor([[0.0850, 0.2812, 0.5188, 0.8523], 0.3948, 0.6615, 0.7466], [0.7072, 0.6103, 0.6953, 0.7047], grad_fn=<TanhBackward>))Copy the code

We’ve redefined the MyCell class again, but here we’ve defined MyDecisionGate. This module utilizes the control flow. Control flow includes things like loops and if statements.

Given a complete program representation, many frameworks take the computed symbol derivation approach. However, in PyTorch, we use gradient bands. We record the operation as it happened and play it back when we calculate the derivative. In this way, the framework does not have to explicitly define derived classes for all constructs in the language.

2. TorchScript foundation

Now, let’s use a running example to see how TorchScript can be applied.

In short, even with PyTorch’s flexible and dynamic nature, TorchScript provides tools to capture model definitions. Let’s start with something called stalking.

2.1 Tracing module

class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        return new_h, new_h

my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell)
traced_cell(x, h)Copy the code

  • The output
TracedModule[MyCell](
  original_name=MyCell
  (linear): TracedModule[Linear](original_name=Linear)
)Copy the code

We backtracked a bit and chose the second version of the MyCell class. As before, we instantiate it, but this time we call torch.jit. Trace, passing the example in the Module and passing in the input that the network might see.

What exactly does this do? It has called the module, logged what happened while the module was running, and created an instance of Torch.jit.ScriptModule (TracedModule is an example)

TorchScript records its definition in an intermediate representation (or IR), often called a graph in deep learning. We can check the graph with the.graph attribute:

print(traced_cell.graph)Copy the code

  • The output
graph(%self : ClassType<MyCell>, %input : Float(3, 4), %h : Float(3, 4)): %1 : ClassType<Linear> = prim::GetAttr[name="linear"](%self) %weight : Tensor = prim::GetAttr[name="weight"](%1) %bias : Tensor = prim::GetAttr[name="bias"](%1) %6 : Float(4, 4) = aten::t(%weight), scope: MyCell/Linear (Linear) # / opt/conda/lib/python3.6 / site - packages/torch/nn/functional. Py: 1370:0% 7: int = prim::Constant[value=1](), scope: MyCell/Linear (Linear) # / opt/conda/lib/python3.6 / site - packages/torch/nn/functional. Py: 1370:8:0% int = prim::Constant[value=1](), scope: MyCell/Linear (Linear) # / opt/conda/lib/python3.6 / site - packages/torch/nn/functional. Py: 1370:0% 9: Float(3, 4) = aten::addmm(%bias, %input, %6, %7, %8), scope: MyCell/Linear (Linear) # / opt/conda/lib/python3.6 / site - packages/torch/nn/functional. Py: 1370:0% 10: int = prim::Constant[value=1](), scope: MyCell # /var/lib/jenkins/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:188:0 %11 : Float(3, 4) = aten::add(%9, %h, %10), scope: MyCell # /var/lib/jenkins/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:188:0 %12 : Float(3, 4) = aten::tanh(%11), scope: MyCell # /var/lib/jenkins/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:188:0 %13 : (Float(3, 4), Float(3, 4)) = prim::TupleConstruct(%12, %12) return (%13)Copy the code

However, this is a very low-level representation, and most of the information contained in the diagram is not useful to the end user. Instead, we can use the.code attribute to give a Python syntax explanation of the code:

print(traced_cell.code)Copy the code

  • The output
import __torch__
import __torch__.torch.nn.modules.linear
def forward(self,
    input: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  _0 = self.linear
  weight = _0.weight
  bias = _0.bias
  _1 = torch.addmm(bias, input, torch.t(weight), beta=1, alpha=1)
  _2 = torch.tanh(torch.add(_1, h, alpha=1))
  return (_2, _2)Copy the code

So why are we doing all this? There are several reasons for this:

  1. TorchScript code can be called in its own interpreter, which is basically a restricted Python interpreter. The interpreter is not locked by the global interpreter, so many requests can be processed simultaneously on the same instance.
  2. This format allows us to save the entire model to disk and load it into another environment, such as a server written in a language other than Python
  3. TorchScript gives us a representation in which we can compiler optimise our code to provide more efficient execution
  4. TorchScript allows us to interface with many back-end/device runtimes that require a broader view of the program than a single operator.

As we can see, calling traced_cell yields the same result as the Python module:

print(my_cell(x, h))
print(traced_cell(x, h))Copy the code

  • The output
(tensor ([[0.3983, 0.5954, 0.2587, 0.3748], [0.5033, 0.4471, 0.8264, 0.2135], [0.3430, 0.5561, 0.6794, 0.2273]]. Grad_fn =<TanhBackward>), tensor([[-0.3983, 0.5954, 0.2587, -0.3748], [-0.5033, 0.4471, 0.8264, 0.2135], [0.3430, 0.5561, 0.6794, 0.2273]], grad_fn = < TanhBackward >)) (tensor ([[0.3983, 0.5954, 0.2587, 0.3748], [0.5033, 0.4471, 0.8264, 0.2135], [0.3430, 0.5561, 0.6794, 0.2273]], grad_fn = < DifferentiableGraphBackward >), tensor ([[0.3983, 0.5954, 0.2587, 0.3748], [0.5033, 0.4471, 0.8264, 0.2135], [0.3430, 0.5561, 0.6794, 0.2273]]. grad_fn=<DifferentiableGraphBackward>))Copy the code

3. Use the script conversion module

One reason is that we used version 2 of the module, rather than a submodule with a lot of control flow. Now let’s check:

class MyDecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x

class MyCell(torch.nn.Module):
    def __init__(self, dg):
        super(MyCell, self).__init__()
        self.dg = dg
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h

my_cell = MyCell(MyDecisionGate())
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell.code)Copy the code

  • The output
import __torch__.___torch_mangle_0
import __torch__
import __torch__.torch.nn.modules.linear.___torch_mangle_1
def forward(self,
    input: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  _0 = self.linear
  weight = _0.weight
  bias = _0.bias
  x = torch.addmm(bias, input, torch.t(weight), beta=1, alpha=1)
  _1 = torch.tanh(torch.add(x, h, alpha=1))
  return (_1, _1)Copy the code

Looking at the.code output, we can see where we can’t find the if-else branch! Why is that? Tracing does exactly what we say: run the code, record what happens, and build a ScriptModule that does it. Unfortunately, things like control flow were erased.

How do we faithfully represent this module in TorchScript? We provide a script compiler that directly parses your Python source code to convert it to TorchScript. Let’s use the script compiler to convert MyDecisionGate:

scripted_gate = torch.jit.script(MyDecisionGate())

my_cell = MyCell(scripted_gate)
traced_cell = torch.jit.script(my_cell)
print(traced_cell.code)Copy the code

  • The output
import __torch__.___torch_mangle_3
import __torch__.___torch_mangle_2
import __torch__.torch.nn.modules.linear.___torch_mangle_4
def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  _0 = self.linear
  _1 = _0.weight
  _2 = _0.bias
  if torch.eq(torch.dim(x), 2):
    _3 = torch.__isnot__(_2, None)
  else:
    _3 = False
  if _3:
    bias = ops.prim.unchecked_unwrap_optional(_2)
    ret = torch.addmm(bias, x, torch.t(_1), beta=1, alpha=1)
  else:
    output = torch.matmul(x, torch.t(_1))
    if torch.__isnot__(_2, None):
      bias0 = ops.prim.unchecked_unwrap_optional(_2)
      output0 = torch.add_(output, bias0, alpha=1)
    else:
      output0 = output
    ret = output0
  _4 = torch.gt(torch.sum(ret, dtype=None), 0)
  if bool(_4):
    _5 = ret
  else:
    _5 = torch.neg(ret)
  new_h = torch.tanh(torch.add(_5, h, alpha=1))
  return (new_h, new_h)Copy the code

We have now faithfully captured the behavior of our program in TorchScript. Now, let’s try running the program:

# New inputs
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell(x, h)Copy the code

3.1 Mixed Scripting and Tracing

In some cases, tracing is needed instead of scripting (for example, modules have many architectural decisions based on constant Python values that we hope will not appear in TorchScript). In this case, scripts can be written by tracing: torch.jit. Script will inline the code of the traced module, while tracing will inline the code of the script module.

  • An example of the first case:
class MyRNNLoop(torch.nn.Module):
    def __init__(self):
        super(MyRNNLoop, self).__init__()
        self.cell = torch.jit.trace(MyCell(scripted_gate), (x, h))

    def forward(self, xs):
        h, y = torch.zeros(3, 4), torch.zeros(3, 4)
        for i in range(xs.size(0)):
            y, h = self.cell(xs[i], h)
        return y, h

rnn_loop = torch.jit.script(MyRNNLoop())
print(rnn_loop.code)Copy the code

  • The output
import __torch__
import __torch__.___torch_mangle_5
import __torch__.___torch_mangle_2
import __torch__.torch.nn.modules.linear.___torch_mangle_6
def forward(self,
    xs: Tensor) -> Tuple[Tensor, Tensor]:
  h = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
  y = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
  y0 = y
  h0 = h
  for i in range(torch.size(xs, 0)):
    _0 = self.cell
    _1 = torch.select(xs, 0, i)
    _2 = _0.linear
    weight = _2.weight
    bias = _2.bias
    _3 = torch.addmm(bias, _1, torch.t(weight), beta=1, alpha=1)
    _4 = torch.gt(torch.sum(_3, dtype=None), 0)
    if bool(_4):
      _5 = _3
    else:
      _5 = torch.neg(_3)
    _6 = torch.tanh(torch.add(_5, h0, alpha=1))
    y0, h0 = _6, _6
  return (y0, h0)Copy the code

  • An example of the second case:
class WrapRNN(torch.nn.Module):
    def __init__(self):
        super(WrapRNN, self).__init__()
        self.loop = torch.jit.script(MyRNNLoop())

    def forward(self, xs):
        y, h = self.loop(xs)
        return torch.relu(y)

traced = torch.jit.trace(WrapRNN(), (torch.rand(10, 3, 4)))
print(traced.code)Copy the code

  • The output
import __torch__
import __torch__.___torch_mangle_9
import __torch__.___torch_mangle_7
import __torch__.___torch_mangle_2
import __torch__.torch.nn.modules.linear.___torch_mangle_8
def forward(self,
    argument_1: Tensor) -> Tensor:
  _0 = self.loop
  h = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
  h0 = h
  for i in range(torch.size(argument_1, 0)):
    _1 = _0.cell
    _2 = torch.select(argument_1, 0, i)
    _3 = _1.linear
    weight = _3.weight
    bias = _3.bias
    _4 = torch.addmm(bias, _2, torch.t(weight), beta=1, alpha=1)
    _5 = torch.gt(torch.sum(_4, dtype=None), 0)
    if bool(_5):
      _6 = _4
    else:
      _6 = torch.neg(_4)
    h0 = torch.tanh(torch.add(_6, h0, alpha=1))
  return torch.relu(h0)Copy the code

This way, you can use scripts and traces and use them together when the situation calls for them.

4. Save and load the model

We provide apis to save TorchScript modules to or load TorchScript modules from disk in archive format. This format includes code, parameters, properties, and debugging information, meaning that the archive is a separate representation of the model that can be loaded in a completely separate process. Let’s save and load the wrapped RNN module:

traced.save('wrapped_rnn.zip')

loaded = torch.jit.load('wrapped_rnn.zip')

print(loaded)
print(loaded.code)Copy the code

  • The output
ScriptModule(
  original_name=WrapRNN
  (loop): ScriptModule(
    original_name=MyRNNLoop
    (cell): ScriptModule(
      original_name=MyCell
      (dg): ScriptModule(original_name=MyDecisionGate)
      (linear): ScriptModule(original_name=Linear)
    )
  )
)
import __torch__
import __torch__.___torch_mangle_9
import __torch__.___torch_mangle_7
import __torch__.___torch_mangle_2
import __torch__.torch.nn.modules.linear.___torch_mangle_8
def forward(self,
    argument_1: Tensor) -> Tensor:
  _0 = self.loop
  h = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
  h0 = h
  for i in range(torch.size(argument_1, 0)):
    _1 = _0.cell
    _2 = torch.select(argument_1, 0, i)
    _3 = _1.linear
    weight = _3.weight
    bias = _3.bias
    _4 = torch.addmm(bias, _2, torch.t(weight), beta=1, alpha=1)
    _5 = torch.gt(torch.sum(_4, dtype=None), 0)
    if bool(_5):
      _6 = _4
    else:
      _6 = torch.neg(_4)
    h0 = torch.tanh(torch.add(_6, h0, alpha=1))
  return torch.relu(h0)Copy the code

As you can see, serialization preserves the module hierarchy and the code we’ve been working on. For example, it is also possible to load the model into C ++ for python-independent execution.

Further reading

We’ve finished the tutorial! For more involved in the presentation, please check NeurIPS demo, to use TorchScript conversion machine translation models: https://colab.research.google.com/drive/1HiICg6jRkBnr5hvK2-VnMi88Vi9pUzEJ

Total script run time :(0 min 0.247 SEC)

Rock and the AI technology blog resources summary station: http://docs.panchuang.net/PyTorch, the official Chinese tutorial station: Chinese official document: http://pytorch.panchuang.net/OpenCV http://woshicver.com/