tensorflow训练中出现nan问题的解决


Posted in Python onFebruary 10, 2018

深度学习中对于网络的训练是参数更新的过程,需要注意一种情况就是输入数据未做归一化时,如果前向传播结果已经是[0,0,0,1,0,0,0,0]这种形式,而真实结果是[1,0,0,0,0,0,0,0,0],此时由于得出的结论不惧有概率性,而是错误的估计值,此时反向传播会使得权重和偏置值变的无穷大,导致数据溢出,也就出现了nan的问题。

解决办法:

1、对输入数据进行归一化处理,如将输入的图片数据除以255将其转化成0-1之间的数据;

2、对于层数较多的情况,各层都做batch_nomorlization;

3、对设置Weights权重使用tf.truncated_normal(0, 0.01, [3,3,1,64])生成,同时值的均值为0,方差要小一些;

4、激活函数可以使用tanh;

5、减小学习率lr。

实例:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('data',one_hot = True)

def add_layer(input_data,in_size, out_size,activation_function=None):
  Weights = tf.Variable(tf.random_normal([in_size,out_size]))
  Biases = tf.Variable(tf.zeros([1, out_size])+0.1)
  Wx_plus_b = tf.add(tf.matmul(input_data, Weights), Biases)
  if activation_function==None:
    outputs = Wx_plus_b
  else:
    outputs = activation_function(Wx_plus_b)
  #return outputs#, Weights
  return {'outdata':outputs, 'w':Weights}

def get_accuracy(t_y):
#  global l1
#  accu = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(l1['outdata'],1),tf.argmax(t_y,1)), dtype = tf.float32))
  global prediction
  accu = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(prediction['outdata'],1),tf.argmax(t_y,1)), dtype = tf.float32))
  return accu

X = tf.placeholder(tf.float32, [None, 784])
Y = tf.placeholder(tf.float32, [None, 10])

#l1 = add_layer(X, 784, 10, tf.nn.softmax)
#cross_entropy = tf.reduce_mean(-tf.reduce_sum(Y*tf.log(l1['outdata']), reduction_indices= [1]))
#l1 = add_layer(X, 784, 1024, tf.nn.relu)

l1 = add_layer(X, 784, 1024, None)
prediction = add_layer(l1['outdata'], 1024, 10, tf.nn.softmax)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(Y*tf.log(prediction['outdata']), reduction_indices= [1]))

optimizer = tf.train.GradientDescentOptimizer(0.000001)
train = optimizer.minimize(cross_entropy)


newW = tf.Variable(tf.random_normal([1024,10]))
newOut = tf.matmul(l1['outdata'],newW)
newSoftMax = tf.nn.softmax(newOut)

init = tf.global_variables_initializer()
with tf.Session() as sess:
  sess.run(init)
  #print(sess.run(l1_Weights))
  for i in range(2):
    X_train, y_train = mnist.train.next_batch(1)
    X_train = X_train/255  #需要进行归一化处理
    #print(sess.run(l1['w'],feed_dict={X:X_train}))
    #print(sess.run(prediction['w'],feed_dict={X:X_train, Y:y_train}))
    #print(sess.run(l1['outdata'],feed_dict={X:X_train, Y:y_train}).shape)
    print(sess.run(prediction['outdata'],feed_dict={X:X_train, Y:y_train}))
    print(sess.run(newOut, feed_dict={X:X_train}))
    print(sess.run(newSoftMax, feed_dict={X:X_train}))
    print(y_train)
    #print(sess.run(l1['outdata'], feed_dict={X:X_train}))
    sess.run(train, feed_dict={X:X_train, Y:y_train})
    if i%100 == 0:
      #print(sess.run(cross_entropy, feed_dict={X:X_train, Y:y_train}))
      accuracy = get_accuracy(mnist.test.labels)
      print(sess.run(accuracy,feed_dict={X:mnist.test.images}))
    
    #if i%100==0:
    #print(sess.run(prediction, feed_dict={X:X_train}))
    #print(sess.run(cross_entropy, feed_dict={X:X_train,Y:y_train}))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中将字典转换为XML以及相关的命名空间解析
Oct 15 Python
Python的requests网络编程包使用教程
Jul 11 Python
python爬取淘宝商品销量信息
Nov 16 Python
Python 3.8中实现functools.cached_property功能
May 29 Python
Python 如何优雅的将数字转化为时间格式的方法
Sep 26 Python
python分别打包出32位和64位应用程序
Feb 18 Python
python修改linux中文件(文件夹)的权限属性操作
Mar 05 Python
DjangoWeb使用Datatable进行后端分页的实现
May 18 Python
详解用Python调用百度地图正/逆地理编码API
Jul 02 Python
Python numpy矩阵处理运算工具用法汇总
Jul 13 Python
python判断一个变量是否已经设置的方法
Aug 13 Python
Python日志打印里logging.getLogger源码分析详解
Jan 17 Python
用Eclipse写python程序
Feb 10 #Python
tensorflow建立一个简单的神经网络的方法
Feb 10 #Python
python取代netcat过程分析
Feb 10 #Python
浅谈Python黑帽子取代netcat
Feb 10 #Python
python3爬取淘宝信息代码分析
Feb 10 #Python
Python中property属性实例解析
Feb 10 #Python
Java编程迭代地删除文件夹及其下的所有文件实例
Feb 10 #Python
You might like
150kHz到30Mhz完全冲浪手册
2020/03/20 无线电
PhpMyAdmin出现export.php Missing parameter: what /export_type错误解决方法
2012/08/09 PHP
微信支付开发教程(一)微信支付URL配置
2014/05/28 PHP
PHP采用XML-RPC构造Web Service实例教程
2014/07/16 PHP
php防止网站被攻击的应急代码
2015/10/21 PHP
Adnroid 微信内置浏览器清除缓存
2016/07/11 PHP
PHP连接MySQL数据库操作代码实例解析
2020/07/11 PHP
深入学习JavaScript中的Rest参数和参数默认值
2015/07/28 Javascript
JavaScript简单修改窗口大小的方法
2015/08/03 Javascript
深入学习JavaScript对象
2015/10/13 Javascript
JS实现根据文件字节数返回文件大小的方法
2016/08/02 Javascript
jQuery基于ajax方式实现用户名存在性检查功能示例
2017/02/10 Javascript
Angularjs处理页面闪烁的解决方法
2017/03/09 Javascript
原生JS实现网页手机音乐播放器 歌词同步播放的示例
2018/02/02 Javascript
vue+SSM实现验证码功能
2018/12/07 Javascript
JS简单数组排序操作示例【sort方法】
2019/05/17 Javascript
jquery向后台提交数组的代码分析
2020/02/20 jQuery
[54:51]Ti4 冒泡赛第二轮LGD vs C9 3
2014/07/14 DOTA
[01:00:14]DOTA2官方TI8总决赛纪录片 真视界True Sight
2019/01/16 DOTA
Python文件操作,open读写文件,追加文本内容实例
2016/12/14 Python
python中 logging的使用详解
2017/10/25 Python
python数据类型判断type与isinstance的区别实例解析
2017/10/31 Python
Python基于Flask框架配置依赖包信息的项目迁移部署
2018/03/02 Python
解决python 输出是省略号的问题
2018/04/19 Python
python使用mitmproxy抓取浏览器请求的方法
2019/07/02 Python
Python  Django 母版和继承解析
2019/08/09 Python
在pytorch中实现只让指定变量向后传播梯度
2020/02/29 Python
Python Pandas list列表数据列拆分成多行的方法实现
2020/12/14 Python
5分钟实现Canvas鼠标跟随动画背景
2019/11/18 HTML / CSS
彪马美国官网:PUMA美国
2017/03/09 全球购物
机电一体化专业毕业生自荐信
2014/06/19 职场文书
民主评议党员工作总结
2014/10/20 职场文书
公务员年度考核登记表个人总结
2015/02/12 职场文书
小学生交通安全寄语
2015/02/27 职场文书
少先队中队工作总结2015
2015/07/23 职场文书
2016廉洁从政心得体会
2016/01/19 职场文书