tensorflow训练中出现nan问题的解决


Posted in Python onFebruary 10, 2018

深度学习中对于网络的训练是参数更新的过程,需要注意一种情况就是输入数据未做归一化时,如果前向传播结果已经是[0,0,0,1,0,0,0,0]这种形式,而真实结果是[1,0,0,0,0,0,0,0,0],此时由于得出的结论不惧有概率性,而是错误的估计值,此时反向传播会使得权重和偏置值变的无穷大,导致数据溢出,也就出现了nan的问题。

解决办法:

1、对输入数据进行归一化处理,如将输入的图片数据除以255将其转化成0-1之间的数据;

2、对于层数较多的情况,各层都做batch_nomorlization;

3、对设置Weights权重使用tf.truncated_normal(0, 0.01, [3,3,1,64])生成,同时值的均值为0,方差要小一些;

4、激活函数可以使用tanh;

5、减小学习率lr。

实例:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('data',one_hot = True)

def add_layer(input_data,in_size, out_size,activation_function=None):
  Weights = tf.Variable(tf.random_normal([in_size,out_size]))
  Biases = tf.Variable(tf.zeros([1, out_size])+0.1)
  Wx_plus_b = tf.add(tf.matmul(input_data, Weights), Biases)
  if activation_function==None:
    outputs = Wx_plus_b
  else:
    outputs = activation_function(Wx_plus_b)
  #return outputs#, Weights
  return {'outdata':outputs, 'w':Weights}

def get_accuracy(t_y):
#  global l1
#  accu = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(l1['outdata'],1),tf.argmax(t_y,1)), dtype = tf.float32))
  global prediction
  accu = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(prediction['outdata'],1),tf.argmax(t_y,1)), dtype = tf.float32))
  return accu

X = tf.placeholder(tf.float32, [None, 784])
Y = tf.placeholder(tf.float32, [None, 10])

#l1 = add_layer(X, 784, 10, tf.nn.softmax)
#cross_entropy = tf.reduce_mean(-tf.reduce_sum(Y*tf.log(l1['outdata']), reduction_indices= [1]))
#l1 = add_layer(X, 784, 1024, tf.nn.relu)

l1 = add_layer(X, 784, 1024, None)
prediction = add_layer(l1['outdata'], 1024, 10, tf.nn.softmax)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(Y*tf.log(prediction['outdata']), reduction_indices= [1]))

optimizer = tf.train.GradientDescentOptimizer(0.000001)
train = optimizer.minimize(cross_entropy)


newW = tf.Variable(tf.random_normal([1024,10]))
newOut = tf.matmul(l1['outdata'],newW)
newSoftMax = tf.nn.softmax(newOut)

init = tf.global_variables_initializer()
with tf.Session() as sess:
  sess.run(init)
  #print(sess.run(l1_Weights))
  for i in range(2):
    X_train, y_train = mnist.train.next_batch(1)
    X_train = X_train/255  #需要进行归一化处理
    #print(sess.run(l1['w'],feed_dict={X:X_train}))
    #print(sess.run(prediction['w'],feed_dict={X:X_train, Y:y_train}))
    #print(sess.run(l1['outdata'],feed_dict={X:X_train, Y:y_train}).shape)
    print(sess.run(prediction['outdata'],feed_dict={X:X_train, Y:y_train}))
    print(sess.run(newOut, feed_dict={X:X_train}))
    print(sess.run(newSoftMax, feed_dict={X:X_train}))
    print(y_train)
    #print(sess.run(l1['outdata'], feed_dict={X:X_train}))
    sess.run(train, feed_dict={X:X_train, Y:y_train})
    if i%100 == 0:
      #print(sess.run(cross_entropy, feed_dict={X:X_train, Y:y_train}))
      accuracy = get_accuracy(mnist.test.labels)
      print(sess.run(accuracy,feed_dict={X:mnist.test.images}))
    
    #if i%100==0:
    #print(sess.run(prediction, feed_dict={X:X_train}))
    #print(sess.run(cross_entropy, feed_dict={X:X_train,Y:y_train}))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python入门篇之函数
Oct 20 Python
浅析Python中元祖、列表和字典的区别
Aug 17 Python
python web.py开发httpserver解决跨域问题实例解析
Feb 12 Python
Python面向对象之静态属性、类方法与静态方法分析
Aug 24 Python
利用Python将文本中的中英文分离方法
Oct 31 Python
python简单实现AES加密和解密
Mar 28 Python
python实现知乎高颜值图片爬取
Aug 12 Python
Python如何在循环内使用list.remove()
Jun 01 Python
Python如何合并多个字典或映射
Jul 24 Python
Selenium 安装和简单使用的实现
Dec 04 Python
关于python中remove的一些坑小结
Jan 04 Python
tensorflow中的梯度求解及梯度裁剪操作
May 26 Python
用Eclipse写python程序
Feb 10 #Python
tensorflow建立一个简单的神经网络的方法
Feb 10 #Python
python取代netcat过程分析
Feb 10 #Python
浅谈Python黑帽子取代netcat
Feb 10 #Python
python3爬取淘宝信息代码分析
Feb 10 #Python
Python中property属性实例解析
Feb 10 #Python
Java编程迭代地删除文件夹及其下的所有文件实例
Feb 10 #Python
You might like
经典的星际争霸,满是回忆的BGM
2020/04/09 星际争霸
PHP 数字左侧自动补0
2008/03/31 PHP
PHP 冒泡排序 二分查找 顺序查找 二维数组排序算法函数的详解
2013/06/25 PHP
php将url地址转化为完整的a标签链接代码(php为url地址添加a标签)
2014/01/17 PHP
php使用curl模拟多线程实现批处理功能示例
2019/07/25 PHP
CLASS_CONFUSION JS混淆 全源码
2007/12/12 Javascript
JavaScript 验证浏览器是否支持javascript的方法小结
2009/05/17 Javascript
jquery下jstree简单应用 - v1.0
2011/04/14 Javascript
Javascript中3个需要注意的运算符
2015/04/02 Javascript
javascript实现连续赋值
2015/08/10 Javascript
果断收藏9个Javascript代码高亮脚本
2016/01/06 Javascript
微信小程序 条件渲染详解
2016/10/09 Javascript
微信小程序 数据交互与渲染实例详解
2017/01/21 Javascript
angular实现表单验证及提交功能
2017/02/01 Javascript
几行js代码实现自适应
2017/02/24 Javascript
JS实现复选框的全选和批量删除功能
2017/04/05 Javascript
原生JS实现图片懒加载(lazyload)实例
2017/06/13 Javascript
基于bootstrap写的一点localStorage本地储存
2017/11/21 Javascript
JavaScript累加、迭代、穷举、递归等常用算法实例小结
2018/05/08 Javascript
npm 下载指定版本的组件方法
2018/05/17 Javascript
angular 实现同步验证器跨字段验证的方法
2019/04/11 Javascript
vue实现滑动到底部加载更多效果
2020/10/27 Javascript
vue制作toast组件npm包示例代码
2020/10/29 Javascript
[01:02:09]Liquid vs TNC 2019国际邀请赛淘汰赛 胜者组 BO3 第二场 8.21
2020/07/19 DOTA
python常规方法实现数组的全排列
2015/03/17 Python
Python3 pip3 list 出现 DEPRECATION 警告的解决方法
2019/02/16 Python
python把1变成01的步骤总结
2019/02/27 Python
python 调试冷知识(小结)
2019/11/11 Python
Python %r和%s区别代码实例解析
2020/04/03 Python
Python 获取异常(Exception)信息的几种方法
2020/12/29 Python
美国家居装饰和豪华家具购物网站:One Kings Lane
2018/12/24 全球购物
Gretna Green中文官网:苏格兰格林小镇
2019/10/16 全球购物
JBL加拿大官方商店:扬声器、耳机等
2020/10/23 全球购物
护士优质服务演讲稿
2014/08/26 职场文书
单位个人查摆问题及整改措施
2014/10/28 职场文书
会计求职自荐信
2015/03/26 职场文书