pytorch cnn 识别手写的字实现自建图片数据


Posted in Python onMay 20, 2018

本文主要介绍了pytorch cnn 识别手写的字实现自建图片数据,分享给大家,具体如下:

# library
# standard library
import os 
# third-party library
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
# torch.manual_seed(1)  # reproducible 
# Hyper Parameters
EPOCH = 1        # train the training data n times, to save time, we just train 1 epoch
BATCH_SIZE = 50
LR = 0.001       # learning rate 
 
root = "./mnist/raw/"
 
def default_loader(path):
  # return Image.open(path).convert('RGB')
  return Image.open(path)
 
class MyDataset(Dataset):
  def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
    fh = open(txt, 'r')
    imgs = []
    for line in fh:
      line = line.strip('\n')
      line = line.rstrip()
      words = line.split()
      imgs.append((words[0], int(words[1])))
    self.imgs = imgs
    self.transform = transform
    self.target_transform = target_transform
    self.loader = loader
    fh.close()
  def __getitem__(self, index):
    fn, label = self.imgs[index]
    img = self.loader(fn)
    img = Image.fromarray(np.array(img), mode='L')
    if self.transform is not None:
      img = self.transform(img)
    return img,label
  def __len__(self):
    return len(self.imgs)
 
train_data = MyDataset(txt= root + 'train.txt', transform = torchvision.transforms.ToTensor())
train_loader = DataLoader(dataset = train_data, batch_size=BATCH_SIZE, shuffle=True)
 
test_data = MyDataset(txt= root + 'test.txt', transform = torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset = test_data, batch_size=BATCH_SIZE)
 
class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    self.conv1 = nn.Sequential(     # input shape (1, 28, 28)
      nn.Conv2d(
        in_channels=1,       # input height
        out_channels=16,      # n_filters
        kernel_size=5,       # filter size
        stride=1,          # filter movement/step
        padding=2,         # if want same width and length of this image after con2d, padding=(kernel_size-1)/2 if stride=1
      ),               # output shape (16, 28, 28)
      nn.ReLU(),           # activation
      nn.MaxPool2d(kernel_size=2),  # choose max value in 2x2 area, output shape (16, 14, 14)
    )
    self.conv2 = nn.Sequential(     # input shape (16, 14, 14)
      nn.Conv2d(16, 32, 5, 1, 2),   # output shape (32, 14, 14)
      nn.ReLU(),           # activation
      nn.MaxPool2d(2),        # output shape (32, 7, 7)
    )
    self.out = nn.Linear(32 * 7 * 7, 10)  # fully connected layer, output 10 classes
 
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size(0), -1)      # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
    output = self.out(x)
    return output, x  # return x for visualization 
cnn = CNN()
print(cnn) # net architecture
 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)  # optimize all cnn parameters
loss_func = nn.CrossEntropyLoss()            # the target label is not one-hotted 
 
# training and testing
for epoch in range(EPOCH):
  for step, (x, y) in enumerate(train_loader):  # gives batch data, normalize x when iterate train_loader
    b_x = Variable(x)  # batch x
    b_y = Variable(y)  # batch y
 
    output = cnn(b_x)[0]        # cnn output
    loss = loss_func(output, b_y)  # cross entropy loss
    optimizer.zero_grad()      # clear gradients for this training step
    loss.backward()         # backpropagation, compute gradients
    optimizer.step()        # apply gradients
 
    if step % 50 == 0:
      cnn.eval()
      eval_loss = 0.
      eval_acc = 0.
      for i, (tx, ty) in enumerate(test_loader):
        t_x = Variable(tx)
        t_y = Variable(ty)
        output = cnn(t_x)[0]
        loss = loss_func(output, t_y)
        eval_loss += loss.data[0]
        pred = torch.max(output, 1)[1]
        num_correct = (pred == t_y).sum()
        eval_acc += float(num_correct.data[0])
      acc_rate = eval_acc / float(len(test_data))
      print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(test_data)), acc_rate))

图片和label 见上一篇文章《pytorch 把MNIST数据集转换成图片和txt》

结果如下:

pytorch cnn 识别手写的字实现自建图片数据

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中使用hashlib模块处理算法的教程
Apr 28 Python
浅谈Python 字符串格式化输出(format/printf)
Jul 21 Python
windows下python安装paramiko模块和pycrypto模块(简单三步)
Jul 06 Python
Python线性方程组求解运算示例
Jan 17 Python
python3如何将docx转换成pdf文件
Mar 23 Python
python-numpy-指数分布实例详解
Dec 07 Python
python中if及if-else如何使用
Jun 02 Python
Python logging模块handlers用法详解
Aug 14 Python
Python 可视化神器Plotly详解
Dec 26 Python
pytorch 如何使用batch训练lstm网络
May 28 Python
Pandas数据类型之category的用法
Jun 28 Python
Python如何快速找到多个字典中的公共键(key)
Apr 29 Python
pytorch 把MNIST数据集转换成图片和txt的方法
May 20 #Python
Python安装lz4-0.10.1遇到的坑
May 20 #Python
Python requests发送post请求的一些疑点
May 20 #Python
python中virtualenvwrapper安装与使用
May 20 #Python
django静态文件加载的方法
May 20 #Python
django中静态文件配置static的方法
May 20 #Python
Python中跳台阶、变态跳台阶与矩形覆盖问题的解决方法
May 19 #Python
You might like
使用ThinkPHP自带的Http类下载远程图片到本地的实现代码
2011/08/02 PHP
PHP 将dataurl转成图片image方法总结
2016/10/14 PHP
javascript基础知识大集锦(二) 推荐收藏
2011/01/13 Javascript
Js base64 加密解密介绍
2013/10/11 Javascript
Document:getElementsByName()使用方法及示例
2013/10/28 Javascript
jquery.validate的使用说明介绍
2013/11/12 Javascript
浅析JavaScript中的typeof运算符
2013/11/30 Javascript
js中的getAttribute方法使用示例
2014/08/01 Javascript
JavaScript中实现继承的三种方式和实例
2015/01/29 Javascript
jQuery源码分析之Callbacks详解
2015/03/13 Javascript
js动态生成Html元素实现Post操作(createElement)
2015/09/14 Javascript
全面了解JS中的匿名函数
2016/06/29 Javascript
BootStrap Validator对于隐藏域验证和程序赋值即时验证的问题浅析
2016/12/01 Javascript
微信小程序开发教程-手势解锁实例
2017/01/06 Javascript
详解RequireJs官方使用教程
2017/10/31 Javascript
element上传组件循环引用及简单时间倒计时的实现
2018/10/01 Javascript
详解vue后台系统登录态管理
2019/04/02 Javascript
nodejs制作小爬虫功能示例
2020/02/24 NodeJs
[06:35]2014DOTA2国际邀请赛 老男孩梦圆西雅图中国军团世界最强
2014/07/22 DOTA
python实现百度关键词排名查询
2014/03/30 Python
python使用thrift教程的方法示例
2019/03/21 Python
Keras 中Leaky ReLU等高级激活函数的用法
2020/07/05 Python
重构Python代码的六个实例
2020/11/25 Python
HTML5实现移动端点击翻牌功能
2020/10/23 HTML / CSS
Meli Melo官网:名媛们钟爱的英国奢侈手包品牌
2017/04/17 全球购物
BookOutlet加拿大:在网上书店购买廉价折扣图书和小说
2018/10/05 全球购物
JAVA程序员自荐书
2014/01/30 职场文书
销售员个人求职的自我评价
2014/02/10 职场文书
三八红旗手先进事迹材料
2014/05/13 职场文书
小学生表扬稿范文
2015/05/05 职场文书
食品安全责任书范本
2015/05/09 职场文书
2015年纪委工作总结
2015/05/13 职场文书
中小企业员工手册范本
2015/05/14 职场文书
新人入职感言
2015/07/31 职场文书
Python Numpy库的超详细教程
2022/04/06 Python
原生JS实现分页
2022/04/19 Javascript