keras训练浅层卷积网络并保存和加载模型实例


Posted in Python onJuly 02, 2020

这里我们使用keras定义简单的神经网络全连接层训练MNIST数据集和cifar10数据集:

keras_mnist.py

from sklearn.preprocessing import LabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from keras.models import Sequential
from keras.layers.core import Dense
from keras.optimizers import SGD
from sklearn import datasets
import matplotlib.pyplot as plt
import numpy as np
import argparse
# 命令行参数运行
ap = argparse.ArgumentParser()
ap.add_argument("-o", "--output", required=True, help="path to the output loss/accuracy plot")
args =vars(ap.parse_args())
# 加载数据MNIST,然后归一化到【0,1】,同时使用75%做训练,25%做测试
print("[INFO] loading MNIST (full) dataset")
dataset = datasets.fetch_mldata("MNIST Original", data_home="/home/king/test/python/train/pyimagesearch/nn/data/")
data = dataset.data.astype("float") / 255.0
(trainX, testX, trainY, testY) = train_test_split(data, dataset.target, test_size=0.25)
# 将label进行one-hot编码
lb = LabelBinarizer()
trainY = lb.fit_transform(trainY)
testY = lb.transform(testY)
# keras定义网络结构784--256--128--10
model = Sequential()
model.add(Dense(256, input_shape=(784,), activation="relu"))
model.add(Dense(128, activation="relu"))
model.add(Dense(10, activation="softmax"))
# 开始训练
print("[INFO] training network...")
# 0.01的学习率
sgd = SGD(0.01)
# 交叉验证
model.compile(loss="categorical_crossentropy", optimizer=sgd, metrics=['accuracy'])
H = model.fit(trainX, trainY, validation_data=(testX, testY), epochs=100, batch_size=128)
# 测试模型和评估
print("[INFO] evaluating network...")
predictions = model.predict(testX, batch_size=128)
print(classification_report(testY.argmax(axis=1), predictions.argmax(axis=1), 
	target_names=[str(x) for x in lb.classes_]))
# 保存可视化训练结果
plt.style.use("ggplot")
plt.figure()
plt.plot(np.arange(0, 100), H.history["loss"], label="train_loss")
plt.plot(np.arange(0, 100), H.history["val_loss"], label="val_loss")
plt.plot(np.arange(0, 100), H.history["acc"], label="train_acc")
plt.plot(np.arange(0, 100), H.history["val_acc"], label="val_acc")
plt.title("Training Loss and Accuracy")
plt.xlabel("# Epoch")
plt.ylabel("Loss/Accuracy")
plt.legend()
plt.savefig(args["output"])

使用relu做激活函数:

keras训练浅层卷积网络并保存和加载模型实例

使用sigmoid做激活函数:

keras训练浅层卷积网络并保存和加载模型实例

接着我们自己定义一些modules去实现一个简单的卷基层去训练cifar10数据集:

imagetoarraypreprocessor.py

'''
该函数主要是实现keras的一个细节转换,因为训练的图像时RGB三颜色通道,读取进来的数据是有depth的,keras为了兼容一些后台,默认是按照(height, width, depth)读取,但有时候就要改变成(depth, height, width)
'''
from keras.preprocessing.image import img_to_array
class ImageToArrayPreprocessor:
	def __init__(self, dataFormat=None):
		self.dataFormat = dataFormat
 
	def preprocess(self, image):
		return img_to_array(image, data_format=self.dataFormat)

shallownet.py

'''
定义一个简单的卷基层:
input->conv->Relu->FC
'''
from keras.models import Sequential
from keras.layers.convolutional import Conv2D
from keras.layers.core import Activation, Flatten, Dense
from keras import backend as K
 
class ShallowNet:
	@staticmethod
	def build(width, height, depth, classes):
		model = Sequential()
		inputShape = (height, width, depth)
 
		if K.image_data_format() == "channels_first":
			inputShape = (depth, height, width)
 
		model.add(Conv2D(32, (3, 3), padding="same", input_shape=inputShape))
		model.add(Activation("relu"))
 
		model.add(Flatten())
		model.add(Dense(classes))
		model.add(Activation("softmax"))
 
		return model

然后就是训练代码:

keras_cifar10.py

from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import classification_report
from shallownet import ShallowNet
from keras.optimizers import SGD
from keras.datasets import cifar10
import matplotlib.pyplot as plt
import numpy as np
import argparse
 
ap = argparse.ArgumentParser()
ap.add_argument("-o", "--output", required=True, help="path to the output loss/accuracy plot")
args = vars(ap.parse_args())
 
print("[INFO] loading CIFAR-10 dataset")
((trainX, trainY), (testX, testY)) = cifar10.load_data()
trainX = trainX.astype("float") / 255.0
testX = testX.astype("float") / 255.0
 
lb = LabelBinarizer()
trainY = lb.fit_transform(trainY)
testY = lb.transform(testY)
# 标签0-9代表的类别string
labelNames = ['airplane', 'automobile', 'bird', 'cat', 
	'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
 
print("[INFO] compiling model...")
opt = SGD(lr=0.0001)
model = ShallowNet.build(width=32, height=32, depth=3, classes=10)
model.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"])
 
print("[INFO] training network...")
H = model.fit(trainX, trainY, validation_data=(testX, testY), batch_size=32, epochs=1000, verbose=1)
 
print("[INFO] evaluating network...")
predictions = model.predict(testX, batch_size=32)
print(classification_report(testY.argmax(axis=1), predictions.argmax(axis=1), 
	target_names=labelNames))
 
# 保存可视化训练结果
plt.style.use("ggplot")
plt.figure()
plt.plot(np.arange(0, 1000), H.history["loss"], label="train_loss")
plt.plot(np.arange(0, 1000), H.history["val_loss"], label="val_loss")
plt.plot(np.arange(0, 1000), H.history["acc"], label="train_acc")
plt.plot(np.arange(0, 1000), H.history["val_acc"], label="val_acc")
plt.title("Training Loss and Accuracy")
plt.xlabel("# Epoch")
plt.ylabel("Loss/Accuracy")
plt.legend()
plt.savefig(args["output"])

代码中可以对训练的learning rate进行微调,大概可以接近60%的准确率。

keras训练浅层卷积网络并保存和加载模型实例

keras训练浅层卷积网络并保存和加载模型实例

然后修改下代码可以保存训练模型:

from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import classification_report
from shallownet import ShallowNet
from keras.optimizers import SGD
from keras.datasets import cifar10
import matplotlib.pyplot as plt
import numpy as np
import argparse
 
ap = argparse.ArgumentParser()
ap.add_argument("-o", "--output", required=True, help="path to the output loss/accuracy plot")
ap.add_argument("-m", "--model", required=True, help="path to save train model")
args = vars(ap.parse_args())
 
print("[INFO] loading CIFAR-10 dataset")
((trainX, trainY), (testX, testY)) = cifar10.load_data()
trainX = trainX.astype("float") / 255.0
testX = testX.astype("float") / 255.0
 
lb = LabelBinarizer()
trainY = lb.fit_transform(trainY)
testY = lb.transform(testY)
# 标签0-9代表的类别string
labelNames = ['airplane', 'automobile', 'bird', 'cat', 
	'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
 
print("[INFO] compiling model...")
opt = SGD(lr=0.005)
model = ShallowNet.build(width=32, height=32, depth=3, classes=10)
model.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"])
 
print("[INFO] training network...")
H = model.fit(trainX, trainY, validation_data=(testX, testY), batch_size=32, epochs=50, verbose=1)
 
model.save(args["model"])
 
print("[INFO] evaluating network...")
predictions = model.predict(testX, batch_size=32)
print(classification_report(testY.argmax(axis=1), predictions.argmax(axis=1), 
	target_names=labelNames))
 
# 保存可视化训练结果
plt.style.use("ggplot")
plt.figure()
plt.plot(np.arange(0, 5), H.history["loss"], label="train_loss")
plt.plot(np.arange(0, 5), H.history["val_loss"], label="val_loss")
plt.plot(np.arange(0, 5), H.history["acc"], label="train_acc")
plt.plot(np.arange(0, 5), H.history["val_acc"], label="val_acc")
plt.title("Training Loss and Accuracy")
plt.xlabel("# Epoch")
plt.ylabel("Loss/Accuracy")
plt.legend()
plt.savefig(args["output"])

命令行运行:

keras训练浅层卷积网络并保存和加载模型实例

我们使用另一个程序来加载上一次训练保存的模型,然后进行测试:

test.py

from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import classification_report
from shallownet import ShallowNet
from keras.optimizers import SGD
from keras.datasets import cifar10
from keras.models import load_model
import matplotlib.pyplot as plt
import numpy as np
import argparse
 
ap = argparse.ArgumentParser()
ap.add_argument("-m", "--model", required=True, help="path to save train model")
args = vars(ap.parse_args())
 
# 标签0-9代表的类别string
labelNames = ['airplane', 'automobile', 'bird', 'cat', 
	'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
 
print("[INFO] loading CIFAR-10 dataset")
((trainX, trainY), (testX, testY)) = cifar10.load_data()
 
idxs = np.random.randint(0, len(testX), size=(10,))
testX = testX[idxs]
testY = testY[idxs]
 
trainX = trainX.astype("float") / 255.0
testX = testX.astype("float") / 255.0
 
lb = LabelBinarizer()
trainY = lb.fit_transform(trainY)
testY = lb.transform(testY)
 
print("[INFO] loading pre-trained network...")
model = load_model(args["model"])
 
print("[INFO] evaluating network...")
predictions = model.predict(testX, batch_size=32).argmax(axis=1)
print("predictions\n", predictions)
for i in range(len(testY)):
	print("label:{}".format(labelNames[predictions[i]]))
 
trueLabel = []
for i in range(len(testY)):
	for j in range(len(testY[i])):
		if testY[i][j] != 0:
			trueLabel.append(j)
print(trueLabel)
 
print("ground truth testY:")
for i in range(len(trueLabel)):
	print("label:{}".format(labelNames[trueLabel[i]]))
 
print("TestY\n", testY)

keras训练浅层卷积网络并保存和加载模型实例

以上这篇keras训练浅层卷积网络并保存和加载模型实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Pycharm学习教程(1) 定制外观
May 02 Python
Python实现的文本对比报告生成工具示例
May 22 Python
对python list 遍历删除的正确方法详解
Jun 29 Python
python 读取目录下csv文件并绘制曲线v111的方法
Jul 06 Python
Python自动发送邮件的方法实例总结
Dec 08 Python
python广度优先搜索得到两点间最短路径
Jan 17 Python
解决Django layui {{}}冲突的问题
Aug 29 Python
python3实现弹弹球小游戏
Nov 25 Python
如何用OpenCV -python3实现视频物体追踪
Dec 04 Python
解决jupyter运行pyqt代码内核重启的问题
Apr 16 Python
python文件操作seek()偏移量,读取指正到指定位置操作
Jul 05 Python
python简单实现9宫格图片实例
Sep 03 Python
Python RabbitMQ实现简单的进程间通信示例
Jul 02 #Python
利用scikitlearn画ROC曲线实例
Jul 02 #Python
Python使用文件操作实现一个XX信息管理系统的示例
Jul 02 #Python
keras用auc做metrics以及早停实例
Jul 02 #Python
keras 简单 lstm实例(基于one-hot编码)
Jul 02 #Python
Python装饰器结合递归原理解析
Jul 02 #Python
Python OpenCV读取中文路径图像的方法
Jul 02 #Python
You might like
php file_put_contents()功能函数(集成了fopen、fwrite、fclose)
2011/05/24 PHP
PHP7.1新功能之Nullable Type用法分析
2016/09/26 PHP
php把字符串指定字符分割成数组的方法
2018/03/12 PHP
javascript对talbe进行动态添加、删除、验证实现代码
2012/03/29 Javascript
javascript实现json页面分页实例代码
2014/02/20 Javascript
详解JavaScript中的Unescape()和String() 函数
2015/11/09 Javascript
bootstrap实现图片自动轮播
2016/12/21 Javascript
Node.js与Sails redis组件的使用教程
2017/02/14 Javascript
js实现移动端微信页面禁止字体放大
2017/02/16 Javascript
常用的几个JQuery代码片段
2017/03/13 Javascript
十大 Node.js 的 Web 框架(快速提升工作效率)
2017/06/30 Javascript
React 子组件向父组件传值的方法
2017/07/24 Javascript
Angular4实现动态添加删除表单输入框功能
2017/08/11 Javascript
vue better-scroll插件使用详解
2018/01/25 Javascript
浅谈webpack 构建性能优化策略小结
2018/06/13 Javascript
一步步教你利用Docker设置Node.js
2018/11/20 Javascript
a标签调用js的方法总结
2019/09/05 Javascript
layui实现图片虚拟路径上传,预览和删除的例子
2019/09/25 Javascript
使用JavaScript获取扫码枪扫描得到的条形码的思路代码详解
2020/06/10 Javascript
Postman动态获取返回值过程详解
2020/06/30 Javascript
Python实现线程池代码分享
2015/06/21 Python
新手如何快速入门Python(菜鸟必看篇)
2017/06/10 Python
Python全局变量与局部变量区别及用法分析
2018/09/03 Python
Python3实现的旋转矩阵图像算法示例
2019/04/03 Python
浅谈Python3中strip()、lstrip()、rstrip()用法详解
2019/04/29 Python
python基于SMTP协议发送邮件
2019/05/31 Python
Django获取该数据的上一条和下一条方法
2019/08/12 Python
基于Python3读写INI配置文件过程解析
2020/07/23 Python
linux系统下pip升级报错的解决方法
2021/01/31 Python
英国排名第一的礼品体验公司:Red Letter Days
2018/08/16 全球购物
英国办公家具网站:Furniture At Work
2019/10/07 全球购物
公司出纳岗位职责
2015/03/31 职场文书
我在伊朗长大观后感
2015/06/16 职场文书
python如何正确使用yield
2021/05/21 Python
为什么MySQL选择Repeatable Read作为默认隔离级别
2021/07/26 MySQL
Python 第三方库 openpyxl 的安装过程
2022/12/24 Python