This paper discusses the implementation of Pytorch for LSTM network, and also discusses the code organization and architecture design of Pytorch library.

LSTM

LSTM is a cyclic neural network suitable for modeling serialized inputs. This article by Chris Olah explains how an LSTM unit works in detail and is recommended reading.

Two ideas

Gate: The Gate through which information flows












Cell: memory pool




Comparison with ordinary RNN



A normal RNN has only one self-updating hidden state unit.



LSTM adds a memory pool Cell, and updates the information in the memory pool in a controlled way through several gates, and determines the hiding state through the information in the memory pool.

From Scratch

Here is the code to implement LSTM manually, inheriting the base class nn.module.

import torch.nn as nn
import torch
from torch.autograd import Variable

class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, cell_size, output_size):
        super(LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.cell_size = cell_size
        self.gate = nn.Linear(input_size + hidden_size, cell_size)
        self.output = nn.Linear(hidden_size, output_size)
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()
        self.softmax = nn.LogSoftmax()

    def forward(self, input, hidden, cell):
        combined = torch.cat((input, hidden), 1)
        f_gate = self.gate(combined)
        i_gate = self.gate(combined)
        o_gate = self.gate(combined)
        f_gate = self.sigmoid(f_gate)
        i_gate = self.sigmoid(i_gate)
        o_gate = self.sigmoid(o_gate)
        cell_helper = self.gate(combined)
        cell_helper = self.tanh(cell_helper)
        cell = torch.add(torch.mul(cell, f_gate), torch.mul(cell_helper, i_gate))
        hidden = torch.mul(self.tanh(cell), o_gate)
        output = self.output(hidden)
        output = self.softmax(output)
        return output, hidden, cell

    def initHidden(self):
        return Variable(torch.zeros(1, self.hidden_size))

    def initCell(self):
        return Variable(torch.zeros(1, self.cell_size))Copy the code

A few key points:

  1. The size of the Tensor
  2. The order in which information is transmitted

Pytorch Module

The Pytorch library itself encapsulates more functionality for LSTM implementations, and the organization of classes and functions is instructive. My understanding of the implementation is based on the following two points:

  1. The layers of cell, layer and stack are decoupled, and each layer abstracts part of the parameters (structure).
  2. Function handle passing: Returns the function handle after processing the argumentforward

Let’s start with a steed, see GitHub for the source.

LSTM class

File: nn/modules/RNN. Py

# nn/modules/rnn.py
class RNNBase(Module):
  def __init__(self, mode, input_size, output_size):
      pass
  def forward(self, input, hx=None):
      if hx is None:
          hx = torch.autograd.Variable()
      if self.mode == 'LSTM':
          hx = (hx, hx)
      func = self._backend.RNN() #!!!
      output, hidden = func(input, self.all_weights, hx) #!!!
      return output, hidden

class LSTM(RNNBase):
    def __init__(self, *args, **kwargs):
        super(LSTM, self).__init__('LSTM', *args, **kwargs)Copy the code
  1. LSTMClass is justRNNBaseClass as a decorator.
  2. In the base classnn.Modulethe__call__()Defined as callingforward()Method, so the real functionality is realized in_backend.RNN()In the

AutogradRNN function

Next look for _backend.rnn. File: nn/backends/THNN. Py

# nn/backends/thnn.py def _initialize_backend(): from .. _functions.rnn import RNN, LSTMCellCopy the code

Originally, _backend is also an index.

Finally find the RNN() function. File: nn / _functions RNN. Py

# nn/_functions/rnn.py
def RNN(*args, **kwargs):
    def forward(input, *fargs, **fkwargs):
        func = AutogradRNN(*args, **kwargs)
        return func(input, *fargs, **fkwargs)
    return forward

def AutogradRNN(mode, input_size, hidden_size):
    cell = LSTMCell
    rec_factory = Recurrent
    layer = (rec_factory(cell),)
    func = StackedRNN(layer, num_layers)
    def forward(input, weight, hidden):
        nexth, output = func(input, hidden, weight)
        return output, nexth
    return forwardCopy the code
  1. RNN()It’s a decorator, depending on whether there is onecudnnLibrary decision callAutogradRNN()orCudnnRNN(), here is just observationAutogradRNN()
  2. AutogradRNN()Choose theLSTMCellwithRecurrent()The function handlesCellConstitute aLayerAnd thenLayerThe incomingStackedRNN()function
  3. RNN()andAutogradRNN()That’s what’s returnedforward()Handle to the function

Here’s a Recurrent() function:

def Recurrent(inner):
    def forward(input, hidden, weight):
        output = []
        steps = range(input.size(0) - 1, -1, -1)
        for i in steps:
            hidden = inner(input[i], hidden, *weight)
            output.append(hidden[0])
        return hidden, output
    return forwardCopy the code
  1. Recurrent()The function implements a “recursive” structure based on the size of the inputCell, completes the iteration of hiding states and parameters.
  2. Recurrent()Function willCell(inner)Combination ofLayer.

StackedRNN () function

def StackedRNN(inners, num_layers):
    num_directions = len(inners)
    total_layers = num_layers * num_directions
    def forward(input, hidden, weight):
        next_hidden = []
        hidden = list(zip(*hidden))
        for i in range(num_layers):
          all_output = []
          for j, inner in enumerate(inners):
              hy, output = inner(input, hidden[l], weight[l])
              next_hidden.append(hy)
              all_output.append(output)
          input = torch.cat(all_output, input.dim() - 1)
        next_h, next_c = zip(*next_hidden)
        next_hidden = (torch.cat(next_h, 0).view(total_layers, *next_h[0].size()),
                  torch.cat(next_c, 0).view(total_layers, *next_c[0].size()))
        return next_hidden, input
    return forwardCopy the code
  1. StackedRNN()Function willLayer(inner)Combination of the stack

Finally, a basic LSTM cell is computed by the LSTMCell() function.

LSTMCell () function

def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
    if input.is_cuda:
        igates = F.linear(input, w_ih)
        hgates = F.linear(hidden[0], w_hh)
        state = fusedBackend.LSTMFused()
        return state(igates, hgates, hidden[1]) if b_ih is None else state(igates, hgates, hidden[1], b_ih, b_hh)

    hx, cx = hidden
    gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)

    ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)

    ingate = F.sigmoid(ingate)
    forgetgate = F.sigmoid(forgetgate)
    cellgate = F.tanh(cellgate)
    outgate = F.sigmoid(outgate)

    cy = (forgetgate * cx) + (ingate * cellgate)
    hy = outgate * F.tanh(cy)

    return hy, cyCopy the code

Look at the code above, which is the basic information transfer formula of LSTM. So our journey is complete.

summary

There is nothing that can’t be solved by adding another layer of abstraction. If not, then add another layer.

To repeat my understanding of the above code:

  1. The layers of cell, layer and stack are decoupled, and each layer abstracts part of the parameters (structure).
  2. Function handle passing: Returns the function handle after processing the argumentforward

Like an onion, we peel to the end and discover that the information we are dealing with is input, hidden state, and the parameters of the LSTM unit control gate. In layers of abstraction, Pytorch handles different parameters at different levels, ensuring scalability and decoupling between layers of abstraction.

@ddlee

This article followsCreative Commons Attribution-ShareAlike 4.0 International License.

This means that you may reprint this article by name and attach this agreement.

If you want regular updates on my blog posts, feel free to subscribeDongdong monthly report.

Links to this article: Blog. Ddlee. Cn/posts / 7 b453…


Related articles

  • Dropout – Pytorch implementation

  • Some thoughts and puzzles about deep learning

  • 500Lines project Crawler source read notes

  • PyCharm+PipEnv Local Python development environment configuration

  • Pandas Speed optimization