Flask is fun and easy to set up, as stated on the Flask website. This Python microframework provides a powerful way to annotate Python functionality using REST endpoints. Flask is being used to publish the ML model API for access by third-party business applications.
This example is based on XGBoost.
For better code maintenance, it is recommended to use a separate Jupyter notebook in which the ML model API will be published. Flask module and Flask CORS:
from flask import Flask, jsonify, request
from flask_cors import CORS, cross_origin
import pickle
import pandas as pd
Copy the code
The model was trained on the Pima Indians Diabetes database. The CSV data can be downloaded here. To construct the Pandas data frame variable as input to the model prediction function, define an array of data set columns:
Raw.githubusercontent.com/jbrownlee/D…
# Get headers for payload
headers = ['times_pregnant', 'glucose', 'blood_pressure', 'skin_fold_thick', 'serum_insuling', 'mass_index', 'diabetes_pedigree', 'age']
Copy the code
Load previously trained and saved models using Pickle:
# Use pickle to load in the pre-trained model
with open(f'diabetes-model.pkl', 'rb') as f:
model = pickle.load(f)
Copy the code
It’s always a good habit to test run and check that your model is working well. Construct data boxes using column name arrays and data arrays (using new data, training or testing data that does not exist in the data set). Call two functions -model.predict and model.predict_proba. Model.predict_proba is generally preferred, which returns a probability describing 0/1 possibilities, which helps interpret the results according to a range (for example, 0.25 to 0.75). Build the Pandas data frame with the sample payload, and then perform the model prediction:
# Test model with data frame input_variables = pd.dataframe ([[1, 106, 70, 28, 135, 34.2, 0.142, 22]], columns=headers, dtype=float, index=['input']) # Get the model's prediction prediction = model.predict(input_variables) print("Prediction: ", prediction) prediction_proba = model.predict_proba(input_variables) print("Probabilities: ", prediction_proba)Copy the code
The Flask API ensures that CORS is enabled, otherwise the API calls will not run on other hosts. Write annotations before functions to be exposed through the REST API. Provide the endpoint name and the supported REST method (POST in this case). Retrieve the payload data from the request, construct the Pandas data frame, and execute the model predict_proba function:
Flask = Flask(__name__) CORS @app.route("/katana-ml/ API /v1.0/diabetes", methods=['POST']) def predict(): Flask(__name__) CORS @app.route("/katana-ml/ API /v1.0/diabetes", methods=['POST']) def predict(): payload = request.json['data'] values = [float(i) for i in payload.split(',')] input_variables = pd.DataFrame([values], columns=headers, dtype=float, index=['input']) # Get the model's prediction prediction_proba = model.predict_proba(input_variables) prediction = (prediction_proba[0])[1] ret = '{"prediction":' + str(float(prediction)) + '}' return ret # running REST interface, Port =5000 for direct test if __name__ == "__main__": app.run(debug=False, host='0.0.0.0', port=5000)Copy the code
The response JSON string is constructed and returned as the result of the function. Run Flask in a Docker container, which is why 0.0.0.0 is used as the host for it to run. Port 5000 is mapped to an external port, which allows calls from outside.
Although it can start the Flask interface directly in a Jupyter notebook, it is recommended to convert it to a Python script and run it as a service from the command line. Convert to a Python script using the Jupyter nbconvert command:
Jupyter nbconvert - to Python diabetes_redSamurai_endpoint_db.ipynbCopy the code
A Python script with a Flask endpoint can be started as a background process for the PM2 process manager. This allows the endpoint to run as a service and start other processes on different ports. PM2 startup command:
pm2 start diabetes_redsamurai_endpoint_db.py
Copy the code
Pm2 Monit helps to display information about running processes:
The ML model classifies REST API calls from Postman to Flask service endpoints: