pytorch cnn 识别手写的字实现自建图片数据


Posted in Python onMay 20, 2018

本文主要介绍了pytorch cnn 识别手写的字实现自建图片数据,分享给大家,具体如下:

# library
# standard library
import os 
# third-party library
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
# torch.manual_seed(1)  # reproducible 
# Hyper Parameters
EPOCH = 1        # train the training data n times, to save time, we just train 1 epoch
BATCH_SIZE = 50
LR = 0.001       # learning rate 
 
root = "./mnist/raw/"
 
def default_loader(path):
  # return Image.open(path).convert('RGB')
  return Image.open(path)
 
class MyDataset(Dataset):
  def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
    fh = open(txt, 'r')
    imgs = []
    for line in fh:
      line = line.strip('\n')
      line = line.rstrip()
      words = line.split()
      imgs.append((words[0], int(words[1])))
    self.imgs = imgs
    self.transform = transform
    self.target_transform = target_transform
    self.loader = loader
    fh.close()
  def __getitem__(self, index):
    fn, label = self.imgs[index]
    img = self.loader(fn)
    img = Image.fromarray(np.array(img), mode='L')
    if self.transform is not None:
      img = self.transform(img)
    return img,label
  def __len__(self):
    return len(self.imgs)
 
train_data = MyDataset(txt= root + 'train.txt', transform = torchvision.transforms.ToTensor())
train_loader = DataLoader(dataset = train_data, batch_size=BATCH_SIZE, shuffle=True)
 
test_data = MyDataset(txt= root + 'test.txt', transform = torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset = test_data, batch_size=BATCH_SIZE)
 
class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    self.conv1 = nn.Sequential(     # input shape (1, 28, 28)
      nn.Conv2d(
        in_channels=1,       # input height
        out_channels=16,      # n_filters
        kernel_size=5,       # filter size
        stride=1,          # filter movement/step
        padding=2,         # if want same width and length of this image after con2d, padding=(kernel_size-1)/2 if stride=1
      ),               # output shape (16, 28, 28)
      nn.ReLU(),           # activation
      nn.MaxPool2d(kernel_size=2),  # choose max value in 2x2 area, output shape (16, 14, 14)
    )
    self.conv2 = nn.Sequential(     # input shape (16, 14, 14)
      nn.Conv2d(16, 32, 5, 1, 2),   # output shape (32, 14, 14)
      nn.ReLU(),           # activation
      nn.MaxPool2d(2),        # output shape (32, 7, 7)
    )
    self.out = nn.Linear(32 * 7 * 7, 10)  # fully connected layer, output 10 classes
 
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size(0), -1)      # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
    output = self.out(x)
    return output, x  # return x for visualization 
cnn = CNN()
print(cnn) # net architecture
 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)  # optimize all cnn parameters
loss_func = nn.CrossEntropyLoss()            # the target label is not one-hotted 
 
# training and testing
for epoch in range(EPOCH):
  for step, (x, y) in enumerate(train_loader):  # gives batch data, normalize x when iterate train_loader
    b_x = Variable(x)  # batch x
    b_y = Variable(y)  # batch y
 
    output = cnn(b_x)[0]        # cnn output
    loss = loss_func(output, b_y)  # cross entropy loss
    optimizer.zero_grad()      # clear gradients for this training step
    loss.backward()         # backpropagation, compute gradients
    optimizer.step()        # apply gradients
 
    if step % 50 == 0:
      cnn.eval()
      eval_loss = 0.
      eval_acc = 0.
      for i, (tx, ty) in enumerate(test_loader):
        t_x = Variable(tx)
        t_y = Variable(ty)
        output = cnn(t_x)[0]
        loss = loss_func(output, t_y)
        eval_loss += loss.data[0]
        pred = torch.max(output, 1)[1]
        num_correct = (pred == t_y).sum()
        eval_acc += float(num_correct.data[0])
      acc_rate = eval_acc / float(len(test_data))
      print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(test_data)), acc_rate))

图片和label 见上一篇文章《pytorch 把MNIST数据集转换成图片和txt》

结果如下:

pytorch cnn 识别手写的字实现自建图片数据

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python连接数据库的方法
Oct 19 Python
python中不能连接超时的问题及解决方法
Jun 10 Python
Python字典中的键映射多个值的方法(列表或者集合)
Oct 17 Python
Python操作json的方法实例分析
Dec 06 Python
Python实现的序列化和反序列化二叉树算法示例
Mar 02 Python
Python中捕获键盘的方式详解
Mar 28 Python
Django Python 获取请求头信息Content-Range的方法
Aug 06 Python
Python 线程池用法简单示例
Oct 02 Python
详解python itertools功能
Feb 07 Python
在脚本中单独使用django的ORM模型详解
Apr 01 Python
Python基于wordcloud及jieba实现中国地图词云图
Jun 09 Python
python 实现两个变量值进行交换的n种操作
Jun 02 Python
pytorch 把MNIST数据集转换成图片和txt的方法
May 20 #Python
Python安装lz4-0.10.1遇到的坑
May 20 #Python
Python requests发送post请求的一些疑点
May 20 #Python
python中virtualenvwrapper安装与使用
May 20 #Python
django静态文件加载的方法
May 20 #Python
django中静态文件配置static的方法
May 20 #Python
Python中跳台阶、变态跳台阶与矩形覆盖问题的解决方法
May 19 #Python
You might like
基于php常用函数总结(数组,字符串,时间,文件操作)
2013/06/27 PHP
php替换字符串中间字符为省略号的方法
2015/05/04 PHP
PHP反射API示例分享
2016/10/08 PHP
PHP实现对xml的增删改查操作案例分析
2017/05/19 PHP
JavaScript 获取事件对象的注意点
2009/07/29 Javascript
很好用的js日历算法详细代码
2013/03/07 Javascript
究竟什么是Node.js?Node.js有什么好处?
2015/05/29 Javascript
jQuery基于ajax实现星星评论代码
2015/08/07 Javascript
jQuery+jsp实现省市县三级联动效果(附源码)
2015/12/03 Javascript
jQuery常用知识点总结以及平时封装常用函数
2016/02/23 Javascript
jQuery实现下拉加载功能实例代码
2016/04/01 Javascript
js添加千分位的实现代码(超简单)
2016/08/01 Javascript
AngularJS表单验证中级篇(3)
2016/09/28 Javascript
浅谈mint-ui loadmore组件注意的问题
2017/11/08 Javascript
Angular 4.x+Ionic3踩坑之Ionic 3.x界面传值详解
2018/03/13 Javascript
jQuery 导航自动跟随滚动的实现代码
2018/05/30 jQuery
解决echarts的多个折现数据出现坐标和值对不上的问题
2018/12/28 Javascript
Vue-CLI 项目在pycharm中配置方法
2019/08/30 Javascript
layui 实现表格某一列显示图标
2019/09/19 Javascript
Vue实现页面添加水印功能
2019/11/09 Javascript
JS apply用法总结和使用场景实例分析
2020/03/14 Javascript
python 测试实现方法
2008/12/24 Python
python 将字符串中的数字相加求和的实现
2019/07/18 Python
Django 后台带有字典的列表数据与页面js交互实例
2020/04/03 Python
在pycharm中使用matplotlib.pyplot 绘图时报错的解决
2020/06/01 Python
CSS3绘制六边形的简单实现
2016/08/25 HTML / CSS
半年思想汇报
2013/12/30 职场文书
祖国在我心中演讲稿400字
2014/05/04 职场文书
师德演讲稿范文
2014/05/06 职场文书
校庆口号
2014/06/20 职场文书
音乐幼师求职信
2014/07/09 职场文书
学生吸烟检讨书
2014/09/14 职场文书
幼儿园教师个人工作总结2015
2015/05/12 职场文书
实习单位鉴定意见
2015/06/04 职场文书
小学班级管理心得体会
2016/01/07 职场文书
解决Mysql中的innoDB幻读问题
2022/04/29 MySQL