python 使用Tensorflow训练BP神经网络实现鸢尾花分类


Posted in Python onMay 12, 2021

Hello,兄弟们,开始搞深度学习了,今天出第一篇博客,小白一枚,如果发现错误请及时指正,万分感谢。

使用软件

Python 3.8,Tensorflow2.0

问题描述

鸢尾花主要分为狗尾草鸢尾(0)、杂色鸢尾(1)、弗吉尼亚鸢尾(2)。
人们发现通过计算鸢尾花的花萼长、花萼宽、花瓣长、花瓣宽可以将鸢尾花分类。
所以只要给出足够多的鸢尾花花萼、花瓣数据,以及对应种类,使用合适的神经网络训练,就可以实现鸢尾花分类。

搭建神经网络

输入数据是花萼长、花萼宽、花瓣长、花瓣宽,是n行四列的矩阵。
而输出的是每个种类的概率,是n行三列的矩阵。
我们采用BP神经网络,设X为输入数据,Y为输出数据,W为权重,B偏置。有

y=x∗w+b

因为x为n行四列的矩阵,y为n行三列的矩阵,所以w必须为四行三列的矩阵,每个神经元对应一个b,所以b为一行三列的的矩阵。
神经网络如下图。

python 使用Tensorflow训练BP神经网络实现鸢尾花分类

所以,只要找到合适的w和b,就能准确判断鸢尾花的种类。
下面就开始对这两个参数进行训练。

训练参数

损失函数

损失函数表达的是预测值(y*)和真实值(y)的差距,我们采用均方误差公式作为损失函数。

python 使用Tensorflow训练BP神经网络实现鸢尾花分类

损失函数值越小,说明预测值和真实值越接近,w和b就越合适。
如果人来一组一组试,那肯定是不行的。所以我们采用梯度下降算法来找到损失函数最小值。
梯度:对函数求偏导的向量。梯度下降的方向就是函数减少的方向。

python 使用Tensorflow训练BP神经网络实现鸢尾花分类

其中a为学习率,即梯度下降的步长,如果a太大,就可能错过最优值,如果a太小,则就需要更多步才能找到最优值。所以选择合适的学习率很关键。

python 使用Tensorflow训练BP神经网络实现鸢尾花分类

参数优化

通过反向传播来优化参数。
反向传播:从后向前,逐层求损失函数对每层神经元参数的偏导数,迭代更新所有参数。
比如

python 使用Tensorflow训练BP神经网络实现鸢尾花分类

python 使用Tensorflow训练BP神经网络实现鸢尾花分类

可以看到w会逐渐趋向于loss的最小值0。
以上就是我们训练的全部关键点。

代码

数据集

我们使用sklearn包提供的鸢尾花数据集。共150组数据。
打乱保证数据的随机性,取前120个为训练集,后30个为测试集。

# 导入数据,分别为输入特征和标签
x_data = datasets.load_iris().data ## 存花萼、花瓣特征数据
y_data = datasets.load_iris().target # 存对应种类
# 随机打乱数据(因为原始数据是顺序的,顺序不打乱会影响准确率)
# seed: 随机数种子,是一个整数,当设置之后,每次生成的随机数都一样(为方便教学,以保每位同学结果一致)
np.random.seed(116)  # 使用相同的seed,保证输入特征和标签一一对应
np.random.shuffle(x_data)
np.random.seed(116)
np.random.shuffle(y_data)
tf.random.set_seed(116)
# 将打乱后的数据集分割为训练集和测试集,训练集为前120行,测试集为后30行
x_train = x_data[:-30]
y_train = y_data[:-30]
x_test = x_data[-30:]
y_test = y_data[-30:]
# 转换x的数据类型,否则后面矩阵相乘时会因数据类型不一致报错
x_train = tf.cast(x_train, tf.float32)
x_test = tf.cast(x_test, tf.float32)
# from_tensor_slices函数使输入特征和标签值一一对应。(把数据集分批次,每个批次batch组数据)
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

参数

# 生成神经网络的参数,4个输入特征故,输入层为4个输入节点;因为3分类,故输出层为3个神经元
# 用tf.Variable()标记参数可训练
w1 = tf.Variable(tf.random.truncated_normal([4, 3], stddev=0.1)) # 四行三列,方差为0.1
b1 = tf.Variable(tf.random.truncated_normal([3], stddev=0.1)) # 一行三列,方差为0.1

训练

a = 0.1  # 学习率为0.1
epoch = 500  # 循环500轮
# 训练部分
for epoch in range(epoch):  # 数据集级别的循环,每个epoch循环一次数据集
    for step, (x_train, y_train) in enumerate(train_db):  # batch级别的循环 ,每个step循环一个batch
        with tf.GradientTape() as tape:  # with结构记录梯度信息
            y = tf.matmul(x_train, w1) + b1  # 神经网络乘加运算
            y = tf.nn.softmax(y)  # 使输出y符合概率分布
            y_ = tf.one_hot(y_train, depth=3)  # 将标签值转换为独热码格式,方便计算loss
            loss = tf.reduce_mean(tf.square(y_ - y))  # 采用均方误差损失函数mse = mean(sum(y-y*)^2)
        # 计算loss对w, b的梯度
        grads = tape.gradient(loss, [w1, b1])
        # 实现梯度更新 w1 = w1 - lr * w1_grad    b = b - lr * b_grad
        w1.assign_sub(a * grads[0])  # 参数w1自更新
        b1.assign_sub(a * grads[1])  # 参数b自更新

测试

# 测试部分
total_correct, total_number = 0, 0
for x_test, y_test in test_db:
    # 前向传播求概率
    y = tf.matmul(x_test, w1) + b1
    y = tf.nn.softmax(y)
    predict = tf.argmax(y, axis=1)  # 返回y中最大值的索引,即预测的分类
    # 将predict转换为y_test的数据类型
    predict = tf.cast(predict, dtype=y_test.dtype)
    # 若分类正确,则correct=1,否则为0,将bool型的结果转换为int型
    correct = tf.cast(tf.equal(predict, y_test), dtype=tf.int32)
    # 将每个batch的correct数加起来
    correct = tf.reduce_sum(correct)
    # 将所有batch中的correct数加起来
    total_correct += int(correct)
    # total_number为测试的总样本数,也就是x_test的行数,shape[0]返回变量的行数
    total_number += x_test.shape[0]
# 总的准确率等于total_correct/total_number
acc = total_correct / total_number
print("测试准确率 = %.2f %%" % (acc * 100.0))
my_test = np.array([[5.9, 3.0, 5.1, 1.8]])
print("输入 5.9  3.0  5.1  1.8")
my_test = tf.convert_to_tensor(my_test)
my_test = tf.cast(my_test, tf.float32)
y = tf.matmul(my_test, w1) + b1
y = tf.nn.softmax(y)
species = {0: "狗尾鸢尾", 1: "杂色鸢尾", 2: "弗吉尼亚鸢尾"}
predict = np.array(tf.argmax(y, axis=1))[0]  # 返回y中最大值的索引,即预测的分类
print("该鸢尾花为:" + species.get(predict))

结果:

python 使用Tensorflow训练BP神经网络实现鸢尾花分类

结语

以上就是全部内容,鸢尾花分类作为经典案例,应该重点掌握理解。有一起学习的伙伴可以把想法打在评论区,大家多多交流,我也会及时回复的!

以上就是python 使用Tensorflow训练BP神经网络实现鸢尾花分类的详细内容,更多关于python 训练BP神经网络实现鸢尾花分类的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
基于python的汉字转GBK码实现代码
Feb 19 Python
go和python调用其它程序并得到程序输出
Feb 10 Python
python读取TXT到数组及列表去重后按原来顺序排序的方法
Jun 26 Python
使用Python实现博客上进行自动翻页
Aug 23 Python
Python实现的根据文件名查找数据文件功能示例
May 02 Python
Python实现的括号匹配判断功能示例
Aug 25 Python
python实现停车管理系统
Nov 30 Python
Python K最近邻从原理到实现的方法
Aug 15 Python
pygame实现五子棋游戏
Oct 29 Python
python数据预处理 :样本分布不均的解决(过采样和欠采样)
Feb 29 Python
python 实现Harris角点检测算法
Dec 11 Python
详解PyTorch模型保存与加载
Apr 28 Python
PyTorch 如何设置随机数种子使结果可复现
May 12 #Python
Python Parser的用法
May 12 #Python
pytorch MSELoss计算平均的实现方法
May 12 #Python
Django如何创作一个简单的最小程序
May 12 #Python
Pytorch中TensorBoard及torchsummary的使用详解
pytorch 一行代码查看网络参数总量的实现
May 12 #Python
pytorch查看网络参数显存占用量等操作
May 12 #Python
You might like
php删除数组中重复元素的方法
2015/12/22 PHP
关于 Laravel Redis 多个进程同时取队列问题详解
2017/12/25 PHP
PHP中命名空间的使用例子
2019/03/22 PHP
在laravel5.2中实现点击用户头像更改头像的方法
2019/10/14 PHP
AngularJS入门教程(零):引导程序
2014/12/06 Javascript
EasyUI,点击开启编辑框,并且编辑框获得焦点的方法
2015/03/01 Javascript
JavaScript通过setTimeout实时显示当前时间的方法
2015/04/16 Javascript
JS中取二维数组中最大值的方法汇总
2016/04/17 Javascript
js原生实现FastClick事件的实例
2016/11/20 Javascript
windows下vue.js开发环境搭建教程
2017/03/20 Javascript
详解如何让Express支持async/await
2017/10/09 Javascript
基于input动态模糊查询的实现方法
2017/12/12 Javascript
JavaScript基础心法 数据类型
2018/03/05 Javascript
在 Vue.js中优雅地使用全局事件的方法
2019/02/01 Javascript
JavaScript键盘事件响应顺序详解
2019/09/30 Javascript
Javascript地址引用代码实例解析
2020/02/25 Javascript
vue如何在用户要关闭当前网页时弹出提示的实现
2020/05/31 Javascript
html中创建并调用vue组件的几种方法汇总
2020/11/17 Javascript
vue登录页实现使用cookie记住7天密码功能的方法
2021/02/18 Vue.js
python中lambda函数 list comprehension 和 zip函数使用指南
2014/09/28 Python
Python使用剪切板的方法
2017/06/06 Python
利用python库在局域网内传输文件的方法
2018/06/04 Python
基于Django框架利用Ajax实现点赞功能实例代码
2018/08/19 Python
Python matplotlib修改默认字体的操作
2020/03/05 Python
捷克体育用品购物网站:D-sport
2017/12/28 全球购物
预订旅游活动、景点和旅游:GetYourGuide
2019/09/29 全球购物
护理专业自荐信
2013/12/03 职场文书
应届实习生的自我评价范文
2014/01/05 职场文书
便利店投资创业计划书
2014/02/08 职场文书
2014植树节活动总结
2014/03/11 职场文书
学校师德师风整改措施
2014/10/27 职场文书
2015年汽车销售工作总结
2015/04/07 职场文书
运动员代表致辞
2015/07/29 职场文书
SQL实现LeetCode(196.删除重复邮箱)
2021/08/07 MySQL
JavaScript小技巧带你提升你的代码技能
2021/09/15 Javascript