pytorch cnn 识别手写的字实现自建图片数据


Posted in Python onMay 20, 2018

本文主要介绍了pytorch cnn 识别手写的字实现自建图片数据,分享给大家,具体如下:

# library
# standard library
import os 
# third-party library
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
# torch.manual_seed(1)  # reproducible 
# Hyper Parameters
EPOCH = 1        # train the training data n times, to save time, we just train 1 epoch
BATCH_SIZE = 50
LR = 0.001       # learning rate 
 
root = "./mnist/raw/"
 
def default_loader(path):
  # return Image.open(path).convert('RGB')
  return Image.open(path)
 
class MyDataset(Dataset):
  def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
    fh = open(txt, 'r')
    imgs = []
    for line in fh:
      line = line.strip('\n')
      line = line.rstrip()
      words = line.split()
      imgs.append((words[0], int(words[1])))
    self.imgs = imgs
    self.transform = transform
    self.target_transform = target_transform
    self.loader = loader
    fh.close()
  def __getitem__(self, index):
    fn, label = self.imgs[index]
    img = self.loader(fn)
    img = Image.fromarray(np.array(img), mode='L')
    if self.transform is not None:
      img = self.transform(img)
    return img,label
  def __len__(self):
    return len(self.imgs)
 
train_data = MyDataset(txt= root + 'train.txt', transform = torchvision.transforms.ToTensor())
train_loader = DataLoader(dataset = train_data, batch_size=BATCH_SIZE, shuffle=True)
 
test_data = MyDataset(txt= root + 'test.txt', transform = torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset = test_data, batch_size=BATCH_SIZE)
 
class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    self.conv1 = nn.Sequential(     # input shape (1, 28, 28)
      nn.Conv2d(
        in_channels=1,       # input height
        out_channels=16,      # n_filters
        kernel_size=5,       # filter size
        stride=1,          # filter movement/step
        padding=2,         # if want same width and length of this image after con2d, padding=(kernel_size-1)/2 if stride=1
      ),               # output shape (16, 28, 28)
      nn.ReLU(),           # activation
      nn.MaxPool2d(kernel_size=2),  # choose max value in 2x2 area, output shape (16, 14, 14)
    )
    self.conv2 = nn.Sequential(     # input shape (16, 14, 14)
      nn.Conv2d(16, 32, 5, 1, 2),   # output shape (32, 14, 14)
      nn.ReLU(),           # activation
      nn.MaxPool2d(2),        # output shape (32, 7, 7)
    )
    self.out = nn.Linear(32 * 7 * 7, 10)  # fully connected layer, output 10 classes
 
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size(0), -1)      # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
    output = self.out(x)
    return output, x  # return x for visualization 
cnn = CNN()
print(cnn) # net architecture
 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)  # optimize all cnn parameters
loss_func = nn.CrossEntropyLoss()            # the target label is not one-hotted 
 
# training and testing
for epoch in range(EPOCH):
  for step, (x, y) in enumerate(train_loader):  # gives batch data, normalize x when iterate train_loader
    b_x = Variable(x)  # batch x
    b_y = Variable(y)  # batch y
 
    output = cnn(b_x)[0]        # cnn output
    loss = loss_func(output, b_y)  # cross entropy loss
    optimizer.zero_grad()      # clear gradients for this training step
    loss.backward()         # backpropagation, compute gradients
    optimizer.step()        # apply gradients
 
    if step % 50 == 0:
      cnn.eval()
      eval_loss = 0.
      eval_acc = 0.
      for i, (tx, ty) in enumerate(test_loader):
        t_x = Variable(tx)
        t_y = Variable(ty)
        output = cnn(t_x)[0]
        loss = loss_func(output, t_y)
        eval_loss += loss.data[0]
        pred = torch.max(output, 1)[1]
        num_correct = (pred == t_y).sum()
        eval_acc += float(num_correct.data[0])
      acc_rate = eval_acc / float(len(test_data))
      print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(test_data)), acc_rate))

图片和label 见上一篇文章《pytorch 把MNIST数据集转换成图片和txt》

结果如下:

pytorch cnn 识别手写的字实现自建图片数据

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python创建xml的方法
Mar 10 Python
Python自动化部署工具Fabric的简单上手指南
Apr 19 Python
python爬虫实战之爬取京东商城实例教程
Apr 24 Python
浅析Python中return和finally共同挖的坑
Aug 18 Python
用python的requests第三方模块抓取王者荣耀所有英雄的皮肤实例
Dec 14 Python
Python完成毫秒级抢淘宝大单功能
Jun 06 Python
python写日志文件操作类与应用示例
Jul 01 Python
Django的models模型的具体使用
Jul 15 Python
jupyter notebook 实现matplotlib图动态刷新
Apr 22 Python
Selenium自动化测试工具使用方法汇总
Jun 12 Python
Python 用__new__方法实现单例的操作
Dec 11 Python
python中四舍五入的正确打开方式
Jan 18 Python
pytorch 把MNIST数据集转换成图片和txt的方法
May 20 #Python
Python安装lz4-0.10.1遇到的坑
May 20 #Python
Python requests发送post请求的一些疑点
May 20 #Python
python中virtualenvwrapper安装与使用
May 20 #Python
django静态文件加载的方法
May 20 #Python
django中静态文件配置static的方法
May 20 #Python
Python中跳台阶、变态跳台阶与矩形覆盖问题的解决方法
May 19 #Python
You might like
MVC模式的PHP实现
2006/10/09 PHP
php读取EXCEL文件 php excelreader读取excel文件
2012/12/06 PHP
PHP中防止SQL注入方法详解
2014/12/25 PHP
php安装php_rar扩展实现rar文件读取和解压的方法
2016/11/17 PHP
ThinkPHP5.1验证码功能实现的示例代码
2020/06/08 PHP
jquery插件开发方法(初学者)
2012/02/03 Javascript
js函数定时器实现定时读取系统实时连接数
2014/04/30 Javascript
12种JavaScript常用的MVC框架比较分析
2015/11/16 Javascript
AngularJS自定义服务与fliter的混合使用
2016/11/24 Javascript
js实现表单提交后不重新刷新当前页面
2016/11/30 Javascript
jQuery+pjax简单示例汇总
2017/04/21 jQuery
Angular4项目中添加i18n国际化插件ngx-translate的步骤详解
2017/07/02 Javascript
解决vue页面DOM操作不生效的问题
2018/03/17 Javascript
详解基于mpvue的小程序markdown适配解决方案
2018/05/08 Javascript
vue-mugen-scroll组件实现pc端滚动刷新
2019/08/16 Javascript
详解uniapp的全局变量实现方式
2021/01/11 Javascript
python 判断自定义对象类型
2009/03/21 Python
基于Python3.6+splinter实现自动抢火车票
2018/09/25 Python
用Python编写一个简单的CS架构后门的方法
2018/11/20 Python
Python 自由定制表格的实现示例
2020/03/20 Python
基于python计算滚动方差(标准差)talib和pd.rolling函数差异详解
2020/06/08 Python
python爬虫爬取淘宝商品比价(附淘宝反爬虫机制解决小办法)
2020/12/03 Python
有关HTML5页面在iPhoneX适配问题
2017/11/13 HTML / CSS
全球独特生活方式产品和礼品购物网站:AHAlife
2018/09/18 全球购物
农村党支部先进事迹
2014/01/14 职场文书
工程招投标邀请书
2014/01/30 职场文书
四风问题民主生活会对照检查材料思想汇报
2014/09/27 职场文书
出纳岗位职责
2015/01/31 职场文书
三好学生个人总结
2015/02/15 职场文书
2015年度党员自我评价范文
2015/03/03 职场文书
办公用品管理制度
2015/08/04 职场文书
Jupyter notebook 更改文件打开的默认路径操作
2021/05/21 Python
JVM钩子函数的使用场景详解
2021/08/23 Java/Android
python编程学习使用管道Pipe编写优化代码
2021/11/20 Python
Python使用MapReduce进行简单的销售统计
2022/04/22 Python
安装harbor作为docker镜像仓库的问题
2022/06/14 Servers