python神经网络学习 使用Keras进行简单分类


Posted in Python onMay 04, 2022

学习前言

上一步讲了如何构建回归算法,这一次将怎么进行简单分类。

Keras中分类的重要函数

1、np_utils.to_categorical

np_utils.to_categorical用于将标签转化为形如(nb_samples, nb_classes)的二值序列。

假设num_classes = 10。

如将[1,2,3,……4]转化成:

[[0,1,0,0,0,0,0,0]
[0,0,1,0,0,0,0,0]
[0,0,0,1,0,0,0,0]
……
[0,0,0,0,1,0,0,0]]

这样的形态。

如将Y_train转化为二值序列,可以用如下方式:

Y_train = np_utils.to_categorical(Y_train,num_classes= 10)

2、Activation

Activation是激活函数,一般在每一层的输出使用。

当我们使用Sequential模型构建函数的时候,只需要在每一层Dense后面添加Activation就可以了。

Sequential函数也支持直接在参数中完成所有层的构建,使用方法如下。

model = Sequential([
    Dense(32,input_dim = 784),
    Activation("relu"),
    Dense(10),
    Activation("softmax")
    ]
)

其中两次Activation分别使用了relu函数和softmax函数。

3、metrics=[‘accuracy’]

在model.compile中添加metrics=[‘accuracy’]表示需要计算分类精确度,具体使用方式如下:

model.compile(
	loss = 'categorical_crossentropy',
	optimizer = rmsprop,
	metrics=['accuracy']
)

全部代码

这是一个简单的仅含有一个隐含层的神经网络,用于完成手写体识别。在本例中,使用的优化器是RMSprop,具体可以使用的优化器可以参照Keras中文文档

import numpy as np
from keras.models import Sequential
from keras.layers import Dense,Activation ## 全连接层
from keras.datasets import mnist
from keras.utils import np_utils
from keras.optimizers import RMSprop
# 获取训练集
(X_train,Y_train),(X_test,Y_test) = mnist.load_data()
# 首先进行标准化 
X_train = X_train.reshape(X_train.shape[0],-1)/255
X_test = X_test.reshape(X_test.shape[0],-1)/255
# 计算categorical_crossentropy需要对分类结果进行categorical
# 即需要将标签转化为形如(nb_samples, nb_classes)的二值序列
Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
Y_test = np_utils.to_categorical(Y_test,num_classes= 10)
# 构建模型
model = Sequential([
    Dense(32,input_dim = 784),
    Activation("relu"),
    Dense(10),
    Activation("softmax")
    ]
)
rmsprop = RMSprop(lr = 0.001,rho = 0.9,epsilon = 1e-08,decay = 0)
## compile
model.compile(loss = 'categorical_crossentropy',optimizer = rmsprop,metrics=['accuracy'])
print("\ntraining")
cost = model.fit(X_train,Y_train,nb_epoch = 2,batch_size = 32)
print("\nTest")
cost,accuracy = model.evaluate(X_test,Y_test)
## W,b = model.layers[0].get_weights()
print("accuracy:",accuracy)

实验结果为:

Epoch 1/2
60000/60000 [==============================] - 12s 202us/step - loss: 0.3512 - acc: 0.9022
Epoch 2/2
60000/60000 [==============================] - 11s 183us/step - loss: 0.2037 - acc: 0.9419
Test
10000/10000 [==============================] - 1s 108us/step
accuracy: 0.9464

以上就是python神经网络学习使用Keras进行简单分类的详细内容!


Tags in this post...

Python 相关文章推荐
在Python中进行自动化单元测试的教程
Apr 15 Python
探究python中open函数的使用
Mar 01 Python
Python3中使用PyMongo的方法详解
Jul 28 Python
Python3编程实现获取阿里云ECS实例及监控的方法
Aug 18 Python
Python编程实现粒子群算法(PSO)详解
Nov 13 Python
Python爬虫基础之XPath语法与lxml库的用法详解
Sep 13 Python
pymysql 开启调试模式的实现
Sep 24 Python
Python利用 utf-8-sig 编码格式解决写入 csv 文件乱码问题
Feb 21 Python
哈工大自然语言处理工具箱之ltp在windows10下的安装使用教程
May 07 Python
python3实现将json对象存入Redis以及数据的导入导出
Jul 16 Python
如何使用Pytorch搭建模型
Oct 26 Python
python 对一幅灰度图像进行直方图均衡化
Oct 27 Python
python神经网络 tf.name_scope 和 tf.variable_scope 的区别
May 04 #Python
Python3使用Qt5来实现简易的五子棋小游戏
May 02 #Python
python开发制作好看的时钟效果
关于的python五子棋的算法
python开发人人对战的五子棋小游戏
python pygame 开发五子棋双人对弈
May 02 #Python
Python开发简易五子棋小游戏
May 02 #Python
You might like
php 目录与文件处理-郑阿奇(续)
2011/07/04 PHP
PHP函数nl2br()与自定义函数nl2p()换行用法分析
2016/04/02 PHP
PHP addcslashes()函数讲解
2019/02/03 PHP
解决laravel groupBy 对查询结果进行分组出现的问题
2019/10/09 PHP
Javascript技术技巧大全(五)
2007/01/22 Javascript
JavaScript 对象、函数和继承
2009/07/07 Javascript
Node.js插件的正确编写方式
2014/08/03 Javascript
Jquery设置attr的disabled属性控制某行显示或者隐藏
2014/09/25 Javascript
javascript中parseInt()函数的定义和用法分析
2014/12/20 Javascript
js代码实现无缝滚动(文字和图片)
2015/08/20 Javascript
Javascript类型转换的规则实例解析
2016/02/23 Javascript
jQuery中的deferred使用方法
2017/03/27 jQuery
JavaScript变量作用域_动力节点Java学院整理
2017/06/27 Javascript
VUE axios上传图片到七牛的实例代码
2017/07/28 Javascript
使用Vue制作图片轮播组件思路详解
2018/03/21 Javascript
详解Vue CLI3配置之filenameHashing使用和源码设计使用和源码设计
2018/08/31 Javascript
VeeValidate 的使用场景以及配置详解
2019/01/11 Javascript
解决vue跨域axios异步通信问题
2019/04/17 Javascript
Bootstrap table 实现树形表格联动选中联动取消功能
2019/09/30 Javascript
Vue-cli assets SubDirectory及PublicPath区别详解
2020/08/18 Javascript
[02:17]快乐加倍!DOTA2食人魔魔法师至宝+迎霜节活动上线
2019/12/22 DOTA
分析python切片原理和方法
2017/12/19 Python
Python爬虫常用库的安装及其环境配置
2018/09/19 Python
Python实现TCP探测目标服务路由轨迹的原理与方法详解
2019/09/04 Python
Algenist奥杰尼官网:微藻抗衰老护肤品牌
2017/07/15 全球购物
JD Sports德国官网:英国领先的运动鞋和运动服饰零售商
2018/02/26 全球购物
校园文化标语
2014/06/18 职场文书
普通党员对照检查材料
2014/09/24 职场文书
2014年优质护理服务工作总结
2014/11/14 职场文书
寻找最美乡村教师观后感
2015/06/18 职场文书
2016年大学生社会实践心得体会
2015/10/09 职场文书
2016感恩父亲节主题广播稿
2015/12/18 职场文书
2019年最新借条范本!
2019/07/08 职场文书
springboot拦截器无法注入redisTemplate的解决方法
2021/06/27 Java/Android
vue3.0 数字翻牌组件的使用方法详解
2022/04/20 Vue.js
Win11 Dev 预览版25174.1000发布 (附更新修复内容汇总)
2022/08/05 数码科技