tensorflow训练中出现nan问题的解决


Posted in Python onFebruary 10, 2018

深度学习中对于网络的训练是参数更新的过程,需要注意一种情况就是输入数据未做归一化时,如果前向传播结果已经是[0,0,0,1,0,0,0,0]这种形式,而真实结果是[1,0,0,0,0,0,0,0,0],此时由于得出的结论不惧有概率性,而是错误的估计值,此时反向传播会使得权重和偏置值变的无穷大,导致数据溢出,也就出现了nan的问题。

解决办法:

1、对输入数据进行归一化处理,如将输入的图片数据除以255将其转化成0-1之间的数据;

2、对于层数较多的情况,各层都做batch_nomorlization;

3、对设置Weights权重使用tf.truncated_normal(0, 0.01, [3,3,1,64])生成,同时值的均值为0,方差要小一些;

4、激活函数可以使用tanh;

5、减小学习率lr。

实例:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('data',one_hot = True)

def add_layer(input_data,in_size, out_size,activation_function=None):
  Weights = tf.Variable(tf.random_normal([in_size,out_size]))
  Biases = tf.Variable(tf.zeros([1, out_size])+0.1)
  Wx_plus_b = tf.add(tf.matmul(input_data, Weights), Biases)
  if activation_function==None:
    outputs = Wx_plus_b
  else:
    outputs = activation_function(Wx_plus_b)
  #return outputs#, Weights
  return {'outdata':outputs, 'w':Weights}

def get_accuracy(t_y):
#  global l1
#  accu = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(l1['outdata'],1),tf.argmax(t_y,1)), dtype = tf.float32))
  global prediction
  accu = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(prediction['outdata'],1),tf.argmax(t_y,1)), dtype = tf.float32))
  return accu

X = tf.placeholder(tf.float32, [None, 784])
Y = tf.placeholder(tf.float32, [None, 10])

#l1 = add_layer(X, 784, 10, tf.nn.softmax)
#cross_entropy = tf.reduce_mean(-tf.reduce_sum(Y*tf.log(l1['outdata']), reduction_indices= [1]))
#l1 = add_layer(X, 784, 1024, tf.nn.relu)

l1 = add_layer(X, 784, 1024, None)
prediction = add_layer(l1['outdata'], 1024, 10, tf.nn.softmax)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(Y*tf.log(prediction['outdata']), reduction_indices= [1]))

optimizer = tf.train.GradientDescentOptimizer(0.000001)
train = optimizer.minimize(cross_entropy)


newW = tf.Variable(tf.random_normal([1024,10]))
newOut = tf.matmul(l1['outdata'],newW)
newSoftMax = tf.nn.softmax(newOut)

init = tf.global_variables_initializer()
with tf.Session() as sess:
  sess.run(init)
  #print(sess.run(l1_Weights))
  for i in range(2):
    X_train, y_train = mnist.train.next_batch(1)
    X_train = X_train/255  #需要进行归一化处理
    #print(sess.run(l1['w'],feed_dict={X:X_train}))
    #print(sess.run(prediction['w'],feed_dict={X:X_train, Y:y_train}))
    #print(sess.run(l1['outdata'],feed_dict={X:X_train, Y:y_train}).shape)
    print(sess.run(prediction['outdata'],feed_dict={X:X_train, Y:y_train}))
    print(sess.run(newOut, feed_dict={X:X_train}))
    print(sess.run(newSoftMax, feed_dict={X:X_train}))
    print(y_train)
    #print(sess.run(l1['outdata'], feed_dict={X:X_train}))
    sess.run(train, feed_dict={X:X_train, Y:y_train})
    if i%100 == 0:
      #print(sess.run(cross_entropy, feed_dict={X:X_train, Y:y_train}))
      accuracy = get_accuracy(mnist.test.labels)
      print(sess.run(accuracy,feed_dict={X:mnist.test.images}))
    
    #if i%100==0:
    #print(sess.run(prediction, feed_dict={X:X_train}))
    #print(sess.run(cross_entropy, feed_dict={X:X_train,Y:y_train}))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
tornado框架blog模块分析与使用
Nov 21 Python
在Python上基于Markov链生成伪随机文本的教程
Apr 17 Python
Python实现简单的文件传输与MySQL备份的脚本分享
Jan 03 Python
教你用Type Hint提高Python程序开发效率
Aug 08 Python
Python探索之ModelForm代码详解
Oct 26 Python
Pandas GroupBy对象 索引与迭代方法
Nov 16 Python
python进行文件对比的方法
Dec 24 Python
TensorFlow卷积神经网络之使用训练好的模型识别猫狗图片
Mar 14 Python
django商品分类及商品数据建模实例详解
Jan 03 Python
学生如何注册Pycharm专业版以及pycharm的安装
Sep 24 Python
安装Anaconda3及使用Jupyter的方法
Oct 27 Python
用Python创建简易网站图文教程
Jun 11 Python
用Eclipse写python程序
Feb 10 #Python
tensorflow建立一个简单的神经网络的方法
Feb 10 #Python
python取代netcat过程分析
Feb 10 #Python
浅谈Python黑帽子取代netcat
Feb 10 #Python
python3爬取淘宝信息代码分析
Feb 10 #Python
Python中property属性实例解析
Feb 10 #Python
Java编程迭代地删除文件夹及其下的所有文件实例
Feb 10 #Python
You might like
ajax 的post方法实例(带循环)
2011/07/04 PHP
PHP 面向对象详解
2012/09/13 PHP
PHP整合七牛实现上传文件
2015/07/03 PHP
js对数字的格式化使用说明
2011/01/12 Javascript
常用的jquery模板插件——jQuery Boilerplate介绍
2014/09/23 Javascript
jQuery中extend函数详解
2015/02/13 Javascript
jquery使用each方法遍历json格式数据实例
2015/05/18 Javascript
使用pcs api往免费的百度网盘上传下载文件的方法
2016/03/17 Javascript
JS正则表达式修饰符中multiline(/m)用法分析
2016/12/27 Javascript
jQuery元素选择器实例代码
2017/02/06 Javascript
微信小程序 标签传入数据
2017/05/08 Javascript
JavaScript面向对象精要(上部)
2017/09/12 Javascript
js + css实现标签内容切换功能(实例讲解)
2017/10/09 Javascript
vue.js实例对象+组件树的详细介绍
2017/10/20 Javascript
极简主义法编写JavaScript类
2017/11/02 Javascript
VueJs使用Amaze ui调整列表和内容页面
2017/11/30 Javascript
nodejs的路径问题的解决
2018/06/30 NodeJs
如何在微信小程序中存setStorage
2019/12/13 Javascript
微信小程序调用wx.getImageInfo遇到的坑解决
2020/05/31 Javascript
jQuery+ajax实现用户登录验证
2020/09/13 jQuery
js+canvas绘制图形验证码
2020/09/21 Javascript
Python 可爱的大小写
2008/09/06 Python
Python多线程编程简单介绍
2015/04/13 Python
实例讲解Python爬取网页数据
2018/07/08 Python
对pycharm 修改程序运行所需内存详解
2018/12/03 Python
python爬虫超时的处理的实例
2018/12/19 Python
python实现对象列表根据某个属性排序的方法详解
2019/06/11 Python
python字符串Intern机制详解
2019/07/01 Python
HTML5标签小集
2011/08/02 HTML / CSS
英国莱斯特松木橡木家具网上商店:Choice Furniture Superstore
2019/07/05 全球购物
zooplus意大利:在线宠物商店
2019/08/07 全球购物
银行先进个人总结
2015/02/15 职场文书
golang elasticsearch Client的使用详解
2021/05/05 Golang
Django与数据库交互的实现
2021/06/03 Python
python获取对象信息的实例详解
2021/07/07 Python
集英社今正式宣布 成立游戏公司“集英社Games”
2022/03/31 其他游戏