keras训练浅层卷积网络并保存和加载模型实例


Posted in Python onJuly 02, 2020

这里我们使用keras定义简单的神经网络全连接层训练MNIST数据集和cifar10数据集:

keras_mnist.py

from sklearn.preprocessing import LabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from keras.models import Sequential
from keras.layers.core import Dense
from keras.optimizers import SGD
from sklearn import datasets
import matplotlib.pyplot as plt
import numpy as np
import argparse
# 命令行参数运行
ap = argparse.ArgumentParser()
ap.add_argument("-o", "--output", required=True, help="path to the output loss/accuracy plot")
args =vars(ap.parse_args())
# 加载数据MNIST,然后归一化到【0,1】,同时使用75%做训练,25%做测试
print("[INFO] loading MNIST (full) dataset")
dataset = datasets.fetch_mldata("MNIST Original", data_home="/home/king/test/python/train/pyimagesearch/nn/data/")
data = dataset.data.astype("float") / 255.0
(trainX, testX, trainY, testY) = train_test_split(data, dataset.target, test_size=0.25)
# 将label进行one-hot编码
lb = LabelBinarizer()
trainY = lb.fit_transform(trainY)
testY = lb.transform(testY)
# keras定义网络结构784--256--128--10
model = Sequential()
model.add(Dense(256, input_shape=(784,), activation="relu"))
model.add(Dense(128, activation="relu"))
model.add(Dense(10, activation="softmax"))
# 开始训练
print("[INFO] training network...")
# 0.01的学习率
sgd = SGD(0.01)
# 交叉验证
model.compile(loss="categorical_crossentropy", optimizer=sgd, metrics=['accuracy'])
H = model.fit(trainX, trainY, validation_data=(testX, testY), epochs=100, batch_size=128)
# 测试模型和评估
print("[INFO] evaluating network...")
predictions = model.predict(testX, batch_size=128)
print(classification_report(testY.argmax(axis=1), predictions.argmax(axis=1), 
	target_names=[str(x) for x in lb.classes_]))
# 保存可视化训练结果
plt.style.use("ggplot")
plt.figure()
plt.plot(np.arange(0, 100), H.history["loss"], label="train_loss")
plt.plot(np.arange(0, 100), H.history["val_loss"], label="val_loss")
plt.plot(np.arange(0, 100), H.history["acc"], label="train_acc")
plt.plot(np.arange(0, 100), H.history["val_acc"], label="val_acc")
plt.title("Training Loss and Accuracy")
plt.xlabel("# Epoch")
plt.ylabel("Loss/Accuracy")
plt.legend()
plt.savefig(args["output"])

使用relu做激活函数:

keras训练浅层卷积网络并保存和加载模型实例

使用sigmoid做激活函数:

keras训练浅层卷积网络并保存和加载模型实例

接着我们自己定义一些modules去实现一个简单的卷基层去训练cifar10数据集:

imagetoarraypreprocessor.py

'''
该函数主要是实现keras的一个细节转换,因为训练的图像时RGB三颜色通道,读取进来的数据是有depth的,keras为了兼容一些后台,默认是按照(height, width, depth)读取,但有时候就要改变成(depth, height, width)
'''
from keras.preprocessing.image import img_to_array
class ImageToArrayPreprocessor:
	def __init__(self, dataFormat=None):
		self.dataFormat = dataFormat
 
	def preprocess(self, image):
		return img_to_array(image, data_format=self.dataFormat)

shallownet.py

'''
定义一个简单的卷基层:
input->conv->Relu->FC
'''
from keras.models import Sequential
from keras.layers.convolutional import Conv2D
from keras.layers.core import Activation, Flatten, Dense
from keras import backend as K
 
class ShallowNet:
	@staticmethod
	def build(width, height, depth, classes):
		model = Sequential()
		inputShape = (height, width, depth)
 
		if K.image_data_format() == "channels_first":
			inputShape = (depth, height, width)
 
		model.add(Conv2D(32, (3, 3), padding="same", input_shape=inputShape))
		model.add(Activation("relu"))
 
		model.add(Flatten())
		model.add(Dense(classes))
		model.add(Activation("softmax"))
 
		return model

然后就是训练代码:

keras_cifar10.py

from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import classification_report
from shallownet import ShallowNet
from keras.optimizers import SGD
from keras.datasets import cifar10
import matplotlib.pyplot as plt
import numpy as np
import argparse
 
ap = argparse.ArgumentParser()
ap.add_argument("-o", "--output", required=True, help="path to the output loss/accuracy plot")
args = vars(ap.parse_args())
 
print("[INFO] loading CIFAR-10 dataset")
((trainX, trainY), (testX, testY)) = cifar10.load_data()
trainX = trainX.astype("float") / 255.0
testX = testX.astype("float") / 255.0
 
lb = LabelBinarizer()
trainY = lb.fit_transform(trainY)
testY = lb.transform(testY)
# 标签0-9代表的类别string
labelNames = ['airplane', 'automobile', 'bird', 'cat', 
	'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
 
print("[INFO] compiling model...")
opt = SGD(lr=0.0001)
model = ShallowNet.build(width=32, height=32, depth=3, classes=10)
model.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"])
 
print("[INFO] training network...")
H = model.fit(trainX, trainY, validation_data=(testX, testY), batch_size=32, epochs=1000, verbose=1)
 
print("[INFO] evaluating network...")
predictions = model.predict(testX, batch_size=32)
print(classification_report(testY.argmax(axis=1), predictions.argmax(axis=1), 
	target_names=labelNames))
 
# 保存可视化训练结果
plt.style.use("ggplot")
plt.figure()
plt.plot(np.arange(0, 1000), H.history["loss"], label="train_loss")
plt.plot(np.arange(0, 1000), H.history["val_loss"], label="val_loss")
plt.plot(np.arange(0, 1000), H.history["acc"], label="train_acc")
plt.plot(np.arange(0, 1000), H.history["val_acc"], label="val_acc")
plt.title("Training Loss and Accuracy")
plt.xlabel("# Epoch")
plt.ylabel("Loss/Accuracy")
plt.legend()
plt.savefig(args["output"])

代码中可以对训练的learning rate进行微调,大概可以接近60%的准确率。

keras训练浅层卷积网络并保存和加载模型实例

keras训练浅层卷积网络并保存和加载模型实例

然后修改下代码可以保存训练模型:

from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import classification_report
from shallownet import ShallowNet
from keras.optimizers import SGD
from keras.datasets import cifar10
import matplotlib.pyplot as plt
import numpy as np
import argparse
 
ap = argparse.ArgumentParser()
ap.add_argument("-o", "--output", required=True, help="path to the output loss/accuracy plot")
ap.add_argument("-m", "--model", required=True, help="path to save train model")
args = vars(ap.parse_args())
 
print("[INFO] loading CIFAR-10 dataset")
((trainX, trainY), (testX, testY)) = cifar10.load_data()
trainX = trainX.astype("float") / 255.0
testX = testX.astype("float") / 255.0
 
lb = LabelBinarizer()
trainY = lb.fit_transform(trainY)
testY = lb.transform(testY)
# 标签0-9代表的类别string
labelNames = ['airplane', 'automobile', 'bird', 'cat', 
	'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
 
print("[INFO] compiling model...")
opt = SGD(lr=0.005)
model = ShallowNet.build(width=32, height=32, depth=3, classes=10)
model.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"])
 
print("[INFO] training network...")
H = model.fit(trainX, trainY, validation_data=(testX, testY), batch_size=32, epochs=50, verbose=1)
 
model.save(args["model"])
 
print("[INFO] evaluating network...")
predictions = model.predict(testX, batch_size=32)
print(classification_report(testY.argmax(axis=1), predictions.argmax(axis=1), 
	target_names=labelNames))
 
# 保存可视化训练结果
plt.style.use("ggplot")
plt.figure()
plt.plot(np.arange(0, 5), H.history["loss"], label="train_loss")
plt.plot(np.arange(0, 5), H.history["val_loss"], label="val_loss")
plt.plot(np.arange(0, 5), H.history["acc"], label="train_acc")
plt.plot(np.arange(0, 5), H.history["val_acc"], label="val_acc")
plt.title("Training Loss and Accuracy")
plt.xlabel("# Epoch")
plt.ylabel("Loss/Accuracy")
plt.legend()
plt.savefig(args["output"])

命令行运行:

keras训练浅层卷积网络并保存和加载模型实例

我们使用另一个程序来加载上一次训练保存的模型,然后进行测试:

test.py

from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import classification_report
from shallownet import ShallowNet
from keras.optimizers import SGD
from keras.datasets import cifar10
from keras.models import load_model
import matplotlib.pyplot as plt
import numpy as np
import argparse
 
ap = argparse.ArgumentParser()
ap.add_argument("-m", "--model", required=True, help="path to save train model")
args = vars(ap.parse_args())
 
# 标签0-9代表的类别string
labelNames = ['airplane', 'automobile', 'bird', 'cat', 
	'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
 
print("[INFO] loading CIFAR-10 dataset")
((trainX, trainY), (testX, testY)) = cifar10.load_data()
 
idxs = np.random.randint(0, len(testX), size=(10,))
testX = testX[idxs]
testY = testY[idxs]
 
trainX = trainX.astype("float") / 255.0
testX = testX.astype("float") / 255.0
 
lb = LabelBinarizer()
trainY = lb.fit_transform(trainY)
testY = lb.transform(testY)
 
print("[INFO] loading pre-trained network...")
model = load_model(args["model"])
 
print("[INFO] evaluating network...")
predictions = model.predict(testX, batch_size=32).argmax(axis=1)
print("predictions\n", predictions)
for i in range(len(testY)):
	print("label:{}".format(labelNames[predictions[i]]))
 
trueLabel = []
for i in range(len(testY)):
	for j in range(len(testY[i])):
		if testY[i][j] != 0:
			trueLabel.append(j)
print(trueLabel)
 
print("ground truth testY:")
for i in range(len(trueLabel)):
	print("label:{}".format(labelNames[trueLabel[i]]))
 
print("TestY\n", testY)

keras训练浅层卷积网络并保存和加载模型实例

以上这篇keras训练浅层卷积网络并保存和加载模型实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python比较两个图片相似度的方法
Mar 13 Python
利用Python获取赶集网招聘信息前篇
Apr 18 Python
python模块之time模块(实例讲解)
Sep 13 Python
Python中pow()和math.pow()函数用法示例
Feb 11 Python
python处理csv中的空值方法
Jun 22 Python
Python基于SMTP协议实现发送邮件功能详解
Aug 14 Python
python实现对服务器脚本敏感信息的加密解密功能
Aug 13 Python
关于jupyter打开之后不能直接跳转到浏览器的解决方式
Apr 13 Python
Python 操作 PostgreSQL 数据库示例【连接、增删改查等】
Apr 21 Python
python中slice参数过长的处理方法及实例
Dec 15 Python
pytorch下的unsqueeze和squeeze的用法说明
Feb 06 Python
教你如何使用Python Tkinter库制作记事本
Jun 10 Python
Python RabbitMQ实现简单的进程间通信示例
Jul 02 #Python
利用scikitlearn画ROC曲线实例
Jul 02 #Python
Python使用文件操作实现一个XX信息管理系统的示例
Jul 02 #Python
keras用auc做metrics以及早停实例
Jul 02 #Python
keras 简单 lstm实例(基于one-hot编码)
Jul 02 #Python
Python装饰器结合递归原理解析
Jul 02 #Python
Python OpenCV读取中文路径图像的方法
Jul 02 #Python
You might like
站长助手-网站web在线管理程序 v1.0 下载
2007/05/12 PHP
PHP HTML代码串截取代码
2008/12/29 PHP
php截取中文字符串函数实例
2015/02/23 PHP
PHP+AjaxForm异步带进度条上传文件实例代码
2017/08/14 PHP
PHP7.1实现的AES与RSA加密操作示例
2018/06/15 PHP
javascript中方便增删改cookie的一个类
2012/10/11 Javascript
页面加载完毕后滚动条自动滚动一定位置
2014/02/20 Javascript
JS中字符串trim()使用示例
2015/05/26 Javascript
JavaScript中setFullYear()方法的使用详解
2015/06/11 Javascript
javascript中类的定义方式详解(四种方式)
2015/12/22 Javascript
自动化测试读写64位操作系统的注册表
2016/08/15 Javascript
JS+HTML5实现的前端购物车功能插件实例【附demo源码下载】
2016/10/17 Javascript
Javascript同时声明一连串(多个)变量的方法
2017/01/23 Javascript
js模拟微博发布消息
2017/02/23 Javascript
JS得到当前时间的方法示例
2017/03/24 Javascript
jQuery上传插件webupload使用方法
2017/08/01 jQuery
JS实现倒序输出的几种常用方法示例
2019/04/13 Javascript
python下如何让web元素的生成更简单的分析
2008/07/17 Python
python获取android设备的GPS信息脚本分享
2015/03/06 Python
在Python中使用全局日志时需要注意的问题
2015/05/06 Python
Python实现PS图像调整黑白效果示例
2018/01/25 Python
详解Python3网络爬虫(二):利用urllib.urlopen向有道翻译发送数据获得翻译结果
2019/05/07 Python
python 输出列表元素实例(以空格/逗号为分隔符)
2019/12/25 Python
Python实现在Windows平台修改文件属性
2020/03/05 Python
Python函数__new__及__init__作用及区别解析
2020/08/31 Python
Python内置函数及功能简介汇总
2020/10/13 Python
Python读取ini配置文件传参的简单示例
2021/01/05 Python
后勤园长自我鉴定
2013/10/17 职场文书
应付会计岗位职责
2013/12/12 职场文书
成品仓管员工作职责
2013/12/29 职场文书
珠宝店促销方案
2014/03/21 职场文书
淘宝文案策划岗位职责
2015/04/14 职场文书
关于召开会议的通知
2015/04/15 职场文书
2015年公司后勤管理工作总结
2015/05/13 职场文书
js实现上传图片到服务器
2021/04/11 Javascript
Redis实现订单自动过期功能的示例代码
2021/05/08 Redis