tensorflow训练中出现nan问题的解决


Posted in Python onFebruary 10, 2018

深度学习中对于网络的训练是参数更新的过程,需要注意一种情况就是输入数据未做归一化时,如果前向传播结果已经是[0,0,0,1,0,0,0,0]这种形式,而真实结果是[1,0,0,0,0,0,0,0,0],此时由于得出的结论不惧有概率性,而是错误的估计值,此时反向传播会使得权重和偏置值变的无穷大,导致数据溢出,也就出现了nan的问题。

解决办法:

1、对输入数据进行归一化处理,如将输入的图片数据除以255将其转化成0-1之间的数据;

2、对于层数较多的情况,各层都做batch_nomorlization;

3、对设置Weights权重使用tf.truncated_normal(0, 0.01, [3,3,1,64])生成,同时值的均值为0,方差要小一些;

4、激活函数可以使用tanh;

5、减小学习率lr。

实例:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('data',one_hot = True)

def add_layer(input_data,in_size, out_size,activation_function=None):
  Weights = tf.Variable(tf.random_normal([in_size,out_size]))
  Biases = tf.Variable(tf.zeros([1, out_size])+0.1)
  Wx_plus_b = tf.add(tf.matmul(input_data, Weights), Biases)
  if activation_function==None:
    outputs = Wx_plus_b
  else:
    outputs = activation_function(Wx_plus_b)
  #return outputs#, Weights
  return {'outdata':outputs, 'w':Weights}

def get_accuracy(t_y):
#  global l1
#  accu = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(l1['outdata'],1),tf.argmax(t_y,1)), dtype = tf.float32))
  global prediction
  accu = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(prediction['outdata'],1),tf.argmax(t_y,1)), dtype = tf.float32))
  return accu

X = tf.placeholder(tf.float32, [None, 784])
Y = tf.placeholder(tf.float32, [None, 10])

#l1 = add_layer(X, 784, 10, tf.nn.softmax)
#cross_entropy = tf.reduce_mean(-tf.reduce_sum(Y*tf.log(l1['outdata']), reduction_indices= [1]))
#l1 = add_layer(X, 784, 1024, tf.nn.relu)

l1 = add_layer(X, 784, 1024, None)
prediction = add_layer(l1['outdata'], 1024, 10, tf.nn.softmax)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(Y*tf.log(prediction['outdata']), reduction_indices= [1]))

optimizer = tf.train.GradientDescentOptimizer(0.000001)
train = optimizer.minimize(cross_entropy)


newW = tf.Variable(tf.random_normal([1024,10]))
newOut = tf.matmul(l1['outdata'],newW)
newSoftMax = tf.nn.softmax(newOut)

init = tf.global_variables_initializer()
with tf.Session() as sess:
  sess.run(init)
  #print(sess.run(l1_Weights))
  for i in range(2):
    X_train, y_train = mnist.train.next_batch(1)
    X_train = X_train/255  #需要进行归一化处理
    #print(sess.run(l1['w'],feed_dict={X:X_train}))
    #print(sess.run(prediction['w'],feed_dict={X:X_train, Y:y_train}))
    #print(sess.run(l1['outdata'],feed_dict={X:X_train, Y:y_train}).shape)
    print(sess.run(prediction['outdata'],feed_dict={X:X_train, Y:y_train}))
    print(sess.run(newOut, feed_dict={X:X_train}))
    print(sess.run(newSoftMax, feed_dict={X:X_train}))
    print(y_train)
    #print(sess.run(l1['outdata'], feed_dict={X:X_train}))
    sess.run(train, feed_dict={X:X_train, Y:y_train})
    if i%100 == 0:
      #print(sess.run(cross_entropy, feed_dict={X:X_train, Y:y_train}))
      accuracy = get_accuracy(mnist.test.labels)
      print(sess.run(accuracy,feed_dict={X:mnist.test.images}))
    
    #if i%100==0:
    #print(sess.run(prediction, feed_dict={X:X_train}))
    #print(sess.run(cross_entropy, feed_dict={X:X_train,Y:y_train}))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python的Flask框架中web表单的教程
Apr 20 Python
Python的Django框架中的表单处理示例
Jul 17 Python
Python利用matplotlib.pyplot绘图时如何设置坐标轴刻度
Apr 09 Python
python脚本监控Tomcat服务器的方法
Jul 06 Python
python简易实现任意位数的水仙花实例
Nov 13 Python
python对视频画框标记后保存的方法
Dec 07 Python
详解js文件通过python访问数据库方法
Mar 03 Python
Python基础学习之类与实例基本用法与注意事项详解
Jun 17 Python
PyQtGraph在pyqt中的应用及安装过程
Aug 04 Python
centos+nginx+uwsgi+Django实现IP+port访问服务器
Nov 15 Python
matlab灰度图像调整及imadjust函数的用法详解
Feb 27 Python
python和php学习哪个更有发展
Jun 17 Python
用Eclipse写python程序
Feb 10 #Python
tensorflow建立一个简单的神经网络的方法
Feb 10 #Python
python取代netcat过程分析
Feb 10 #Python
浅谈Python黑帽子取代netcat
Feb 10 #Python
python3爬取淘宝信息代码分析
Feb 10 #Python
Python中property属性实例解析
Feb 10 #Python
Java编程迭代地删除文件夹及其下的所有文件实例
Feb 10 #Python
You might like
PHP Document 代码注释规范
2009/04/13 PHP
PHP中空字符串介绍0、null、empty和false之间的关系
2012/09/25 PHP
phpmyadmin config.inc.php配置示例
2013/08/27 PHP
thinkphp中memcache的用法实例
2014/11/29 PHP
阿里云的WindowsServer2016上部署php+apache
2018/07/17 PHP
php实现有序数组旋转后寻找最小值方法
2018/09/27 PHP
laravel 实现用户登录注销并限制功能
2019/10/24 PHP
javascript据option的value值快速设定初始的selected选项
2007/08/13 Javascript
js类中的公有变量和私有变量
2008/07/24 Javascript
使用jQuery的ajax功能实现的RSS Reader 代码
2009/09/03 Javascript
Javascript下判断是否为闰年的Datetime包
2010/10/26 Javascript
js获得鼠标的坐标值的方法
2013/03/13 Javascript
侧栏跟随滚动的简单实现代码
2013/03/18 Javascript
js实现单击图片放大图片的方法
2015/02/17 Javascript
jquery注册文本框获取焦点清空,失去焦点赋值的简单实例
2016/09/08 Javascript
AngularJS入门示例之Hello World详解
2017/01/04 Javascript
jQuery弹出层插件popShow(改进版)用法示例
2017/01/23 Javascript
解析jquery easyui tree异步加载子节点问题
2017/03/08 Javascript
几种响应式文字详解
2017/05/19 Javascript
JS模拟超市简易收银台小程序代码解析
2017/08/18 Javascript
深入理解NodeJS 多进程和集群
2018/10/17 NodeJs
JavaScript使用闭包模仿块级作用域操作示例
2019/01/21 Javascript
[17:36]VG战队纪录片
2014/08/21 DOTA
[04:44]DOTA2 2017全国高校联赛视频回顾
2017/08/21 DOTA
Django实现自定义404,500页面教程
2017/03/26 Python
一个基于flask的web应用诞生 用户注册功能开发(5)
2017/04/11 Python
python中如何使用正则表达式的非贪婪模式示例
2017/10/09 Python
Flask模拟实现CSRF攻击的方法
2018/07/24 Python
python ftp 按目录结构上传下载的实现代码
2018/09/12 Python
python实现微信定时每天和女友发送消息
2019/04/29 Python
Python脚本实现监听服务器的思路代码详解
2020/05/28 Python
萌新的HTML5 入门指南
2020/11/06 HTML / CSS
斯洛伐克时尚服装网上商店:Cellbes
2016/10/20 全球购物
expedia比利时:预订航班+酒店并省钱
2018/07/13 全球购物
2014年教师节红领巾广播稿
2014/09/10 职场文书
导游词之绍兴柯岩古镇
2020/01/09 职场文书