keras训练浅层卷积网络并保存和加载模型实例


Posted in Python onJuly 02, 2020

这里我们使用keras定义简单的神经网络全连接层训练MNIST数据集和cifar10数据集:

keras_mnist.py

from sklearn.preprocessing import LabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from keras.models import Sequential
from keras.layers.core import Dense
from keras.optimizers import SGD
from sklearn import datasets
import matplotlib.pyplot as plt
import numpy as np
import argparse
# 命令行参数运行
ap = argparse.ArgumentParser()
ap.add_argument("-o", "--output", required=True, help="path to the output loss/accuracy plot")
args =vars(ap.parse_args())
# 加载数据MNIST,然后归一化到【0,1】,同时使用75%做训练,25%做测试
print("[INFO] loading MNIST (full) dataset")
dataset = datasets.fetch_mldata("MNIST Original", data_home="/home/king/test/python/train/pyimagesearch/nn/data/")
data = dataset.data.astype("float") / 255.0
(trainX, testX, trainY, testY) = train_test_split(data, dataset.target, test_size=0.25)
# 将label进行one-hot编码
lb = LabelBinarizer()
trainY = lb.fit_transform(trainY)
testY = lb.transform(testY)
# keras定义网络结构784--256--128--10
model = Sequential()
model.add(Dense(256, input_shape=(784,), activation="relu"))
model.add(Dense(128, activation="relu"))
model.add(Dense(10, activation="softmax"))
# 开始训练
print("[INFO] training network...")
# 0.01的学习率
sgd = SGD(0.01)
# 交叉验证
model.compile(loss="categorical_crossentropy", optimizer=sgd, metrics=['accuracy'])
H = model.fit(trainX, trainY, validation_data=(testX, testY), epochs=100, batch_size=128)
# 测试模型和评估
print("[INFO] evaluating network...")
predictions = model.predict(testX, batch_size=128)
print(classification_report(testY.argmax(axis=1), predictions.argmax(axis=1), 
	target_names=[str(x) for x in lb.classes_]))
# 保存可视化训练结果
plt.style.use("ggplot")
plt.figure()
plt.plot(np.arange(0, 100), H.history["loss"], label="train_loss")
plt.plot(np.arange(0, 100), H.history["val_loss"], label="val_loss")
plt.plot(np.arange(0, 100), H.history["acc"], label="train_acc")
plt.plot(np.arange(0, 100), H.history["val_acc"], label="val_acc")
plt.title("Training Loss and Accuracy")
plt.xlabel("# Epoch")
plt.ylabel("Loss/Accuracy")
plt.legend()
plt.savefig(args["output"])

使用relu做激活函数:

keras训练浅层卷积网络并保存和加载模型实例

使用sigmoid做激活函数:

keras训练浅层卷积网络并保存和加载模型实例

接着我们自己定义一些modules去实现一个简单的卷基层去训练cifar10数据集:

imagetoarraypreprocessor.py

'''
该函数主要是实现keras的一个细节转换,因为训练的图像时RGB三颜色通道,读取进来的数据是有depth的,keras为了兼容一些后台,默认是按照(height, width, depth)读取,但有时候就要改变成(depth, height, width)
'''
from keras.preprocessing.image import img_to_array
class ImageToArrayPreprocessor:
	def __init__(self, dataFormat=None):
		self.dataFormat = dataFormat
 
	def preprocess(self, image):
		return img_to_array(image, data_format=self.dataFormat)

shallownet.py

'''
定义一个简单的卷基层:
input->conv->Relu->FC
'''
from keras.models import Sequential
from keras.layers.convolutional import Conv2D
from keras.layers.core import Activation, Flatten, Dense
from keras import backend as K
 
class ShallowNet:
	@staticmethod
	def build(width, height, depth, classes):
		model = Sequential()
		inputShape = (height, width, depth)
 
		if K.image_data_format() == "channels_first":
			inputShape = (depth, height, width)
 
		model.add(Conv2D(32, (3, 3), padding="same", input_shape=inputShape))
		model.add(Activation("relu"))
 
		model.add(Flatten())
		model.add(Dense(classes))
		model.add(Activation("softmax"))
 
		return model

然后就是训练代码:

keras_cifar10.py

from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import classification_report
from shallownet import ShallowNet
from keras.optimizers import SGD
from keras.datasets import cifar10
import matplotlib.pyplot as plt
import numpy as np
import argparse
 
ap = argparse.ArgumentParser()
ap.add_argument("-o", "--output", required=True, help="path to the output loss/accuracy plot")
args = vars(ap.parse_args())
 
print("[INFO] loading CIFAR-10 dataset")
((trainX, trainY), (testX, testY)) = cifar10.load_data()
trainX = trainX.astype("float") / 255.0
testX = testX.astype("float") / 255.0
 
lb = LabelBinarizer()
trainY = lb.fit_transform(trainY)
testY = lb.transform(testY)
# 标签0-9代表的类别string
labelNames = ['airplane', 'automobile', 'bird', 'cat', 
	'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
 
print("[INFO] compiling model...")
opt = SGD(lr=0.0001)
model = ShallowNet.build(width=32, height=32, depth=3, classes=10)
model.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"])
 
print("[INFO] training network...")
H = model.fit(trainX, trainY, validation_data=(testX, testY), batch_size=32, epochs=1000, verbose=1)
 
print("[INFO] evaluating network...")
predictions = model.predict(testX, batch_size=32)
print(classification_report(testY.argmax(axis=1), predictions.argmax(axis=1), 
	target_names=labelNames))
 
# 保存可视化训练结果
plt.style.use("ggplot")
plt.figure()
plt.plot(np.arange(0, 1000), H.history["loss"], label="train_loss")
plt.plot(np.arange(0, 1000), H.history["val_loss"], label="val_loss")
plt.plot(np.arange(0, 1000), H.history["acc"], label="train_acc")
plt.plot(np.arange(0, 1000), H.history["val_acc"], label="val_acc")
plt.title("Training Loss and Accuracy")
plt.xlabel("# Epoch")
plt.ylabel("Loss/Accuracy")
plt.legend()
plt.savefig(args["output"])

代码中可以对训练的learning rate进行微调,大概可以接近60%的准确率。

keras训练浅层卷积网络并保存和加载模型实例

keras训练浅层卷积网络并保存和加载模型实例

然后修改下代码可以保存训练模型:

from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import classification_report
from shallownet import ShallowNet
from keras.optimizers import SGD
from keras.datasets import cifar10
import matplotlib.pyplot as plt
import numpy as np
import argparse
 
ap = argparse.ArgumentParser()
ap.add_argument("-o", "--output", required=True, help="path to the output loss/accuracy plot")
ap.add_argument("-m", "--model", required=True, help="path to save train model")
args = vars(ap.parse_args())
 
print("[INFO] loading CIFAR-10 dataset")
((trainX, trainY), (testX, testY)) = cifar10.load_data()
trainX = trainX.astype("float") / 255.0
testX = testX.astype("float") / 255.0
 
lb = LabelBinarizer()
trainY = lb.fit_transform(trainY)
testY = lb.transform(testY)
# 标签0-9代表的类别string
labelNames = ['airplane', 'automobile', 'bird', 'cat', 
	'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
 
print("[INFO] compiling model...")
opt = SGD(lr=0.005)
model = ShallowNet.build(width=32, height=32, depth=3, classes=10)
model.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"])
 
print("[INFO] training network...")
H = model.fit(trainX, trainY, validation_data=(testX, testY), batch_size=32, epochs=50, verbose=1)
 
model.save(args["model"])
 
print("[INFO] evaluating network...")
predictions = model.predict(testX, batch_size=32)
print(classification_report(testY.argmax(axis=1), predictions.argmax(axis=1), 
	target_names=labelNames))
 
# 保存可视化训练结果
plt.style.use("ggplot")
plt.figure()
plt.plot(np.arange(0, 5), H.history["loss"], label="train_loss")
plt.plot(np.arange(0, 5), H.history["val_loss"], label="val_loss")
plt.plot(np.arange(0, 5), H.history["acc"], label="train_acc")
plt.plot(np.arange(0, 5), H.history["val_acc"], label="val_acc")
plt.title("Training Loss and Accuracy")
plt.xlabel("# Epoch")
plt.ylabel("Loss/Accuracy")
plt.legend()
plt.savefig(args["output"])

命令行运行:

keras训练浅层卷积网络并保存和加载模型实例

我们使用另一个程序来加载上一次训练保存的模型,然后进行测试:

test.py

from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import classification_report
from shallownet import ShallowNet
from keras.optimizers import SGD
from keras.datasets import cifar10
from keras.models import load_model
import matplotlib.pyplot as plt
import numpy as np
import argparse
 
ap = argparse.ArgumentParser()
ap.add_argument("-m", "--model", required=True, help="path to save train model")
args = vars(ap.parse_args())
 
# 标签0-9代表的类别string
labelNames = ['airplane', 'automobile', 'bird', 'cat', 
	'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
 
print("[INFO] loading CIFAR-10 dataset")
((trainX, trainY), (testX, testY)) = cifar10.load_data()
 
idxs = np.random.randint(0, len(testX), size=(10,))
testX = testX[idxs]
testY = testY[idxs]
 
trainX = trainX.astype("float") / 255.0
testX = testX.astype("float") / 255.0
 
lb = LabelBinarizer()
trainY = lb.fit_transform(trainY)
testY = lb.transform(testY)
 
print("[INFO] loading pre-trained network...")
model = load_model(args["model"])
 
print("[INFO] evaluating network...")
predictions = model.predict(testX, batch_size=32).argmax(axis=1)
print("predictions\n", predictions)
for i in range(len(testY)):
	print("label:{}".format(labelNames[predictions[i]]))
 
trueLabel = []
for i in range(len(testY)):
	for j in range(len(testY[i])):
		if testY[i][j] != 0:
			trueLabel.append(j)
print(trueLabel)
 
print("ground truth testY:")
for i in range(len(trueLabel)):
	print("label:{}".format(labelNames[trueLabel[i]]))
 
print("TestY\n", testY)

keras训练浅层卷积网络并保存和加载模型实例

以上这篇keras训练浅层卷积网络并保存和加载模型实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python subprocess模块学习总结
Mar 13 Python
Python Numpy:找到list中的np.nan值方法
Oct 30 Python
python 使用poster模块进行http方式的文件传输到服务器的方法
Jan 15 Python
Python制作动态字符图的实例
Jan 27 Python
详解python 3.6 安装json 模块(simplejson)
Apr 02 Python
python程序快速缩进多行代码方法总结
Jun 23 Python
TensorFlow实现自定义Op方式
Feb 04 Python
Python实现的北京积分落户数据分析示例
Mar 27 Python
OpenCV 使用imread()函数读取图片的六种正确姿势
Jul 09 Python
Python迭代器协议及for循环工作机制详解
Jul 14 Python
Python Selenium破解滑块验证码最新版(GEETEST95%以上通过率)
Jan 29 Python
python本地文件服务器实例教程
May 02 Python
Python RabbitMQ实现简单的进程间通信示例
Jul 02 #Python
利用scikitlearn画ROC曲线实例
Jul 02 #Python
Python使用文件操作实现一个XX信息管理系统的示例
Jul 02 #Python
keras用auc做metrics以及早停实例
Jul 02 #Python
keras 简单 lstm实例(基于one-hot编码)
Jul 02 #Python
Python装饰器结合递归原理解析
Jul 02 #Python
Python OpenCV读取中文路径图像的方法
Jul 02 #Python
You might like
PHP编码规范-php coding standard
2007/03/16 PHP
php URL跳转代码 减少外链
2011/06/25 PHP
Laravel框架路由管理简单示例
2019/05/07 PHP
JavaScript 继承详解(三)
2009/07/13 Javascript
js字符串转成JSON
2013/11/07 Javascript
使用jquery菜单插件HoverTree仿京东无限级菜单
2014/12/18 Javascript
Javascript核心读书有感之表达式和运算符
2015/02/11 Javascript
jQuery实现气球弹出框式的侧边导航菜单效果
2015/09/22 Javascript
js父页面中使用子页面的方法
2016/01/09 Javascript
js仿百度登录页实现拖动窗口效果
2016/03/11 Javascript
node.js与C语言 实现遍历文件夹下最大的文件,并输出路径,大小
2017/01/20 Javascript
JS/jquery实现一个网页内同时调用多个倒计时的方法
2017/04/27 jQuery
jsonp跨域请求详解
2017/07/13 Javascript
vue父组件向子组件传递多个数据的实例
2018/03/01 Javascript
vue.js 2.*项目环境搭建、运行、打包发布的详细步骤
2019/05/01 Javascript
微信小程序数据统计和错误统计的实现方法
2019/06/26 Javascript
JavaScript实现好看的跟随彩色气泡效果
2020/02/06 Javascript
javascript将16进制的字符串转换为10进制整数hex
2020/03/05 Javascript
Vue使用路由钩子拦截器beforeEach和afterEach监听路由
2020/11/16 Javascript
Python 判断 有向图 是否有环的实例讲解
2018/02/01 Python
解决Django layui {{}}冲突的问题
2019/08/29 Python
python无序链表删除重复项的方法
2020/01/17 Python
解决PyCharm无法使用lxml库的问题(图解)
2020/12/22 Python
French Connection官网:女装、男装及家居用品
2019/03/18 全球购物
娇韵诗Clarins意大利官方网站:法国天然护肤品牌
2020/03/11 全球购物
自我评价的范文
2014/02/02 职场文书
机械专业毕业生自我鉴定2014
2014/10/04 职场文书
群众路线表态发言材料
2014/10/17 职场文书
工作证明英文模板
2014/10/21 职场文书
主持人开幕词
2015/01/29 职场文书
财务会计求职信范文
2015/03/20 职场文书
学校食堂食品安全承诺书
2015/04/29 职场文书
火烧圆明园的观后感
2015/06/03 职场文书
怎样写好工作计划
2019/04/10 职场文书
解决redis批量删除key值的问题
2022/03/23 Redis
GTX1650super好不好 gtx1650super显卡属于什么级别
2022/04/08 数码科技