pytorch三层全连接层实现手写字母识别方式


Posted in Python onJanuary 14, 2020

先用最简单的三层全连接神经网络,然后添加激活层查看实验结果,最后加上批标准化验证是否有效

首先根据已有的模板定义网络结构SimpleNet,命名为net.py

import torch
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
from torch import nn,optim
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
#定义三层全连接神经网络
class simpleNet(nn.Module):
 def __init__(self,in_dim,n_hidden_1,n_hidden_2,out_dim):#输入维度,第一层的神经元个数、第二层的神经元个数,以及第三层的神经元个数
  super(simpleNet,self).__init__()
  self.layer1=nn.Linear(in_dim,n_hidden_1)
  self.layer2=nn.Linear(n_hidden_1,n_hidden_2)
  self.layer3=nn.Linear(n_hidden_2,out_dim)
 def forward(self,x):
  x=self.layer1(x)
  x=self.layer2(x)
  x=self.layer3(x)
  return x
 
 
#添加激活函数
class Activation_Net(nn.Module):
 def __init__(self,in_dim,n_hidden_1,n_hidden_2,out_dim):
  super(NeutalNetwork,self).__init__()
  self.layer1=nn.Sequential(#Sequential组合结构
  nn.Linear(in_dim,n_hidden_1),nn.ReLU(True))
  self.layer2=nn.Sequential(
  nn.Linear(n_hidden_1,n_hidden_2),nn.ReLU(True))
  self.layer3=nn.Sequential(
  nn.Linear(n_hidden_2,out_dim))
 def forward(self,x):
  x=self.layer1(x)
  x=self.layer2(x)
  x=self.layer3(x)
  return x
#添加批标准化处理模块,皮标准化放在全连接的后面,非线性的前面
class Batch_Net(nn.Module):
 def _init__(self,in_dim,n_hidden_1,n_hidden_2,out_dim):
  super(Batch_net,self).__init__()
  self.layer1=nn.Sequential(nn.Linear(in_dim,n_hidden_1),nn.BatchNormld(n_hidden_1),nn.ReLU(True))
  self.layer2=nn.Sequential(nn.Linear(n_hidden_1,n_hidden_2),nn.BatchNormld(n_hidden_2),nn.ReLU(True))
  self.layer3=nn.Sequential(nn.Linear(n_hidden_2,out_dim))
 def forword(self,x):
  x=self.layer1(x)
  x=self.layer2(x)
  x=self.layer3(x)
  return x

训练网络,

import torch
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from torch import nn,optim
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
#定义一些超参数
import net
batch_size=64
learning_rate=1e-2
num_epoches=20
#预处理
data_tf=transforms.Compose(
[transforms.ToTensor(),transforms.Normalize([0.5],[0.5])])#将图像转化成tensor,然后继续标准化,就是减均值,除以方差

#读取数据集
train_dataset=datasets.MNIST(root='./data',train=True,transform=data_tf,download=True)
test_dataset=datasets.MNIST(root='./data',train=False,transform=data_tf)
#使用内置的函数导入数据集
train_loader=DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
test_loader=DataLoader(test_dataset,batch_size=batch_size,shuffle=False)

#导入网络,定义损失函数和优化方法
model=net.simpleNet(28*28,300,100,10)
if torch.cuda.is_available():#是否使用cuda加速
 model=model.cuda()
criterion=nn.CrossEntropyLoss()
optimizer=optim.SGD(model.parameters(),lr=learning_rate)
import net
n_epochs=5
for epoch in range(n_epochs):
 running_loss=0.0
 running_correct=0
 print("epoch {}/{}".format(epoch,n_epochs))
 print("-"*10)
 for data in train_loader:
  img,label=data
  img=img.view(img.size(0),-1)
  if torch.cuda.is_available():
   img=img.cuda()
   label=label.cuda()
  else:
   img=Variable(img)
   label=Variable(label)
  out=model(img)#得到前向传播的结果
  loss=criterion(out,label)#得到损失函数
  print_loss=loss.data.item()
  optimizer.zero_grad()#归0梯度
  loss.backward()#反向传播
  optimizer.step()#优化
  running_loss+=loss.item()
  epoch+=1
  if epoch%50==0:
   print('epoch:{},loss:{:.4f}'.format(epoch,loss.data.item()))

训练的结果截图如下:

pytorch三层全连接层实现手写字母识别方式

测试网络

#测试网络
model.eval()#将模型变成测试模式
eval_loss=0
eval_acc=0
for data in test_loader:
 img,label=data
 img=img.view(img.size(0),-1)#测试集不需要反向传播,所以可以在前项传播的时候释放内存,节约内存空间
 if torch.cuda.is_available():
  img=Variable(img,volatile=True).cuda()
  label=Variable(label,volatile=True).cuda()
 else:
  img=Variable(img,volatile=True)
  label=Variable(label,volatile=True)
 out=model(img)
 loss=criterion(out,label)
 eval_loss+=loss.item()*label.size(0)
 _,pred=torch.max(out,1)
 num_correct=(pred==label).sum()
 eval_acc+=num_correct.item()
print('test loss:{:.6f},ac:{:.6f}'.format(eval_loss/(len(test_dataset)),eval_acc/(len(test_dataset))))

pytorch三层全连接层实现手写字母识别方式

训练的时候,还可以加入一些dropout,正则化,修改隐藏层神经元的个数,增加隐藏层数,可以自己添加。

以上这篇pytorch三层全连接层实现手写字母识别方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python合并文本文件示例
Feb 07 Python
python opencv3实现人脸识别(windows)
May 25 Python
python 统计一个列表当中的每一个元素出现了多少次的方法
Nov 14 Python
python实现基于信息增益的决策树归纳
Dec 18 Python
python实现关闭第三方窗口的方法
Jun 28 Python
Python学习笔记之字符串和字符串方法实例详解
Aug 22 Python
在python中计算ssim的方法(与Matlab结果一致)
Dec 19 Python
Python Pillow.Image 图像保存和参数选择方式
Jan 09 Python
python中upper是做什么用的
Jul 20 Python
Django配置Bootstrap, js实现过程详解
Oct 13 Python
PyCharm安装PyQt5及其工具(Qt Designer、PyUIC、PyRcc)的步骤详解
Nov 02 Python
python多线程方法详解
Jan 18 Python
Python实现bilibili时间长度查询的示例代码
Jan 14 #Python
基于python监控程序是否关闭
Jan 14 #Python
关于pytorch中全连接神经网络搭建两种模式详解
Jan 14 #Python
使用Pytorch来拟合函数方式
Jan 14 #Python
pytorch 模拟关系拟合——回归实例
Jan 14 #Python
PyTorch实现AlexNet示例
Jan 14 #Python
Pytorch 实现focal_loss 多类别和二分类示例
Jan 14 #Python
You might like
ThinkPHP3.1新特性之多数据库操作更加完善
2014/06/19 PHP
支付宝接口开发集成支付环境小结
2015/03/17 PHP
JavaScript Perfection kill 测试及答案
2010/03/23 Javascript
JS下高效拼装字符串的几种方法比较与测试代码
2010/04/15 Javascript
Jquery常用技巧收集整理篇
2010/11/14 Javascript
js实现的仿新浪微博完美的时间组件升级版
2011/12/20 Javascript
如何制作浮动广告 JavaScript制作浮动广告代码
2012/12/30 Javascript
关于Javascript作用域链的八点总结
2013/12/06 Javascript
javascript计算用户打开网页的停留时间
2014/01/09 Javascript
table对象中的insertRow与deleteRow使用示例
2014/01/26 Javascript
纯javascript实现四方向文本无缝滚动效果
2015/06/16 Javascript
基于jQuery实现的QQ表情插件
2015/08/25 Javascript
基于javascript实现最简单的选项卡切换效果
2016/05/16 Javascript
详解handlebars+require基本使用方法
2016/12/21 Javascript
jQuery实现选中行变色效果(实例讲解)
2017/07/06 jQuery
基于BootStrap实现简洁注册界面
2017/07/20 Javascript
判断滚动条滑到底部触发事件(实例讲解)
2017/11/15 Javascript
js删除数组中的元素delete和splice的区别详解
2018/02/03 Javascript
基于Vue+elementUI实现动态表单的校验功能(根据条件动态切换校验格式)
2019/04/04 Javascript
Vuejs通过拖动改变元素宽度实现自适应
2020/09/02 Javascript
[01:00:06]加油DOTA_EP01_网络版
2014/08/09 DOTA
详解django三种文件下载方式
2018/04/06 Python
基于K.image_data_format() == 'channels_first' 的理解
2020/06/29 Python
django跳转页面传参的实现
2020/09/17 Python
Python应用自动化部署工具Fabric原理及使用解析
2020/11/30 Python
纯css3制作煽动翅膀的蝴蝶的示例
2018/04/23 HTML / CSS
美国照明、家居装饰和家具购物网站:Bellacor
2017/09/20 全球购物
eBay意大利购物网站:eBay.it
2019/09/04 全球购物
中医药大学毕业生自荐信
2013/11/08 职场文书
求职信的要素有哪些呢
2013/12/26 职场文书
小学绿色学校申报材料
2014/08/23 职场文书
2015年环境整治工作总结
2015/05/22 职场文书
三好学生主要事迹怎么写
2015/11/03 职场文书
小学科学课教学反思
2016/02/23 职场文书
Python机器学习算法之决策树算法的实现与优缺点
2021/05/13 Python
教你使用RustDesk 搭建一个自己的远程桌面中继服务器
2022/08/14 Servers