pytorch cnn 识别手写的字实现自建图片数据


Posted in Python onMay 20, 2018

本文主要介绍了pytorch cnn 识别手写的字实现自建图片数据,分享给大家,具体如下:

# library
# standard library
import os 
# third-party library
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
# torch.manual_seed(1)  # reproducible 
# Hyper Parameters
EPOCH = 1        # train the training data n times, to save time, we just train 1 epoch
BATCH_SIZE = 50
LR = 0.001       # learning rate 
 
root = "./mnist/raw/"
 
def default_loader(path):
  # return Image.open(path).convert('RGB')
  return Image.open(path)
 
class MyDataset(Dataset):
  def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
    fh = open(txt, 'r')
    imgs = []
    for line in fh:
      line = line.strip('\n')
      line = line.rstrip()
      words = line.split()
      imgs.append((words[0], int(words[1])))
    self.imgs = imgs
    self.transform = transform
    self.target_transform = target_transform
    self.loader = loader
    fh.close()
  def __getitem__(self, index):
    fn, label = self.imgs[index]
    img = self.loader(fn)
    img = Image.fromarray(np.array(img), mode='L')
    if self.transform is not None:
      img = self.transform(img)
    return img,label
  def __len__(self):
    return len(self.imgs)
 
train_data = MyDataset(txt= root + 'train.txt', transform = torchvision.transforms.ToTensor())
train_loader = DataLoader(dataset = train_data, batch_size=BATCH_SIZE, shuffle=True)
 
test_data = MyDataset(txt= root + 'test.txt', transform = torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset = test_data, batch_size=BATCH_SIZE)
 
class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    self.conv1 = nn.Sequential(     # input shape (1, 28, 28)
      nn.Conv2d(
        in_channels=1,       # input height
        out_channels=16,      # n_filters
        kernel_size=5,       # filter size
        stride=1,          # filter movement/step
        padding=2,         # if want same width and length of this image after con2d, padding=(kernel_size-1)/2 if stride=1
      ),               # output shape (16, 28, 28)
      nn.ReLU(),           # activation
      nn.MaxPool2d(kernel_size=2),  # choose max value in 2x2 area, output shape (16, 14, 14)
    )
    self.conv2 = nn.Sequential(     # input shape (16, 14, 14)
      nn.Conv2d(16, 32, 5, 1, 2),   # output shape (32, 14, 14)
      nn.ReLU(),           # activation
      nn.MaxPool2d(2),        # output shape (32, 7, 7)
    )
    self.out = nn.Linear(32 * 7 * 7, 10)  # fully connected layer, output 10 classes
 
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size(0), -1)      # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
    output = self.out(x)
    return output, x  # return x for visualization 
cnn = CNN()
print(cnn) # net architecture
 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)  # optimize all cnn parameters
loss_func = nn.CrossEntropyLoss()            # the target label is not one-hotted 
 
# training and testing
for epoch in range(EPOCH):
  for step, (x, y) in enumerate(train_loader):  # gives batch data, normalize x when iterate train_loader
    b_x = Variable(x)  # batch x
    b_y = Variable(y)  # batch y
 
    output = cnn(b_x)[0]        # cnn output
    loss = loss_func(output, b_y)  # cross entropy loss
    optimizer.zero_grad()      # clear gradients for this training step
    loss.backward()         # backpropagation, compute gradients
    optimizer.step()        # apply gradients
 
    if step % 50 == 0:
      cnn.eval()
      eval_loss = 0.
      eval_acc = 0.
      for i, (tx, ty) in enumerate(test_loader):
        t_x = Variable(tx)
        t_y = Variable(ty)
        output = cnn(t_x)[0]
        loss = loss_func(output, t_y)
        eval_loss += loss.data[0]
        pred = torch.max(output, 1)[1]
        num_correct = (pred == t_y).sum()
        eval_acc += float(num_correct.data[0])
      acc_rate = eval_acc / float(len(test_data))
      print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(test_data)), acc_rate))

图片和label 见上一篇文章《pytorch 把MNIST数据集转换成图片和txt》

结果如下:

pytorch cnn 识别手写的字实现自建图片数据

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用python编写脚本获取手机当前应用apk的信息
Jul 21 Python
Python中设置变量作为默认值时容易遇到的错误
Apr 03 Python
python中input()与raw_input()的区别分析
Feb 27 Python
详解使用 pyenv 管理多个版本 python 环境
Oct 19 Python
Python数据结构与算法之列表(链表,linked list)简单实现
Oct 30 Python
Pandas中把dataframe转成array的方法
Apr 13 Python
Python基于win32ui模块创建弹出式菜单示例
May 09 Python
Python利用ORM控制MongoDB(MongoEngine)的步骤全纪录
Sep 13 Python
TensorFlow基本的常量、变量和运算操作详解
Feb 03 Python
Python连接HDFS实现文件上传下载及Pandas转换文本文件到CSV操作
Jun 06 Python
Python 使用xlwt模块将多行多列数据循环写入excel文档的操作
Nov 10 Python
如何判断pytorch是否支持GPU加速
Jun 01 Python
pytorch 把MNIST数据集转换成图片和txt的方法
May 20 #Python
Python安装lz4-0.10.1遇到的坑
May 20 #Python
Python requests发送post请求的一些疑点
May 20 #Python
python中virtualenvwrapper安装与使用
May 20 #Python
django静态文件加载的方法
May 20 #Python
django中静态文件配置static的方法
May 20 #Python
Python中跳台阶、变态跳台阶与矩形覆盖问题的解决方法
May 19 #Python
You might like
自制汽车收音机天线:收听广播的技巧和方法
2021/03/02 无线电
PHP无限分类(树形类)
2013/09/28 PHP
php使用gzip压缩传输js和css文件的方法
2015/07/29 PHP
详解PHP使用日期时间处理器Carbon人性化显示时间
2017/08/10 PHP
createElement与createDocumentFragment的点点区别小结
2011/12/19 Javascript
jQuery中focus事件用法实例
2014/12/26 Javascript
AngularJs动态加载模块和依赖注入详解
2016/01/11 Javascript
基于JQuery的$.ajax方法进行异步请求导致页面闪烁的解决办法
2016/05/10 Javascript
一览画面点击复选框后获取多个id值的方法
2016/05/30 Javascript
实用又漂亮的BootstrapValidator表单验证插件
2016/05/30 Javascript
WEB前端实现裁剪上传图片功能
2016/10/17 Javascript
微信小程序-获得用户输入内容
2017/02/13 Javascript
vue.js中mint-ui框架的使用方法
2017/05/12 Javascript
angular 用拦截器统一处理http请求和响应的方法
2017/06/08 Javascript
详解vue的数据劫持以及操作数组的坑
2019/04/18 Javascript
js实现倒计时秒杀效果
2020/03/25 Javascript
Bootstrap实现前端登录页面带验证码功能完整示例
2020/03/26 Javascript
python创建只读属性对象的方法(ReadOnlyObject)
2013/02/10 Python
python查找指定具有相同内容文件的方法
2015/06/28 Python
详解Python装饰器由浅入深
2016/12/09 Python
Python中使用haystack实现django全文检索搜索引擎功能
2017/08/26 Python
使用anaconda的pip安装第三方python包的操作步骤
2018/06/11 Python
Pyinstaller 打包exe教程及问题解决
2019/08/16 Python
Matplotlib绘制雷达图和三维图的示例代码
2020/01/07 Python
如何在python中执行另一个py文件
2020/04/30 Python
Django数据模型中on_delete使用详解
2020/11/30 Python
详解CSS3的box-shadow属性制作边框阴影效果的方法
2016/05/10 HTML / CSS
一站式跨境收款解决方案:Payoneer(派安盈)
2018/09/06 全球购物
美国滑板店:Tactics
2020/11/08 全球购物
百年校庆节目主持词
2014/03/27 职场文书
大学生活动总结怎么写
2014/04/29 职场文书
化学专业自荐信
2014/05/28 职场文书
2014年社区计生工作总结
2014/11/18 职场文书
先进事迹材料怎么写
2014/12/30 职场文书
陪护人员误工证明
2015/06/24 职场文书
Java 常见的限流算法详细分析并实现
2022/04/07 Java/Android