pytorch三层全连接层实现手写字母识别方式


Posted in Python onJanuary 14, 2020

先用最简单的三层全连接神经网络,然后添加激活层查看实验结果,最后加上批标准化验证是否有效

首先根据已有的模板定义网络结构SimpleNet,命名为net.py

import torch
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
from torch import nn,optim
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
#定义三层全连接神经网络
class simpleNet(nn.Module):
 def __init__(self,in_dim,n_hidden_1,n_hidden_2,out_dim):#输入维度,第一层的神经元个数、第二层的神经元个数,以及第三层的神经元个数
  super(simpleNet,self).__init__()
  self.layer1=nn.Linear(in_dim,n_hidden_1)
  self.layer2=nn.Linear(n_hidden_1,n_hidden_2)
  self.layer3=nn.Linear(n_hidden_2,out_dim)
 def forward(self,x):
  x=self.layer1(x)
  x=self.layer2(x)
  x=self.layer3(x)
  return x
 
 
#添加激活函数
class Activation_Net(nn.Module):
 def __init__(self,in_dim,n_hidden_1,n_hidden_2,out_dim):
  super(NeutalNetwork,self).__init__()
  self.layer1=nn.Sequential(#Sequential组合结构
  nn.Linear(in_dim,n_hidden_1),nn.ReLU(True))
  self.layer2=nn.Sequential(
  nn.Linear(n_hidden_1,n_hidden_2),nn.ReLU(True))
  self.layer3=nn.Sequential(
  nn.Linear(n_hidden_2,out_dim))
 def forward(self,x):
  x=self.layer1(x)
  x=self.layer2(x)
  x=self.layer3(x)
  return x
#添加批标准化处理模块,皮标准化放在全连接的后面,非线性的前面
class Batch_Net(nn.Module):
 def _init__(self,in_dim,n_hidden_1,n_hidden_2,out_dim):
  super(Batch_net,self).__init__()
  self.layer1=nn.Sequential(nn.Linear(in_dim,n_hidden_1),nn.BatchNormld(n_hidden_1),nn.ReLU(True))
  self.layer2=nn.Sequential(nn.Linear(n_hidden_1,n_hidden_2),nn.BatchNormld(n_hidden_2),nn.ReLU(True))
  self.layer3=nn.Sequential(nn.Linear(n_hidden_2,out_dim))
 def forword(self,x):
  x=self.layer1(x)
  x=self.layer2(x)
  x=self.layer3(x)
  return x

训练网络,

import torch
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from torch import nn,optim
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
#定义一些超参数
import net
batch_size=64
learning_rate=1e-2
num_epoches=20
#预处理
data_tf=transforms.Compose(
[transforms.ToTensor(),transforms.Normalize([0.5],[0.5])])#将图像转化成tensor,然后继续标准化,就是减均值,除以方差

#读取数据集
train_dataset=datasets.MNIST(root='./data',train=True,transform=data_tf,download=True)
test_dataset=datasets.MNIST(root='./data',train=False,transform=data_tf)
#使用内置的函数导入数据集
train_loader=DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
test_loader=DataLoader(test_dataset,batch_size=batch_size,shuffle=False)

#导入网络,定义损失函数和优化方法
model=net.simpleNet(28*28,300,100,10)
if torch.cuda.is_available():#是否使用cuda加速
 model=model.cuda()
criterion=nn.CrossEntropyLoss()
optimizer=optim.SGD(model.parameters(),lr=learning_rate)
import net
n_epochs=5
for epoch in range(n_epochs):
 running_loss=0.0
 running_correct=0
 print("epoch {}/{}".format(epoch,n_epochs))
 print("-"*10)
 for data in train_loader:
  img,label=data
  img=img.view(img.size(0),-1)
  if torch.cuda.is_available():
   img=img.cuda()
   label=label.cuda()
  else:
   img=Variable(img)
   label=Variable(label)
  out=model(img)#得到前向传播的结果
  loss=criterion(out,label)#得到损失函数
  print_loss=loss.data.item()
  optimizer.zero_grad()#归0梯度
  loss.backward()#反向传播
  optimizer.step()#优化
  running_loss+=loss.item()
  epoch+=1
  if epoch%50==0:
   print('epoch:{},loss:{:.4f}'.format(epoch,loss.data.item()))

训练的结果截图如下:

pytorch三层全连接层实现手写字母识别方式

测试网络

#测试网络
model.eval()#将模型变成测试模式
eval_loss=0
eval_acc=0
for data in test_loader:
 img,label=data
 img=img.view(img.size(0),-1)#测试集不需要反向传播,所以可以在前项传播的时候释放内存,节约内存空间
 if torch.cuda.is_available():
  img=Variable(img,volatile=True).cuda()
  label=Variable(label,volatile=True).cuda()
 else:
  img=Variable(img,volatile=True)
  label=Variable(label,volatile=True)
 out=model(img)
 loss=criterion(out,label)
 eval_loss+=loss.item()*label.size(0)
 _,pred=torch.max(out,1)
 num_correct=(pred==label).sum()
 eval_acc+=num_correct.item()
print('test loss:{:.6f},ac:{:.6f}'.format(eval_loss/(len(test_dataset)),eval_acc/(len(test_dataset))))

pytorch三层全连接层实现手写字母识别方式

训练的时候,还可以加入一些dropout,正则化,修改隐藏层神经元的个数,增加隐藏层数,可以自己添加。

以上这篇pytorch三层全连接层实现手写字母识别方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用PyGame绘制图像并保存为图片文件的方法
Apr 24 Python
Python中用于计算对数的log()方法
May 15 Python
Python的Django中将文件上传至七牛云存储的代码分享
Jun 03 Python
python字符串常用方法
Jun 14 Python
python3基于OpenCV实现证件照背景替换
Jul 18 Python
在python中对变量判断是否为None的三种方法总结
Jan 23 Python
Django用户认证系统 User对象解析
Aug 02 Python
Python Tensor FLow简单使用方法实例详解
Jan 14 Python
Python使用requests模块爬取百度翻译
Aug 25 Python
python 利用toapi库自动生成api
Oct 19 Python
tensorflow+k-means聚类简单实现猫狗图像分类的方法
Apr 28 Python
python scrapy简单模拟登录的代码分析
Jul 21 Python
Python实现bilibili时间长度查询的示例代码
Jan 14 #Python
基于python监控程序是否关闭
Jan 14 #Python
关于pytorch中全连接神经网络搭建两种模式详解
Jan 14 #Python
使用Pytorch来拟合函数方式
Jan 14 #Python
pytorch 模拟关系拟合——回归实例
Jan 14 #Python
PyTorch实现AlexNet示例
Jan 14 #Python
Pytorch 实现focal_loss 多类别和二分类示例
Jan 14 #Python
You might like
php写一个函数,实现扫描并打印出自定目录下(含子目录)所有jpg文件名
2017/05/26 PHP
dojo随手记 gird组件引用
2011/02/24 Javascript
深入理解JavaScript系列(41):设计模式之模板方法详解
2015/03/04 Javascript
JS实现窗口加载时模拟鼠标移动的方法
2015/06/03 Javascript
在JavaScript中使用开平方根的sqrt()方法
2015/06/15 Javascript
JavaScript操作XML/HTML比较常用的对象属性集锦
2015/10/30 Javascript
jQuery unbind 删除绑定事件详解
2016/05/24 Javascript
jQuery实现CheckBox全选、全不选功能
2017/01/11 Javascript
AngulaJS路由 ui-router 传参实例
2017/04/28 Javascript
详解微信小程序 通过控制CSS实现view隐藏与显示
2017/05/24 Javascript
Vue.js2.0中的变化小结
2017/10/24 Javascript
手动用webpack搭建第一个ReactApp的示例
2018/04/11 Javascript
微信小程序 Storage更新详解
2019/07/16 Javascript
Openlayers测量距离与面积的实现方法
2020/09/25 Javascript
[00:09]DOTA2新版本PA至宝特效动作展示
2014/11/19 DOTA
通过Python使用saltstack生成服务器资产清单
2016/03/01 Python
Python字符串拼接的几种方法整理
2017/08/02 Python
Django实现登录随机验证码的示例代码
2018/06/20 Python
Python从函数参数类型引出元组实例分析
2019/05/28 Python
Python迭代器iterator生成器generator使用解析
2019/10/24 Python
python next()和iter()函数原理解析
2020/02/07 Python
python字典key不能是可以是啥类型
2020/08/04 Python
windows+vscode安装paddleOCR运行环境的步骤
2020/11/11 Python
教你如何一步一步用Canvas写一个贪吃蛇
2018/10/22 HTML / CSS
荷兰照明、灯具和配件网上商店:dmlights
2019/08/25 全球购物
业务经理的岗位职责
2013/11/16 职场文书
高职教师岗位职责
2013/12/24 职场文书
监督检查工作方案
2014/05/28 职场文书
党委书记群众路线对照检查材料思想汇报
2014/10/04 职场文书
教师党员个人整改措施
2014/10/27 职场文书
淘宝好评语句大全
2014/12/31 职场文书
工程技术员岗位职责
2015/04/11 职场文书
工程技术负责人岗位职责
2015/04/13 职场文书
2019餐饮行业创业计划书!
2019/06/27 职场文书
Python爬虫基础之爬虫的分类知识总结
2021/05/13 Python
基于Python和openCV实现图像的全景拼接详细步骤
2021/10/05 Python