python 使用Tensorflow训练BP神经网络实现鸢尾花分类


Posted in Python onMay 12, 2021

Hello,兄弟们,开始搞深度学习了,今天出第一篇博客,小白一枚,如果发现错误请及时指正,万分感谢。

使用软件

Python 3.8,Tensorflow2.0

问题描述

鸢尾花主要分为狗尾草鸢尾(0)、杂色鸢尾(1)、弗吉尼亚鸢尾(2)。
人们发现通过计算鸢尾花的花萼长、花萼宽、花瓣长、花瓣宽可以将鸢尾花分类。
所以只要给出足够多的鸢尾花花萼、花瓣数据,以及对应种类,使用合适的神经网络训练,就可以实现鸢尾花分类。

搭建神经网络

输入数据是花萼长、花萼宽、花瓣长、花瓣宽,是n行四列的矩阵。
而输出的是每个种类的概率,是n行三列的矩阵。
我们采用BP神经网络,设X为输入数据,Y为输出数据,W为权重,B偏置。有

y=x∗w+b

因为x为n行四列的矩阵,y为n行三列的矩阵,所以w必须为四行三列的矩阵,每个神经元对应一个b,所以b为一行三列的的矩阵。
神经网络如下图。

python 使用Tensorflow训练BP神经网络实现鸢尾花分类

所以,只要找到合适的w和b,就能准确判断鸢尾花的种类。
下面就开始对这两个参数进行训练。

训练参数

损失函数

损失函数表达的是预测值(y*)和真实值(y)的差距,我们采用均方误差公式作为损失函数。

python 使用Tensorflow训练BP神经网络实现鸢尾花分类

损失函数值越小,说明预测值和真实值越接近,w和b就越合适。
如果人来一组一组试,那肯定是不行的。所以我们采用梯度下降算法来找到损失函数最小值。
梯度:对函数求偏导的向量。梯度下降的方向就是函数减少的方向。

python 使用Tensorflow训练BP神经网络实现鸢尾花分类

其中a为学习率,即梯度下降的步长,如果a太大,就可能错过最优值,如果a太小,则就需要更多步才能找到最优值。所以选择合适的学习率很关键。

python 使用Tensorflow训练BP神经网络实现鸢尾花分类

参数优化

通过反向传播来优化参数。
反向传播:从后向前,逐层求损失函数对每层神经元参数的偏导数,迭代更新所有参数。
比如

python 使用Tensorflow训练BP神经网络实现鸢尾花分类

python 使用Tensorflow训练BP神经网络实现鸢尾花分类

可以看到w会逐渐趋向于loss的最小值0。
以上就是我们训练的全部关键点。

代码

数据集

我们使用sklearn包提供的鸢尾花数据集。共150组数据。
打乱保证数据的随机性,取前120个为训练集,后30个为测试集。

# 导入数据,分别为输入特征和标签
x_data = datasets.load_iris().data ## 存花萼、花瓣特征数据
y_data = datasets.load_iris().target # 存对应种类
# 随机打乱数据(因为原始数据是顺序的,顺序不打乱会影响准确率)
# seed: 随机数种子,是一个整数,当设置之后,每次生成的随机数都一样(为方便教学,以保每位同学结果一致)
np.random.seed(116)  # 使用相同的seed,保证输入特征和标签一一对应
np.random.shuffle(x_data)
np.random.seed(116)
np.random.shuffle(y_data)
tf.random.set_seed(116)
# 将打乱后的数据集分割为训练集和测试集,训练集为前120行,测试集为后30行
x_train = x_data[:-30]
y_train = y_data[:-30]
x_test = x_data[-30:]
y_test = y_data[-30:]
# 转换x的数据类型,否则后面矩阵相乘时会因数据类型不一致报错
x_train = tf.cast(x_train, tf.float32)
x_test = tf.cast(x_test, tf.float32)
# from_tensor_slices函数使输入特征和标签值一一对应。(把数据集分批次,每个批次batch组数据)
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

参数

# 生成神经网络的参数,4个输入特征故,输入层为4个输入节点;因为3分类,故输出层为3个神经元
# 用tf.Variable()标记参数可训练
w1 = tf.Variable(tf.random.truncated_normal([4, 3], stddev=0.1)) # 四行三列,方差为0.1
b1 = tf.Variable(tf.random.truncated_normal([3], stddev=0.1)) # 一行三列,方差为0.1

训练

a = 0.1  # 学习率为0.1
epoch = 500  # 循环500轮
# 训练部分
for epoch in range(epoch):  # 数据集级别的循环,每个epoch循环一次数据集
    for step, (x_train, y_train) in enumerate(train_db):  # batch级别的循环 ,每个step循环一个batch
        with tf.GradientTape() as tape:  # with结构记录梯度信息
            y = tf.matmul(x_train, w1) + b1  # 神经网络乘加运算
            y = tf.nn.softmax(y)  # 使输出y符合概率分布
            y_ = tf.one_hot(y_train, depth=3)  # 将标签值转换为独热码格式,方便计算loss
            loss = tf.reduce_mean(tf.square(y_ - y))  # 采用均方误差损失函数mse = mean(sum(y-y*)^2)
        # 计算loss对w, b的梯度
        grads = tape.gradient(loss, [w1, b1])
        # 实现梯度更新 w1 = w1 - lr * w1_grad    b = b - lr * b_grad
        w1.assign_sub(a * grads[0])  # 参数w1自更新
        b1.assign_sub(a * grads[1])  # 参数b自更新

测试

# 测试部分
total_correct, total_number = 0, 0
for x_test, y_test in test_db:
    # 前向传播求概率
    y = tf.matmul(x_test, w1) + b1
    y = tf.nn.softmax(y)
    predict = tf.argmax(y, axis=1)  # 返回y中最大值的索引,即预测的分类
    # 将predict转换为y_test的数据类型
    predict = tf.cast(predict, dtype=y_test.dtype)
    # 若分类正确,则correct=1,否则为0,将bool型的结果转换为int型
    correct = tf.cast(tf.equal(predict, y_test), dtype=tf.int32)
    # 将每个batch的correct数加起来
    correct = tf.reduce_sum(correct)
    # 将所有batch中的correct数加起来
    total_correct += int(correct)
    # total_number为测试的总样本数,也就是x_test的行数,shape[0]返回变量的行数
    total_number += x_test.shape[0]
# 总的准确率等于total_correct/total_number
acc = total_correct / total_number
print("测试准确率 = %.2f %%" % (acc * 100.0))
my_test = np.array([[5.9, 3.0, 5.1, 1.8]])
print("输入 5.9  3.0  5.1  1.8")
my_test = tf.convert_to_tensor(my_test)
my_test = tf.cast(my_test, tf.float32)
y = tf.matmul(my_test, w1) + b1
y = tf.nn.softmax(y)
species = {0: "狗尾鸢尾", 1: "杂色鸢尾", 2: "弗吉尼亚鸢尾"}
predict = np.array(tf.argmax(y, axis=1))[0]  # 返回y中最大值的索引,即预测的分类
print("该鸢尾花为:" + species.get(predict))

结果:

python 使用Tensorflow训练BP神经网络实现鸢尾花分类

结语

以上就是全部内容,鸢尾花分类作为经典案例,应该重点掌握理解。有一起学习的伙伴可以把想法打在评论区,大家多多交流,我也会及时回复的!

以上就是python 使用Tensorflow训练BP神经网络实现鸢尾花分类的详细内容,更多关于python 训练BP神经网络实现鸢尾花分类的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
python中cPickle用法例子分享
Jan 03 Python
python实现巡检系统(solaris)示例
Apr 02 Python
python获取文件版本信息、公司名和产品名的方法
Oct 05 Python
Python去除列表中重复元素的方法
Mar 20 Python
python反编译学习之字节码详解
May 19 Python
python之PyQt按钮右键菜单功能的实现代码
Aug 17 Python
python @propert装饰器使用方法原理解析
Dec 25 Python
python已协程方式处理任务实现过程
Dec 27 Python
python实现的批量分析xml标签中各个类别个数功能示例
Dec 30 Python
Python + selenium + crontab实现每日定时自动打卡功能
Mar 31 Python
教你用Python写一个植物大战僵尸小游戏
Apr 25 Python
4种方法python批量修改替换列表中元素
Apr 07 Python
PyTorch 如何设置随机数种子使结果可复现
May 12 #Python
Python Parser的用法
May 12 #Python
pytorch MSELoss计算平均的实现方法
May 12 #Python
Django如何创作一个简单的最小程序
May 12 #Python
Pytorch中TensorBoard及torchsummary的使用详解
pytorch 一行代码查看网络参数总量的实现
May 12 #Python
pytorch查看网络参数显存占用量等操作
May 12 #Python
You might like
PHP实现ftp上传文件示例
2014/08/21 PHP
使用PHPExcel操作Excel用法实例分析
2015/03/26 PHP
php表单处理操作
2017/11/16 PHP
php多进程应用场景实例详解
2019/07/22 PHP
thinkphp5实现微信扫码支付
2019/12/23 PHP
由浅到深了解JavaScript类
2006/09/08 Javascript
腾讯与新浪的通过IP地址获取当前地理位置(省份)的接口
2010/07/26 Javascript
JS获取浏览器版本及名称实现函数
2013/04/02 Javascript
js图片延迟加载的实现方法及思路
2013/07/22 Javascript
JavaScript控制浏览器全屏及各种浏览器全屏模式的方法、属性和事件
2015/12/20 Javascript
jquery的ajax提交form表单的两种方法小结(推荐)
2016/05/25 Javascript
jQuery插件WebUploader实现文件上传
2016/11/07 Javascript
angularjs中ng-attr的用法详解
2016/12/31 Javascript
vue router下的html5 history在iis服务器上的设置方法
2017/10/18 Javascript
Vue2 添加数据可视化支持的方法步骤
2019/01/02 Javascript
JavaScript实现多个物体同时运动
2020/03/12 Javascript
vue实现多个echarts根据屏幕大小变化而变化实例
2020/07/19 Javascript
JavaScript编写开发动态时钟
2020/07/29 Javascript
JavaScript中Object、map、weakmap的区别分析
2020/12/15 Javascript
[00:58]他们到底在电话里听到了什么?
2017/11/21 DOTA
[02:38]2018年度DOTA2最佳劣单位选手-完美盛典
2018/12/17 DOTA
Python yield 使用浅析
2015/05/28 Python
python访问mysql数据库的实现方法(2则示例)
2016/01/06 Python
使用Python向DataFrame中指定位置添加一列或多列的方法
2019/01/29 Python
Python3并发写文件与Python对比
2019/11/20 Python
Python面向对象封装操作案例详解 II
2020/01/02 Python
python代码xml转txt实例
2020/03/10 Python
python把一个字符串切开的实例方法
2020/09/27 Python
matplotlib绘制正余弦曲线图的实现
2021/02/22 Python
可打印的优惠券、杂货和优惠券代码:Coupons.com
2018/06/12 全球购物
教师自荐书
2013/10/08 职场文书
毕业生实习鉴定
2013/12/11 职场文书
小学门卫岗位职责
2013/12/17 职场文书
MySQL 时间类型的选择
2021/06/05 MySQL
springboot实现string转json json里面带数组
2022/06/16 Java/Android
Windows11 Insider Preview Build 25206今日发布 更新内容汇总
2022/09/23 数码科技