pytorch三层全连接层实现手写字母识别方式


Posted in Python onJanuary 14, 2020

先用最简单的三层全连接神经网络,然后添加激活层查看实验结果,最后加上批标准化验证是否有效

首先根据已有的模板定义网络结构SimpleNet,命名为net.py

import torch
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
from torch import nn,optim
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
#定义三层全连接神经网络
class simpleNet(nn.Module):
 def __init__(self,in_dim,n_hidden_1,n_hidden_2,out_dim):#输入维度,第一层的神经元个数、第二层的神经元个数,以及第三层的神经元个数
  super(simpleNet,self).__init__()
  self.layer1=nn.Linear(in_dim,n_hidden_1)
  self.layer2=nn.Linear(n_hidden_1,n_hidden_2)
  self.layer3=nn.Linear(n_hidden_2,out_dim)
 def forward(self,x):
  x=self.layer1(x)
  x=self.layer2(x)
  x=self.layer3(x)
  return x
 
 
#添加激活函数
class Activation_Net(nn.Module):
 def __init__(self,in_dim,n_hidden_1,n_hidden_2,out_dim):
  super(NeutalNetwork,self).__init__()
  self.layer1=nn.Sequential(#Sequential组合结构
  nn.Linear(in_dim,n_hidden_1),nn.ReLU(True))
  self.layer2=nn.Sequential(
  nn.Linear(n_hidden_1,n_hidden_2),nn.ReLU(True))
  self.layer3=nn.Sequential(
  nn.Linear(n_hidden_2,out_dim))
 def forward(self,x):
  x=self.layer1(x)
  x=self.layer2(x)
  x=self.layer3(x)
  return x
#添加批标准化处理模块,皮标准化放在全连接的后面,非线性的前面
class Batch_Net(nn.Module):
 def _init__(self,in_dim,n_hidden_1,n_hidden_2,out_dim):
  super(Batch_net,self).__init__()
  self.layer1=nn.Sequential(nn.Linear(in_dim,n_hidden_1),nn.BatchNormld(n_hidden_1),nn.ReLU(True))
  self.layer2=nn.Sequential(nn.Linear(n_hidden_1,n_hidden_2),nn.BatchNormld(n_hidden_2),nn.ReLU(True))
  self.layer3=nn.Sequential(nn.Linear(n_hidden_2,out_dim))
 def forword(self,x):
  x=self.layer1(x)
  x=self.layer2(x)
  x=self.layer3(x)
  return x

训练网络,

import torch
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from torch import nn,optim
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
#定义一些超参数
import net
batch_size=64
learning_rate=1e-2
num_epoches=20
#预处理
data_tf=transforms.Compose(
[transforms.ToTensor(),transforms.Normalize([0.5],[0.5])])#将图像转化成tensor,然后继续标准化,就是减均值,除以方差

#读取数据集
train_dataset=datasets.MNIST(root='./data',train=True,transform=data_tf,download=True)
test_dataset=datasets.MNIST(root='./data',train=False,transform=data_tf)
#使用内置的函数导入数据集
train_loader=DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
test_loader=DataLoader(test_dataset,batch_size=batch_size,shuffle=False)

#导入网络,定义损失函数和优化方法
model=net.simpleNet(28*28,300,100,10)
if torch.cuda.is_available():#是否使用cuda加速
 model=model.cuda()
criterion=nn.CrossEntropyLoss()
optimizer=optim.SGD(model.parameters(),lr=learning_rate)
import net
n_epochs=5
for epoch in range(n_epochs):
 running_loss=0.0
 running_correct=0
 print("epoch {}/{}".format(epoch,n_epochs))
 print("-"*10)
 for data in train_loader:
  img,label=data
  img=img.view(img.size(0),-1)
  if torch.cuda.is_available():
   img=img.cuda()
   label=label.cuda()
  else:
   img=Variable(img)
   label=Variable(label)
  out=model(img)#得到前向传播的结果
  loss=criterion(out,label)#得到损失函数
  print_loss=loss.data.item()
  optimizer.zero_grad()#归0梯度
  loss.backward()#反向传播
  optimizer.step()#优化
  running_loss+=loss.item()
  epoch+=1
  if epoch%50==0:
   print('epoch:{},loss:{:.4f}'.format(epoch,loss.data.item()))

训练的结果截图如下:

pytorch三层全连接层实现手写字母识别方式

测试网络

#测试网络
model.eval()#将模型变成测试模式
eval_loss=0
eval_acc=0
for data in test_loader:
 img,label=data
 img=img.view(img.size(0),-1)#测试集不需要反向传播,所以可以在前项传播的时候释放内存,节约内存空间
 if torch.cuda.is_available():
  img=Variable(img,volatile=True).cuda()
  label=Variable(label,volatile=True).cuda()
 else:
  img=Variable(img,volatile=True)
  label=Variable(label,volatile=True)
 out=model(img)
 loss=criterion(out,label)
 eval_loss+=loss.item()*label.size(0)
 _,pred=torch.max(out,1)
 num_correct=(pred==label).sum()
 eval_acc+=num_correct.item()
print('test loss:{:.6f},ac:{:.6f}'.format(eval_loss/(len(test_dataset)),eval_acc/(len(test_dataset))))

pytorch三层全连接层实现手写字母识别方式

训练的时候,还可以加入一些dropout,正则化,修改隐藏层神经元的个数,增加隐藏层数,可以自己添加。

以上这篇pytorch三层全连接层实现手写字母识别方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python的设计模式编程入门指南
Apr 02 Python
Python 多线程实例详解
Mar 25 Python
疯狂上涨的Python 开发者应从2.x还是3.x着手?
Nov 16 Python
python中的常量和变量代码详解
Jul 25 Python
python tkinter canvas 显示图片的示例
Jun 13 Python
Python imageio读取视频并进行编解码详解
Dec 10 Python
win10从零安装配置pytorch全过程图文详解
May 08 Python
Mysql数据库反向生成Django里面的models指令方式
May 18 Python
python 写函数在一定条件下需要调用自身时的写法说明
Jun 01 Python
Python configparser模块操作代码实例
Jun 08 Python
浅谈Python中的正则表达式
Jun 28 Python
7个关于Python的经典基础案例
Nov 07 Python
Python实现bilibili时间长度查询的示例代码
Jan 14 #Python
基于python监控程序是否关闭
Jan 14 #Python
关于pytorch中全连接神经网络搭建两种模式详解
Jan 14 #Python
使用Pytorch来拟合函数方式
Jan 14 #Python
pytorch 模拟关系拟合——回归实例
Jan 14 #Python
PyTorch实现AlexNet示例
Jan 14 #Python
Pytorch 实现focal_loss 多类别和二分类示例
Jan 14 #Python
You might like
php短域名转换为实际域名函数
2011/01/17 PHP
JavaScript中Object和Function的关系小结
2009/09/26 Javascript
传智播客学习之JavaScript基础篇
2009/11/13 Javascript
基于jQuery的消息提示插件之旅 DivAlert(三)
2010/04/01 Javascript
基于Jquery+Ajax+Json的高效分页实现代码
2011/10/29 Javascript
js获取class的所有元素
2013/03/28 Javascript
在JavaScript中处理时间之getHours()方法的使用
2015/06/10 Javascript
js实现的全国省市二级联动下拉选择菜单完整实例
2015/08/17 Javascript
js实现当前输入框高亮显示的方法
2015/08/19 Javascript
js addDqmForPP给标签内属性值加上双引号的函数
2016/12/24 Javascript
微信小程序实现图片上传功能
2018/05/28 Javascript
python实现定时同步本机与北京时间的方法
2015/03/24 Python
栈和队列数据结构的基本概念及其相关的Python实现
2015/08/24 Python
Python实现的科学计算器功能示例
2017/08/04 Python
基于Python List的赋值方法
2018/06/23 Python
利用Python如何制作好玩的GIF动图详解
2018/07/11 Python
python 地图经纬度转换、纠偏的实例代码
2018/08/06 Python
Django的CVB实例详解
2020/02/10 Python
python对XML文件的操作实现代码
2020/03/27 Python
python中scrapy处理项目数据的实例分析
2020/11/22 Python
python中使用.py配置文件的方法详解
2020/11/23 Python
HTML5中的nav标签学习笔记
2016/06/24 HTML / CSS
爱尔兰灯和灯具网上商店:Lights.ie
2018/03/26 全球购物
英国Boots旗下太阳镜网站:Boots Designer Sunglasses
2018/07/07 全球购物
Ticketmaster意大利:音乐会、节日、艺术和剧院的官方门票
2019/12/23 全球购物
信号量和自旋锁的区别?如何选择使用?
2015/09/08 面试题
教师校本培训方案
2014/02/26 职场文书
给学校的建议书
2014/03/12 职场文书
我为自己代言广告词
2014/03/18 职场文书
教师党员公开承诺事项
2014/05/28 职场文书
空气的环保标语
2014/06/12 职场文书
财产分割协议书范本
2014/11/03 职场文书
田径运动会广播稿
2015/08/19 职场文书
三好学生主要事迹怎么写
2015/11/03 职场文书
2016年父亲节寄语
2015/12/04 职场文书
2016春季田径运动会广播稿
2015/12/21 职场文书