tensorflow中固定部分参数训练和只训练部分参数的示例分析-创新互联
这篇文章主要为大家展示了“tensorflow中固定部分参数训练和只训练部分参数的示例分析”,内容简而易懂,条理清晰,希望能够帮助大家解决疑惑,下面让小编带领大家一起研究并学习一下“tensorflow中固定部分参数训练和只训练部分参数的示例分析”这篇文章吧。
成都创新互联公司作为成都网站建设公司,专注重庆网站建设公司、网站设计,有关成都企业网站定制方案、改版、费用等问题,行业涉及PE包装袋等多个领域,已为上千家企业服务,得到了客户的尊重与认可。在使用tensorflow来训练一个模型的时候,有时候需要依靠验证集来判断模型是否已经过拟合,是否需要停止训练。
1.首先想到的是用tf.placeholder()载入不同的数据来进行计算,比如
def inference(input_): """ this is where you put your graph. the following is just an example. """ conv1 = tf.layers.conv2d(input_) conv2 = tf.layers.conv2d(conv1) return conv2 input_ = tf.placeholder() output = inference(input_) ... calculate_loss_op = ... train_op = ... ... with tf.Session() as sess: sess.run([loss, train_op], feed_dict={input_: train_data}) if validation == True: sess.run([loss], feed_dict={input_: validate_date})
这种方式很简单,也很直接了然。
2.但是,如果处理的数据量很大的时候,使用 tf.placeholder() 来载入数据会严重地拖慢训练的进度,因此,常用tfrecords文件来读取数据。
此时,很容易想到,将不同的值传入inference()函数中进行计算。
train_batch, label_batch = decode_train() val_train_batch, val_label_batch = decode_validation() train_result = inference(train_batch) ... loss = .. train_op = ... ... if validation == True: val_result = inference(val_train_batch) val_loss = .. with tf.Session() as sess: sess.run([loss, train_op]) if validation == True: sess.run([val_result, val_loss])
这种方式看似能够直接调用inference()来对验证数据进行前向传播计算,但是,实则会在原图上添加上许多新的结点,这些结点的参数都是需要重新初始化的,也是就是说,验证的时候并不是使用训练的权重。
3.用一个tf.placeholder来控制是否训练、验证。
def inference(input_): ... ... ... return inference_result train_batch, label_batch = decode_train() val_batch, val_label = decode_validation() is_training = tf.placeholder(tf.bool, shape=()) x = tf.cond(is_training, lambda: train_batch, lambda: val_batch) y = tf.cond(is_training, lambda: train_label, lambda: val_label) logits = inference(x) loss = cal_loss(logits, y) train_op = optimize(loss) with tf.Session() as sess: loss, _ = sess.run([loss, train_op], feed_dict={is_training: True}) if validation == True: loss = sess.run(loss, feed_dict={is_training: False})
使用这种方式就可以在一个大图里创建一个分支条件,从而通过控制placeholder来控制是否进行验证。
以上是“tensorflow中固定部分参数训练和只训练部分参数的示例分析”这篇文章的所有内容,感谢各位的阅读!相信大家都有了一定的了解,希望分享的内容对大家有所帮助,如果还想学习更多知识,欢迎关注创新互联成都网站设计公司行业资讯频道!
另外有需要云服务器可以了解下创新互联scvps.cn,海内外云服务器15元起步,三天无理由+7*72小时售后在线,公司持有idc许可证,提供“云服务器、裸金属服务器、高防服务器、香港服务器、美国服务器、虚拟主机、免备案服务器”等云主机租用服务以及企业上云的综合解决方案,具有“安全稳定、简单易用、服务可用性高、性价比高”等特点与优势,专为企业上云打造定制,能够满足用户丰富、多元化的应用场景需求。
网站栏目:tensorflow中固定部分参数训练和只训练部分参数的示例分析-创新互联
网页地址:http://scyanting.com/article/dggieh.html