TensorFlow实现Logistic回归


Posted in Python onSeptember 07, 2018

本文实例为大家分享了TensorFlow实现Logistic回归的具体代码,供大家参考,具体内容如下

1.导入模块

import numpy as np
import pandas as pd
from pandas import Series,DataFrame

from matplotlib import pyplot as plt
%matplotlib inline

#导入tensorflow
import tensorflow as tf

#导入MNIST(手写数字数据集)
from tensorflow.examples.tutorials.mnist import input_data

2.获取训练数据和测试数据

import ssl 
ssl._create_default_https_context = ssl._create_unverified_context

mnist = input_data.read_data_sets('./TensorFlow',one_hot=True)

test = mnist.test
test_images = test.images

train = mnist.train
images = train.images

3.模拟线性方程

#创建占矩阵位符X,Y
X = tf.placeholder(tf.float32,shape=[None,784])
Y = tf.placeholder(tf.float32,shape=[None,10])

#随机生成斜率W和截距b
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))

#根据模拟线性方程得出预测值
y_pre = tf.matmul(X,W)+b

#将预测值结果概率化
y_pre_r = tf.nn.softmax(y_pre)

4.构造损失函数

# -y*tf.log(y_pre_r) --->-Pi*log(Pi)  信息熵公式

cost = tf.reduce_mean(-tf.reduce_sum(Y*tf.log(y_pre_r),axis=1))

5.实现梯度下降,获取最小损失函数

#learning_rate:学习率,是进行训练时在最陡的梯度方向上所采取的「步」长;
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

6.TensorFlow初始化,并进行训练

#定义相关参数

#训练循环次数
training_epochs = 25
#batch 一批,每次训练给算法10个数据
batch_size = 10
#每隔5次,打印输出运算的结果
display_step = 5


#预定义初始化
init = tf.global_variables_initializer()

#开始训练
with tf.Session() as sess:
  #初始化
  sess.run(init)
  #循环训练次数
  for epoch in range(training_epochs):
    avg_cost = 0.
    #总训练批次total_batch =训练总样本量/每批次样本数量
    total_batch = int(train.num_examples/batch_size)
    for i in range(total_batch):
      #每次取出100个数据作为训练数据
      batch_xs,batch_ys = mnist.train.next_batch(batch_size)
      _, c = sess.run([optimizer,cost],feed_dict={X:batch_xs,Y:batch_ys})
      avg_cost +=c/total_batch
    if(epoch+1)%display_step == 0:
      print(batch_xs.shape,batch_ys.shape)
      print('epoch:','%04d'%(epoch+1),'cost=','{:.9f}'.format(avg_cost))
  print('Optimization Finished!')

  #7.评估效果
  # Test model
  correct_prediction = tf.equal(tf.argmax(y_pre_r,1),tf.argmax(Y,1))
  # Calculate accuracy for 3000 examples
  # tf.cast类型转换
  accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
  print("Accuracy:",accuracy.eval({X: mnist.test.images[:3000], Y: mnist.test.labels[:3000]}))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中使用wxPython开发的一个简易笔记本程序实例
Feb 08 Python
Python文件和目录操作详解
Feb 08 Python
python实现将文本转换成语音的方法
May 28 Python
Windows下Anaconda的安装和简单使用方法
Jan 04 Python
基于Python实现定时自动给微信好友发送天气预报
Oct 25 Python
Python的numpy库下的几个小函数的用法(小结)
Jul 12 Python
关于Python核心框架tornado的异步协程的2种方法详解
Aug 28 Python
pandas实现将日期转换成timestamp
Dec 07 Python
pytorch制作自己的LMDB数据操作示例
Dec 18 Python
关于torch.optim的灵活使用详解(包括重写SGD,加上L1正则)
Feb 20 Python
在tensorflow以及keras安装目录查询操作(windows下)
Jun 19 Python
Python调用腾讯API实现人脸身份证比对功能
Apr 04 Python
tensorflow实现简单逻辑回归
Sep 07 #Python
Tensorflow使用支持向量机拟合线性回归
Sep 07 #Python
TensorFlow实现iris数据集线性回归
Sep 07 #Python
TensorFlow实现模型评估
Sep 07 #Python
使用tensorflow实现线性svm
Sep 07 #Python
Python多进程池 multiprocessing Pool用法示例
Sep 07 #Python
详解python while 函数及while和for的区别
Sep 07 #Python
You might like
php 移除数组重复元素的一点说明
2008/11/27 PHP
php 更新数据库中断的解决方法
2009/06/05 PHP
php的mssql数据库连接类实例
2014/11/28 PHP
PHP中使用正则表达式提取中文实现笔记
2015/01/20 PHP
php接口数据加密、解密、验证签名
2015/03/12 PHP
js判断鼠标同时离开两个div的思路及代码
2013/05/31 Javascript
IE中的File域无法清空使用jQuery重设File域
2014/04/24 Javascript
基于javascript的JSON格式页面展示美化方法
2014/07/02 Javascript
CSS或者JS实现鼠标悬停显示另一元素
2016/01/22 Javascript
jQuery中each()、find()和filter()等节点操作方法详解(推荐)
2016/05/25 Javascript
javascript 解决浏览器不支持的问题
2016/09/24 Javascript
JavaScript中数组Array方法详解
2017/02/27 Javascript
vue2 自定义动态组件所遇到的问题
2017/06/08 Javascript
在vue中实现点击选择框阻止弹出层消失的方法
2018/09/15 Javascript
详解vue-video-player使用心得(兼容m3u8)
2019/08/23 Javascript
Vertx基于EventBus发送接受自定义对象
2020/11/16 Javascript
[01:08:10]2014 DOTA2国际邀请赛中国区预选赛 SPD-GAMING VS LGD-CDEC
2014/05/22 DOTA
在Python的web框架中中编写日志列表的教程
2015/04/30 Python
简单了解python中的f.b.u.r函数
2019/11/02 Python
关于Numpy中的行向量和列向量详解
2019/11/30 Python
linux 下python多线程递归复制文件夹及文件夹中的文件
2020/01/02 Python
Django 批量插入数据的实现方法
2020/01/12 Python
python检查目录文件权限并修改目录文件权限的操作
2020/03/11 Python
python如何编写win程序
2020/06/08 Python
PyTorch预训练Bert模型的示例
2020/11/17 Python
美国著名首饰网站:BaubleBar
2016/08/29 全球购物
Sunglasses Shop丹麦:欧洲第一的太阳镜在线销售网站
2017/10/22 全球购物
老公给老婆的保证书
2014/04/28 职场文书
大专毕业生求职信
2014/07/05 职场文书
管理工程专业求职信
2014/08/10 职场文书
合同和协议有什么区别?
2014/10/08 职场文书
公司禁烟通知
2015/04/23 职场文书
会议主持词结束语
2015/07/03 职场文书
python本地文件服务器实例教程
2021/05/02 Python
Netty分布式客户端接入流程初始化源码分析
2022/03/25 Java/Android
Mysql表数据比较大情况下修改添加字段的方法实例
2022/06/28 MySQL