pytorch cnn 识别手写的字实现自建图片数据


Posted in Python onMay 20, 2018

本文主要介绍了pytorch cnn 识别手写的字实现自建图片数据,分享给大家,具体如下:

# library
# standard library
import os 
# third-party library
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
# torch.manual_seed(1)  # reproducible 
# Hyper Parameters
EPOCH = 1        # train the training data n times, to save time, we just train 1 epoch
BATCH_SIZE = 50
LR = 0.001       # learning rate 
 
root = "./mnist/raw/"
 
def default_loader(path):
  # return Image.open(path).convert('RGB')
  return Image.open(path)
 
class MyDataset(Dataset):
  def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
    fh = open(txt, 'r')
    imgs = []
    for line in fh:
      line = line.strip('\n')
      line = line.rstrip()
      words = line.split()
      imgs.append((words[0], int(words[1])))
    self.imgs = imgs
    self.transform = transform
    self.target_transform = target_transform
    self.loader = loader
    fh.close()
  def __getitem__(self, index):
    fn, label = self.imgs[index]
    img = self.loader(fn)
    img = Image.fromarray(np.array(img), mode='L')
    if self.transform is not None:
      img = self.transform(img)
    return img,label
  def __len__(self):
    return len(self.imgs)
 
train_data = MyDataset(txt= root + 'train.txt', transform = torchvision.transforms.ToTensor())
train_loader = DataLoader(dataset = train_data, batch_size=BATCH_SIZE, shuffle=True)
 
test_data = MyDataset(txt= root + 'test.txt', transform = torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset = test_data, batch_size=BATCH_SIZE)
 
class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    self.conv1 = nn.Sequential(     # input shape (1, 28, 28)
      nn.Conv2d(
        in_channels=1,       # input height
        out_channels=16,      # n_filters
        kernel_size=5,       # filter size
        stride=1,          # filter movement/step
        padding=2,         # if want same width and length of this image after con2d, padding=(kernel_size-1)/2 if stride=1
      ),               # output shape (16, 28, 28)
      nn.ReLU(),           # activation
      nn.MaxPool2d(kernel_size=2),  # choose max value in 2x2 area, output shape (16, 14, 14)
    )
    self.conv2 = nn.Sequential(     # input shape (16, 14, 14)
      nn.Conv2d(16, 32, 5, 1, 2),   # output shape (32, 14, 14)
      nn.ReLU(),           # activation
      nn.MaxPool2d(2),        # output shape (32, 7, 7)
    )
    self.out = nn.Linear(32 * 7 * 7, 10)  # fully connected layer, output 10 classes
 
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size(0), -1)      # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
    output = self.out(x)
    return output, x  # return x for visualization 
cnn = CNN()
print(cnn) # net architecture
 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)  # optimize all cnn parameters
loss_func = nn.CrossEntropyLoss()            # the target label is not one-hotted 
 
# training and testing
for epoch in range(EPOCH):
  for step, (x, y) in enumerate(train_loader):  # gives batch data, normalize x when iterate train_loader
    b_x = Variable(x)  # batch x
    b_y = Variable(y)  # batch y
 
    output = cnn(b_x)[0]        # cnn output
    loss = loss_func(output, b_y)  # cross entropy loss
    optimizer.zero_grad()      # clear gradients for this training step
    loss.backward()         # backpropagation, compute gradients
    optimizer.step()        # apply gradients
 
    if step % 50 == 0:
      cnn.eval()
      eval_loss = 0.
      eval_acc = 0.
      for i, (tx, ty) in enumerate(test_loader):
        t_x = Variable(tx)
        t_y = Variable(ty)
        output = cnn(t_x)[0]
        loss = loss_func(output, t_y)
        eval_loss += loss.data[0]
        pred = torch.max(output, 1)[1]
        num_correct = (pred == t_y).sum()
        eval_acc += float(num_correct.data[0])
      acc_rate = eval_acc / float(len(test_data))
      print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(test_data)), acc_rate))

图片和label 见上一篇文章《pytorch 把MNIST数据集转换成图片和txt》

结果如下:

pytorch cnn 识别手写的字实现自建图片数据

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 生成不重复的随机数的代码
May 15 Python
python操作CouchDB的方法
Oct 08 Python
Python实现SVN的目录周期性备份实例
Jul 17 Python
机器学习python实战之手写数字识别
Nov 01 Python
scrapy spider的几种爬取方式实例代码
Jan 25 Python
django传值给模板, 再用JS接收并进行操作的实例
May 28 Python
python中ASCII码字符与int之间的转换方法
Jul 09 Python
Python+Selenium使用Page Object实现页面自动化测试
Jul 14 Python
Python迭代器Iterable判断方法解析
Mar 16 Python
opencv-python的RGB与BGR互转方式
Jun 02 Python
Django CBV模型源码运行流程详解
Aug 17 Python
python3中确保枚举值代码分析
Dec 02 Python
pytorch 把MNIST数据集转换成图片和txt的方法
May 20 #Python
Python安装lz4-0.10.1遇到的坑
May 20 #Python
Python requests发送post请求的一些疑点
May 20 #Python
python中virtualenvwrapper安装与使用
May 20 #Python
django静态文件加载的方法
May 20 #Python
django中静态文件配置static的方法
May 20 #Python
Python中跳台阶、变态跳台阶与矩形覆盖问题的解决方法
May 19 #Python
You might like
PHP 柱状图实现代码
2009/12/04 PHP
php使用escapeshellarg时中文被过滤的解决方法
2016/07/10 PHP
php实现微信模板消息推送
2018/03/30 PHP
PHP预定义接口――Iterator用法示例
2020/06/05 PHP
框架页面高度自动刷新的Javascript脚本
2013/11/01 Javascript
使用js画图之饼图
2015/01/12 Javascript
jQuery的事件委托实例分析
2015/07/15 Javascript
js实现仿爱微网两级导航菜单效果代码
2015/08/31 Javascript
Bootstrap富文本组件wysiwyg数据保存到mysql的方法
2016/05/09 Javascript
使用jQuery给input标签设置默认值
2016/06/20 Javascript
Bootstrap页面标题Page Header的实现方法
2017/03/22 Javascript
Bootstrap 模态框多次显示后台提交多次BUG的解决方法
2017/12/26 Javascript
详解vue.js移动端配置flexible.js及注意事项
2019/04/10 Javascript
小程序如何定位所在城市及发起周边搜索
2020/02/11 Javascript
解决Vue @submit 提交后不刷新页面问题
2020/07/18 Javascript
Vue实现图书管理案例
2021/01/20 Vue.js
[01:53]2016完美“圣”典风云人物:Maybe专访
2016/12/05 DOTA
Python 可爱的大小写
2008/09/06 Python
Python编程中装饰器的使用示例解析
2016/06/20 Python
python实现的正则表达式功能入门教程【经典】
2017/06/05 Python
Python supervisor强大的进程管理工具的使用
2019/04/24 Python
Python+Kepler.gl实现时间轮播地图过程解析
2020/07/20 Python
英国领先的杂志订阅网站:Magazine.co.uk
2018/01/25 全球购物
美国健康和保健平台:healtop
2020/07/02 全球购物
期末自我鉴定
2014/02/02 职场文书
中学校庆方案
2014/03/17 职场文书
爱岗敬业演讲稿
2014/05/05 职场文书
保护环境建议书100字
2014/05/13 职场文书
党的群众路线教育实践活动对照检查材料(教师)
2014/09/24 职场文书
2014年重阳节老干部座谈会上的讲话稿
2014/09/25 职场文书
民事诉讼代理授权委托书范本
2014/10/08 职场文书
无子女夫妻离婚协议书(4篇)
2014/10/20 职场文书
2015年度对口支援工作总结
2015/07/22 职场文书
Spring整合Mybatis的全过程
2021/06/28 Java/Android
浅谈spring boot使用thymeleaf版本的问题
2021/08/04 Java/Android
Mysql将字符串按照指定字符分割的正确方法
2022/05/30 MySQL