pytorch cnn 识别手写的字实现自建图片数据


Posted in Python onMay 20, 2018

本文主要介绍了pytorch cnn 识别手写的字实现自建图片数据,分享给大家,具体如下:

# library
# standard library
import os 
# third-party library
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
# torch.manual_seed(1)  # reproducible 
# Hyper Parameters
EPOCH = 1        # train the training data n times, to save time, we just train 1 epoch
BATCH_SIZE = 50
LR = 0.001       # learning rate 
 
root = "./mnist/raw/"
 
def default_loader(path):
  # return Image.open(path).convert('RGB')
  return Image.open(path)
 
class MyDataset(Dataset):
  def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
    fh = open(txt, 'r')
    imgs = []
    for line in fh:
      line = line.strip('\n')
      line = line.rstrip()
      words = line.split()
      imgs.append((words[0], int(words[1])))
    self.imgs = imgs
    self.transform = transform
    self.target_transform = target_transform
    self.loader = loader
    fh.close()
  def __getitem__(self, index):
    fn, label = self.imgs[index]
    img = self.loader(fn)
    img = Image.fromarray(np.array(img), mode='L')
    if self.transform is not None:
      img = self.transform(img)
    return img,label
  def __len__(self):
    return len(self.imgs)
 
train_data = MyDataset(txt= root + 'train.txt', transform = torchvision.transforms.ToTensor())
train_loader = DataLoader(dataset = train_data, batch_size=BATCH_SIZE, shuffle=True)
 
test_data = MyDataset(txt= root + 'test.txt', transform = torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset = test_data, batch_size=BATCH_SIZE)
 
class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    self.conv1 = nn.Sequential(     # input shape (1, 28, 28)
      nn.Conv2d(
        in_channels=1,       # input height
        out_channels=16,      # n_filters
        kernel_size=5,       # filter size
        stride=1,          # filter movement/step
        padding=2,         # if want same width and length of this image after con2d, padding=(kernel_size-1)/2 if stride=1
      ),               # output shape (16, 28, 28)
      nn.ReLU(),           # activation
      nn.MaxPool2d(kernel_size=2),  # choose max value in 2x2 area, output shape (16, 14, 14)
    )
    self.conv2 = nn.Sequential(     # input shape (16, 14, 14)
      nn.Conv2d(16, 32, 5, 1, 2),   # output shape (32, 14, 14)
      nn.ReLU(),           # activation
      nn.MaxPool2d(2),        # output shape (32, 7, 7)
    )
    self.out = nn.Linear(32 * 7 * 7, 10)  # fully connected layer, output 10 classes
 
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size(0), -1)      # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
    output = self.out(x)
    return output, x  # return x for visualization 
cnn = CNN()
print(cnn) # net architecture
 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)  # optimize all cnn parameters
loss_func = nn.CrossEntropyLoss()            # the target label is not one-hotted 
 
# training and testing
for epoch in range(EPOCH):
  for step, (x, y) in enumerate(train_loader):  # gives batch data, normalize x when iterate train_loader
    b_x = Variable(x)  # batch x
    b_y = Variable(y)  # batch y
 
    output = cnn(b_x)[0]        # cnn output
    loss = loss_func(output, b_y)  # cross entropy loss
    optimizer.zero_grad()      # clear gradients for this training step
    loss.backward()         # backpropagation, compute gradients
    optimizer.step()        # apply gradients
 
    if step % 50 == 0:
      cnn.eval()
      eval_loss = 0.
      eval_acc = 0.
      for i, (tx, ty) in enumerate(test_loader):
        t_x = Variable(tx)
        t_y = Variable(ty)
        output = cnn(t_x)[0]
        loss = loss_func(output, t_y)
        eval_loss += loss.data[0]
        pred = torch.max(output, 1)[1]
        num_correct = (pred == t_y).sum()
        eval_acc += float(num_correct.data[0])
      acc_rate = eval_acc / float(len(test_data))
      print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(test_data)), acc_rate))

图片和label 见上一篇文章《pytorch 把MNIST数据集转换成图片和txt》

结果如下:

pytorch cnn 识别手写的字实现自建图片数据

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
简单的Python的curses库使用教程
Apr 11 Python
python通过索引遍历列表的方法
May 04 Python
Python在Console下显示文本进度条的方法
Feb 14 Python
python中for循环输出列表索引与对应的值方法
Nov 07 Python
对Python 除法负数取商的取整方式详解
Dec 12 Python
Python Pandas 箱线图的实现
Jul 23 Python
解决python中导入win32com.client出错的问题
Jul 26 Python
Pycharm激活码激活两种快速方式(附最新激活码和插件)
Mar 12 Python
Python 改变数组类型为uint8的实现
Apr 09 Python
Python文件夹批处理操作代码实例
Jul 21 Python
用python批量下载apk
Dec 29 Python
Django展示可视化图表的多种方式
Apr 08 Python
pytorch 把MNIST数据集转换成图片和txt的方法
May 20 #Python
Python安装lz4-0.10.1遇到的坑
May 20 #Python
Python requests发送post请求的一些疑点
May 20 #Python
python中virtualenvwrapper安装与使用
May 20 #Python
django静态文件加载的方法
May 20 #Python
django中静态文件配置static的方法
May 20 #Python
Python中跳台阶、变态跳台阶与矩形覆盖问题的解决方法
May 19 #Python
You might like
PHP中在数据库中保存Checkbox数据(2)
2006/10/09 PHP
隐藏Nginx或Apache以及PHP的版本号的方法
2016/01/03 PHP
Javascript客户端脚本的设计和应用
2006/08/21 Javascript
javascript while语句和do while语句的区别分析
2007/12/08 Javascript
JS 字符串连接[性能比较]
2009/05/10 Javascript
jquery判断单个复选框是否被选中的代码
2009/09/03 Javascript
ASP小贴士/ASP Tips javascript tips可以当桌面
2009/12/10 Javascript
ASP.NET jQuery 实例13 原创jQuery文本框字符限制插件-TextArea Counter
2012/02/03 Javascript
公共js在页面底部加载的注意事项介绍
2013/07/18 Javascript
Jquery跳到页面指定位置的方法
2014/05/12 Javascript
js propertychange和oninput事件
2014/09/28 Javascript
javascript实现表单提交后,提交按钮不可用的方法
2015/04/18 Javascript
jquery实现的美女拼图游戏实例
2015/05/04 Javascript
详解JavaScript正则表达式之分组匹配及反向引用
2016/03/09 Javascript
Javascript中的几种继承方式对比分析
2016/03/22 Javascript
JavaScript简单获取系统当前时间完整示例
2016/08/02 Javascript
Vue.js用法详解
2017/11/13 Javascript
10分钟上手vue-cli 3.0 入门介绍
2018/04/04 Javascript
详解node Async/Await 更好的异步编程解决方案
2018/05/10 Javascript
原生JS实现的碰撞检测功能示例
2018/05/18 Javascript
javascript中join方法实例讲解
2019/02/21 Javascript
详解VUE前端按钮权限控制
2019/04/26 Javascript
微信小程序报错: thirdScriptError的错误问题
2020/06/19 Javascript
JavaScript中EventBus实现对象之间通信
2020/10/18 Javascript
深入解析Python设计模式编程中建造者模式的使用
2016/03/02 Python
Python OpenCV 直方图的计算与显示的方法示例
2018/02/08 Python
pandas多级分组实现排序的方法
2018/04/20 Python
django-filter和普通查询的例子
2019/08/12 Python
详解使用Python下载文件的几种方法
2019/10/13 Python
python base64库给用户名或密码加密的流程
2020/01/02 Python
浅谈Tensorflow 动态双向RNN的输出问题
2020/01/20 Python
PyCharm取消波浪线、下划线和中划线的实现
2020/03/03 Python
美国指甲油品牌:Deco Miami
2017/01/30 全球购物
美国体育用品在线:Modell’s Sporting Goods
2018/06/07 全球购物
Nayomi官网:沙特阿拉伯王国睡衣和内衣品牌
2020/12/19 全球购物
2015个人简历自我评价语
2015/03/11 职场文书