This paper discusses the implementation of Pytorch for LSTM network, and also discusses the code organization and architecture design of Pytorch library.
LSTM
LSTM is a cyclic neural network suitable for modeling serialized inputs. This article by Chris Olah explains how an LSTM unit works in detail and is recommended reading.
Two ideas
Gate: The Gate through which information flows
Cell: memory pool
Comparison with ordinary RNN
A normal RNN has only one self-updating hidden state unit.
LSTM adds a memory pool Cell, and updates the information in the memory pool in a controlled way through several gates, and determines the hiding state through the information in the memory pool.
From Scratch
Here is the code to implement LSTM manually, inheriting the base class nn.module.
import torch.nn as nn
import torch
from torch.autograd import Variable
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, cell_size, output_size):
super(LSTM, self).__init__()
self.hidden_size = hidden_size
self.cell_size = cell_size
self.gate = nn.Linear(input_size + hidden_size, cell_size)
self.output = nn.Linear(hidden_size, output_size)
self.sigmoid = nn.Sigmoid()
self.tanh = nn.Tanh()
self.softmax = nn.LogSoftmax()
def forward(self, input, hidden, cell):
combined = torch.cat((input, hidden), 1)
f_gate = self.gate(combined)
i_gate = self.gate(combined)
o_gate = self.gate(combined)
f_gate = self.sigmoid(f_gate)
i_gate = self.sigmoid(i_gate)
o_gate = self.sigmoid(o_gate)
cell_helper = self.gate(combined)
cell_helper = self.tanh(cell_helper)
cell = torch.add(torch.mul(cell, f_gate), torch.mul(cell_helper, i_gate))
hidden = torch.mul(self.tanh(cell), o_gate)
output = self.output(hidden)
output = self.softmax(output)
return output, hidden, cell
def initHidden(self):
return Variable(torch.zeros(1, self.hidden_size))
def initCell(self):
return Variable(torch.zeros(1, self.cell_size))Copy the code
A few key points:
- The size of the Tensor
- The order in which information is transmitted
Pytorch Module
The Pytorch library itself encapsulates more functionality for LSTM implementations, and the organization of classes and functions is instructive. My understanding of the implementation is based on the following two points:
- The layers of cell, layer and stack are decoupled, and each layer abstracts part of the parameters (structure).
- Function handle passing: Returns the function handle after processing the argument
forward
Let’s start with a steed, see GitHub for the source.
LSTM class
File: nn/modules/RNN. Py
# nn/modules/rnn.py
class RNNBase(Module):
def __init__(self, mode, input_size, output_size):
pass
def forward(self, input, hx=None):
if hx is None:
hx = torch.autograd.Variable()
if self.mode == 'LSTM':
hx = (hx, hx)
func = self._backend.RNN() #!!!
output, hidden = func(input, self.all_weights, hx) #!!!
return output, hidden
class LSTM(RNNBase):
def __init__(self, *args, **kwargs):
super(LSTM, self).__init__('LSTM', *args, **kwargs)Copy the code
LSTM
Class is justRNNBase
Class as a decorator.- In the base class
nn.Module
the__call__()
Defined as callingforward()
Method, so the real functionality is realized in_backend.RNN()
In the
AutogradRNN function
Next look for _backend.rnn. File: nn/backends/THNN. Py
# nn/backends/thnn.py def _initialize_backend(): from .. _functions.rnn import RNN, LSTMCellCopy the code
Originally, _backend is also an index.
Finally find the RNN() function. File: nn / _functions RNN. Py
# nn/_functions/rnn.py
def RNN(*args, **kwargs):
def forward(input, *fargs, **fkwargs):
func = AutogradRNN(*args, **kwargs)
return func(input, *fargs, **fkwargs)
return forward
def AutogradRNN(mode, input_size, hidden_size):
cell = LSTMCell
rec_factory = Recurrent
layer = (rec_factory(cell),)
func = StackedRNN(layer, num_layers)
def forward(input, weight, hidden):
nexth, output = func(input, hidden, weight)
return output, nexth
return forwardCopy the code
RNN()
It’s a decorator, depending on whether there is onecudnn
Library decision callAutogradRNN()
orCudnnRNN()
, here is just observationAutogradRNN()
AutogradRNN()
Choose theLSTMCell
withRecurrent()
The function handlesCell
Constitute aLayer
And thenLayer
The incomingStackedRNN()
functionRNN()
andAutogradRNN()
That’s what’s returnedforward()
Handle to the function
Here’s a Recurrent() function:
def Recurrent(inner):
def forward(input, hidden, weight):
output = []
steps = range(input.size(0) - 1, -1, -1)
for i in steps:
hidden = inner(input[i], hidden, *weight)
output.append(hidden[0])
return hidden, output
return forwardCopy the code
Recurrent()
The function implements a “recursive” structure based on the size of the inputCell
, completes the iteration of hiding states and parameters.Recurrent()
Function willCell(inner)
Combination ofLayer
.
StackedRNN () function
def StackedRNN(inners, num_layers):
num_directions = len(inners)
total_layers = num_layers * num_directions
def forward(input, hidden, weight):
next_hidden = []
hidden = list(zip(*hidden))
for i in range(num_layers):
all_output = []
for j, inner in enumerate(inners):
hy, output = inner(input, hidden[l], weight[l])
next_hidden.append(hy)
all_output.append(output)
input = torch.cat(all_output, input.dim() - 1)
next_h, next_c = zip(*next_hidden)
next_hidden = (torch.cat(next_h, 0).view(total_layers, *next_h[0].size()),
torch.cat(next_c, 0).view(total_layers, *next_c[0].size()))
return next_hidden, input
return forwardCopy the code
StackedRNN()
Function willLayer(inner)
Combination of the stack
Finally, a basic LSTM cell is computed by the LSTMCell() function.
LSTMCell () function
def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
if input.is_cuda:
igates = F.linear(input, w_ih)
hgates = F.linear(hidden[0], w_hh)
state = fusedBackend.LSTMFused()
return state(igates, hgates, hidden[1]) if b_ih is None else state(igates, hgates, hidden[1], b_ih, b_hh)
hx, cx = hidden
gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = F.sigmoid(ingate)
forgetgate = F.sigmoid(forgetgate)
cellgate = F.tanh(cellgate)
outgate = F.sigmoid(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * F.tanh(cy)
return hy, cyCopy the code
Look at the code above, which is the basic information transfer formula of LSTM. So our journey is complete.
summary
There is nothing that can’t be solved by adding another layer of abstraction. If not, then add another layer.
To repeat my understanding of the above code:
- The layers of cell, layer and stack are decoupled, and each layer abstracts part of the parameters (structure).
- Function handle passing: Returns the function handle after processing the argument
forward
Like an onion, we peel to the end and discover that the information we are dealing with is input, hidden state, and the parameters of the LSTM unit control gate. In layers of abstraction, Pytorch handles different parameters at different levels, ensuring scalability and decoupling between layers of abstraction.
@ddlee
This article followsCreative Commons Attribution-ShareAlike 4.0 International License.
This means that you may reprint this article by name and attach this agreement.
If you want regular updates on my blog posts, feel free to subscribeDongdong monthly report.
Links to this article: Blog. Ddlee. Cn/posts / 7 b453…
Related articles
-
Dropout – Pytorch implementation
-
Some thoughts and puzzles about deep learning
-
500Lines project Crawler source read notes
-
PyCharm+PipEnv Local Python development environment configuration
-
Pandas Speed optimization