Tensorflow持久化原理及数据格式
1 2 3 4 5 6 7
| message MetaGraphDef{ MeatInfoDef meta_info_def = 1; GraphDef graph_def = 2; SaverDef saver_def = 3; map<string,CollectionDef> collection_def = 4; map<string,SignatureDef> signature_def = 5; }
|
1 2 3 4 5 6 7 8 9
| import tensorflow as tf
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name ="v1") v2 = tf.Variable(tf.constant(2.0, shape=[1]), name ="v2") result1 = v1 + v2
saver = tf.train.Saver()
saver.export_meta_graph("test/test.ckpt.json", as_text=True)
|
1 2 3 4 5 6 7
| message MetaInfoDef{ string meta_graph_version = 1; OpList stripped_op_list = 2; google.protobuf.Any any_info = 3; repeated string tags = 4; }
|
OpList类型是一个OpDef类型的列表,以下代码给出OpDef类型的定义:
1 2 3 4 5 6 7 8 9 10 11 12 13
| message opDef{ string name = 1; repeated ArgDef input_arg = 2; repeated ArgDef output_arg =3; repeated AttrDef attr = 4; string summary = 5; string description = 6; OpDeprecation deprecation = 8; bool is_commutative = 18; bool is_aggregate = 16 bool is_stateful = 17; bool allows_uninitialized_input = 19; };
|
下面给出一个比较有代表性的运算来辅助说明OpDef的数据结构。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
| op { name: "Add" input_arg{ name: "x" type_attr:"T" } input_arg{ name: "y" type_attr:"T" } output_arg{ name: "z" type_attr:"T" } attr{ name:"T" type:"type" allow_values{ list{ type:DT_HALF type:DT_FLOAT ... } } } }
|
上面给出的是名称为Add的运算。这个运算的输入有两个,输出有一个,输入输出属性均指定了属性typr_attr,并且这个属性的值为T。在OpDef的attr的属性中。必须要出现名称(name)为T的属性。以上样例中,这个属性指定了运算输入输出允许的参数类型 (allowed_values)。
graph_def属性
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
| message GraphDef{ repeated NodeDef node = 1; VersionDef versions = 4; };
message NodeDef{ string name = 1;
''' op属性给出了该节点使用的Tensorflow运算方法的名称。 通过这个名称可以在TensorFlow计算图元图的meta_info_def属性中找到该运算的具体信息。 ''' string op = 2; ''' input属性是一个字符串列表,他定义了运算的输入。每个字符串饿的取值格式为弄的:src_output node部分给出节点名称,src_output表明了这个输入是指定节点的第几个输出。 src_output=0时可以省略src_output部分 ''' repeated string input = 3; string device = 4;
map<string, AttrValue> attr = 5; };
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
| graph def { node { name: "v1" op: "Variable" attr { key:"_output_shapes" value { list{ shape { dim { size: 1 } } } } } } attr { key :"dtype" value { type: DT_FLOAT } } ... } node { name :"add" op :"Add" input :"v1/read" input: "v2/read" ... } node { name: "save/control_dependency" op:"Identity" ... } versions { producer :9 } }
|
saver_def属性
1 2 3 4 5 6 7 8 9 10 11 12 13 14
| message SaverDef { string filename_tensor_name = 1; string save_tensor_name = 2; string restore_op_name = 3; int32 max_to_keep = 4; bool sharded = 5; float keep_checkpoint_every_n_hours = 6; enum CheckpointFormatVersion { LEGACY = 0; V1 = 1; V2 = 2; } CheckpointFormatVersion version = 7; }
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| saver_def { filename_tensor_name :"save/Const:0” #给出了保存文件的张量名,这个张量就是节点save/Const的第一个输出。
save_tensor_name :"save/control_dependency: 0” restore_op_name: "save/restore_all" max_to_keep:5 keep_checkpoint_every_n_hours :10000.0 ''' 上面两个属性设定了tf.train.Saver类清理之前保存的模型的策略。比如当max_to_keep为5时,第六次调用 saver.save时,第一次保存的模型就会被自动删除,通过设置keep_checkpoint_every_n_hours,每n小 时可以在max_to_keep的基础上保存一个模型 '''
|
collection_def属性
collection_def属性是一个集合名称到集合内容的映射,其中集合的名称为字符串,而集合内容为CollectionDef Protocol Buffer。以下代码给出CollectionDef类型的定义
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
| message CollectionDef { message Nodelist { repeated string value = 1; }
message BytesList { repeated bytes value = 1 ; }
message Int64List { repeated int64 value = 1[packed = true]; } message FloatList { repeated float value = 1[packed = true] ; } message AnyList { repeated google.protobuf.Any value= 1; } oneof kind { NodeList node_list = 1; BytesList bytes_lista = 2; Int64List int64_list = 3; Floatlist float_list = 4; AnyList any_list = 5; } }
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
| collection_def { key: "trainable_variables" value { bytes_list { value; "\n\004v1:0\022\tv1/Assign\032\tv1/read:0" value: "\n\004v2:0\022\tv2/Assign\032\cv2/read:0" } } } collection_def { key: "variables" value { bytes_list { value:"\n\004v1:0\022\tv1/Assign\032\tv1/read:0" value:"\n\004v2:0\022\tv2/Assign\032\tv2/read:0" } } }
|