TensorFlow实现Logistic回归


Posted in Python onSeptember 07, 2018

本文实例为大家分享了TensorFlow实现Logistic回归的具体代码,供大家参考,具体内容如下

1.导入模块

import numpy as np
import pandas as pd
from pandas import Series,DataFrame

from matplotlib import pyplot as plt
%matplotlib inline

#导入tensorflow
import tensorflow as tf

#导入MNIST(手写数字数据集)
from tensorflow.examples.tutorials.mnist import input_data

2.获取训练数据和测试数据

import ssl 
ssl._create_default_https_context = ssl._create_unverified_context

mnist = input_data.read_data_sets('./TensorFlow',one_hot=True)

test = mnist.test
test_images = test.images

train = mnist.train
images = train.images

3.模拟线性方程

#创建占矩阵位符X,Y
X = tf.placeholder(tf.float32,shape=[None,784])
Y = tf.placeholder(tf.float32,shape=[None,10])

#随机生成斜率W和截距b
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))

#根据模拟线性方程得出预测值
y_pre = tf.matmul(X,W)+b

#将预测值结果概率化
y_pre_r = tf.nn.softmax(y_pre)

4.构造损失函数

# -y*tf.log(y_pre_r) --->-Pi*log(Pi)  信息熵公式

cost = tf.reduce_mean(-tf.reduce_sum(Y*tf.log(y_pre_r),axis=1))

5.实现梯度下降,获取最小损失函数

#learning_rate:学习率,是进行训练时在最陡的梯度方向上所采取的「步」长;
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

6.TensorFlow初始化,并进行训练

#定义相关参数

#训练循环次数
training_epochs = 25
#batch 一批,每次训练给算法10个数据
batch_size = 10
#每隔5次,打印输出运算的结果
display_step = 5


#预定义初始化
init = tf.global_variables_initializer()

#开始训练
with tf.Session() as sess:
  #初始化
  sess.run(init)
  #循环训练次数
  for epoch in range(training_epochs):
    avg_cost = 0.
    #总训练批次total_batch =训练总样本量/每批次样本数量
    total_batch = int(train.num_examples/batch_size)
    for i in range(total_batch):
      #每次取出100个数据作为训练数据
      batch_xs,batch_ys = mnist.train.next_batch(batch_size)
      _, c = sess.run([optimizer,cost],feed_dict={X:batch_xs,Y:batch_ys})
      avg_cost +=c/total_batch
    if(epoch+1)%display_step == 0:
      print(batch_xs.shape,batch_ys.shape)
      print('epoch:','%04d'%(epoch+1),'cost=','{:.9f}'.format(avg_cost))
  print('Optimization Finished!')

  #7.评估效果
  # Test model
  correct_prediction = tf.equal(tf.argmax(y_pre_r,1),tf.argmax(Y,1))
  # Calculate accuracy for 3000 examples
  # tf.cast类型转换
  accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
  print("Accuracy:",accuracy.eval({X: mnist.test.images[:3000], Y: mnist.test.labels[:3000]}))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现获取客户机上指定文件并传输到服务器的方法
Mar 16 Python
Win7下搭建python开发环境图文教程(安装Python、pip、解释器)
May 17 Python
python开发简易版在线音乐播放器
Mar 03 Python
一篇文章弄懂Python中的可迭代对象、迭代器和生成器
Aug 12 Python
用Pelican搭建一个极简静态博客系统过程解析
Aug 22 Python
Python自动生成代码 使用tkinter图形化操作并生成代码框架
Sep 18 Python
Python操作注册表详细步骤介绍
Feb 05 Python
Python实现栈的方法详解【基于数组和单链表两种方法】
Feb 22 Python
信号生成及DFT的python实现方式
Feb 25 Python
python实现单张图像拼接与批量图片拼接
Mar 23 Python
pyecharts动态轨迹图的实现示例
Apr 17 Python
python中pickle模块浅析
Dec 29 Python
tensorflow实现简单逻辑回归
Sep 07 #Python
Tensorflow使用支持向量机拟合线性回归
Sep 07 #Python
TensorFlow实现iris数据集线性回归
Sep 07 #Python
TensorFlow实现模型评估
Sep 07 #Python
使用tensorflow实现线性svm
Sep 07 #Python
Python多进程池 multiprocessing Pool用法示例
Sep 07 #Python
详解python while 函数及while和for的区别
Sep 07 #Python
You might like
一个多文件上传的例子(原创)
2006/10/09 PHP
php小偷相关截取函数备忘
2010/11/28 PHP
详解php中 === 的使用
2016/10/24 PHP
PHP实现链式操作的三种方法详解
2017/11/16 PHP
php实现的PDO异常处理操作分析
2018/12/27 PHP
Laravel 前端资源配置教程
2019/10/18 PHP
prototype 的说明 js类
2006/09/07 Javascript
js 获取浏览器高度和宽度值(多浏览器)
2009/09/02 Javascript
javascript 模拟点击广告
2010/01/02 Javascript
理解Javascript_14_函数形式参数与arguments
2010/10/20 Javascript
再论Javascript的类继承
2011/03/05 Javascript
js multiple全选与取消全选实现代码
2012/12/04 Javascript
jQuery根据元素值删除数组元素的方法
2015/06/24 Javascript
js 点击a标签 获取a的自定义属性方法
2016/11/21 Javascript
vue.js 1.x与2.0中js实时监听input值的变化
2017/03/15 Javascript
详解微信JS-SDK选择图片遇到的坑
2018/08/15 Javascript
[03:03]2014DOTA2西雅图国际邀请赛 Alliance战队巡礼
2014/07/07 DOTA
ptyhon实现sitemap生成示例
2014/03/30 Python
详解Django中间件的5种自定义方法
2018/07/26 Python
Python with关键字,上下文管理器,@contextmanager文件操作示例
2019/10/17 Python
在Python中使用filter去除列表中值为假及空字符串的例子
2019/11/18 Python
python如何提取英语pdf内容并翻译
2020/03/03 Python
Python2与Python3关于字符串编码处理的差别总结
2020/09/07 Python
python使用Windows的wmic命令监控文件运行状况,如有异常发送邮件报警
2021/01/30 Python
用html5的canvas画布绘制贝塞尔曲线完整代码
2013/08/14 HTML / CSS
英国儿童设计师服装的领先零售商:Base
2019/03/17 全球购物
计算机专业个人求职信范例
2013/09/23 职场文书
物理教育专业毕业生推荐信
2013/11/03 职场文书
高校毕业生登记表自我鉴定
2013/11/03 职场文书
遗嘱公证书标准样本
2014/04/08 职场文书
优秀应届生求职信
2014/06/16 职场文书
开会通知短信大全
2015/04/20 职场文书
学风建设主题班会
2015/08/17 职场文书
聊聊golang中多个defer的执行顺序
2021/05/08 Golang
pytorch DataLoader的num_workers参数与设置大小详解
2021/05/28 Python
SQL Server #{}可以防止SQL注入
2022/05/11 SQL Server