tensorflow训练中出现nan问题的解决


Posted in Python onFebruary 10, 2018

深度学习中对于网络的训练是参数更新的过程,需要注意一种情况就是输入数据未做归一化时,如果前向传播结果已经是[0,0,0,1,0,0,0,0]这种形式,而真实结果是[1,0,0,0,0,0,0,0,0],此时由于得出的结论不惧有概率性,而是错误的估计值,此时反向传播会使得权重和偏置值变的无穷大,导致数据溢出,也就出现了nan的问题。

解决办法:

1、对输入数据进行归一化处理,如将输入的图片数据除以255将其转化成0-1之间的数据;

2、对于层数较多的情况,各层都做batch_nomorlization;

3、对设置Weights权重使用tf.truncated_normal(0, 0.01, [3,3,1,64])生成,同时值的均值为0,方差要小一些;

4、激活函数可以使用tanh;

5、减小学习率lr。

实例:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('data',one_hot = True)

def add_layer(input_data,in_size, out_size,activation_function=None):
  Weights = tf.Variable(tf.random_normal([in_size,out_size]))
  Biases = tf.Variable(tf.zeros([1, out_size])+0.1)
  Wx_plus_b = tf.add(tf.matmul(input_data, Weights), Biases)
  if activation_function==None:
    outputs = Wx_plus_b
  else:
    outputs = activation_function(Wx_plus_b)
  #return outputs#, Weights
  return {'outdata':outputs, 'w':Weights}

def get_accuracy(t_y):
#  global l1
#  accu = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(l1['outdata'],1),tf.argmax(t_y,1)), dtype = tf.float32))
  global prediction
  accu = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(prediction['outdata'],1),tf.argmax(t_y,1)), dtype = tf.float32))
  return accu

X = tf.placeholder(tf.float32, [None, 784])
Y = tf.placeholder(tf.float32, [None, 10])

#l1 = add_layer(X, 784, 10, tf.nn.softmax)
#cross_entropy = tf.reduce_mean(-tf.reduce_sum(Y*tf.log(l1['outdata']), reduction_indices= [1]))
#l1 = add_layer(X, 784, 1024, tf.nn.relu)

l1 = add_layer(X, 784, 1024, None)
prediction = add_layer(l1['outdata'], 1024, 10, tf.nn.softmax)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(Y*tf.log(prediction['outdata']), reduction_indices= [1]))

optimizer = tf.train.GradientDescentOptimizer(0.000001)
train = optimizer.minimize(cross_entropy)


newW = tf.Variable(tf.random_normal([1024,10]))
newOut = tf.matmul(l1['outdata'],newW)
newSoftMax = tf.nn.softmax(newOut)

init = tf.global_variables_initializer()
with tf.Session() as sess:
  sess.run(init)
  #print(sess.run(l1_Weights))
  for i in range(2):
    X_train, y_train = mnist.train.next_batch(1)
    X_train = X_train/255  #需要进行归一化处理
    #print(sess.run(l1['w'],feed_dict={X:X_train}))
    #print(sess.run(prediction['w'],feed_dict={X:X_train, Y:y_train}))
    #print(sess.run(l1['outdata'],feed_dict={X:X_train, Y:y_train}).shape)
    print(sess.run(prediction['outdata'],feed_dict={X:X_train, Y:y_train}))
    print(sess.run(newOut, feed_dict={X:X_train}))
    print(sess.run(newSoftMax, feed_dict={X:X_train}))
    print(y_train)
    #print(sess.run(l1['outdata'], feed_dict={X:X_train}))
    sess.run(train, feed_dict={X:X_train, Y:y_train})
    if i%100 == 0:
      #print(sess.run(cross_entropy, feed_dict={X:X_train, Y:y_train}))
      accuracy = get_accuracy(mnist.test.labels)
      print(sess.run(accuracy,feed_dict={X:mnist.test.images}))
    
    #if i%100==0:
    #print(sess.run(prediction, feed_dict={X:X_train}))
    #print(sess.run(cross_entropy, feed_dict={X:X_train,Y:y_train}))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python版微信红包分配算法
May 04 Python
python下paramiko模块实现ssh连接登录Linux服务器
Jun 03 Python
Python基于pillow判断图片完整性的方法
Sep 18 Python
利用Python找出序列中出现最多的元素示例代码
Dec 08 Python
Django CBV与FBV原理及实例详解
Aug 12 Python
python数据持久存储 pickle模块的基本使用方法解析
Aug 30 Python
通过Python编写一个简单登录功能过程解析
Sep 04 Python
PyCharm下载和安装详细步骤
Dec 17 Python
python如何更新包
Jun 11 Python
如何解决pycharm调试报错的问题
Aug 06 Python
关于python中remove的一些坑小结
Jan 04 Python
Python编程中内置的NotImplemented类型的用法
Mar 23 Python
用Eclipse写python程序
Feb 10 #Python
tensorflow建立一个简单的神经网络的方法
Feb 10 #Python
python取代netcat过程分析
Feb 10 #Python
浅谈Python黑帽子取代netcat
Feb 10 #Python
python3爬取淘宝信息代码分析
Feb 10 #Python
Python中property属性实例解析
Feb 10 #Python
Java编程迭代地删除文件夹及其下的所有文件实例
Feb 10 #Python
You might like
PHP基础学习小结
2011/04/17 PHP
PHP递归复制、移动目录的自定义函数分享
2014/11/18 PHP
yii实现图片上传及缩略图生成的方法
2014/12/04 PHP
php通过array_shift()函数移除数组第一个元素的方法
2015/03/18 PHP
基于php的CMS中展示文章类实例分析
2015/06/18 PHP
laravel 使用auth编写登录的方法
2019/09/30 PHP
jquery方法+js一般方法+js面向对象方法实现拖拽效果
2012/08/30 Javascript
js实现日期级联效果
2014/01/23 Javascript
JSON与XML优缺点对比分析
2015/07/17 Javascript
使用angularjs创建简单表格
2016/01/21 Javascript
jQuery过滤特殊字符及JS字符串转为数字
2016/05/26 Javascript
jQuery的层级查找方式分析
2016/06/16 Javascript
移动端点击图片放大特效PhotoSwipe.js插件实现
2016/08/25 Javascript
AngularJS使用带属性值的ng-app指令实现自定义模块自动加载的方法
2017/01/04 Javascript
JS组件系列之MVVM组件 vue 30分钟搞定前端增删改查
2017/04/28 Javascript
ReactNative页面跳转Navigator实现的示例代码
2017/08/02 Javascript
BootStrap Table实现server分页序号连续显示功能(当前页从上一页的结束序号开始)
2017/09/12 Javascript
JS三级联动代码格式实例详解
2019/12/30 Javascript
JS继承定义与使用方法简单示例
2020/02/19 Javascript
在Webpack中用url-loader处理图片和字体的问题
2020/04/28 Javascript
JavaScript 实现拖拽效果组件功能(兼容移动端)
2020/11/11 Javascript
nodejs+express最简易的连接数据库的方法
2020/12/23 NodeJs
使用python删除nginx缓存文件示例(python文件操作)
2014/03/26 Python
用pycharm开发django项目示例代码
2018/10/24 Python
对pandas的层次索引与取值的新方法详解
2018/11/06 Python
详解Django项目中模板标签及模板的继承与引用(网站中快速布置广告)
2019/03/27 Python
python实现简易淘宝购物
2019/11/22 Python
以SQLite和PySqlite为例来学习Python DB API
2020/02/05 Python
html5+css3之动画在webapp中的应用
2014/11/21 HTML / CSS
海蓝之谜(LA MER)澳大利亚官方商城:全球高端奢华护肤品牌
2017/10/27 全球购物
Shopee菲律宾:在线购买和出售
2019/11/25 全球购物
安全大检查实施方案
2014/02/22 职场文书
本科生求职信
2014/06/17 职场文书
学生实习证明模板汇总
2014/09/25 职场文书
财务统计员岗位职责
2015/04/14 职场文书
大学生心理健康活动总结
2015/05/08 职场文书