python 使用Tensorflow训练BP神经网络实现鸢尾花分类


Posted in Python onMay 12, 2021

Hello,兄弟们,开始搞深度学习了,今天出第一篇博客,小白一枚,如果发现错误请及时指正,万分感谢。

使用软件

Python 3.8,Tensorflow2.0

问题描述

鸢尾花主要分为狗尾草鸢尾(0)、杂色鸢尾(1)、弗吉尼亚鸢尾(2)。
人们发现通过计算鸢尾花的花萼长、花萼宽、花瓣长、花瓣宽可以将鸢尾花分类。
所以只要给出足够多的鸢尾花花萼、花瓣数据,以及对应种类,使用合适的神经网络训练,就可以实现鸢尾花分类。

搭建神经网络

输入数据是花萼长、花萼宽、花瓣长、花瓣宽,是n行四列的矩阵。
而输出的是每个种类的概率,是n行三列的矩阵。
我们采用BP神经网络,设X为输入数据,Y为输出数据,W为权重,B偏置。有

y=x∗w+b

因为x为n行四列的矩阵,y为n行三列的矩阵,所以w必须为四行三列的矩阵,每个神经元对应一个b,所以b为一行三列的的矩阵。
神经网络如下图。

python 使用Tensorflow训练BP神经网络实现鸢尾花分类

所以,只要找到合适的w和b,就能准确判断鸢尾花的种类。
下面就开始对这两个参数进行训练。

训练参数

损失函数

损失函数表达的是预测值(y*)和真实值(y)的差距,我们采用均方误差公式作为损失函数。

python 使用Tensorflow训练BP神经网络实现鸢尾花分类

损失函数值越小,说明预测值和真实值越接近,w和b就越合适。
如果人来一组一组试,那肯定是不行的。所以我们采用梯度下降算法来找到损失函数最小值。
梯度:对函数求偏导的向量。梯度下降的方向就是函数减少的方向。

python 使用Tensorflow训练BP神经网络实现鸢尾花分类

其中a为学习率,即梯度下降的步长,如果a太大,就可能错过最优值,如果a太小,则就需要更多步才能找到最优值。所以选择合适的学习率很关键。

python 使用Tensorflow训练BP神经网络实现鸢尾花分类

参数优化

通过反向传播来优化参数。
反向传播:从后向前,逐层求损失函数对每层神经元参数的偏导数,迭代更新所有参数。
比如

python 使用Tensorflow训练BP神经网络实现鸢尾花分类

python 使用Tensorflow训练BP神经网络实现鸢尾花分类

可以看到w会逐渐趋向于loss的最小值0。
以上就是我们训练的全部关键点。

代码

数据集

我们使用sklearn包提供的鸢尾花数据集。共150组数据。
打乱保证数据的随机性,取前120个为训练集,后30个为测试集。

# 导入数据,分别为输入特征和标签
x_data = datasets.load_iris().data ## 存花萼、花瓣特征数据
y_data = datasets.load_iris().target # 存对应种类
# 随机打乱数据(因为原始数据是顺序的,顺序不打乱会影响准确率)
# seed: 随机数种子,是一个整数,当设置之后,每次生成的随机数都一样(为方便教学,以保每位同学结果一致)
np.random.seed(116)  # 使用相同的seed,保证输入特征和标签一一对应
np.random.shuffle(x_data)
np.random.seed(116)
np.random.shuffle(y_data)
tf.random.set_seed(116)
# 将打乱后的数据集分割为训练集和测试集,训练集为前120行,测试集为后30行
x_train = x_data[:-30]
y_train = y_data[:-30]
x_test = x_data[-30:]
y_test = y_data[-30:]
# 转换x的数据类型,否则后面矩阵相乘时会因数据类型不一致报错
x_train = tf.cast(x_train, tf.float32)
x_test = tf.cast(x_test, tf.float32)
# from_tensor_slices函数使输入特征和标签值一一对应。(把数据集分批次,每个批次batch组数据)
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

参数

# 生成神经网络的参数,4个输入特征故,输入层为4个输入节点;因为3分类,故输出层为3个神经元
# 用tf.Variable()标记参数可训练
w1 = tf.Variable(tf.random.truncated_normal([4, 3], stddev=0.1)) # 四行三列,方差为0.1
b1 = tf.Variable(tf.random.truncated_normal([3], stddev=0.1)) # 一行三列,方差为0.1

训练

a = 0.1  # 学习率为0.1
epoch = 500  # 循环500轮
# 训练部分
for epoch in range(epoch):  # 数据集级别的循环,每个epoch循环一次数据集
    for step, (x_train, y_train) in enumerate(train_db):  # batch级别的循环 ,每个step循环一个batch
        with tf.GradientTape() as tape:  # with结构记录梯度信息
            y = tf.matmul(x_train, w1) + b1  # 神经网络乘加运算
            y = tf.nn.softmax(y)  # 使输出y符合概率分布
            y_ = tf.one_hot(y_train, depth=3)  # 将标签值转换为独热码格式,方便计算loss
            loss = tf.reduce_mean(tf.square(y_ - y))  # 采用均方误差损失函数mse = mean(sum(y-y*)^2)
        # 计算loss对w, b的梯度
        grads = tape.gradient(loss, [w1, b1])
        # 实现梯度更新 w1 = w1 - lr * w1_grad    b = b - lr * b_grad
        w1.assign_sub(a * grads[0])  # 参数w1自更新
        b1.assign_sub(a * grads[1])  # 参数b自更新

测试

# 测试部分
total_correct, total_number = 0, 0
for x_test, y_test in test_db:
    # 前向传播求概率
    y = tf.matmul(x_test, w1) + b1
    y = tf.nn.softmax(y)
    predict = tf.argmax(y, axis=1)  # 返回y中最大值的索引,即预测的分类
    # 将predict转换为y_test的数据类型
    predict = tf.cast(predict, dtype=y_test.dtype)
    # 若分类正确,则correct=1,否则为0,将bool型的结果转换为int型
    correct = tf.cast(tf.equal(predict, y_test), dtype=tf.int32)
    # 将每个batch的correct数加起来
    correct = tf.reduce_sum(correct)
    # 将所有batch中的correct数加起来
    total_correct += int(correct)
    # total_number为测试的总样本数,也就是x_test的行数,shape[0]返回变量的行数
    total_number += x_test.shape[0]
# 总的准确率等于total_correct/total_number
acc = total_correct / total_number
print("测试准确率 = %.2f %%" % (acc * 100.0))
my_test = np.array([[5.9, 3.0, 5.1, 1.8]])
print("输入 5.9  3.0  5.1  1.8")
my_test = tf.convert_to_tensor(my_test)
my_test = tf.cast(my_test, tf.float32)
y = tf.matmul(my_test, w1) + b1
y = tf.nn.softmax(y)
species = {0: "狗尾鸢尾", 1: "杂色鸢尾", 2: "弗吉尼亚鸢尾"}
predict = np.array(tf.argmax(y, axis=1))[0]  # 返回y中最大值的索引,即预测的分类
print("该鸢尾花为:" + species.get(predict))

结果:

python 使用Tensorflow训练BP神经网络实现鸢尾花分类

结语

以上就是全部内容,鸢尾花分类作为经典案例,应该重点掌握理解。有一起学习的伙伴可以把想法打在评论区,大家多多交流,我也会及时回复的!

以上就是python 使用Tensorflow训练BP神经网络实现鸢尾花分类的详细内容,更多关于python 训练BP神经网络实现鸢尾花分类的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
分析在Python中何种情况下需要使用断言
Apr 01 Python
Python数据类型详解(四)字典:dict
May 12 Python
python 安装virtualenv和virtualenvwrapper的方法
Jan 13 Python
详解tensorflow载入数据的三种方式
Apr 24 Python
python re模块的高级用法详解
Jun 06 Python
Python 类方法和实例方法(@classmethod),静态方法(@staticmethod)原理与用法分析
Sep 20 Python
pytorch的梯度计算以及backward方法详解
Jan 10 Python
python实现删除列表中某个元素的3种方法
Jan 15 Python
使用django自带的user做外键的方法
Nov 30 Python
使用python实现学生信息管理系统
Feb 25 Python
bat批处理之字符串操作的实现
Mar 16 Python
Python代码实现双链表
May 25 Python
PyTorch 如何设置随机数种子使结果可复现
May 12 #Python
Python Parser的用法
May 12 #Python
pytorch MSELoss计算平均的实现方法
May 12 #Python
Django如何创作一个简单的最小程序
May 12 #Python
Pytorch中TensorBoard及torchsummary的使用详解
pytorch 一行代码查看网络参数总量的实现
May 12 #Python
pytorch查看网络参数显存占用量等操作
May 12 #Python
You might like
在任意字符集下正常显示网页的方法一
2007/04/01 PHP
PHP页面中文乱码分析
2013/10/29 PHP
Yii学习总结之数据访问对象 (DAO)
2015/02/22 PHP
作为程序员必知的16个最佳PHP库
2015/12/09 PHP
PHP实现的数独求解问题示例
2017/04/18 PHP
理清PHP在Linxu下执行时的文件权限方法
2017/06/07 PHP
JS 常用校验函数
2009/03/26 Javascript
Javascript实现带关闭按钮的网页漂浮广告代码
2014/01/12 Javascript
jQuery实现“扫码阅读”功能
2015/01/21 Javascript
如何实现JavaScript动态加载CSS和JS文件
2020/12/28 Javascript
jQuery 1.9.1源码分析系列(十五)之动画处理
2015/12/03 Javascript
使用jQuery mobile库检测url绝对地址和相对地址的方法
2015/12/04 Javascript
Javascript将字符串日期格式化为yyyy-mm-dd的方法
2016/10/27 Javascript
vue组件开发之用户无限添加自定义填写表单的方法
2018/08/28 Javascript
如何使用webpack打包一个库library的方法步骤
2019/12/18 Javascript
js实现简单的无缝轮播效果
2020/09/05 Javascript
Python环境下搭建属于自己的pip源的教程
2016/05/05 Python
浅析AST抽象语法树及Python代码实现
2016/06/06 Python
python中reload(module)的用法示例详解
2017/09/15 Python
python实现excel读写数据
2021/03/02 Python
解决pandas使用read_csv()读取文件遇到的问题
2018/06/15 Python
python实现图片彩色转化为素描
2019/01/15 Python
Python字典遍历操作实例小结
2019/03/05 Python
Python OpenCV实现视频分帧
2019/06/01 Python
详解Python中的测试工具
2019/06/09 Python
Python中__repr__和__str__区别详解
2019/11/07 Python
简单了解python调用其他脚本方法实例
2020/03/26 Python
数学专业推荐信范文
2013/11/21 职场文书
中学生获奖感言
2014/02/04 职场文书
经管应届生求职信范文
2014/05/18 职场文书
2014县委书记四风对照检查材料思想汇报
2014/09/21 职场文书
诉讼代理人授权委托书
2014/10/11 职场文书
课堂打架检讨书200字
2014/11/21 职场文书
毕业生个人自荐书
2015/03/05 职场文书
工作自我评价范文
2019/03/21 职场文书
vue实现Toast组件轻提示
2022/04/10 Vue.js