tensorflow 训练好的模型怎么调用?

技术讨论 xiaoxiaohui ⋅ 于 3天前 ⋅ 469 阅读

tensorflow训练好的模型怎么调用?

搭建了个简单的tensorflow 神经网络,训练完毕之后,如何使用代码调用模型来进行识别?


直接说结论:

1、实验室环境下,直接saver和restore即可。

2、生产环境:

(1)部署在移动终端上的(例如ios、android),场景:图像识别等。用freeze_graph合成pb和ckpt文件,然后用optimize_for_inference、quantize_graph进行优化。再用TensorFlowInferenceInterface调用(这个,不知道ios和android是否相同)。
(2)部署在服务端提供服务使用的,场景:推荐系统等。使用tensorflow serving进行模型服务化。

------------------------------------------------------------

下边是基于部署在服务端提供服务的方式,查阅资料时tensorflow和tensorflow serving都是1.3版本。

_推荐本tensorflow实践的工具书。上手学习比guide更系统一些,不过工程化方面还是需要多实践、优化。

《Scikit-Learn与TensorFlow机器学习实用指南》
(此书电子版后续会上传到极市社区)


在读goooole的paper的时候经常看到下边这张图。三个虚框已经把google的系统典型流程描述得很清楚。Data Generation这步,有非常多的学问这里木有经验,略过。我们来看Model Training和Model Serving两部分。也正是题主的问题的核心。
file

注:整个系统流程都为线上生产流程非实验室环境。


前面几位答友的知识点已经都提到了,这里也就总结整理了下,没有新知识:

1、Previous Models为训练好的模型,即Model Trainer的训练结果。通常在实验室环境中完成一个模型并验证其能发布到线上使用后,通过模型保存扔到生产环境的这里提供给线上系统使用。对应的代码实现:

# Export inference model.
output_path = os.path.join(
          tf.compat.as_bytes(FLAGS.output_dir),
          tf.compat.as_bytes(str(FLAGS.model_version)))
print 'Exporting trained model to', output_path
 ...
builder = tf.saved_model.builder.SavedModelBuilder(output_path)
 ...
builder.save()

目录里是类似这样的文件:(没什么神秘的,看save的手册即可)
file

2、Model Trainer,模型训练。只要训练集准备好,就可以对模型进行训练。通常需要有个触发的条件,例如晚上1点,或者数据集抽样完成等,只要能把你的模型运行起来就可以。那这里就涉及两点1)加载Previous Model,2)验证模型,如果满足你的要跟则保存模型。加载模型的代码实现:

# Restore variables from training checkpoint.
variable_averages = tf.train.ExponentialMovingAverage(inception_model.MOVING_AVERAGE_DECAY)
variables_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)



3、Model Verifier,不多说,每个模型都要实现的。即accuracy,通常只有accuracy达到我们预计的值为才执行。对应的代码类似:

train_accuracy = accuracy.eval(feed_dict={
      x:batch[0], y_: batch[1], keep_prob: 1.0})
print "step %d, training accuracy %g"%(i, train_accuracy) 

4、关键一步,Model verfierg到Model Servers。模型保存训练并达到我们的要求后,把它保存了下来。因为是生产环境,为了保障线上实时运行的稳定性,需要让训练中的模型和线上系统进行隔离,需要使用model_version+AB分流来解决这个问题。这里就开始用到Tensorflow Serving这个家伙了,即把你的模型给服务化,通过gRPC方式的HTTP提供实时调用。当然,移动端本地化的不需要这样,需要合成pb文件后直接本地调用。

模型服务化的命令:

下载完Tensorflow Serving,编译的命令,具体看官网。

bazel build -c opt //tensorflow_serving/model_servers:tensorflow_model_server

模型服务化,后边那个“/models/mnist_mode”为前边保存模型的目录

bazel-bin/tensorflow_serving/model_servers/tensorflow_model_server --port=9000 --model_name=mnist --model_base_path=/models/mnist_model/

如果能顺利到这步,剩下的事情就是通过9000端口调用你的模型了。

(这步有很多坑,tensorflow serving变化很快,建议能docker的尽量docker否则,嘿嘿嘿...)

5、Apps Rec Engine线上系统,即调用Model Servers的Clients端,做个gRPC发请求调用就可以了。


文章来源:我踏马不看知呼@知乎

相关推荐:

GitHub:TensorFlow 最全资料集锦
TensorFlow 真的要被 PyTorch 比下去了吗?
【电子书】Deep Learning with TensorFlow 2 and Keras, 2th

成为第一个点赞的人吧 :bowtie:
回复数量: 0
暂无回复~
您需要登陆以后才能留下评论!