pytorch cnn 识别手写的字实现自建图片数据


Posted in Python onMay 20, 2018

本文主要介绍了pytorch cnn 识别手写的字实现自建图片数据,分享给大家,具体如下:

# library
# standard library
import os 
# third-party library
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
# torch.manual_seed(1)  # reproducible 
# Hyper Parameters
EPOCH = 1        # train the training data n times, to save time, we just train 1 epoch
BATCH_SIZE = 50
LR = 0.001       # learning rate 
 
root = "./mnist/raw/"
 
def default_loader(path):
  # return Image.open(path).convert('RGB')
  return Image.open(path)
 
class MyDataset(Dataset):
  def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
    fh = open(txt, 'r')
    imgs = []
    for line in fh:
      line = line.strip('\n')
      line = line.rstrip()
      words = line.split()
      imgs.append((words[0], int(words[1])))
    self.imgs = imgs
    self.transform = transform
    self.target_transform = target_transform
    self.loader = loader
    fh.close()
  def __getitem__(self, index):
    fn, label = self.imgs[index]
    img = self.loader(fn)
    img = Image.fromarray(np.array(img), mode='L')
    if self.transform is not None:
      img = self.transform(img)
    return img,label
  def __len__(self):
    return len(self.imgs)
 
train_data = MyDataset(txt= root + 'train.txt', transform = torchvision.transforms.ToTensor())
train_loader = DataLoader(dataset = train_data, batch_size=BATCH_SIZE, shuffle=True)
 
test_data = MyDataset(txt= root + 'test.txt', transform = torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset = test_data, batch_size=BATCH_SIZE)
 
class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    self.conv1 = nn.Sequential(     # input shape (1, 28, 28)
      nn.Conv2d(
        in_channels=1,       # input height
        out_channels=16,      # n_filters
        kernel_size=5,       # filter size
        stride=1,          # filter movement/step
        padding=2,         # if want same width and length of this image after con2d, padding=(kernel_size-1)/2 if stride=1
      ),               # output shape (16, 28, 28)
      nn.ReLU(),           # activation
      nn.MaxPool2d(kernel_size=2),  # choose max value in 2x2 area, output shape (16, 14, 14)
    )
    self.conv2 = nn.Sequential(     # input shape (16, 14, 14)
      nn.Conv2d(16, 32, 5, 1, 2),   # output shape (32, 14, 14)
      nn.ReLU(),           # activation
      nn.MaxPool2d(2),        # output shape (32, 7, 7)
    )
    self.out = nn.Linear(32 * 7 * 7, 10)  # fully connected layer, output 10 classes
 
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size(0), -1)      # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
    output = self.out(x)
    return output, x  # return x for visualization 
cnn = CNN()
print(cnn) # net architecture
 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)  # optimize all cnn parameters
loss_func = nn.CrossEntropyLoss()            # the target label is not one-hotted 
 
# training and testing
for epoch in range(EPOCH):
  for step, (x, y) in enumerate(train_loader):  # gives batch data, normalize x when iterate train_loader
    b_x = Variable(x)  # batch x
    b_y = Variable(y)  # batch y
 
    output = cnn(b_x)[0]        # cnn output
    loss = loss_func(output, b_y)  # cross entropy loss
    optimizer.zero_grad()      # clear gradients for this training step
    loss.backward()         # backpropagation, compute gradients
    optimizer.step()        # apply gradients
 
    if step % 50 == 0:
      cnn.eval()
      eval_loss = 0.
      eval_acc = 0.
      for i, (tx, ty) in enumerate(test_loader):
        t_x = Variable(tx)
        t_y = Variable(ty)
        output = cnn(t_x)[0]
        loss = loss_func(output, t_y)
        eval_loss += loss.data[0]
        pred = torch.max(output, 1)[1]
        num_correct = (pred == t_y).sum()
        eval_acc += float(num_correct.data[0])
      acc_rate = eval_acc / float(len(test_data))
      print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(test_data)), acc_rate))

图片和label 见上一篇文章《pytorch 把MNIST数据集转换成图片和txt》

结果如下:

pytorch cnn 识别手写的字实现自建图片数据

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python遍历文件夹和读写文件的实现方法
May 10 Python
python机器学习理论与实战(六)支持向量机
Jan 19 Python
python开启摄像头以及深度学习实现目标检测方法
Aug 03 Python
Python3.5 处理文本txt,删除不需要的行方法
Dec 10 Python
python打包exe开机自动启动的实例(windows)
Jun 28 Python
kafka-python 获取topic lag值方式
Dec 23 Python
python matplotlib中的subplot函数使用详解
Jan 19 Python
Python图像处理库PIL的ImageGrab模块介绍详解
Feb 26 Python
python图片验证码识别最新模块muggle_ocr的示例代码
Jul 03 Python
python+opencv实现车道线检测
Feb 19 Python
Python爬虫制作翻译程序的示例代码
Feb 22 Python
对Pytorch 中的contiguous理解说明
Mar 03 Python
pytorch 把MNIST数据集转换成图片和txt的方法
May 20 #Python
Python安装lz4-0.10.1遇到的坑
May 20 #Python
Python requests发送post请求的一些疑点
May 20 #Python
python中virtualenvwrapper安装与使用
May 20 #Python
django静态文件加载的方法
May 20 #Python
django中静态文件配置static的方法
May 20 #Python
Python中跳台阶、变态跳台阶与矩形覆盖问题的解决方法
May 19 #Python
You might like
在Windows中安装Apache2和PHP4的权威指南
2006/10/09 PHP
PHP+FLASH实现上传文件进度条相关文件 下载
2007/07/21 PHP
PHP获取windows登录用户名的方法
2014/06/24 PHP
PHP制作图形验证码代码分享
2014/10/23 PHP
Yii2框架视图(View)操作及Layout的使用方法分析
2019/05/27 PHP
PHP Primary script unknown 解决方法总结
2019/08/22 PHP
javascript 全角转换实现代码
2009/07/17 Javascript
jquery 获取标签名(tagName)示例代码
2013/07/11 Javascript
调试代码导致IE出错的避免方法
2014/04/04 Javascript
javascript中实现兼容JAVA的hashCode算法代码分享
2020/08/11 Javascript
jQuery插件开发的五种形态小结
2015/03/04 Javascript
js给selected添加options的方法
2015/05/06 Javascript
Three.js基础部分学习
2017/01/08 Javascript
jQuery动态生成表格及右键菜单功能示例
2017/01/13 Javascript
利用node.js搭建简单web服务器的方法教程
2017/02/20 Javascript
快速了解vue-cli 3.0 新特性
2018/02/28 Javascript
vue文件树组件使用详解
2018/03/29 Javascript
通过jquery获取上传文件名称、类型和大小的实现代码
2018/04/19 jQuery
AngularJS模态框模板ngDialog的使用详解
2018/05/11 Javascript
Postman动态获取返回值过程详解
2020/06/30 Javascript
[05:35]DOTA2英雄梦之声_第13期_拉比克
2014/06/21 DOTA
python计算圆周率pi的方法
2015/07/11 Python
PyCharm使用教程之搭建Python开发环境
2016/06/07 Python
Django 生成登陆验证码代码分享
2017/12/12 Python
使用python 的matplotlib 画轨道实例
2020/01/19 Python
浅谈HTML5中dialog元素尝鲜
2018/10/15 HTML / CSS
国际领先的在线时尚服装和配饰店:DressLily
2019/03/03 全球购物
2014端午节活动策划方案
2014/01/27 职场文书
2013年军训通讯稿
2014/02/05 职场文书
一位农村小子的自荐信
2014/04/07 职场文书
政府采购方案
2014/06/12 职场文书
小学安全汇报材料
2014/08/14 职场文书
大班下学期个人总结
2015/02/13 职场文书
Python自动化之批量处理工作簿和工作表
2021/06/03 Python
nginx的zabbix 5.0安装部署的方法步骤
2021/07/16 Servers
win11电脑关机鼠标灯还亮怎么解决? win11关机后鼠标灯还亮解决方法
2023/01/09 数码科技