tensorflow训练中出现nan问题的解决


Posted in Python onFebruary 10, 2018

深度学习中对于网络的训练是参数更新的过程,需要注意一种情况就是输入数据未做归一化时,如果前向传播结果已经是[0,0,0,1,0,0,0,0]这种形式,而真实结果是[1,0,0,0,0,0,0,0,0],此时由于得出的结论不惧有概率性,而是错误的估计值,此时反向传播会使得权重和偏置值变的无穷大,导致数据溢出,也就出现了nan的问题。

解决办法:

1、对输入数据进行归一化处理,如将输入的图片数据除以255将其转化成0-1之间的数据;

2、对于层数较多的情况,各层都做batch_nomorlization;

3、对设置Weights权重使用tf.truncated_normal(0, 0.01, [3,3,1,64])生成,同时值的均值为0,方差要小一些;

4、激活函数可以使用tanh;

5、减小学习率lr。

实例:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('data',one_hot = True)

def add_layer(input_data,in_size, out_size,activation_function=None):
  Weights = tf.Variable(tf.random_normal([in_size,out_size]))
  Biases = tf.Variable(tf.zeros([1, out_size])+0.1)
  Wx_plus_b = tf.add(tf.matmul(input_data, Weights), Biases)
  if activation_function==None:
    outputs = Wx_plus_b
  else:
    outputs = activation_function(Wx_plus_b)
  #return outputs#, Weights
  return {'outdata':outputs, 'w':Weights}

def get_accuracy(t_y):
#  global l1
#  accu = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(l1['outdata'],1),tf.argmax(t_y,1)), dtype = tf.float32))
  global prediction
  accu = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(prediction['outdata'],1),tf.argmax(t_y,1)), dtype = tf.float32))
  return accu

X = tf.placeholder(tf.float32, [None, 784])
Y = tf.placeholder(tf.float32, [None, 10])

#l1 = add_layer(X, 784, 10, tf.nn.softmax)
#cross_entropy = tf.reduce_mean(-tf.reduce_sum(Y*tf.log(l1['outdata']), reduction_indices= [1]))
#l1 = add_layer(X, 784, 1024, tf.nn.relu)

l1 = add_layer(X, 784, 1024, None)
prediction = add_layer(l1['outdata'], 1024, 10, tf.nn.softmax)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(Y*tf.log(prediction['outdata']), reduction_indices= [1]))

optimizer = tf.train.GradientDescentOptimizer(0.000001)
train = optimizer.minimize(cross_entropy)


newW = tf.Variable(tf.random_normal([1024,10]))
newOut = tf.matmul(l1['outdata'],newW)
newSoftMax = tf.nn.softmax(newOut)

init = tf.global_variables_initializer()
with tf.Session() as sess:
  sess.run(init)
  #print(sess.run(l1_Weights))
  for i in range(2):
    X_train, y_train = mnist.train.next_batch(1)
    X_train = X_train/255  #需要进行归一化处理
    #print(sess.run(l1['w'],feed_dict={X:X_train}))
    #print(sess.run(prediction['w'],feed_dict={X:X_train, Y:y_train}))
    #print(sess.run(l1['outdata'],feed_dict={X:X_train, Y:y_train}).shape)
    print(sess.run(prediction['outdata'],feed_dict={X:X_train, Y:y_train}))
    print(sess.run(newOut, feed_dict={X:X_train}))
    print(sess.run(newSoftMax, feed_dict={X:X_train}))
    print(y_train)
    #print(sess.run(l1['outdata'], feed_dict={X:X_train}))
    sess.run(train, feed_dict={X:X_train, Y:y_train})
    if i%100 == 0:
      #print(sess.run(cross_entropy, feed_dict={X:X_train, Y:y_train}))
      accuracy = get_accuracy(mnist.test.labels)
      print(sess.run(accuracy,feed_dict={X:mnist.test.images}))
    
    #if i%100==0:
    #print(sess.run(prediction, feed_dict={X:X_train}))
    #print(sess.run(cross_entropy, feed_dict={X:X_train,Y:y_train}))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python脚本监控docker容器
Apr 27 Python
pandas 实现将重复表格去重,并重新转换为表格的方法
Apr 18 Python
Python 实现删除某路径下文件及文件夹的实例讲解
Apr 24 Python
Python装饰器知识点补充
May 28 Python
Python自动化运维之Ansible定义主机与组规则操作详解
Jun 13 Python
Python Django2.0集成Celery4.1教程
Nov 19 Python
python软件都是免费的吗
Jun 18 Python
Python爬虫防封ip的一些技巧
Aug 06 Python
Python collections.deque双边队列原理详解
Oct 05 Python
Python中tkinter的用户登录管理的实现
Apr 22 Python
Python中的min及返回最小值索引的操作
May 10 Python
Python jiaba库的使用详解
Nov 23 Python
用Eclipse写python程序
Feb 10 #Python
tensorflow建立一个简单的神经网络的方法
Feb 10 #Python
python取代netcat过程分析
Feb 10 #Python
浅谈Python黑帽子取代netcat
Feb 10 #Python
python3爬取淘宝信息代码分析
Feb 10 #Python
Python中property属性实例解析
Feb 10 #Python
Java编程迭代地删除文件夹及其下的所有文件实例
Feb 10 #Python
You might like
php foreach、while性能比较
2009/10/15 PHP
PHP 日,周,月点击排行统计
2012/01/11 PHP
深入file_get_contents函数抓取内容失败的原因分析
2013/06/25 PHP
php ci框架中加载css和js文件失败的原因及解决方法
2014/07/29 PHP
PHP实现远程下载文件到本地
2015/05/17 PHP
php中html_entity_decode实现HTML实体转义
2018/06/13 PHP
JQuery 学习笔记 选择器之五
2009/07/23 Javascript
用Javascript 获取页面元素的位置的代码
2009/09/25 Javascript
jQuery 借助插件Lavalamp实现导航条动态美化效果
2013/09/27 Javascript
谷歌浏览器不支持showModalDialog模态对话框的解决方法
2014/09/22 Javascript
jQuery及JS实现循环中暂停的方法
2015/02/02 Javascript
Jquery动态添加输入框的方法
2015/05/29 Javascript
jQuery匹配文档链接并添加class的方法
2015/06/26 Javascript
javascript弹性运动效果简单实现方法
2016/01/08 Javascript
JS实现复制内容到剪贴板功能兼容所有浏览器(推荐)
2016/06/17 Javascript
ES6正则表达式扩展笔记
2017/07/25 Javascript
写gulp遇到的ES6问题详解
2018/12/03 Javascript
Node.js爬虫如何获取天气和每日问候详解
2019/08/26 Javascript
js+canvas实现五子棋小游戏
2020/08/02 Javascript
[02:57]2014DOTA2国际邀请赛 选手辛苦解说更辛苦
2014/07/10 DOTA
在Debian下配置Python+Django+Nginx+uWSGI+MySQL的教程
2015/04/25 Python
详解django中自定义标签和过滤器
2017/07/03 Python
Python使用matplotlib绘图无法显示中文问题的解决方法
2018/03/14 Python
pandas数据框,统计某列数据对应的个数方法
2018/04/11 Python
基于python神经卷积网络的人脸识别
2018/05/24 Python
Python实现二叉树的常见遍历操作总结【7种方法】
2019/03/06 Python
python实现对输入的密文加密
2019/03/20 Python
python画微信表情符的实例代码
2019/10/09 Python
CSS3教程(2):网页边框半径和网页圆角
2009/04/02 HTML / CSS
Delphi工程师笔试题
2013/09/21 面试题
十岁生日家长答谢词
2014/01/17 职场文书
中医学专业自荐信范文
2014/04/01 职场文书
电影建国大业观后感
2015/06/01 职场文书
2015年文秘个人工作总结
2015/10/14 职场文书
Spring Data JPA使用JPQL与原生SQL进行查询的操作
2021/06/15 Java/Android
简单聊聊TypeScript只读修饰符
2022/04/06 Javascript