tensorflow 固定部分参数训练,只训练部分参数的实例

时间:2021-05-22

在使用tensorflow来训练一个模型的时候,有时候需要依靠验证集来判断模型是否已经过拟合,是否需要停止训练。

1.首先想到的是用tf.placeholder()载入不同的数据来进行计算,比如

def inference(input_): """ this is where you put your graph. the following is just an example. """ conv1 = tf.layers.conv2d(input_) conv2 = tf.layers.conv2d(conv1) return conv2 input_ = tf.placeholder()output = inference(input_)...calculate_loss_op = ...train_op = ...... with tf.Session() as sess: sess.run([loss, train_op], feed_dict={input_: train_data}) if validation == True: sess.run([loss], feed_dict={input_: validate_date})

这种方式很简单,也很直接了然。

2.但是,如果处理的数据量很大的时候,使用 tf.placeholder() 来载入数据会严重地拖慢训练的进度,因此,常用tfrecords文件来读取数据。

此时,很容易想到,将不同的值传入inference()函数中进行计算。

train_batch, label_batch = decode_train()val_train_batch, val_label_batch = decode_validation() train_result = inference(train_batch)...loss = ..train_op = ...... if validation == True: val_result = inference(val_train_batch) val_loss = .. with tf.Session() as sess: sess.run([loss, train_op]) if validation == True: sess.run([val_result, val_loss])

这种方式看似能够直接调用inference()来对验证数据进行前向传播计算,但是,实则会在原图上添加上许多新的结点,这些结点的参数都是需要重新初始化的,也是就是说,验证的时候并不是使用训练的权重。

3.用一个tf.placeholder来控制是否训练、验证。

def inference(input_): ... ... ... return inference_result train_batch, label_batch = decode_train()val_batch, val_label = decode_validation() is_training = tf.placeholder(tf.bool, shape=()) x = tf.cond(is_training, lambda: train_batch, lambda: val_batch)y = tf.cond(is_training, lambda: train_label, lambda: val_label) logits = inference(x)loss = cal_loss(logits, y)train_op = optimize(loss) with tf.Session() as sess: loss, _ = sess.run([loss, train_op], feed_dict={is_training: True}) if validation == True: loss = sess.run(loss, feed_dict={is_training: False})

使用这种方式就可以在一个大图里创建一个分支条件,从而通过控制placeholder来控制是否进行验证。

以上这篇tensorflow 固定部分参数训练,只训练部分参数的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。

声明:本页内容来源网络,仅供用户参考;我单位不保证亦不表示资料全面及准确无误,也不保证亦不表示这些资料为最新信息,如因任何原因,本网内容或者用户因倚赖本网内容造成任何损失或损害,我单位将不会负任何法律责任。如涉及版权问题,请提交至online#300.cn邮箱联系删除。

相关文章