tensorflow训练中出现nan问题的解决


Posted in Python onFebruary 10, 2018

深度学习中对于网络的训练是参数更新的过程,需要注意一种情况就是输入数据未做归一化时,如果前向传播结果已经是[0,0,0,1,0,0,0,0]这种形式,而真实结果是[1,0,0,0,0,0,0,0,0],此时由于得出的结论不惧有概率性,而是错误的估计值,此时反向传播会使得权重和偏置值变的无穷大,导致数据溢出,也就出现了nan的问题。

解决办法:

1、对输入数据进行归一化处理,如将输入的图片数据除以255将其转化成0-1之间的数据;

2、对于层数较多的情况,各层都做batch_nomorlization;

3、对设置Weights权重使用tf.truncated_normal(0, 0.01, [3,3,1,64])生成,同时值的均值为0,方差要小一些;

4、激活函数可以使用tanh;

5、减小学习率lr。

实例:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('data',one_hot = True)

def add_layer(input_data,in_size, out_size,activation_function=None):
  Weights = tf.Variable(tf.random_normal([in_size,out_size]))
  Biases = tf.Variable(tf.zeros([1, out_size])+0.1)
  Wx_plus_b = tf.add(tf.matmul(input_data, Weights), Biases)
  if activation_function==None:
    outputs = Wx_plus_b
  else:
    outputs = activation_function(Wx_plus_b)
  #return outputs#, Weights
  return {'outdata':outputs, 'w':Weights}

def get_accuracy(t_y):
#  global l1
#  accu = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(l1['outdata'],1),tf.argmax(t_y,1)), dtype = tf.float32))
  global prediction
  accu = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(prediction['outdata'],1),tf.argmax(t_y,1)), dtype = tf.float32))
  return accu

X = tf.placeholder(tf.float32, [None, 784])
Y = tf.placeholder(tf.float32, [None, 10])

#l1 = add_layer(X, 784, 10, tf.nn.softmax)
#cross_entropy = tf.reduce_mean(-tf.reduce_sum(Y*tf.log(l1['outdata']), reduction_indices= [1]))
#l1 = add_layer(X, 784, 1024, tf.nn.relu)

l1 = add_layer(X, 784, 1024, None)
prediction = add_layer(l1['outdata'], 1024, 10, tf.nn.softmax)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(Y*tf.log(prediction['outdata']), reduction_indices= [1]))

optimizer = tf.train.GradientDescentOptimizer(0.000001)
train = optimizer.minimize(cross_entropy)


newW = tf.Variable(tf.random_normal([1024,10]))
newOut = tf.matmul(l1['outdata'],newW)
newSoftMax = tf.nn.softmax(newOut)

init = tf.global_variables_initializer()
with tf.Session() as sess:
  sess.run(init)
  #print(sess.run(l1_Weights))
  for i in range(2):
    X_train, y_train = mnist.train.next_batch(1)
    X_train = X_train/255  #需要进行归一化处理
    #print(sess.run(l1['w'],feed_dict={X:X_train}))
    #print(sess.run(prediction['w'],feed_dict={X:X_train, Y:y_train}))
    #print(sess.run(l1['outdata'],feed_dict={X:X_train, Y:y_train}).shape)
    print(sess.run(prediction['outdata'],feed_dict={X:X_train, Y:y_train}))
    print(sess.run(newOut, feed_dict={X:X_train}))
    print(sess.run(newSoftMax, feed_dict={X:X_train}))
    print(y_train)
    #print(sess.run(l1['outdata'], feed_dict={X:X_train}))
    sess.run(train, feed_dict={X:X_train, Y:y_train})
    if i%100 == 0:
      #print(sess.run(cross_entropy, feed_dict={X:X_train, Y:y_train}))
      accuracy = get_accuracy(mnist.test.labels)
      print(sess.run(accuracy,feed_dict={X:mnist.test.images}))
    
    #if i%100==0:
    #print(sess.run(prediction, feed_dict={X:X_train}))
    #print(sess.run(cross_entropy, feed_dict={X:X_train,Y:y_train}))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python ljust rjust center输出
Sep 06 Python
python使用wxpython开发简单记事本的方法
May 20 Python
Python中使用items()方法返回字典元素对的教程
May 21 Python
使用Python从有道词典网页获取单词翻译
Jul 03 Python
python 上下文管理器使用方法小结
Oct 10 Python
解决python3中解压zip文件是文件名乱码的问题
Mar 22 Python
Window环境下Scrapy开发环境搭建
Nov 18 Python
python 实现的发送邮件模板【普通邮件、带附件、带图片邮件】
Jul 06 Python
利用Python模拟登录pastebin.com的实现方法
Jul 12 Python
Django CBV与FBV原理及实例详解
Aug 12 Python
Python中for后接else的语法使用
May 18 Python
使用python生成大量数据写入es数据库并查询操作(2)
Sep 23 Python
用Eclipse写python程序
Feb 10 #Python
tensorflow建立一个简单的神经网络的方法
Feb 10 #Python
python取代netcat过程分析
Feb 10 #Python
浅谈Python黑帽子取代netcat
Feb 10 #Python
python3爬取淘宝信息代码分析
Feb 10 #Python
Python中property属性实例解析
Feb 10 #Python
Java编程迭代地删除文件夹及其下的所有文件实例
Feb 10 #Python
You might like
PHP5 的对象赋值机制介绍
2011/08/02 PHP
基于ubuntu下nginx+php+mysql安装配置的具体操作步骤
2013/04/28 PHP
从PHP的源码中深入了解stdClass类
2014/04/18 PHP
PHP 面向对象程序设计(oop)学习笔记 (四) - 异常处理类Exception
2014/06/12 PHP
php操纵mysqli数据库的实现方法
2016/09/18 PHP
Yii2框架视图(View)操作及Layout的使用方法分析
2019/05/27 PHP
php设计模式之工厂方法模式分析【星际争霸游戏案例】
2020/01/23 PHP
php + ajax 实现的写入数据库操作简单示例
2020/05/16 PHP
网页中可关闭的漂浮窗口实现可自行调节
2013/08/20 Javascript
extjs中form与grid交互数据(record)的方法
2013/08/29 Javascript
用Jquery.load载入页面实现局部刷新
2014/01/22 Javascript
jquery validate 自定义验证方法介绍 日期验证
2014/02/27 Javascript
jQuery scrollFix滚动定位插件
2015/04/01 Javascript
jQuery实现html元素拖拽
2015/07/21 Javascript
非常漂亮的相册集 使用jquery制作相册集
2016/04/28 Javascript
微信公众号 客服接口的开发实例详解
2016/09/28 Javascript
JS中this上下文对象使用方式
2016/10/09 Javascript
使用jQuery实现简单的tab框实例
2017/08/22 jQuery
vue路由嵌套的SPA实现步骤
2017/11/06 Javascript
jQuery创建及操作xml格式数据示例
2018/05/26 jQuery
微信小程序入门之广告条实现方法示例
2018/12/05 Javascript
详解nodejs 开发企业微信第三方应用入门教程
2019/03/12 NodeJs
基于aotu.js实现微信自动添加通讯录中的联系人功能
2020/05/28 Javascript
详解Howler.js Web音频播放终极解决方案
2020/08/23 Javascript
12步教你理解Python装饰器
2016/02/25 Python
在Python web中实现验证码图片代码分享
2017/11/09 Python
Python使用functools实现注解同步方法
2018/02/06 Python
Python3.5基础之变量、数据结构、条件和循环语句、break与continue语句实例详解
2019/04/26 Python
python代码 FTP备份交换机配置脚本实例解析
2019/08/01 Python
Python 可变类型和不可变类型及引用过程解析
2019/09/27 Python
python3 pathlib库Path类方法总结
2019/12/26 Python
详解CSS3原生支持div铺满浏览器的方法
2018/08/30 HTML / CSS
竞聘自述材料
2014/08/25 职场文书
改革共识倡议书
2014/08/29 职场文书
Python+tkinter实现高清图片保存
2022/03/13 Python
Mysql的Table doesn't exist问题及解决
2022/12/24 MySQL