Small knowledge, big challenge! This article is participating in the creation activity of “Essential Tips for Programmers”.
Keras is a deep learning framework based on Theano/TensorFlow written in pure Python. Keras is a high-level neural network API that supports rapid experimentation and can quickly turn your ideas into results. When we understand the network structure of a model through codes, it is not easy to understand the complex structure. However, if this structure is displayed in the form of pictures, it can be more intuitive and fast for us to understand. In this paper, Keras framework is used to draw the network structure of bi-LSTM model.
I. Preliminary preparation
1. Install PyDot
pip install pydot
Copy the code
2. Install Graphviz
Graphviz should be installed on the official website: Graphvizgraphviz.org/
After the installation, you need to add system variables to the bin folder of the program directory
Write code
1. Import related packages
Load_model: Used to load network models
CRF: The CRF model layer exists in the network model
Plot_model: Generates the network model structure and saves it as a picture
Pyplot: Loads network model structure pictures
from keras.models import load_model
from keras_contrib.layers import CRF
from keras.utils.vis_utils import plot_model
import matplotlib.pyplot as plt
Copy the code
2. Generate network model structure
Plot_model Interface parameters:
To_file: path and name for storing the network model structure picture
Show_shapes: Displays shapes (neural layer input and output)
Show_layer_names: indicates whether to display the name of the neural layer
Rankdir: Direction between neural layers. TB stands for up and down, LR stands for left and right
model_path = "./model/ch_ner_model.h5"
# model file
model = load_model(model_path, custom_objects={'CRF': CRF}, compile=False)
plot_model(model,to_file='./model/nerbilstm.png',show_shapes=True,show_layer_names='False',rankdir="TB")
Copy the code
3. Load network model structure
Use the Pyplot method in the Matplotlib package to load the generated network model structure picture.
plt.figure(figsize=(10.10))
img = plt.imread("./model/nerbilstm.png")
plt.imshow(img)
plt.axis("off")
plt.show()
Copy the code