python 使用Tensorflow训练BP神经网络实现鸢尾花分类


Posted in Python onMay 12, 2021

Hello,兄弟们,开始搞深度学习了,今天出第一篇博客,小白一枚,如果发现错误请及时指正,万分感谢。

使用软件

Python 3.8,Tensorflow2.0

问题描述

鸢尾花主要分为狗尾草鸢尾(0)、杂色鸢尾(1)、弗吉尼亚鸢尾(2)。
人们发现通过计算鸢尾花的花萼长、花萼宽、花瓣长、花瓣宽可以将鸢尾花分类。
所以只要给出足够多的鸢尾花花萼、花瓣数据,以及对应种类,使用合适的神经网络训练,就可以实现鸢尾花分类。

搭建神经网络

输入数据是花萼长、花萼宽、花瓣长、花瓣宽,是n行四列的矩阵。
而输出的是每个种类的概率,是n行三列的矩阵。
我们采用BP神经网络,设X为输入数据,Y为输出数据,W为权重,B偏置。有

y=x∗w+b

因为x为n行四列的矩阵,y为n行三列的矩阵,所以w必须为四行三列的矩阵,每个神经元对应一个b,所以b为一行三列的的矩阵。
神经网络如下图。

python 使用Tensorflow训练BP神经网络实现鸢尾花分类

所以,只要找到合适的w和b,就能准确判断鸢尾花的种类。
下面就开始对这两个参数进行训练。

训练参数

损失函数

损失函数表达的是预测值(y*)和真实值(y)的差距,我们采用均方误差公式作为损失函数。

python 使用Tensorflow训练BP神经网络实现鸢尾花分类

损失函数值越小,说明预测值和真实值越接近,w和b就越合适。
如果人来一组一组试,那肯定是不行的。所以我们采用梯度下降算法来找到损失函数最小值。
梯度:对函数求偏导的向量。梯度下降的方向就是函数减少的方向。

python 使用Tensorflow训练BP神经网络实现鸢尾花分类

其中a为学习率,即梯度下降的步长,如果a太大,就可能错过最优值,如果a太小,则就需要更多步才能找到最优值。所以选择合适的学习率很关键。

python 使用Tensorflow训练BP神经网络实现鸢尾花分类

参数优化

通过反向传播来优化参数。
反向传播:从后向前,逐层求损失函数对每层神经元参数的偏导数,迭代更新所有参数。
比如

python 使用Tensorflow训练BP神经网络实现鸢尾花分类

python 使用Tensorflow训练BP神经网络实现鸢尾花分类

可以看到w会逐渐趋向于loss的最小值0。
以上就是我们训练的全部关键点。

代码

数据集

我们使用sklearn包提供的鸢尾花数据集。共150组数据。
打乱保证数据的随机性,取前120个为训练集,后30个为测试集。

# 导入数据,分别为输入特征和标签
x_data = datasets.load_iris().data ## 存花萼、花瓣特征数据
y_data = datasets.load_iris().target # 存对应种类
# 随机打乱数据(因为原始数据是顺序的,顺序不打乱会影响准确率)
# seed: 随机数种子,是一个整数,当设置之后,每次生成的随机数都一样(为方便教学,以保每位同学结果一致)
np.random.seed(116)  # 使用相同的seed,保证输入特征和标签一一对应
np.random.shuffle(x_data)
np.random.seed(116)
np.random.shuffle(y_data)
tf.random.set_seed(116)
# 将打乱后的数据集分割为训练集和测试集,训练集为前120行,测试集为后30行
x_train = x_data[:-30]
y_train = y_data[:-30]
x_test = x_data[-30:]
y_test = y_data[-30:]
# 转换x的数据类型,否则后面矩阵相乘时会因数据类型不一致报错
x_train = tf.cast(x_train, tf.float32)
x_test = tf.cast(x_test, tf.float32)
# from_tensor_slices函数使输入特征和标签值一一对应。(把数据集分批次,每个批次batch组数据)
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

参数

# 生成神经网络的参数,4个输入特征故,输入层为4个输入节点;因为3分类,故输出层为3个神经元
# 用tf.Variable()标记参数可训练
w1 = tf.Variable(tf.random.truncated_normal([4, 3], stddev=0.1)) # 四行三列,方差为0.1
b1 = tf.Variable(tf.random.truncated_normal([3], stddev=0.1)) # 一行三列,方差为0.1

训练

a = 0.1  # 学习率为0.1
epoch = 500  # 循环500轮
# 训练部分
for epoch in range(epoch):  # 数据集级别的循环,每个epoch循环一次数据集
    for step, (x_train, y_train) in enumerate(train_db):  # batch级别的循环 ,每个step循环一个batch
        with tf.GradientTape() as tape:  # with结构记录梯度信息
            y = tf.matmul(x_train, w1) + b1  # 神经网络乘加运算
            y = tf.nn.softmax(y)  # 使输出y符合概率分布
            y_ = tf.one_hot(y_train, depth=3)  # 将标签值转换为独热码格式,方便计算loss
            loss = tf.reduce_mean(tf.square(y_ - y))  # 采用均方误差损失函数mse = mean(sum(y-y*)^2)
        # 计算loss对w, b的梯度
        grads = tape.gradient(loss, [w1, b1])
        # 实现梯度更新 w1 = w1 - lr * w1_grad    b = b - lr * b_grad
        w1.assign_sub(a * grads[0])  # 参数w1自更新
        b1.assign_sub(a * grads[1])  # 参数b自更新

测试

# 测试部分
total_correct, total_number = 0, 0
for x_test, y_test in test_db:
    # 前向传播求概率
    y = tf.matmul(x_test, w1) + b1
    y = tf.nn.softmax(y)
    predict = tf.argmax(y, axis=1)  # 返回y中最大值的索引,即预测的分类
    # 将predict转换为y_test的数据类型
    predict = tf.cast(predict, dtype=y_test.dtype)
    # 若分类正确,则correct=1,否则为0,将bool型的结果转换为int型
    correct = tf.cast(tf.equal(predict, y_test), dtype=tf.int32)
    # 将每个batch的correct数加起来
    correct = tf.reduce_sum(correct)
    # 将所有batch中的correct数加起来
    total_correct += int(correct)
    # total_number为测试的总样本数,也就是x_test的行数,shape[0]返回变量的行数
    total_number += x_test.shape[0]
# 总的准确率等于total_correct/total_number
acc = total_correct / total_number
print("测试准确率 = %.2f %%" % (acc * 100.0))
my_test = np.array([[5.9, 3.0, 5.1, 1.8]])
print("输入 5.9  3.0  5.1  1.8")
my_test = tf.convert_to_tensor(my_test)
my_test = tf.cast(my_test, tf.float32)
y = tf.matmul(my_test, w1) + b1
y = tf.nn.softmax(y)
species = {0: "狗尾鸢尾", 1: "杂色鸢尾", 2: "弗吉尼亚鸢尾"}
predict = np.array(tf.argmax(y, axis=1))[0]  # 返回y中最大值的索引,即预测的分类
print("该鸢尾花为:" + species.get(predict))

结果:

python 使用Tensorflow训练BP神经网络实现鸢尾花分类

结语

以上就是全部内容,鸢尾花分类作为经典案例,应该重点掌握理解。有一起学习的伙伴可以把想法打在评论区,大家多多交流,我也会及时回复的!

以上就是python 使用Tensorflow训练BP神经网络实现鸢尾花分类的详细内容,更多关于python 训练BP神经网络实现鸢尾花分类的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
python实现将内容分行输出
Nov 05 Python
Python中对象的引用与复制代码示例
Dec 04 Python
在Python中分别打印列表中的每一个元素方法
Nov 07 Python
Python使用Selenium爬取淘宝异步加载的数据方法
Dec 17 Python
Python redis操作实例分析【连接、管道、发布和订阅等】
May 16 Python
Python实现带下标索引的遍历操作示例
May 30 Python
Django 查询数据库并返回页面的例子
Aug 12 Python
使用Python和OpenCV检测图像中的物体并将物体裁剪下来
Oct 30 Python
Python Pickle 实现在同一个文件中序列化多个对象
Dec 30 Python
python自动下载图片的方法示例
Mar 25 Python
Scrapy中如何向Spider传入参数的方法实现
Sep 28 Python
matlab xlabel位置的设置方式
May 21 Python
PyTorch 如何设置随机数种子使结果可复现
May 12 #Python
Python Parser的用法
May 12 #Python
pytorch MSELoss计算平均的实现方法
May 12 #Python
Django如何创作一个简单的最小程序
May 12 #Python
Pytorch中TensorBoard及torchsummary的使用详解
pytorch 一行代码查看网络参数总量的实现
May 12 #Python
pytorch查看网络参数显存占用量等操作
May 12 #Python
You might like
php实现cc攻击防御和防止快速刷新页面示例
2014/02/13 PHP
Yii2使用小技巧之通过 Composer 添加 FontAwesome 字体资源
2014/06/22 PHP
JS下拉缓冲菜单示例代码
2013/08/30 Javascript
js实现省市联动效果的简单实例
2014/02/10 Javascript
SyntaxHighlighter 3.0.83使用笔记
2015/01/26 Javascript
基于JavaScript实现拖动滑块效果
2017/02/16 Javascript
深究AngularJS之ui-router详解
2017/06/13 Javascript
JS基于正则实现数字千分位用逗号分隔的方法
2017/06/16 Javascript
实现两个文本框同时输入的实例
2017/09/25 Javascript
vue中SPA单页面应用程序详解
2017/11/07 Javascript
JavaScript中Object基础内部方法图
2018/02/05 Javascript
Vue2.0子同级组件之间数据交互方法
2018/02/28 Javascript
js经验分享 JavaScript反调试技巧
2018/03/10 Javascript
灵活使用console让js调试更简单的方法步骤
2019/04/23 Javascript
24行JavaScript代码实现Redux的方法实例
2019/11/17 Javascript
搭建vscode+vue环境的详细教程
2020/08/31 Javascript
[01:48]DOTA2 2015国际邀请赛中国区预选赛第二日战报
2015/05/27 DOTA
[01:18:36]LGD vs VP Supermajor 败者组决赛 BO3 第一场 6.10
2018/07/04 DOTA
[57:41]Secret vs Serenity 2018国际邀请赛小组赛BO2 第一场 8.16
2018/08/17 DOTA
pyramid配置session的方法教程
2013/11/27 Python
Python 爬虫学习笔记之单线程爬虫
2016/09/21 Python
Python实现对象转换为xml的方法示例
2017/06/08 Python
Python中turtle作图示例
2017/11/15 Python
Python自然语言处理之词干,词形与最大匹配算法代码详解
2017/11/16 Python
python:print格式化输出到文件的实例
2018/05/14 Python
Python英文文章词频统计(14份剑桥真题词频统计)
2019/10/13 Python
Python笔记之观察者模式
2019/11/20 Python
pandas处理csv文件的方法步骤
2020/10/16 Python
CSS3系列教程:背景图片(背景大小和多背景图) 应用说明
2012/12/19 HTML / CSS
Mamaearth官方网站:印度母婴护理产品公司
2019/10/06 全球购物
幼儿教师思想汇报
2014/01/10 职场文书
少先队活动总结
2014/08/29 职场文书
乡镇遵守党的政治纪律情况对照检查材料
2014/09/26 职场文书
工作批评与自我批评范文
2014/10/16 职场文书
《夜莺的歌声》教学反思
2016/02/22 职场文书
MySQL之高可用集群部署及故障切换实现
2021/04/22 MySQL