tensorflow训练中出现nan问题的解决


Posted in Python onFebruary 10, 2018

深度学习中对于网络的训练是参数更新的过程,需要注意一种情况就是输入数据未做归一化时,如果前向传播结果已经是[0,0,0,1,0,0,0,0]这种形式,而真实结果是[1,0,0,0,0,0,0,0,0],此时由于得出的结论不惧有概率性,而是错误的估计值,此时反向传播会使得权重和偏置值变的无穷大,导致数据溢出,也就出现了nan的问题。

解决办法:

1、对输入数据进行归一化处理,如将输入的图片数据除以255将其转化成0-1之间的数据;

2、对于层数较多的情况,各层都做batch_nomorlization;

3、对设置Weights权重使用tf.truncated_normal(0, 0.01, [3,3,1,64])生成,同时值的均值为0,方差要小一些;

4、激活函数可以使用tanh;

5、减小学习率lr。

实例:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('data',one_hot = True)

def add_layer(input_data,in_size, out_size,activation_function=None):
  Weights = tf.Variable(tf.random_normal([in_size,out_size]))
  Biases = tf.Variable(tf.zeros([1, out_size])+0.1)
  Wx_plus_b = tf.add(tf.matmul(input_data, Weights), Biases)
  if activation_function==None:
    outputs = Wx_plus_b
  else:
    outputs = activation_function(Wx_plus_b)
  #return outputs#, Weights
  return {'outdata':outputs, 'w':Weights}

def get_accuracy(t_y):
#  global l1
#  accu = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(l1['outdata'],1),tf.argmax(t_y,1)), dtype = tf.float32))
  global prediction
  accu = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(prediction['outdata'],1),tf.argmax(t_y,1)), dtype = tf.float32))
  return accu

X = tf.placeholder(tf.float32, [None, 784])
Y = tf.placeholder(tf.float32, [None, 10])

#l1 = add_layer(X, 784, 10, tf.nn.softmax)
#cross_entropy = tf.reduce_mean(-tf.reduce_sum(Y*tf.log(l1['outdata']), reduction_indices= [1]))
#l1 = add_layer(X, 784, 1024, tf.nn.relu)

l1 = add_layer(X, 784, 1024, None)
prediction = add_layer(l1['outdata'], 1024, 10, tf.nn.softmax)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(Y*tf.log(prediction['outdata']), reduction_indices= [1]))

optimizer = tf.train.GradientDescentOptimizer(0.000001)
train = optimizer.minimize(cross_entropy)


newW = tf.Variable(tf.random_normal([1024,10]))
newOut = tf.matmul(l1['outdata'],newW)
newSoftMax = tf.nn.softmax(newOut)

init = tf.global_variables_initializer()
with tf.Session() as sess:
  sess.run(init)
  #print(sess.run(l1_Weights))
  for i in range(2):
    X_train, y_train = mnist.train.next_batch(1)
    X_train = X_train/255  #需要进行归一化处理
    #print(sess.run(l1['w'],feed_dict={X:X_train}))
    #print(sess.run(prediction['w'],feed_dict={X:X_train, Y:y_train}))
    #print(sess.run(l1['outdata'],feed_dict={X:X_train, Y:y_train}).shape)
    print(sess.run(prediction['outdata'],feed_dict={X:X_train, Y:y_train}))
    print(sess.run(newOut, feed_dict={X:X_train}))
    print(sess.run(newSoftMax, feed_dict={X:X_train}))
    print(y_train)
    #print(sess.run(l1['outdata'], feed_dict={X:X_train}))
    sess.run(train, feed_dict={X:X_train, Y:y_train})
    if i%100 == 0:
      #print(sess.run(cross_entropy, feed_dict={X:X_train, Y:y_train}))
      accuracy = get_accuracy(mnist.test.labels)
      print(sess.run(accuracy,feed_dict={X:mnist.test.images}))
    
    #if i%100==0:
    #print(sess.run(prediction, feed_dict={X:X_train}))
    #print(sess.run(cross_entropy, feed_dict={X:X_train,Y:y_train}))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python即时网络爬虫项目启动说明详解
Feb 23 Python
python实现跨excel的工作表sheet之间的复制方法
May 03 Python
wxPython的安装与使用教程
Aug 31 Python
python三引号输出方法
Feb 27 Python
Python向excel中写入数据的方法
May 05 Python
pycharm new project变成灰色的解决方法
Jun 27 Python
python3 中的字符串(单引号、双引号、三引号)以及字符串与数字的运算
Jul 18 Python
python mqtt 客户端的实现代码实例
Sep 25 Python
django实现类似触发器的功能
Nov 15 Python
python实现用类读取文件数据并计算矩形面积
Jan 18 Python
python 等差数列末项计算方式
May 03 Python
Python 如何实现文件自动去重
Jun 02 Python
用Eclipse写python程序
Feb 10 #Python
tensorflow建立一个简单的神经网络的方法
Feb 10 #Python
python取代netcat过程分析
Feb 10 #Python
浅谈Python黑帽子取代netcat
Feb 10 #Python
python3爬取淘宝信息代码分析
Feb 10 #Python
Python中property属性实例解析
Feb 10 #Python
Java编程迭代地删除文件夹及其下的所有文件实例
Feb 10 #Python
You might like
Linux下进行MYSQL编程时插入中文乱码的解决方案
2007/03/15 PHP
比较简单的百度网盘文件直链PHP代码
2013/03/24 PHP
PHP截断标题且兼容utf8和gb2312编码
2013/09/22 PHP
微信支付扫码支付php版
2016/07/22 PHP
document.getElementById方法在Firefox与IE中的区别
2010/05/18 Javascript
拖动布局之保存布局页面cookies篇
2010/10/29 Javascript
收集的10个免费的jQuery相册
2011/02/26 Javascript
Javascript节点关系实例分析
2015/05/15 Javascript
JS验证IP,子网掩码,网关和MAC的方法
2015/07/02 Javascript
jquery+CSS3模拟Path2.0动画菜单效果代码
2015/08/31 Javascript
jquery ztree实现树的搜索功能
2016/02/25 Javascript
node模块机制与异步处理详解
2016/03/13 Javascript
JS匿名函数类生成方式实例分析
2016/11/26 Javascript
详解vue.js的事件处理器v-on:click
2017/06/27 Javascript
使用Node.js实现RESTful API的示例
2017/08/01 Javascript
JS实现定时任务每隔N秒请求后台setInterval定时和ajax请求问题
2017/10/15 Javascript
详解layui弹窗父子窗口之间传参数的方法
2018/01/16 Javascript
JavaScript实现邮箱后缀提示功能的示例代码
2018/12/13 Javascript
Element-UI中关于table表格的那些骚操作(小结)
2019/08/15 Javascript
OpenLayers实现图层切换控件
2020/09/25 Javascript
jenkins自动构建发布vue项目的方法步骤
2021/01/04 Vue.js
python传递参数方式小结
2015/04/17 Python
Python编程中使用Pillow来处理图像的基础教程
2015/11/20 Python
Python基于贪心算法解决背包问题示例
2017/11/27 Python
详解Python sys.argv使用方法
2019/05/10 Python
使用python爬取微博数据打造一颗“心”
2019/06/28 Python
python实现提取COCO,VOC数据集中特定的类
2020/03/10 Python
利用scikitlearn画ROC曲线实例
2020/07/02 Python
Pytest如何使用skip跳过执行测试
2020/08/13 Python
用HTML5.0制作网页的教程
2010/05/30 HTML / CSS
餐厅保洁员岗位职责
2015/04/10 职场文书
逃出克隆岛观后感
2015/06/09 职场文书
2015秋季田径运动会广播稿
2015/08/19 职场文书
学习《中小学教师职业道德规范》心得体会
2016/01/18 职场文书
[有人@你]你有一封绿色倡议书,请查收!
2019/07/18 职场文书
clear 万能清除浮动(clearfix:after)
2023/05/21 HTML / CSS