0%

Tensorflow持久化原理及数据格式

Tensorflow持久化原理及数据格式


Tensorflow是一个通过图的形式来表述计算的编程系统,Tensorflow中所有的计算都会被表达成计算图上的节点。Tensorflow通过元图(MetaGraph)来记录计算图中的信息,以及运行计算图中节点所需要的元数据。以下代码给出了MetaGraphDef类型的定义

1
2
3
4
5
6
7
message MetaGraphDef{
MeatInfoDef meta_info_def = 1;
GraphDef graph_def = 2;
SaverDef saver_def = 3;
map<string,CollectionDef> collection_def = 4;
map<string,SignatureDef> signature_def = 5;
}

保存MetaGraphDef信息的文件默认以.meta为后缀名,在之前的例子中文件test.ckpt.meta中存储的就是元图的数据。由于得到的是二进制文件不方便查看。为了方便调试,Tensorflow提供了export_meta_graph函数,这个函数支持以json格式导出MetaGraphDef。下面为实现的代码

1
2
3
4
5
6
7
8
9
import tensorflow as tf

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name ="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name ="v2")
result1 = v1 + v2

saver = tf.train.Saver()
#通过export_meta_graph函数导出Tensorflow的计算图的元图,并保存为json格式
saver.export_meta_graph("test/test.ckpt.json", as_text=True)

meta_info_def 属性

meta_info_def属性通过MetaInfoDef定义,它记录了Tensorflow计算图中的元数据以及Tensorflow程序中所有使用到的运算方法的信息。下面的MetaInfoDef的定义

1
2
3
4
5
6
7
message MetaInfoDef{
#saver没有特殊指定,默认属性都为空。meta_info_def属性里只有stripped_op_list属性不能为空。
string meta_graph_version = 1;#该属性不能为空
OpList stripped_op_list = 2;#该属性记录了计算图中使用到的所有运算方法的信息,该函数只记录运算信息,不记录计算的次数
google.protobuf.Any any_info = 3;
repeated string tags = 4;
}

元数据包括计算图的版本号(meta_graph_version属性)以及用户指定的一些标签(tags属性)。


OpList类型是一个OpDef类型的列表,以下代码给出OpDef类型的定义:

1
2
3
4
5
6
7
8
9
10
11
12
13
message opDef{
string name = 1;#定义了运算的名称
repeated ArgDef input_arg = 2; #定义了输入,属性是列表
repeated ArgDef output_arg =3; #定义了输出,属性是列表
repeated AttrDef attr = 4;#给出了其他运算的参数信息
string summary = 5;
string description = 6;
OpDeprecation deprecation = 8;
bool is_commutative = 18;
bool is_aggregate = 16
bool is_stateful = 17;
bool allows_uninitialized_input = 19;
};

下面给出一个比较有代表性的运算来辅助说明OpDef的数据结构。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
op {
name: "Add"
input_arg{
name: "x"
type_attr:"T"
}
input_arg{
name: "y"
type_attr:"T"
}
output_arg{
name: "z"
type_attr:"T"
}
attr{
name:"T"
type:"type"
allow_values{
list{
type:DT_HALF
type:DT_FLOAT
...
}
}
}
}

上面给出的是名称为Add的运算。这个运算的输入有两个,输出有一个,输入输出属性均指定了属性typr_attr,并且这个属性的值为T。在OpDef的attr的属性中。必须要出现名称(name)为T的属性。以上样例中,这个属性指定了运算输入输出允许的参数类型 (allowed_values)。


graph_def属性

graph_def属性主要记录了计算图中的节点信息。Tensorflow计算图中的一个节点对应Tensorflow中的一个运算。因为meta_info_def中已经包含所有运算的具体信息,所以graph_def属性指关注运算的连接结构。GraphDef主要包含了一个NodeDef类型的列表。以下代码给出GraphDef和NodeDef类型中包含的信息:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
message GraphDef{
#GraphDef的主要信息存储在node属性中,他记录了Tensorflow计算图上所有的节点信息。
repeated NodeDef node = 1;
VersionDef versions = 4; #主要储存了Tensorflow的版本号
};

message NodeDef{
#NodeDef类型中有一个名称属性name,他是一个节点的唯一标识符,在程序中,通过节点的名称来获得相应的节点。
string name = 1;

'''
op属性给出了该节点使用的Tensorflow运算方法的名称。
通过这个名称可以在TensorFlow计算图元图的meta_info_def属性中找到该运算的具体信息。
'''
string op = 2;

'''
input属性是一个字符串列表,他定义了运算的输入。每个字符串饿的取值格式为弄的:src_output
node部分给出节点名称,src_output表明了这个输入是指定节点的第几个输出。
src_output=0时可以省略src_output部分
'''
repeated string input = 3;

#制定了处理这个运算的设备,可以是本地或者远程的CPU or GPU。属性为空时自动选择
string device = 4;

#制定了和当前运算有关的配置信息
map<string, AttrValue> attr = 5;
};

下面列举test.ckpt.meta.json具体介绍graph_def属性

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
graph def {
node {
name: "v1"
op: "Variable"
attr {
key:"_output_shapes"
value {
list{ shape { dim { size: 1 } } }
}
}
}
attr {
key :"dtype"
value {
type: DT_FLOAT
}
}
...
}
node {
name :"add"
op :"Add"
input :"v1/read" #read指读取变量v1的值
input: "v2/read"
...
}
node {
name: "save/control_dependency" #指系统在完成tensorflow模型持久化过程中自动生成一个运算。
op:"Identity"
...
}
versions {
producer :9 #给出了文件使用时的Tensorflow版本号。
}
}

saver_def属性

1
2
3
4
5
6
7
8
9
10
11
12
13
14
message SaverDef {
string filename_tensor_name = 1;
string save_tensor_name = 2;
string restore_op_name = 3;
int32 max_to_keep = 4;
bool sharded = 5;
float keep_checkpoint_every_n_hours = 6;
enum CheckpointFormatVersion {
LEGACY = 0;
V1 = 1;
V2 = 2;
}
CheckpointFormatVersion version = 7;
}

下面给出test.ckpt.meta.json文件中saver_def属性的内容。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
saver_def {
filename_tensor_name :"save/Const:0”
#给出了保存文件的张量名,这个张量就是节点save/Const的第一个输出。

save_tensor_name :"save/control_dependency: 0
#给出了持久化模型运算所对应的节点名称

restore_op_name: "save/restore_all"
#和持久性模型运算对应的是加载模型的运算的名称

max_to_keep:5
keep_checkpoint_every_n_hours :10000.0
'''
上面两个属性设定了tf.train.Saver类清理之前保存的模型的策略。比如当max_to_keep为5时,第六次调用
saver.save时,第一次保存的模型就会被自动删除,通过设置keep_checkpoint_every_n_hours,每n小
时可以在max_to_keep的基础上保存一个模型
'''

collection_def属性

collection_def属性是一个集合名称到集合内容的映射,其中集合的名称为字符串,而集合内容为CollectionDef Protocol Buffer。以下代码给出CollectionDef类型的定义

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
message CollectionDef {
message Nodelist {
#用于维护计算图上的节点集合
repeated string value = 1;
}

message BytesList {
#维护字符串或者系列化之后的Procotol Buffer的集合。例如张量是通过Protocol Buffer表示的,而张量的集合是通过BytesList维护的。
repeated bytes value = 1 ;
}

message Int64List {
repeated int64 value = 1[packed = true];
}
message FloatList {
repeated float value = 1[packed = true] ;
}
message AnyList {
repeated google.protobuf.Any value= 1;
}
oneof kind {
NodeList node_list = 1;
BytesList bytes_lista = 2;
Int64List int64_list = 3;
Floatlist float_list = 4;
AnyList any_list = 5;
}
}

下面给出了test.ckpt.meta.json文件中的collection_def属性的内容

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
collection_def {
#可训练变量的集合
key: "trainable_variables"
value {
bytes_list {
value; "\n\004v1:0\022\tv1/Assign\032\tv1/read:0"
value: "\n\004v2:0\022\tv2/Assign\032\cv2/read:0"
}
}
}
collection_def {
#所有变量的集合
key: "variables"
value {
bytes_list {
value:"\n\004v1:0\022\tv1/Assign\032\tv1/read:0"
value:"\n\004v2:0\022\tv2/Assign\032\tv2/read:0"
}
}
}