我正在使用 Flask 运行一个 Web 服务器,当我尝试使用 vgg16(这是 keras 预训练的 VGG16 模型的全局变量)时,会出现错误。我不知道为什么会出现这个错误,也不知道它是否与 Tensorflow 后端有关。
这是我的代码:
vgg16 = VGG16(weights='imagenet', include_top=True)
def getVGG16Prediction(img_path):
global vgg16
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
pred = vgg16.predict(x)
return x, sort(decode_predictions(pred, top=3)[0])
@app.route("/uploadMultipleImages", methods=["POST"])
def uploadMultipleImages():
uploaded_files = request.files.getlist("file[]")
for file in uploaded_files:
path = os.path.join(STATIC_PATH, file.filename)
pInput, result = getVGG16Prediction(path)
Here is the full error:
如有任何意见或建议,我们将不胜感激。谢谢。
看一眼avital
的回答这个 github 问题 https://github.com/keras-team/keras/issues/2397。这里引用相关部分:
加载或构建模型后,立即保存 TensorFlow 图:
graph = tf.get_default_graph()
在另一个线程中(或者可能在异步事件处理程序中),执行以下操作:
global graph
with graph.as_default():
(... do inference here ...)
我对此进行了一些修改,并将图表存储在我的应用程序的配置对象中,而不是使其成为全局的。
The TensorFlow 文档 https://www.tensorflow.org/api_docs/python/tf/get_default_graph for get_default_graph
解释了为什么这是必要的:
注意:默认图表是当前线程的属性。如果您创建一个新线程,并希望在该线程中使用默认图,则必须在该线程的函数中显式添加 with g.as_default(): 。
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)