TensorFlow2.X结合OpenCV 实现手势识别功能


Posted in Python onApril 08, 2020

使用Tensorflow 构建卷积神经网络,训练手势识别模型,使用opencv DNN 模块加载模型实时手势识别
效果如下:

TensorFlow2.X结合OpenCV 实现手势识别功能

先显示下部分数据集图片(0到9的表示,感觉很怪)

TensorFlow2.X结合OpenCV 实现手势识别功能

构建模型进行训练

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets,layers,optimizers,Sequential,metrics
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
import os 
import pathlib
import random
import matplotlib.pyplot as plt
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
def read_data(path):
 path_root = pathlib.Path(path)
 # print(path_root)
 # for item in path_root.iterdir():
 #  print(item)
 image_paths = list(path_root.glob('*/*'))
 image_paths = [str(path) for path in image_paths]
 random.shuffle(image_paths)
 image_count = len(image_paths)
 # print(image_count)
 # print(image_paths[:10])
 label_names = sorted(item.name for item in path_root.glob('*/') if item.is_dir())
 # print(label_names)
 label_name_index = dict((name, index) for index, name in enumerate(label_names))
 # print(label_name_index)
 image_labels = [label_name_index[pathlib.Path(path).parent.name] for path in image_paths]
 # print("First 10 labels indices: ", image_labels[:10])
 return image_paths,image_labels,image_count
def preprocess_image(image):
 image = tf.image.decode_jpeg(image, channels=3)
 image = tf.image.resize(image, [100, 100])
 image /= 255.0 # normalize to [0,1] range
 # image = tf.reshape(image,[100*100*3])
 return image
def load_and_preprocess_image(path,label):
 image = tf.io.read_file(path)
 return preprocess_image(image),label
def creat_dataset(image_paths,image_labels,bitch_size):
 db = tf.data.Dataset.from_tensor_slices((image_paths, image_labels))
 dataset = db.map(load_and_preprocess_image).batch(bitch_size) 
 return dataset
def train_model(train_data,test_data):
 #构建模型
 network = keras.Sequential([
   keras.layers.Conv2D(32,kernel_size=[5,5],padding="same",activation=tf.nn.relu),
   keras.layers.MaxPool2D(pool_size=[2, 2], strides=2, padding='same'),
   keras.layers.Conv2D(64,kernel_size=[3,3],padding="same",activation=tf.nn.relu),
   keras.layers.MaxPool2D(pool_size=[2, 2], strides=2, padding='same'),
   keras.layers.Conv2D(64,kernel_size=[3,3],padding="same",activation=tf.nn.relu),
   keras.layers.Flatten(),
   keras.layers.Dense(512,activation='relu'),
   keras.layers.Dropout(0.5),
   keras.layers.Dense(128,activation='relu'),
   keras.layers.Dense(10)])
 network.build(input_shape=(None,100,100,3))
 network.summary()
 network.compile(optimizer=optimizers.SGD(lr=0.001),
   loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
   metrics=['accuracy']
 )
 #模型训练
 network.fit(train_data, epochs = 100,validation_data=test_data,validation_freq=2) 
 network.evaluate(test_data)
 tf.saved_model.save(network,'D:\\code\\PYTHON\\gesture_recognition\\model\\')
 print("保存模型成功")
 # Convert Keras model to ConcreteFunction
 full_model = tf.function(lambda x: network(x))
 full_model = full_model.get_concrete_function(
 tf.TensorSpec(network.inputs[0].shape, network.inputs[0].dtype))
 # Get frozen ConcreteFunction
 frozen_func = convert_variables_to_constants_v2(full_model)
 frozen_func.graph.as_graph_def()

 layers = [op.name for op in frozen_func.graph.get_operations()]
 print("-" * 50)
 print("Frozen model layers: ")
 for layer in layers:
  print(layer)

 print("-" * 50)
 print("Frozen model inputs: ")
 print(frozen_func.inputs)
 print("Frozen model outputs: ")
 print(frozen_func.outputs)

 # Save frozen graph from frozen ConcreteFunction to hard drive
 tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
   logdir="D:\\code\\PYTHON\\gesture_recognition\\model\\frozen_model\\",
   name="frozen_graph.pb",
   as_text=False)
 print("模型转换完成,训练结束")


if __name__ == "__main__":
 print(tf.__version__)
 train_path = 'D:\\code\\PYTHON\\gesture_recognition\\Dataset'
 test_path = 'D:\\code\\PYTHON\\gesture_recognition\\testdata' 
 image_paths,image_labels,_ = read_data(train_path)
 train_data = creat_dataset(image_paths,image_labels,16)
 image_paths,image_labels,_ = read_data(test_path)
 test_data = creat_dataset(image_paths,image_labels,16)
 train_model(train_data,test_data)

OpenCV加载模型,实时检测

这里为了简化检测使用了ROI。

import cv2
from cv2 import dnn
import numpy as np
print(cv2.__version__)
class_name = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
net = dnn.readNetFromTensorflow('D:\\code\\PYTHON\\gesture_recognition\\model\\frozen_model\\frozen_graph.pb')
cap = cv2.VideoCapture(0)
i = 0
while True:
 _,frame= cap.read() 
 src_image = frame
 cv2.rectangle(src_image, (300, 100),(600, 400), (0, 255, 0), 1, 4)
 frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
 pic = frame[100:400,300:600]
 cv2.imshow("pic1", pic)
 # print(pic.shape)
 pic = cv2.resize(pic,(100,100))
 blob = cv2.dnn.blobFromImage(pic,  
        scalefactor=1.0/225.,
        size=(100, 100),
        mean=(0, 0, 0),
        swapRB=False,
        crop=False)
 # blob = np.transpose(blob, (0,2,3,1))       
 net.setInput(blob)
 out = net.forward()
 out = out.flatten()

 classId = np.argmax(out)
 # print("classId",classId)
 print("预测结果为:",class_name[classId])
 src_image = cv2.putText(src_image,str(classId),(300,100), cv2.FONT_HERSHEY_SIMPLEX, 2,(0,0,255),2,4)
 # cv.putText(img, text, org, fontFace, fontScale, fontcolor, thickness, lineType)
 cv2.imshow("pic",src_image)
 if cv2.waitKey(10) == ord('0'):
  break

小结

这里本质上还是一个图像分类任务。而且,样本数量较少。优化的时候需要做数据增强,还需要防止过拟合。

到此这篇关于TensorFlow2.X结合OpenCV 实现手势识别功能的文章就介绍到这了,更多相关TensorFlow OpenCV 手势识别内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python 字符串操作方法大全
Mar 11 Python
Python发送email的3种方法
Apr 28 Python
python比较两个列表是否相等的方法
Jul 28 Python
python去除文件中空格、Tab及回车的方法
Apr 12 Python
老生常谈进程线程协程那些事儿
Jul 24 Python
Python 使用PIL numpy 实现拼接图片的示例
May 08 Python
Python操作MySQL数据库的方法
Jun 20 Python
Python内存管理实例分析
Jul 10 Python
python中利用matplotlib读取灰度图的例子
Dec 07 Python
Python实现遗传算法(二进制编码)求函数最优值方式
Feb 11 Python
Python字符串及文本模式方法详解
Sep 10 Python
Python-typing: 类型标注与支持 Any类型详解
May 10 Python
python 安装库几种方法之cmd,anaconda,pycharm详解
Apr 08 #Python
TensorFlow2.1.0最新版本安装详细教程
Apr 08 #Python
解决python多线程报错:AttributeError: Can't pickle local object问题
Apr 08 #Python
解决Python 异常TypeError: cannot concatenate 'str' and 'int' objects
Apr 08 #Python
TensorFlow2.1.0安装过程中setuptools、wrapt等相关错误指南
Apr 08 #Python
解决windows下python3使用multiprocessing.Pool出现的问题
Apr 08 #Python
python操作yaml说明
Apr 08 #Python
You might like
Terran兵种介绍
2020/03/14 星际争霸
php抓取页面与代码解析 推荐
2010/07/23 PHP
PHP弹出提示框并跳转到新页面即重定向到新页面
2014/01/24 PHP
ThinkPHP控制器详解
2015/07/27 PHP
通过PHP的Wrapper无缝迁移原有项目到新服务的实现方法
2020/04/02 PHP
中文字符串截取的js函数代码
2013/04/17 Javascript
JavaScript中读取和保存文件实例
2014/05/08 Javascript
jQuery的ready方法详解
2014/11/27 Javascript
js仿淘宝和百度文库的评分功能
2016/05/15 Javascript
jQuery实现摸拟alert提示框
2016/05/22 Javascript
Angular表格神器ui-grid应用详解
2017/09/29 Javascript
JS实现的抛物线运动效果示例
2018/01/30 Javascript
jquery实现进度条状态展示
2020/03/26 jQuery
python实现可将字符转换成大写的tcp服务器实例
2015/04/29 Python
深入解析Python中的list列表及其切片和迭代操作
2016/03/13 Python
Python科学计算之NumPy入门教程
2017/01/15 Python
Python单例模式实例详解
2017/03/01 Python
python实现朴素贝叶斯分类器
2018/03/28 Python
python写入并获取剪切板内容的实例
2018/05/31 Python
django 中QuerySet特性功能详解
2019/07/25 Python
学Python 3的理由和必要性
2019/11/19 Python
Django自带的用户验证系统实现
2020/12/18 Python
Shopee马来西亚:随拍即卖,最佳行动电商拍卖平台
2017/06/05 全球购物
StubHub意大利:购买和出售全球演唱会和体育赛事门票
2017/11/21 全球购物
澳大利亚最受欢迎的美发和美容在线商店:Catwalk
2018/12/12 全球购物
Subside Sports德国:足球球衣和球迷商品
2019/06/08 全球购物
英国床垫和床架购物网站:Bedman
2019/11/04 全球购物
应届生英语教师求职信
2013/11/05 职场文书
保密工作责任书
2014/04/16 职场文书
询价采购方案
2014/06/09 职场文书
国际商务专业毕业生自我鉴定2014
2014/09/27 职场文书
房屋所有权证明
2014/10/20 职场文书
2014年办公室文秘工作总结
2014/12/09 职场文书
2016年六一儿童节开幕词
2016/03/04 职场文书
2016年“我们的节日·重阳节”主题活动总结
2016/04/01 职场文书
2019最新版火锅店的创业计划书 !
2019/07/12 职场文书