【导读】TensorFlow的生态圈极其强大,覆盖了科研、工程中的各种流程,其中一些特别好用的模块和技巧可以使你的工作效率大幅度提升,也可以让你的产品变得非常稳定。本文介绍其中的一些鲜为人知却又十分实用的知识。
一. GraphDef才是正确地模型保存的方法
大部分用户保存TensorFlow模型的方法是tf.train.Saver.save,这是众多科研代码中用来保存模型的方法,保存之后的模型如下图所示。
实际上这种保存的方法,是给模型训练做checkpoint用的,也就是说为了让你能够随时保存实验过程,随时恢复实验用的(防止断电、死机导致实验丢失)。
如果你希望为TensorFlow保存一个能够用于产品用的模型,并且这个模型能够被C/C++/Java/NodeJS等调用(类似Caffe模型),你需要了解GraphDef。用GraphDef方式保存的模型是一个独立地Protobuf文件,看一下维基百科对Protobuf的解释:
Protocol Buffers是一种序列化数据结构的协议。对于透过管线(pipeline)或存储数据进行通信的程序开发上是很有用的。这个方法包含一个接口描述语言,描述一些数据结构,并提供程序工具根据这些描述产生代码,用于将这些数据结构产生或解析数据流。
也就是说Protobuf文件是一种无视语种的数据描述文件,存成Protobuf文件,模型可以被Protobuf支持的各大语种(C/C++/Java/NodeJS等)读取。
TensorFlow模型的正确保存方式如下:
最终,我们只会得到一个model.pb文件:
model.pb存储的是压缩版的frozen_graph_def,上面我们用print函数将frozen_graph_def 输出的结果如下,这可以看到,这是一个标准的图结构的数据(也就是静态图),不仅包含了节点,还包含了节点中的数据。
node {
name: 'x'
op: 'Placeholder'
attr {
key: 'dtype'
value {
type: DT_FLOAT
}
}
attr {
key: 'shape'
value {
shape {
unknown_rank: true
}
}
}
}
node {
name: 'y'
op: 'Const'
attr {
key: 'dtype'
value {
type: DT_FLOAT
}
}
attr {
key: 'value'
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 10.0
}
}
}
}
node {
name: 'y/read'
op: 'Identity'
input: 'y'
attr {
key: 'T'
value {
type: DT_FLOAT
}
}
attr {
key: '_class'
value {
list {
s: 'loc:@y'
}
}
}
}
node {
name: 'add'
op: 'Add'
input: 'x'
input: 'y/read'
attr {
key: 'T'
value {
type: DT_FLOAT
}
}
}
node {
name: 'z'
op: 'Log'
input: 'add'
attr {
key: 'T'
value {
type: DT_FLOAT
}
}
}
library {
}
为什么在保存GraphDef前要调用tf.graph_util.convert_variables_to_constants方法,我们发现在调用tf.graph_util.convert_variables_to_constants方法时,程序有一行输出:
其实默认状态下,静态图的数据是被同时保存在GraphDef和Session中的,图结构、常量的值等被存储在GraphDef中,而变量的值被存储在Session中,这也是为什么每次用静态图都要在Session中使用的原因。
tf.graph_util.convert_variables_to_constants方法将Session中的变量转换到GraphDef中以常量形式存储,由于没有了变量,得到的GraphDef中包含了静态图的所有信息,即包含了整个模型,保存GraphDef即保存了整个模型。
现在我们可以用C/C++/Java/NodeJS等来读取并执行保存的GraphDef文件,以Java为例(需要Maven导入java版tensorflow api),整个流程和Python API很像,读取图,开启Session,并将读取的图放入Session,指定输入,获取输出:
import org.apache.commons.io.IOUtils;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import java.io.FileInputStream;
import java.io.IOException;
public class DemoImportGraph {
public static void main(String[] args) throws IOException {
try (Graph graph = new Graph()) {
//导入图
byte[] graphBytes = IOUtils.toByteArray(new FileInputStream('model.pb'));
graph.importGraphDef(graphBytes);
//根据图建立Session
try(Session session = new Session(graph)){
//相当于TensorFlow Python中的sess.run(z, feed_dict = {'x': 10.0})
float z = session.runner()
.feed('x', Tensor.create(10.0f))
.fetch('z').run().get(0).floatValue();
System.out.println(z);
}
}
}
}
所以,TensorFlow模型并非只能被Python调用。按照GraphDef方式保存为Protobuf模型后,可以被任何TensorFlow提供了API的语种调用。
详情可以参考:
http://www.zhuanzhi.ai/document/5f2d760783fb7a0d49e971140a1c4561
二. 可以在Keras中使用TensorFlow,也可以在TensorFlow中使用Keras
TensorFlow是最终要的内核之一,在默认的使用TensorFlow作为内核的情况下,Keras的各种层、包括模型的执行,都是依赖TensorFlow的各种操作、Session等去完成的,在Keras中使用TensorFlow是众所周知的,然而在TensorFlow中使用Keras确是一个不常见的情况。其实Keras早就进入了TensorFlow的核心库(tf.keras),而且成为了官方较为推荐使用tf.keras进行模型的构建,看一下TensorFlow 1.9官网教程首页的示例代码,
原先在TensorFlow需要几十行才能构建的模型和流程,用tf.keras模块十几行就可以搞定了。
三. TensorFlow Hub中有许多可以直接使用的模型
TensorFlow Hub是TensorFlow官方提供的用于模型发布、复用的工具。例如下面的代码可以获取句子的Embedding,我们只需要给出TensorFlow Hub模型发布的url以及输入,通过简单的几行调用即可完成原先需要数百还才能完成的工作。另外,指定url的方式相比于自己下载模型的方式便利了许多。
import tensorflow as tf
import tensorflow_hub as hub
with tf.Graph().as_default():
module_url = 'https://tfhub.dev/google/nnlm-en-dim128-with-normalization/1'
embed = hub.Module(module_url)
embeddings = embed(['A long sentence.', 'single-word',
'http://example.com'])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.tables_initializer())
print(sess.run(embeddings))
参考链接:
https://www.tensorflow.org/hub/
四. 在静态图中也可以像动态图那样写条件判断语句
原先在静态图中是无法使用Python的if语句来为静态图定义条件判断结构的,需要使用特殊的tf.cond操作来定义一个条件判断节点,非常的麻烦,近期TensorFlow新出的AutoGraph功能可以让用户按照Python的if语句来定义结构,然后利用AutoGraph注解将其转换为相应的静态图结构,这样可以大幅度降低静态图构建的难度:
参考链接:
https://www.tensorflow.org/guide/autograph
联系客服