Today, a netizen commented on the article about how to build TensorFlow Lite2.0 on Android
How to do an image input and an array output?
I think this is also a pain point for many beginners, many beginners have not fully from the model building, training, conversion to TensorFlowLite, and the actual use of Android.
So I gave him the demo I had written before. I thought I would take the time to write this demo into an article, hoping that I could help more students who started it.
Although TensorFlow based on the implementation of handwriting, a lot of articles, but I still need to verbose, after all, it is a good example of artificial intelligence entry.
I don’t care about the details of the handwriting recognition algorithm. I care about the whole process from the model to the application. If you want to know about the algorithm, please learn it yourself.
Interested students can follow my series of blogs artificial Intelligence series (update……) , they are also learning this knowledge, learning and communication together.
1. Basic knowledge of handwriting
1.1 Exploring the MINIST data set
The MNIST data set was adopted from the National Institute of Standards and Technology (NIST). The training set consists of handwritten numbers from 250 different people, 50 percent high school students and 50 percent people who work at the Census Bureau. The test set is the same proportion of handwritten numeric data.
What does each image in the data set look like?
Here it is:This is obtained by the following code:
# Plot ad hoc mnist instances
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt
# load (downloaded if needed) the MNIST dataset
(X_train, y_train), (X_test, y_test) = mnist.load_data()
# plot 4 images as gray scale
plt.subplot(221)
plt.imshow(X_train[0], cmap=plt.get_cmap("gray"))
plt.subplot(222)
plt.imshow(X_train[1], cmap=plt.get_cmap("gray"))
plt.subplot(223)
plt.imshow(X_train[2], cmap=plt.get_cmap("gray"))
plt.subplot(224)
plt.imshow(X_train[3], cmap=plt.get_cmap("gray"))
# show the plot
plt.show()
Copy the code
But what is the actual storage?You can see that this is a zero word, it’s stored as the RGB value of the image zero, and everything that’s zero is black, and everything that’s not zero is gray. This is an image grayscale RGB matrix.
1.2 Basic introduction to CNN
The handwritten recognition algorithm is CNN (convolutional neural network), which is widely used in computer vision.The most classic CNN handwriting recognition map describes the entire process of handwriting recognition, and the specific details will not be mentioned. If I have the opportunity to write an article on the details of this algorithm, the neural network model structure of this paper is as follows:
1.3 Handwriting recognition based on TensorFlow
The Keras interface in TensorFlow is adopted, which is suitable for beginners to use. It makes you feel like building a neural network is like building blocks.
The code is as follows. Note the comments.
import numpy
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.python.keras.utils import np_utils
import tensorflow as tf
import pathlib
# fix random seed for reproducibility
seed = 7
numpy.random.seed(seed)
# load data
(X_train, y_train).(X_test, y_test) = mnist.load_data()
# reshape to be [samples][channels][width][height]
X_train = X_train.reshape(X_train.shape[0].28.28.1).astype('float32')
X_test = X_test.reshape(X_test.shape[0].28.28.1).astype('float32')
# normalize inputs from 0-255 to 0-1
X_train = X_train / 255
X_test = X_test / 255
print(X_train.shape)
# one hot encode outputs
y_train = np_utils.to_categorical(y_train)
y_test = np_utils.to_categorical(y_test)
print(X_train[0])
num_classes = y_test.shape[1]
def baseline_model(a):
# create model
model = Sequential()
model.add(Conv2D(32, kernel_size=(5.5),
input_shape=(28.28.1),// Use single-pass images
activation='relu'))
model.add(MaxPooling2D(pool_size=(2.2)))
model.add(Dropout(0.2))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dense(num_classes, activation='softmax'))
# Compile model
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer='adam',
metrics=['accuracy'])
return model
model = baseline_model()
# Fit the model
model.fit(X_train, y_train, validation_data=(X_test, y_test), epochs=10, batch_size=200, verbose=2)
# Final evaluation of the model
scores = model.evaluate(X_test, y_test, verbose=0)
print("CNN Error: %.2f%%" % (100 - scores[1] * 100) # Upgrade the network training process # Next need to convert it to tensorFlow Lite model for easy use in Android. converter = tf.lite.TFLiteConverter.from_keras_model(model) tflite_model = converter.convert() tflite_model_file = pathlib.Path('saved_model/model.tflite')
tflite_model_file.write_bytes(tflite_model)
Copy the code
2. Implement handwriting recognition on Android
If you do not know how to configure the Android environment, please refer to The Hand in hand guide to build tensorFlow Lite2.0 on Android
2.1 Loading Model
Place your trained TensorFlow Lite file in the Android Asset folder.
public class TF {
private static Context mContext;
Interpreter mInterpreter;
private static TF instance;
public static TF newInstance(Context context) {
mContext = context;
if (instance == null) {
instance = new TF();
}
return instance;
}
Interpreter get(a) {
try {
if (Objects.isNull(mInterpreter))
mInterpreter = new Interpreter(loadModelFile(mContext));
} catch (IOException e) {
e.printStackTrace();
}
return mInterpreter;
}
// Get the file
private MappedByteBuffer loadModelFile(Context context) throws IOException {
AssetFileDescriptor fileDescriptor = context.getAssets().openFd("model.tflite");
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
returnfileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); }}Copy the code
2.2 Custom drawing View
public class HandWriteView extends View {
Path mPath = new Path();
Paint mPaint;
Bitmap mBitmap;
Canvas mCanvas;
public HandWriteView(Context context) {
super(context);
init();
}
public HandWriteView(Context context, AttributeSet attrs) {
super(context, attrs);
init();
}
void init(a) {
mPaint = new Paint();
mPaint.setColor(Color.WHITE);
mPaint.setStyle(Paint.Style.STROKE);
mPaint.setStrokeJoin(Paint.Join.ROUND);
mPaint.setStrokeCap(Paint.Cap.ROUND);
mPaint.setStrokeWidth(30);
}
@Override
protected void onDraw(Canvas canvas) {
super.onDraw(canvas);
mBitmap = Bitmap.createBitmap(getWidth(), getHeight(), Bitmap.Config.ARGB_8888);
mCanvas = new Canvas(mBitmap);
mCanvas.drawColor(Color.BLACK);
canvas.drawPath(mPath, mPaint);
mCanvas.drawPath(mPath, mPaint);
}
@Override
public boolean onTouchEvent(MotionEvent event) {
switch (event.getAction()) {
case MotionEvent.ACTION_DOWN:
mPath.moveTo(event.getX(), event.getY());
break;
case MotionEvent.ACTION_MOVE:
mPath.lineTo(event.getX(), event.getY());
break;
case MotionEvent.ACTION_UP:
case MotionEvent.ACTION_CANCEL:
break;
}
postInvalidate();
return true;
}
Bitmap getBitmap(a) {
mPath.reset();
returnmBitmap; }}Copy the code
2.3 Converting a Bitmap to a network format
Since all the data in the data set are 28 * 28 * 3, 28 is the width and height of the picture, and 3 is the three channels R, G and B, we need to convert the bitmap into the format required by the network before input to the network.
private ByteBuffer convertBitmapToByteBuffer(Bitmap bitmap) {
int inputShape[] = TF.newInstance(getApplicationContext()).get().getInputTensor(0).shape();
int inputImageWidth = inputShape[1];
int inputImageHeight = inputShape[2];
Bitmap bs = Bitmap.createScaledBitmap(bitmap, inputImageWidth, inputImageHeight, true);
mImageView.setImageBitmap(bs);
ByteBuffer byteBuffer = ByteBuffer.allocateDirect(4 * inputImageHeight * inputImageWidth);
byteBuffer.order(ByteOrder.nativeOrder());
int[] pixels = new int[inputImageWidth * inputImageHeight];
bs.getPixels(pixels, 0, bs.getWidth(), 0.0, bs.getWidth(), bs.getHeight());
for (int pixelValue : pixels) {
int r = (pixelValue >> 16 & 0xFF);
int g = (pixelValue >> 8 & 0xFF);
int b = (pixelValue & 0xFF);
// Convert RGB to grayscale and normalize pixel value to [0..1]
float normalizedPixelValue = (r + g + b) / 3.0 f / 255.0 f;
byteBuffer.putFloat(normalizedPixelValue);
}
return byteBuffer;
}
Copy the code
2.4 Output of identification results
The result of recognition is judged according to the probability of 0-9, and the one with the highest probability is the result of recognition.
float[][] input = new float[1] [10];
TF.newInstance(getApplicationContext()).get().run(convertBitmapToByteBuffer(mHandWriteView.getBitmap()), input);
int result = -1;
float value = 0f;
for (int j = 0; j < 10; j++) {
if (input[0][j] > value) {
value = input[0][j];
result = j;
}
Log.i("TAG"."result: " + j + "" + input[0][j]);
}
if (input[0][result] < 0.2 f) {
mTextView.setText("Result: no recognition");
} else {
mTextView.setText("The result is :" + result);
}
Copy the code
Identification results:
If necessary, please click demo to download.
3 summary
So much for the main process of developing an AI APP. The key is the algorithm. To get a more accurate model, in addition to using a better model, you need to rotate, enhance, or white matter the data to improve the diversity of the data.
Welcome everyone to exchange !!!!