Flask快速搭建轻量级图像识别服务器
Python中文社区
共 3552字,需浏览 8分钟
·
2020-12-24 07:27
笔者选择Flask作为开发后台。Flask是一个基于Python编写的Web应用框架,相对于Django及Tornado,Flask可以说是一个轻量级选手,灵活、可扩展性强(主要是容易上手),很适用于开发中小型网站或应用服务。
# 使用以下命令进行安装
pip install Flask
要部署的模型则是基于Keras深度学习框架下(以TensorFlow为后端)训练保存的模型,其格式为HDF5,需同时含有网络结构信息与权重信息。因此需要在服务器端安装TensorFlow及Keras,若需要使用显卡加速,则还需安装CUDA及cuDNN,并升级TensorFlow为GPU版本。
# 使用显卡加速时,根据ID指定特定的显卡
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# 调整显存使用率,方便服务器快速启动
import tensorflow as tf
config=tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth=True
sess=tf.compat.v1.Session(config=config)
# 加载其他所需组件
from keras.preprocessing.image import img_to_array
from keras.models import load_model
from keras.applications import imagenet_utils
from PIL import Image
import numpy as np
import flask
import io
# 新建一个应用服务实例及定义保存模型的变量名
app = flask.Flask(__name__)
model = None
# 加载事先训练好的模型,注意这里函数名为load_model2旨在为避免与keras的load_model重名而报错
def load_model2():
# 设为全局变量,确保调用正确
global model
model =load_model('model_name.h5')
# 图片预处理过程
def prepare_image(image, target):
# 如果文件通道非RGB排列则调整为RGB
if image.mode != "RGB":
image = image.convert("RGB")
# 尺寸缩放及数据格式调整
image = image.resize(target)
image = img_to_array(image)
image = np.expand_dims(image, axis=0)
image = imagenet_utils.preprocess_input(image)
# 返回处理后结果
return image
# 使用函数装饰器对新建的服务实例进行地址与请求方式的扩展设置
@app.route("/predict", methods=["POST"])
# 图像识别判断函数,若将load_model写入则每次发送请求都会加载一次模型
def predict():
# 初始化状态值
data = {"success": False}
# 判断请求方式是否为POST,若为真则读入图片
if flask.request.method == "POST":
if flask.request.files.get("image"):
image = flask.request.files["image"].read()
image = Image.open(io.BytesIO(image))
# 对图像进行预处理
image = prepare_image(image, target=(224, 224))
# 使用模型进行识别
preds = model.predict(image)
# 将识别结果进行转换
fct=np.argmax(preds, axis=1)+1
# 创建返回结果的字典
data["predictions"] = []
# 保存识别结果及对应的概率
r = {"能见度等级": str(fct[0]), "probability": str(preds[0][fct[0]-1])}
# 将结果存放入字典中
data["predictions"].append(r)
# 改变状态值
data["success"] = True
# 服务端显示结果
print(data)
# 将结果以JSON格式返回
return flask.jsonify(data)
if __name__ == "__main__":
# 加载模型
load_model2()
# 启动服务(若出现thread_local错误,请在run中添加threaded=False,或升级keras版本)
app.run()
# 加载所需组件
import requests
# 设置URL地址
KERAS_REST_API_URL = "http://localhost:5000/predict"
# 设置图像文件地址
IMAGE_PATH = "test.jpg"
# 读取图像并将其保存在字典中进行发送
image = open(IMAGE_PATH, "rb").read()
payload = {"image": image}
# 以POST发送请求,并将返回的JSON格式进行解析
r = requests.post(KERAS_REST_API_URL, files=payload).json()
# 确保服务端程序执行正确,并显示结果,否则提示请求失败
if r["success"]:
print(r["predictions"])
else:
print("Request failed")
---— 5.结果展示 —---
#首先启动服务
python keras_serve.py
#由于在本地调式未启用WSGI服务,会有warning提示,另在生产环境中切勿使用debug模式。
开启应用后,服务端界面
#其次发送请求
python client_request.py
发送请求后,用户端界面
服务端界面
更多阅读
特别推荐
点击下方阅读原文加入社区会员
评论