TensorFlow2.X结合OpenCV 实现手势识别功能


Posted in Python onApril 08, 2020

使用Tensorflow 构建卷积神经网络,训练手势识别模型,使用opencv DNN 模块加载模型实时手势识别
效果如下:

TensorFlow2.X结合OpenCV 实现手势识别功能

先显示下部分数据集图片(0到9的表示,感觉很怪)

TensorFlow2.X结合OpenCV 实现手势识别功能

构建模型进行训练

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets,layers,optimizers,Sequential,metrics
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
import os 
import pathlib
import random
import matplotlib.pyplot as plt
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
def read_data(path):
 path_root = pathlib.Path(path)
 # print(path_root)
 # for item in path_root.iterdir():
 #  print(item)
 image_paths = list(path_root.glob('*/*'))
 image_paths = [str(path) for path in image_paths]
 random.shuffle(image_paths)
 image_count = len(image_paths)
 # print(image_count)
 # print(image_paths[:10])
 label_names = sorted(item.name for item in path_root.glob('*/') if item.is_dir())
 # print(label_names)
 label_name_index = dict((name, index) for index, name in enumerate(label_names))
 # print(label_name_index)
 image_labels = [label_name_index[pathlib.Path(path).parent.name] for path in image_paths]
 # print("First 10 labels indices: ", image_labels[:10])
 return image_paths,image_labels,image_count
def preprocess_image(image):
 image = tf.image.decode_jpeg(image, channels=3)
 image = tf.image.resize(image, [100, 100])
 image /= 255.0 # normalize to [0,1] range
 # image = tf.reshape(image,[100*100*3])
 return image
def load_and_preprocess_image(path,label):
 image = tf.io.read_file(path)
 return preprocess_image(image),label
def creat_dataset(image_paths,image_labels,bitch_size):
 db = tf.data.Dataset.from_tensor_slices((image_paths, image_labels))
 dataset = db.map(load_and_preprocess_image).batch(bitch_size) 
 return dataset
def train_model(train_data,test_data):
 #构建模型
 network = keras.Sequential([
   keras.layers.Conv2D(32,kernel_size=[5,5],padding="same",activation=tf.nn.relu),
   keras.layers.MaxPool2D(pool_size=[2, 2], strides=2, padding='same'),
   keras.layers.Conv2D(64,kernel_size=[3,3],padding="same",activation=tf.nn.relu),
   keras.layers.MaxPool2D(pool_size=[2, 2], strides=2, padding='same'),
   keras.layers.Conv2D(64,kernel_size=[3,3],padding="same",activation=tf.nn.relu),
   keras.layers.Flatten(),
   keras.layers.Dense(512,activation='relu'),
   keras.layers.Dropout(0.5),
   keras.layers.Dense(128,activation='relu'),
   keras.layers.Dense(10)])
 network.build(input_shape=(None,100,100,3))
 network.summary()
 network.compile(optimizer=optimizers.SGD(lr=0.001),
   loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
   metrics=['accuracy']
 )
 #模型训练
 network.fit(train_data, epochs = 100,validation_data=test_data,validation_freq=2) 
 network.evaluate(test_data)
 tf.saved_model.save(network,'D:\\code\\PYTHON\\gesture_recognition\\model\\')
 print("保存模型成功")
 # Convert Keras model to ConcreteFunction
 full_model = tf.function(lambda x: network(x))
 full_model = full_model.get_concrete_function(
 tf.TensorSpec(network.inputs[0].shape, network.inputs[0].dtype))
 # Get frozen ConcreteFunction
 frozen_func = convert_variables_to_constants_v2(full_model)
 frozen_func.graph.as_graph_def()

 layers = [op.name for op in frozen_func.graph.get_operations()]
 print("-" * 50)
 print("Frozen model layers: ")
 for layer in layers:
  print(layer)

 print("-" * 50)
 print("Frozen model inputs: ")
 print(frozen_func.inputs)
 print("Frozen model outputs: ")
 print(frozen_func.outputs)

 # Save frozen graph from frozen ConcreteFunction to hard drive
 tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
   logdir="D:\\code\\PYTHON\\gesture_recognition\\model\\frozen_model\\",
   name="frozen_graph.pb",
   as_text=False)
 print("模型转换完成,训练结束")


if __name__ == "__main__":
 print(tf.__version__)
 train_path = 'D:\\code\\PYTHON\\gesture_recognition\\Dataset'
 test_path = 'D:\\code\\PYTHON\\gesture_recognition\\testdata' 
 image_paths,image_labels,_ = read_data(train_path)
 train_data = creat_dataset(image_paths,image_labels,16)
 image_paths,image_labels,_ = read_data(test_path)
 test_data = creat_dataset(image_paths,image_labels,16)
 train_model(train_data,test_data)

OpenCV加载模型,实时检测

这里为了简化检测使用了ROI。

import cv2
from cv2 import dnn
import numpy as np
print(cv2.__version__)
class_name = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
net = dnn.readNetFromTensorflow('D:\\code\\PYTHON\\gesture_recognition\\model\\frozen_model\\frozen_graph.pb')
cap = cv2.VideoCapture(0)
i = 0
while True:
 _,frame= cap.read() 
 src_image = frame
 cv2.rectangle(src_image, (300, 100),(600, 400), (0, 255, 0), 1, 4)
 frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
 pic = frame[100:400,300:600]
 cv2.imshow("pic1", pic)
 # print(pic.shape)
 pic = cv2.resize(pic,(100,100))
 blob = cv2.dnn.blobFromImage(pic,  
        scalefactor=1.0/225.,
        size=(100, 100),
        mean=(0, 0, 0),
        swapRB=False,
        crop=False)
 # blob = np.transpose(blob, (0,2,3,1))       
 net.setInput(blob)
 out = net.forward()
 out = out.flatten()

 classId = np.argmax(out)
 # print("classId",classId)
 print("预测结果为:",class_name[classId])
 src_image = cv2.putText(src_image,str(classId),(300,100), cv2.FONT_HERSHEY_SIMPLEX, 2,(0,0,255),2,4)
 # cv.putText(img, text, org, fontFace, fontScale, fontcolor, thickness, lineType)
 cv2.imshow("pic",src_image)
 if cv2.waitKey(10) == ord('0'):
  break

小结

这里本质上还是一个图像分类任务。而且,样本数量较少。优化的时候需要做数据增强,还需要防止过拟合。

到此这篇关于TensorFlow2.X结合OpenCV 实现手势识别功能的文章就介绍到这了,更多相关TensorFlow OpenCV 手势识别内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
深度剖析使用python抓取网页正文的源码
Jun 11 Python
python写xml文件的操作实例
Oct 05 Python
对numpy中的transpose和swapaxes函数详解
Aug 02 Python
pycharm 在windows上编辑代码用linux执行配置的方法
Oct 27 Python
使用python制作一个为hex文件增加版本号的脚本实例
Jun 12 Python
pycharm配置当鼠标悬停时快速提示方法参数
Jul 31 Python
python实现静态服务器
Sep 05 Python
python打开使用的方法
Sep 30 Python
更新升级python和pip版本后不生效的问题解决
Apr 17 Python
属性与 @property 方法让你的python更高效
Sep 21 Python
python语言实现贪吃蛇游戏
Nov 13 Python
Django-celery-beat动态添加周期性任务实现过程解析
Nov 26 Python
python 安装库几种方法之cmd,anaconda,pycharm详解
Apr 08 #Python
TensorFlow2.1.0最新版本安装详细教程
Apr 08 #Python
解决python多线程报错:AttributeError: Can't pickle local object问题
Apr 08 #Python
解决Python 异常TypeError: cannot concatenate 'str' and 'int' objects
Apr 08 #Python
TensorFlow2.1.0安装过程中setuptools、wrapt等相关错误指南
Apr 08 #Python
解决windows下python3使用multiprocessing.Pool出现的问题
Apr 08 #Python
python操作yaml说明
Apr 08 #Python
You might like
PHP HTML代码串 截取实现代码
2009/06/29 PHP
php at(@)符号的用法简介
2009/07/11 PHP
PHP curl_setopt()函数实例代码与参数分析
2011/06/02 PHP
PHP 面向对象程序设计(oop)学习笔记 (二) - 静态变量的属性和方法及延迟绑定
2014/06/12 PHP
如何让动态插入的javascript脚本代码跑起来。
2007/01/09 Javascript
学习ExtJS form布局
2009/10/08 Javascript
Js如何判断客户端是PC还是手持设备简单分析
2012/11/22 Javascript
nodejs通过phantomjs实现下载网页
2015/05/04 NodeJs
JavaScript中的toDateString()方法使用详解
2015/06/12 Javascript
JS实现点击登录弹出窗口同时背景色渐变动画效果
2016/03/25 Javascript
mvvm双向绑定机制的原理和实现代码(推荐)
2016/06/07 Javascript
全面了解JavaScript的数据类型转换
2016/07/01 Javascript
对javascript继承的理解
2016/10/11 Javascript
基于JavaScript实现图片剪切效果
2017/03/07 Javascript
VUE中v-model和v-for指令详解
2017/06/23 Javascript
基于Datatables跳转到指定页的简单实例
2017/11/09 Javascript
select标签设置默认选中的选项方法
2018/03/02 Javascript
nodejs之koa2请求示例(GET,POST)
2018/08/07 NodeJs
keep-alive不能缓存多层级路由菜单问题解决
2020/03/10 Javascript
微信小程序实现发微博功能的示例代码
2020/06/24 Javascript
[05:41]2014DOTA2西雅图国际邀请赛 小组赛7月10日TOPPLAY
2014/07/10 DOTA
[02:03]《现实生活中的DOTA2》—林书豪&DOTA2职业选手出演短片
2015/08/18 DOTA
Python break语句详解
2014/03/11 Python
Python 获取中文字拼音首个字母的方法
2018/11/28 Python
python列表插入append(), extend(), insert()用法详解
2019/09/14 Python
Python利用命名空间解析XML文档
2020/08/10 Python
PHP面试题及答案一
2012/06/18 面试题
远东集团网络工程师面试题
2014/10/20 面试题
初一生物教学反思
2014/01/18 职场文书
自我评价的范文
2014/02/02 职场文书
一句话工作感言
2014/03/01 职场文书
中学生学雷锋活动心得体会
2014/03/10 职场文书
迎新晚会主持词开场白
2015/05/28 职场文书
Python基础之元编程知识总结
2021/05/23 Python
Go语言应该什么情况使用指针
2021/07/25 Golang
电脑无法安装Windows 11怎么办?无法安装Win11的解决方法
2021/11/21 数码科技