时间:2021-05-22
一、TensorFlow常规模型加载方法
保存模型
tf.train.Saver()类,.save(sess, ckpt文件目录)方法
参数名称 功能说明 默认值 var_list Saver中存储变量集合 全局变量集合 reshape 加载时是否恢复变量形状 True sharded 是否将变量轮循放在所有设备上 True max_to_keep 保留最近检查点个数 5 restore_sequentially 是否按顺序恢复变量,模型较大时顺序恢复内存消耗小 True
var_list是字典形式{变量名字符串: 变量符号},相对应的restore也根据同样形式的字典将ckpt中的字符串对应的变量加载给程序中的符号。
如果Saver给定了字典作为加载方式,则按照字典来,如:saver = tf.train.Saver({"v/ExponentialMovingAverage":v}),否则每个变量寻找自己的name属性在ckpt中的对应值进行加载。
加载模型
当我们基于checkpoint文件(ckpt)加载参数时,实际上我们使用Saver.restore取代了initializer的初始化
checkpoint文件会记录保存信息,通过它可以定位最新保存的模型:
ckpt = tf.train.get_checkpoint_state('./model/')print(ckpt.model_checkpoint_path).meta文件保存了当前图结构
.index文件保存了当前参数名
.data文件保存了当前参数值
tf.train.import_meta_graph函数给出model.ckpt-n.meta的路径后会加载图结构,并返回saver对象
ckpt = tf.train.get_checkpoint_state('./model/')tf.train.Saver函数会返回加载默认图的saver对象,saver对象初始化时可以指定变量映射方式,根据名字映射变量(『TensorFlow』滑动平均)
saver.restore函数给出model.ckpt-n的路径后会自动寻找参数名-值文件进行加载
saver.restore(sess,'./model/model.ckpt-0')saver.restore(sess,ckpt.model_checkpoint_path)1.不加载图结构,只加载参数
由于实际上我们参数保存的都是Variable变量的值,所以其他的参数值(例如batch_size)等,我们在restore时可能希望修改,但是图结构在train时一般就已经确定了,所以我们可以使用tf.Graph().as_default()新建一个默认图(建议使用上下文环境),利用这个新图修改和变量无关的参值大小,从而达到目的。
'''使用原网络保存的模型加载到自己重新定义的图上可以使用python变量名加载模型,也可以使用节点名'''import AlexNet as Netimport AlexNet_train as trainimport randomimport tensorflow as tf IMAGE_PATH = './flower_photos/daisy/5673728_71b8cb57eb.jpg' with tf.Graph().as_default() as g: x = tf.placeholder(tf.float32, [1, train.INPUT_SIZE[0], train.INPUT_SIZE[1], 3]) y = Net.inference_1(x, N_CLASS=5, train=False) with tf.Session() as sess: # 程序前面得有 Variable 供 save or restore 才不报错 # 否则会提示没有可保存的变量 saver = tf.train.Saver() ckpt = tf.train.get_checkpoint_state('./model/') img_raw = tf.gfile.FastGFile(IMAGE_PATH, 'rb').read() img = sess.run(tf.expand_dims(tf.image.resize_images( tf.image.decode_jpeg(img_raw),[224,224],method=random.randint(0,3)),0)) if ckpt and ckpt.model_checkpoint_path: print(ckpt.model_checkpoint_path) saver.restore(sess,'./model/model.ckpt-0') global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] res = sess.run(y, feed_dict={x: img}) print(global_step,sess.run(tf.argmax(res,1)))2.加载图结构和参数
注意,在所有两种方式中都可以通过调用节点名称使用节点输出张量,节点.name属性返回节点名称。
3.简化版本
二、TensorFlow二进制模型加载方法
这种加载方法一般是对应网上各大公司已经训练好的网络模型进行修改的工作
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持。
声明:本页内容来源网络,仅供用户参考;我单位不保证亦不表示资料全面及准确无误,也不保证亦不表示这些资料为最新信息,如因任何原因,本网内容或者用户因倚赖本网内容造成任何损失或损害,我单位将不会负任何法律责任。如涉及版权问题,请提交至online#300.cn邮箱联系删除。
TensorFlow模型保存/载入我们在上线使用一个算法模型的时候,首先必须将已经训练好的模型保存下来。tensorflow保存模型的方式与sklearn不太一
一、TensorFlow模型保存和提取方法1.TensorFlow通过tf.train.Saver类实现神经网络模型的保存和提取。tf.train.Saver对
使用tensorflow训练模型时,我们可以使用tensorflow自带的Save模块tf.train.Saver()来保存模型,使用方式很简单就是在训练完模型
1.介绍当我们使用pytorch来构建网络框架的时候,也会遇到和tensorflow(tensorflow__init__、build和call小结)类似的情况
TensorFlow保存模型代码importtensorflowastffromtensorflow.python.frameworkimportgraph_u