pytorch cnn 识别手写的字实现自建图片数据


Posted in Python onMay 20, 2018

本文主要介绍了pytorch cnn 识别手写的字实现自建图片数据,分享给大家,具体如下:

# library
# standard library
import os 
# third-party library
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
# torch.manual_seed(1)  # reproducible 
# Hyper Parameters
EPOCH = 1        # train the training data n times, to save time, we just train 1 epoch
BATCH_SIZE = 50
LR = 0.001       # learning rate 
 
root = "./mnist/raw/"
 
def default_loader(path):
  # return Image.open(path).convert('RGB')
  return Image.open(path)
 
class MyDataset(Dataset):
  def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
    fh = open(txt, 'r')
    imgs = []
    for line in fh:
      line = line.strip('\n')
      line = line.rstrip()
      words = line.split()
      imgs.append((words[0], int(words[1])))
    self.imgs = imgs
    self.transform = transform
    self.target_transform = target_transform
    self.loader = loader
    fh.close()
  def __getitem__(self, index):
    fn, label = self.imgs[index]
    img = self.loader(fn)
    img = Image.fromarray(np.array(img), mode='L')
    if self.transform is not None:
      img = self.transform(img)
    return img,label
  def __len__(self):
    return len(self.imgs)
 
train_data = MyDataset(txt= root + 'train.txt', transform = torchvision.transforms.ToTensor())
train_loader = DataLoader(dataset = train_data, batch_size=BATCH_SIZE, shuffle=True)
 
test_data = MyDataset(txt= root + 'test.txt', transform = torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset = test_data, batch_size=BATCH_SIZE)
 
class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    self.conv1 = nn.Sequential(     # input shape (1, 28, 28)
      nn.Conv2d(
        in_channels=1,       # input height
        out_channels=16,      # n_filters
        kernel_size=5,       # filter size
        stride=1,          # filter movement/step
        padding=2,         # if want same width and length of this image after con2d, padding=(kernel_size-1)/2 if stride=1
      ),               # output shape (16, 28, 28)
      nn.ReLU(),           # activation
      nn.MaxPool2d(kernel_size=2),  # choose max value in 2x2 area, output shape (16, 14, 14)
    )
    self.conv2 = nn.Sequential(     # input shape (16, 14, 14)
      nn.Conv2d(16, 32, 5, 1, 2),   # output shape (32, 14, 14)
      nn.ReLU(),           # activation
      nn.MaxPool2d(2),        # output shape (32, 7, 7)
    )
    self.out = nn.Linear(32 * 7 * 7, 10)  # fully connected layer, output 10 classes
 
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size(0), -1)      # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
    output = self.out(x)
    return output, x  # return x for visualization 
cnn = CNN()
print(cnn) # net architecture
 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)  # optimize all cnn parameters
loss_func = nn.CrossEntropyLoss()            # the target label is not one-hotted 
 
# training and testing
for epoch in range(EPOCH):
  for step, (x, y) in enumerate(train_loader):  # gives batch data, normalize x when iterate train_loader
    b_x = Variable(x)  # batch x
    b_y = Variable(y)  # batch y
 
    output = cnn(b_x)[0]        # cnn output
    loss = loss_func(output, b_y)  # cross entropy loss
    optimizer.zero_grad()      # clear gradients for this training step
    loss.backward()         # backpropagation, compute gradients
    optimizer.step()        # apply gradients
 
    if step % 50 == 0:
      cnn.eval()
      eval_loss = 0.
      eval_acc = 0.
      for i, (tx, ty) in enumerate(test_loader):
        t_x = Variable(tx)
        t_y = Variable(ty)
        output = cnn(t_x)[0]
        loss = loss_func(output, t_y)
        eval_loss += loss.data[0]
        pred = torch.max(output, 1)[1]
        num_correct = (pred == t_y).sum()
        eval_acc += float(num_correct.data[0])
      acc_rate = eval_acc / float(len(test_data))
      print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(test_data)), acc_rate))

图片和label 见上一篇文章《pytorch 把MNIST数据集转换成图片和txt》

结果如下:

pytorch cnn 识别手写的字实现自建图片数据

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之赋值,简单也不简单
Sep 24 Python
用Python展示动态规则法用以解决重叠子问题的示例
Apr 02 Python
利用Python学习RabbitMQ消息队列
Nov 30 Python
Python操作MySQL数据库9个实用实例
Dec 11 Python
Python制作爬虫抓取美女图
Jan 20 Python
对Python3.6 IDLE常用快捷键介绍
Jul 16 Python
Python引用计数操作示例
Aug 23 Python
python实现在函数图像上添加文字和标注的方法
Jul 08 Python
详解pandas中MultiIndex和对象实际索引不一致问题
Jul 23 Python
Python 元组操作总结
Sep 18 Python
利用python实现AR教程
Nov 20 Python
python线程join方法原理解析
Feb 11 Python
pytorch 把MNIST数据集转换成图片和txt的方法
May 20 #Python
Python安装lz4-0.10.1遇到的坑
May 20 #Python
Python requests发送post请求的一些疑点
May 20 #Python
python中virtualenvwrapper安装与使用
May 20 #Python
django静态文件加载的方法
May 20 #Python
django中静态文件配置static的方法
May 20 #Python
Python中跳台阶、变态跳台阶与矩形覆盖问题的解决方法
May 19 #Python
You might like
php获取远程图片的两种 CURL方式和sockets方式获取远程图片
2011/11/07 PHP
kohana框架上传文件验证规则写法示例
2014/07/14 PHP
Yii2 hasOne(), hasMany() 实现三表关联的方法(两种)
2017/02/15 PHP
使用ucenter实现多站点同步登录的讲解
2019/03/21 PHP
Thinkphp 框架扩展之行为扩展原理与实现方法分析
2020/04/23 PHP
非常不错的功能强大代码简单的管理菜单美化版
2008/07/09 Javascript
关于event.cancelBubble和event.stopPropagation()的区别介绍
2011/12/11 Javascript
NodeJS 模块开发及发布详解分享
2012/03/07 NodeJs
jquery Mobile入门—多页面切换示例学习
2013/01/08 Javascript
IE与FireFox的JavaScript兼容问题解决办法
2013/12/31 Javascript
js实现交换运动效果的方法
2015/04/10 Javascript
JS+CSS3模拟溢出滚动效果
2016/08/12 Javascript
Js得到radiobuttonlist选中值的两种方法(推荐)
2016/08/25 Javascript
React Native实现地址挑选器功能
2017/10/24 Javascript
本地搭建微信小程序服务器的实现方法
2017/10/27 Javascript
Vue axios设置访问基础路径方法
2018/09/19 Javascript
angular中的post请求处理示例详解
2020/06/30 Javascript
解决Vue大括号字符换行踩的坑
2020/11/09 Javascript
浅谈python中copy和deepcopy中的区别
2017/10/23 Python
numpy中实现二维数组按照某列、某行排序的方法
2018/04/04 Python
Django添加favicon.ico图标的示例代码
2018/08/07 Python
Python模拟自动存取款机的查询、存取款、修改密码等操作
2018/09/02 Python
python读取.mat文件的数据及实例代码
2019/07/12 Python
下载与当前Chrome对应的chromedriver.exe(用于python+selenium)
2020/01/14 Python
Keras搭建自编码器操作
2020/07/03 Python
简单的命令查看安装的python版本号
2020/08/28 Python
CSS3 实现图形下落动画效果
2020/11/13 HTML / CSS
美国医生配方营养补充剂供应商:Healthy Directions
2019/07/10 全球购物
自1926年以来就为冰岛保持温暖:66°North
2020/11/27 全球购物
英国发展最快的在线超市之一:Click Marketplace
2021/02/15 全球购物
超市营业员岗位职责
2013/12/20 职场文书
大学生学习2014全国两会心得体会
2014/03/13 职场文书
竞选班干部演讲稿300字
2014/08/20 职场文书
朋友聚会祝酒词
2015/08/10 职场文书
导游词之白茶谷九龙峡
2019/10/23 职场文书
Python 如何将integer转化为罗马数(3999以内)
2021/06/05 Python