首页 > 编程 > Python > 正文

详解TensorFlow查看ckpt中变量的几种方法

2020-02-15 21:54:34
字体:
来源:转载
供稿:网友

查看TensorFlow中checkpoint内变量的几种方法

查看ckpt中变量的方法有三种:

    在有model的情况下,使用tf.train.Saver进行restore 使用tf.train.NewCheckpointReader直接读取ckpt文件,这种方法不需要model。 使用tools里的freeze_graph来读取ckpt

注意:

    如果模型保存为.ckpt的文件,则使用该文件就可以查看.ckpt文件里的变量。ckpt路径为 model.ckpt 如果模型保存为.ckpt-xxx-data (图结构)、.ckpt-xxx.index (参数名)、.ckpt-xxx-meta (参数值)文件,则需要同时拥有这三个文件才行。并且ckpt的路径为 model.ckpt-xxx

1. 基于model来读取ckpt文件里的变量

1.首先建立model
2.从ckpt中恢复变量

with tf.Graph().as_default() as g:   #建立model  images, labels = cifar10.inputs(eval_data=eval_data)   logits = cifar10.inference(images)   top_k_op = tf.nn.in_top_k(logits, labels, 1)   #从ckpt中恢复变量  sess = tf.Session()  saver = tf.train.Saver() #saver = tf.train.Saver(...variables...) # 恢复部分变量时,只需要在Saver里指定要恢复的变量  save_path = 'ckpt的路径'  saver.restore(sess, save_path) # 从ckpt中恢复变量

注意:基于model来读取ckpt中变量时,model和ckpt必须匹配。

2. 使用tf.train.NewCheckpointReader直接读取ckpt文件里的变量,使用tools.inspect_checkpoint里的print_tensors_in_checkpoint_file函数打印ckpt里的东西

#使用NewCheckpointReader来读取ckpt里的变量from tensorflow.python import pywrap_tensorflowcheckpoint_path = os.path.join(model_dir, "model.ckpt")reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) #tf.train.NewCheckpointReadervar_to_shape_map = reader.get_variable_to_shape_map()for key in var_to_shape_map:  print("tensor_name: ", key)  #print(reader.get_tensor(key))
#使用print_tensors_in_checkpoint_file打印ckpt里的内容from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_fileprint_tensors_in_checkpoint_file(file_name, #ckpt文件名字                 tensor_name, # 如果为None,则默认为ckpt里的所有变量                 all_tensors, # bool 是否打印所有的tensor,这里打印出的是tensor的值,一般不推荐这里设置为False                 all_tensor_names) # bool 是否打印所有的tensor的name#上面的打印ckpt的内部使用的是pywrap_tensorflow.NewCheckpointReader所以,掌握NewCheckpointReader才是王道

3.使用tools里的freeze_graph来读取ckpt

from tensorflow.python.tools import freeze_graphfreeze_graph(input_graph, #=some_graph_def.pb       input_saver,        input_binary,        input_checkpoint, #=model.ckpt       output_node_names, #=softmax       restore_op_name,        filename_tensor_name,        output_graph, #='./tmp/frozen_graph.pb'       clear_devices,        initializer_nodes,        variable_names_whitelist='',        variable_names_blacklist='',        input_meta_graph=None,        input_saved_model_dir=None,        saved_model_tags='serve',        checkpoint_version=2)#freeze_graph_test.py讲述了怎么使用freeze_grapg。            
发表评论 共有条评论
用户名: 密码:
验证码: 匿名发表