TensorFlow实现Logistic回归


Posted in Python onSeptember 07, 2018

本文实例为大家分享了TensorFlow实现Logistic回归的具体代码,供大家参考,具体内容如下

1.导入模块

import numpy as np
import pandas as pd
from pandas import Series,DataFrame

from matplotlib import pyplot as plt
%matplotlib inline

#导入tensorflow
import tensorflow as tf

#导入MNIST(手写数字数据集)
from tensorflow.examples.tutorials.mnist import input_data

2.获取训练数据和测试数据

import ssl 
ssl._create_default_https_context = ssl._create_unverified_context

mnist = input_data.read_data_sets('./TensorFlow',one_hot=True)

test = mnist.test
test_images = test.images

train = mnist.train
images = train.images

3.模拟线性方程

#创建占矩阵位符X,Y
X = tf.placeholder(tf.float32,shape=[None,784])
Y = tf.placeholder(tf.float32,shape=[None,10])

#随机生成斜率W和截距b
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))

#根据模拟线性方程得出预测值
y_pre = tf.matmul(X,W)+b

#将预测值结果概率化
y_pre_r = tf.nn.softmax(y_pre)

4.构造损失函数

# -y*tf.log(y_pre_r) --->-Pi*log(Pi)  信息熵公式

cost = tf.reduce_mean(-tf.reduce_sum(Y*tf.log(y_pre_r),axis=1))

5.实现梯度下降,获取最小损失函数

#learning_rate:学习率,是进行训练时在最陡的梯度方向上所采取的「步」长;
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

6.TensorFlow初始化,并进行训练

#定义相关参数

#训练循环次数
training_epochs = 25
#batch 一批,每次训练给算法10个数据
batch_size = 10
#每隔5次,打印输出运算的结果
display_step = 5


#预定义初始化
init = tf.global_variables_initializer()

#开始训练
with tf.Session() as sess:
  #初始化
  sess.run(init)
  #循环训练次数
  for epoch in range(training_epochs):
    avg_cost = 0.
    #总训练批次total_batch =训练总样本量/每批次样本数量
    total_batch = int(train.num_examples/batch_size)
    for i in range(total_batch):
      #每次取出100个数据作为训练数据
      batch_xs,batch_ys = mnist.train.next_batch(batch_size)
      _, c = sess.run([optimizer,cost],feed_dict={X:batch_xs,Y:batch_ys})
      avg_cost +=c/total_batch
    if(epoch+1)%display_step == 0:
      print(batch_xs.shape,batch_ys.shape)
      print('epoch:','%04d'%(epoch+1),'cost=','{:.9f}'.format(avg_cost))
  print('Optimization Finished!')

  #7.评估效果
  # Test model
  correct_prediction = tf.equal(tf.argmax(y_pre_r,1),tf.argmax(Y,1))
  # Calculate accuracy for 3000 examples
  # tf.cast类型转换
  accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
  print("Accuracy:",accuracy.eval({X: mnist.test.images[:3000], Y: mnist.test.labels[:3000]}))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python快速查找算法应用实例
Sep 26 Python
python类装饰器用法实例
Jun 04 Python
python虚拟环境virualenv的安装与使用
Dec 18 Python
Python 2.x如何设置命令执行的超时时间实例
Oct 19 Python
django使用xlwt导出excel文件实例代码
Feb 06 Python
win10下tensorflow和matplotlib安装教程
Sep 19 Python
详解python 3.6 安装json 模块(simplejson)
Apr 02 Python
python实现函数极小值
Jul 10 Python
python程序 线程队列queue使用方法解析
Sep 23 Python
django执行数据库查询之后实现返回的结果集转json
Mar 31 Python
解决pycharm导入本地py文件时,模块下方出现红色波浪线的问题
Jun 01 Python
Python爬虫中Selenium实现文件上传
Dec 04 Python
tensorflow实现简单逻辑回归
Sep 07 #Python
Tensorflow使用支持向量机拟合线性回归
Sep 07 #Python
TensorFlow实现iris数据集线性回归
Sep 07 #Python
TensorFlow实现模型评估
Sep 07 #Python
使用tensorflow实现线性svm
Sep 07 #Python
Python多进程池 multiprocessing Pool用法示例
Sep 07 #Python
详解python while 函数及while和for的区别
Sep 07 #Python
You might like
BBS(php & mysql)完整版(一)
2006/10/09 PHP
PHP5 面向对象(学习记录)
2009/12/02 PHP
php实现概率性随机抽奖代码
2016/01/02 PHP
php版微信公众平台回复中文出现乱码问题的解决方法
2016/09/22 PHP
CodeIgniter框架常见用法工作总结
2017/03/16 PHP
Laravel6.18.19如何优雅的切换发件账户
2020/06/14 PHP
laravel7学习之无限级分类的最新实现方法
2020/09/30 PHP
PNG背景在不同浏览器下的应用
2009/06/22 Javascript
javascript 面向对象编程 聊聊对象的事
2009/09/17 Javascript
Javascript匿名函数的一种应用 代码封装
2010/06/27 Javascript
JS获取几种URL地址的方法小结
2014/02/26 Javascript
JS中数组重排序方法
2016/11/11 Javascript
seajs模块压缩问题与解决方法实例分析
2017/10/10 Javascript
vue项目在安卓低版本机显示空白的原因分析(两种)
2018/09/04 Javascript
vscode下vue项目中eslint的使用方法
2019/01/13 Javascript
基于node+vue实现简单的WebSocket聊天功能
2020/02/01 Javascript
[02:05]DOTA2完美大师赛趣味视频之看我表演
2017/11/18 DOTA
[52:07]完美世界DOTA2联赛PWL S3 LBZS vs access 第二场 12.10
2020/12/13 DOTA
python使用chardet判断字符串编码的方法
2015/03/13 Python
python 使用pandas计算累积求和的方法
2019/02/08 Python
Python爬虫之Spider类用法简单介绍
2020/08/04 Python
python如何遍历指定路径下所有文件(按按照时间区间检索)
2020/09/14 Python
世界上最全面的草药补充剂和顶级品牌维生素网站:HerbsPro
2019/01/20 全球购物
巴西最大的珠宝连锁店:Vivara
2019/04/18 全球购物
红旗方阵解说词
2014/02/12 职场文书
煤矿安全生产责任书
2014/04/15 职场文书
环境卫生标语
2014/06/09 职场文书
珠宝的促销活动方案
2014/08/31 职场文书
公司股东合作协议书
2014/09/14 职场文书
户籍证明模板
2014/09/28 职场文书
会计简历自我评价
2015/03/10 职场文书
社区国庆节活动总结
2015/03/23 职场文书
人与自然的观后感
2015/06/18 职场文书
Python爬虫之爬取哔哩哔哩热门视频排行榜
2021/04/28 Python
django学习之ajax post传参的2种格式实例
2021/05/14 Python
MySQL的prepare使用以及遇到的bug
2022/05/11 MySQL