tensorflow训练中出现nan问题的解决


Posted in Python onFebruary 10, 2018

深度学习中对于网络的训练是参数更新的过程,需要注意一种情况就是输入数据未做归一化时,如果前向传播结果已经是[0,0,0,1,0,0,0,0]这种形式,而真实结果是[1,0,0,0,0,0,0,0,0],此时由于得出的结论不惧有概率性,而是错误的估计值,此时反向传播会使得权重和偏置值变的无穷大,导致数据溢出,也就出现了nan的问题。

解决办法:

1、对输入数据进行归一化处理,如将输入的图片数据除以255将其转化成0-1之间的数据;

2、对于层数较多的情况,各层都做batch_nomorlization;

3、对设置Weights权重使用tf.truncated_normal(0, 0.01, [3,3,1,64])生成,同时值的均值为0,方差要小一些;

4、激活函数可以使用tanh;

5、减小学习率lr。

实例:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('data',one_hot = True)

def add_layer(input_data,in_size, out_size,activation_function=None):
  Weights = tf.Variable(tf.random_normal([in_size,out_size]))
  Biases = tf.Variable(tf.zeros([1, out_size])+0.1)
  Wx_plus_b = tf.add(tf.matmul(input_data, Weights), Biases)
  if activation_function==None:
    outputs = Wx_plus_b
  else:
    outputs = activation_function(Wx_plus_b)
  #return outputs#, Weights
  return {'outdata':outputs, 'w':Weights}

def get_accuracy(t_y):
#  global l1
#  accu = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(l1['outdata'],1),tf.argmax(t_y,1)), dtype = tf.float32))
  global prediction
  accu = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(prediction['outdata'],1),tf.argmax(t_y,1)), dtype = tf.float32))
  return accu

X = tf.placeholder(tf.float32, [None, 784])
Y = tf.placeholder(tf.float32, [None, 10])

#l1 = add_layer(X, 784, 10, tf.nn.softmax)
#cross_entropy = tf.reduce_mean(-tf.reduce_sum(Y*tf.log(l1['outdata']), reduction_indices= [1]))
#l1 = add_layer(X, 784, 1024, tf.nn.relu)

l1 = add_layer(X, 784, 1024, None)
prediction = add_layer(l1['outdata'], 1024, 10, tf.nn.softmax)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(Y*tf.log(prediction['outdata']), reduction_indices= [1]))

optimizer = tf.train.GradientDescentOptimizer(0.000001)
train = optimizer.minimize(cross_entropy)


newW = tf.Variable(tf.random_normal([1024,10]))
newOut = tf.matmul(l1['outdata'],newW)
newSoftMax = tf.nn.softmax(newOut)

init = tf.global_variables_initializer()
with tf.Session() as sess:
  sess.run(init)
  #print(sess.run(l1_Weights))
  for i in range(2):
    X_train, y_train = mnist.train.next_batch(1)
    X_train = X_train/255  #需要进行归一化处理
    #print(sess.run(l1['w'],feed_dict={X:X_train}))
    #print(sess.run(prediction['w'],feed_dict={X:X_train, Y:y_train}))
    #print(sess.run(l1['outdata'],feed_dict={X:X_train, Y:y_train}).shape)
    print(sess.run(prediction['outdata'],feed_dict={X:X_train, Y:y_train}))
    print(sess.run(newOut, feed_dict={X:X_train}))
    print(sess.run(newSoftMax, feed_dict={X:X_train}))
    print(y_train)
    #print(sess.run(l1['outdata'], feed_dict={X:X_train}))
    sess.run(train, feed_dict={X:X_train, Y:y_train})
    if i%100 == 0:
      #print(sess.run(cross_entropy, feed_dict={X:X_train, Y:y_train}))
      accuracy = get_accuracy(mnist.test.labels)
      print(sess.run(accuracy,feed_dict={X:mnist.test.images}))
    
    #if i%100==0:
    #print(sess.run(prediction, feed_dict={X:X_train}))
    #print(sess.run(cross_entropy, feed_dict={X:X_train,Y:y_train}))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python解析xml模块封装代码
Feb 07 Python
wxPython窗口的继承机制实例分析
Sep 28 Python
Python中os.path用法分析
Jan 15 Python
django 常用orm操作详解
Sep 13 Python
Python实现的简单读写csv文件操作示例
Jul 12 Python
pycharm重置设置,恢复默认设置的方法
Oct 22 Python
python实现控制电脑鼠标和键盘,登录QQ的方法示例
Jul 06 Python
django 控制页面跳转的例子
Aug 06 Python
python实现异常信息堆栈输出到日志文件
Dec 26 Python
python实现删除列表中某个元素的3种方法
Jan 15 Python
如何基于Python Matplotlib实现网格动画
Jul 20 Python
详解Python中*args和**kwargs的使用
Apr 07 Python
用Eclipse写python程序
Feb 10 #Python
tensorflow建立一个简单的神经网络的方法
Feb 10 #Python
python取代netcat过程分析
Feb 10 #Python
浅谈Python黑帽子取代netcat
Feb 10 #Python
python3爬取淘宝信息代码分析
Feb 10 #Python
Python中property属性实例解析
Feb 10 #Python
Java编程迭代地删除文件夹及其下的所有文件实例
Feb 10 #Python
You might like
PHP中在数据库中保存Checkbox数据(2)
2006/10/09 PHP
PHP+MySQL之Insert Into数据插入用法分析
2015/09/27 PHP
jQuery 1.8 Release版本发布了
2012/08/14 Javascript
JS动态获取当前时间,并写到特定的区域
2013/05/03 Javascript
js 实现 input type="file" 文件上传示例代码
2013/08/07 Javascript
jquery中的查找parents与closest方法之间的区别
2013/12/02 Javascript
JQuery中的html()、text()、val()区别示例介绍
2014/09/01 Javascript
轻松创建nodejs服务器(8):非阻塞是如何实现的
2014/12/18 NodeJs
JavaScript汉诺塔问题解决方法
2015/04/21 Javascript
JS根据生日算年龄的方法
2015/05/05 Javascript
jQuery animate和CSS3相结合实现缓动追逐效果附源码下载
2016/04/18 Javascript
如何使用jquery修改css中带有!important的样式属性
2016/04/28 Javascript
自己动手制作基于jQuery的Web页面加载进度条插件
2016/06/03 Javascript
JS实现iframe编辑器光标位置插入内容的方法(兼容IE和Firefox)
2016/06/24 Javascript
select下拉框插件jquery.editable-select详解
2017/01/22 Javascript
BootStrap selectpicker后台动态绑定数据
2017/06/01 Javascript
深入学习js函数的隐式参数 arguments 和 this
2019/06/24 Javascript
Vue绑定用户接口实现代码示例
2020/11/04 Javascript
angular8.5集成TinyMce5的使用和详细配置(推荐)
2020/11/16 Javascript
Python3 中把txt数据文件读入到矩阵中的方法
2018/04/27 Python
使用python爬虫获取黄金价格的核心代码
2018/06/13 Python
numpy linalg模块的具体使用方法
2019/05/26 Python
使用Python求解带约束的最优化问题详解
2020/02/11 Python
Python编程快速上手——疯狂填词程序实现方法分析
2020/02/29 Python
Python运算符+与+=的方法实例
2021/02/18 Python
HTML5之HTML元素扩展(下)—增强的Form表单元素值得关注
2013/01/31 HTML / CSS
草莓网美国官网:Strawberrynet USA
2016/12/11 全球购物
元旦获奖感言
2014/03/08 职场文书
文明之星事迹材料
2014/05/09 职场文书
家长给学校的建议书
2014/05/15 职场文书
安全承诺书格式
2014/05/21 职场文书
资源环境与城乡规划管理专业自荐书
2014/09/26 职场文书
华清池导游词
2015/02/02 职场文书
2015新生加入学生会自荐书
2015/03/24 职场文书
2016年暑期见闻作文
2015/11/25 职场文书
python 实现图片特效处理
2022/04/03 Python