python 使用Tensorflow训练BP神经网络实现鸢尾花分类


Posted in Python onMay 12, 2021

Hello,兄弟们,开始搞深度学习了,今天出第一篇博客,小白一枚,如果发现错误请及时指正,万分感谢。

使用软件

Python 3.8,Tensorflow2.0

问题描述

鸢尾花主要分为狗尾草鸢尾(0)、杂色鸢尾(1)、弗吉尼亚鸢尾(2)。
人们发现通过计算鸢尾花的花萼长、花萼宽、花瓣长、花瓣宽可以将鸢尾花分类。
所以只要给出足够多的鸢尾花花萼、花瓣数据,以及对应种类,使用合适的神经网络训练,就可以实现鸢尾花分类。

搭建神经网络

输入数据是花萼长、花萼宽、花瓣长、花瓣宽,是n行四列的矩阵。
而输出的是每个种类的概率,是n行三列的矩阵。
我们采用BP神经网络,设X为输入数据,Y为输出数据,W为权重,B偏置。有

y=x∗w+b

因为x为n行四列的矩阵,y为n行三列的矩阵,所以w必须为四行三列的矩阵,每个神经元对应一个b,所以b为一行三列的的矩阵。
神经网络如下图。

python 使用Tensorflow训练BP神经网络实现鸢尾花分类

所以,只要找到合适的w和b,就能准确判断鸢尾花的种类。
下面就开始对这两个参数进行训练。

训练参数

损失函数

损失函数表达的是预测值(y*)和真实值(y)的差距,我们采用均方误差公式作为损失函数。

python 使用Tensorflow训练BP神经网络实现鸢尾花分类

损失函数值越小,说明预测值和真实值越接近,w和b就越合适。
如果人来一组一组试,那肯定是不行的。所以我们采用梯度下降算法来找到损失函数最小值。
梯度:对函数求偏导的向量。梯度下降的方向就是函数减少的方向。

python 使用Tensorflow训练BP神经网络实现鸢尾花分类

其中a为学习率,即梯度下降的步长,如果a太大,就可能错过最优值,如果a太小,则就需要更多步才能找到最优值。所以选择合适的学习率很关键。

python 使用Tensorflow训练BP神经网络实现鸢尾花分类

参数优化

通过反向传播来优化参数。
反向传播:从后向前,逐层求损失函数对每层神经元参数的偏导数,迭代更新所有参数。
比如

python 使用Tensorflow训练BP神经网络实现鸢尾花分类

python 使用Tensorflow训练BP神经网络实现鸢尾花分类

可以看到w会逐渐趋向于loss的最小值0。
以上就是我们训练的全部关键点。

代码

数据集

我们使用sklearn包提供的鸢尾花数据集。共150组数据。
打乱保证数据的随机性,取前120个为训练集,后30个为测试集。

# 导入数据,分别为输入特征和标签
x_data = datasets.load_iris().data ## 存花萼、花瓣特征数据
y_data = datasets.load_iris().target # 存对应种类
# 随机打乱数据(因为原始数据是顺序的,顺序不打乱会影响准确率)
# seed: 随机数种子,是一个整数,当设置之后,每次生成的随机数都一样(为方便教学,以保每位同学结果一致)
np.random.seed(116)  # 使用相同的seed,保证输入特征和标签一一对应
np.random.shuffle(x_data)
np.random.seed(116)
np.random.shuffle(y_data)
tf.random.set_seed(116)
# 将打乱后的数据集分割为训练集和测试集,训练集为前120行,测试集为后30行
x_train = x_data[:-30]
y_train = y_data[:-30]
x_test = x_data[-30:]
y_test = y_data[-30:]
# 转换x的数据类型,否则后面矩阵相乘时会因数据类型不一致报错
x_train = tf.cast(x_train, tf.float32)
x_test = tf.cast(x_test, tf.float32)
# from_tensor_slices函数使输入特征和标签值一一对应。(把数据集分批次,每个批次batch组数据)
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

参数

# 生成神经网络的参数,4个输入特征故,输入层为4个输入节点;因为3分类,故输出层为3个神经元
# 用tf.Variable()标记参数可训练
w1 = tf.Variable(tf.random.truncated_normal([4, 3], stddev=0.1)) # 四行三列,方差为0.1
b1 = tf.Variable(tf.random.truncated_normal([3], stddev=0.1)) # 一行三列,方差为0.1

训练

a = 0.1  # 学习率为0.1
epoch = 500  # 循环500轮
# 训练部分
for epoch in range(epoch):  # 数据集级别的循环,每个epoch循环一次数据集
    for step, (x_train, y_train) in enumerate(train_db):  # batch级别的循环 ,每个step循环一个batch
        with tf.GradientTape() as tape:  # with结构记录梯度信息
            y = tf.matmul(x_train, w1) + b1  # 神经网络乘加运算
            y = tf.nn.softmax(y)  # 使输出y符合概率分布
            y_ = tf.one_hot(y_train, depth=3)  # 将标签值转换为独热码格式,方便计算loss
            loss = tf.reduce_mean(tf.square(y_ - y))  # 采用均方误差损失函数mse = mean(sum(y-y*)^2)
        # 计算loss对w, b的梯度
        grads = tape.gradient(loss, [w1, b1])
        # 实现梯度更新 w1 = w1 - lr * w1_grad    b = b - lr * b_grad
        w1.assign_sub(a * grads[0])  # 参数w1自更新
        b1.assign_sub(a * grads[1])  # 参数b自更新

测试

# 测试部分
total_correct, total_number = 0, 0
for x_test, y_test in test_db:
    # 前向传播求概率
    y = tf.matmul(x_test, w1) + b1
    y = tf.nn.softmax(y)
    predict = tf.argmax(y, axis=1)  # 返回y中最大值的索引,即预测的分类
    # 将predict转换为y_test的数据类型
    predict = tf.cast(predict, dtype=y_test.dtype)
    # 若分类正确,则correct=1,否则为0,将bool型的结果转换为int型
    correct = tf.cast(tf.equal(predict, y_test), dtype=tf.int32)
    # 将每个batch的correct数加起来
    correct = tf.reduce_sum(correct)
    # 将所有batch中的correct数加起来
    total_correct += int(correct)
    # total_number为测试的总样本数,也就是x_test的行数,shape[0]返回变量的行数
    total_number += x_test.shape[0]
# 总的准确率等于total_correct/total_number
acc = total_correct / total_number
print("测试准确率 = %.2f %%" % (acc * 100.0))
my_test = np.array([[5.9, 3.0, 5.1, 1.8]])
print("输入 5.9  3.0  5.1  1.8")
my_test = tf.convert_to_tensor(my_test)
my_test = tf.cast(my_test, tf.float32)
y = tf.matmul(my_test, w1) + b1
y = tf.nn.softmax(y)
species = {0: "狗尾鸢尾", 1: "杂色鸢尾", 2: "弗吉尼亚鸢尾"}
predict = np.array(tf.argmax(y, axis=1))[0]  # 返回y中最大值的索引,即预测的分类
print("该鸢尾花为:" + species.get(predict))

结果:

python 使用Tensorflow训练BP神经网络实现鸢尾花分类

结语

以上就是全部内容,鸢尾花分类作为经典案例,应该重点掌握理解。有一起学习的伙伴可以把想法打在评论区,大家多多交流,我也会及时回复的!

以上就是python 使用Tensorflow训练BP神经网络实现鸢尾花分类的详细内容,更多关于python 训练BP神经网络实现鸢尾花分类的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
windows下python模拟鼠标点击和键盘输示例
Feb 28 Python
用python读写excel的方法
Nov 18 Python
Python的Django中django-userena组件的简单使用教程
May 30 Python
Python连接PostgreSQL数据库的方法
Nov 28 Python
Python基于ThreadingTCPServer创建多线程代理的方法示例
Jan 11 Python
用python 批量更改图像尺寸到统一大小的方法
Mar 31 Python
python操作excel的方法(xlsxwriter包的使用)
Jun 11 Python
python skimage 连通性区域检测方法
Jun 21 Python
python 多进程和协程配合使用写入数据
Oct 30 Python
Python和Bash结合在一起的方法
Nov 13 Python
详解Python生成器和基于生成器的协程
Jun 03 Python
Python3.8官网文档之类的基础语法阅读
Sep 04 Python
PyTorch 如何设置随机数种子使结果可复现
May 12 #Python
Python Parser的用法
May 12 #Python
pytorch MSELoss计算平均的实现方法
May 12 #Python
Django如何创作一个简单的最小程序
May 12 #Python
Pytorch中TensorBoard及torchsummary的使用详解
pytorch 一行代码查看网络参数总量的实现
May 12 #Python
pytorch查看网络参数显存占用量等操作
May 12 #Python
You might like
php抓取https的内容的代码
2010/04/06 PHP
eaglephp使用微信api接口开发微信框架
2014/01/09 PHP
PHP使用mkdir创建多级目录的方法
2015/12/22 PHP
PHP中的随机性 你觉得自己幸运吗?
2016/01/22 PHP
可兼容IE的获取及设置cookie的jquery.cookie函数方法
2013/09/02 Javascript
javascript上传图片前预览图片兼容大多数浏览器
2013/10/25 Javascript
教你如何在 Javascript 文件里使用 .Net MVC Razor 语法
2014/07/23 Javascript
深入理解JavaScript系列(39):设计模式之适配器模式详解
2015/03/04 Javascript
jQuery实现仿百度帖吧头部固定导航效果
2015/08/07 Javascript
JavaScript常用标签和方法总结
2015/09/01 Javascript
纯JavaScript代码实现移动设备绘图解锁
2015/10/16 Javascript
JS实现网页标题栏显示当前时间和日期的完整代码
2015/11/02 Javascript
详解nodejs微信公众号开发——2.自动回复
2017/04/10 NodeJs
js使用i18n实现页面国际化的方法
2017/05/09 Javascript
Vue组件实例间的直接访问实现代码
2017/08/20 Javascript
jQuery实现的事件绑定功能基本示例
2017/10/11 jQuery
javascript用rem来做响应式开发
2018/01/13 Javascript
Angular 4.x+Ionic3踩坑之Ionic 3.x界面传值详解
2018/03/13 Javascript
Angular封装搜索框组件操作示例
2019/04/25 Javascript
win7下python3.6安装配置方法图文教程
2018/07/31 Python
Python单向链表和双向链表原理与用法实例详解
2018/08/31 Python
OpenCV 之按位运算举例解析
2020/06/19 Python
详解python中的三种命令行模块(sys.argv,argparse,click)
2020/12/15 Python
h5页面背景图很长要有滚动条滑动效果的实现
2021/01/27 HTML / CSS
全球知名旅游社区法国站点:TripAdvisor法国
2016/08/03 全球购物
施华洛世奇加拿大官网:SWAROVSKI加拿大
2018/06/03 全球购物
Shopee新加坡:东南亚与台湾电商平台
2019/01/25 全球购物
东南亚排名第一的服务市场:kaodim
2019/03/28 全球购物
阿迪达斯希腊官方网上商店:adidas希腊
2019/04/06 全球购物
巴西Mr. Cat在线商店:购买包包和鞋子
2019/09/08 全球购物
移动通信行业实习自我鉴定
2013/09/28 职场文书
自我鉴定 电子商务专业
2014/01/30 职场文书
《学会合作》教学反思
2014/04/12 职场文书
运动会演讲稿100字
2014/08/25 职场文书
2014个人四风对照检查材料思想汇报
2014/09/18 职场文书
乡镇党员群众路线教育实践活动对照检查材料思想汇报
2014/10/05 职场文书