pytorch三层全连接层实现手写字母识别方式


Posted in Python onJanuary 14, 2020

先用最简单的三层全连接神经网络,然后添加激活层查看实验结果,最后加上批标准化验证是否有效

首先根据已有的模板定义网络结构SimpleNet,命名为net.py

import torch
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
from torch import nn,optim
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
#定义三层全连接神经网络
class simpleNet(nn.Module):
 def __init__(self,in_dim,n_hidden_1,n_hidden_2,out_dim):#输入维度,第一层的神经元个数、第二层的神经元个数,以及第三层的神经元个数
  super(simpleNet,self).__init__()
  self.layer1=nn.Linear(in_dim,n_hidden_1)
  self.layer2=nn.Linear(n_hidden_1,n_hidden_2)
  self.layer3=nn.Linear(n_hidden_2,out_dim)
 def forward(self,x):
  x=self.layer1(x)
  x=self.layer2(x)
  x=self.layer3(x)
  return x
 
 
#添加激活函数
class Activation_Net(nn.Module):
 def __init__(self,in_dim,n_hidden_1,n_hidden_2,out_dim):
  super(NeutalNetwork,self).__init__()
  self.layer1=nn.Sequential(#Sequential组合结构
  nn.Linear(in_dim,n_hidden_1),nn.ReLU(True))
  self.layer2=nn.Sequential(
  nn.Linear(n_hidden_1,n_hidden_2),nn.ReLU(True))
  self.layer3=nn.Sequential(
  nn.Linear(n_hidden_2,out_dim))
 def forward(self,x):
  x=self.layer1(x)
  x=self.layer2(x)
  x=self.layer3(x)
  return x
#添加批标准化处理模块,皮标准化放在全连接的后面,非线性的前面
class Batch_Net(nn.Module):
 def _init__(self,in_dim,n_hidden_1,n_hidden_2,out_dim):
  super(Batch_net,self).__init__()
  self.layer1=nn.Sequential(nn.Linear(in_dim,n_hidden_1),nn.BatchNormld(n_hidden_1),nn.ReLU(True))
  self.layer2=nn.Sequential(nn.Linear(n_hidden_1,n_hidden_2),nn.BatchNormld(n_hidden_2),nn.ReLU(True))
  self.layer3=nn.Sequential(nn.Linear(n_hidden_2,out_dim))
 def forword(self,x):
  x=self.layer1(x)
  x=self.layer2(x)
  x=self.layer3(x)
  return x

训练网络,

import torch
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from torch import nn,optim
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
#定义一些超参数
import net
batch_size=64
learning_rate=1e-2
num_epoches=20
#预处理
data_tf=transforms.Compose(
[transforms.ToTensor(),transforms.Normalize([0.5],[0.5])])#将图像转化成tensor,然后继续标准化,就是减均值,除以方差

#读取数据集
train_dataset=datasets.MNIST(root='./data',train=True,transform=data_tf,download=True)
test_dataset=datasets.MNIST(root='./data',train=False,transform=data_tf)
#使用内置的函数导入数据集
train_loader=DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
test_loader=DataLoader(test_dataset,batch_size=batch_size,shuffle=False)

#导入网络,定义损失函数和优化方法
model=net.simpleNet(28*28,300,100,10)
if torch.cuda.is_available():#是否使用cuda加速
 model=model.cuda()
criterion=nn.CrossEntropyLoss()
optimizer=optim.SGD(model.parameters(),lr=learning_rate)
import net
n_epochs=5
for epoch in range(n_epochs):
 running_loss=0.0
 running_correct=0
 print("epoch {}/{}".format(epoch,n_epochs))
 print("-"*10)
 for data in train_loader:
  img,label=data
  img=img.view(img.size(0),-1)
  if torch.cuda.is_available():
   img=img.cuda()
   label=label.cuda()
  else:
   img=Variable(img)
   label=Variable(label)
  out=model(img)#得到前向传播的结果
  loss=criterion(out,label)#得到损失函数
  print_loss=loss.data.item()
  optimizer.zero_grad()#归0梯度
  loss.backward()#反向传播
  optimizer.step()#优化
  running_loss+=loss.item()
  epoch+=1
  if epoch%50==0:
   print('epoch:{},loss:{:.4f}'.format(epoch,loss.data.item()))

训练的结果截图如下:

pytorch三层全连接层实现手写字母识别方式

测试网络

#测试网络
model.eval()#将模型变成测试模式
eval_loss=0
eval_acc=0
for data in test_loader:
 img,label=data
 img=img.view(img.size(0),-1)#测试集不需要反向传播,所以可以在前项传播的时候释放内存,节约内存空间
 if torch.cuda.is_available():
  img=Variable(img,volatile=True).cuda()
  label=Variable(label,volatile=True).cuda()
 else:
  img=Variable(img,volatile=True)
  label=Variable(label,volatile=True)
 out=model(img)
 loss=criterion(out,label)
 eval_loss+=loss.item()*label.size(0)
 _,pred=torch.max(out,1)
 num_correct=(pred==label).sum()
 eval_acc+=num_correct.item()
print('test loss:{:.6f},ac:{:.6f}'.format(eval_loss/(len(test_dataset)),eval_acc/(len(test_dataset))))

pytorch三层全连接层实现手写字母识别方式

训练的时候,还可以加入一些dropout,正则化,修改隐藏层神经元的个数,增加隐藏层数,可以自己添加。

以上这篇pytorch三层全连接层实现手写字母识别方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python下如何让web元素的生成更简单的分析
Jul 17 Python
python脚本实现xls(xlsx)转成csv
Apr 10 Python
Python实现一个Git日志统计分析的小工具
Dec 14 Python
Python构建网页爬虫原理分析
Dec 19 Python
详细解读tornado协程(coroutine)原理
Jan 15 Python
python模块之paramiko实例代码
Jan 31 Python
python使用KNN算法手写体识别
Feb 01 Python
python中logging模块的一些简单用法的使用
Feb 22 Python
pandas按行按列遍历Dataframe的几种方式
Oct 23 Python
使用IDLE的Python shell窗口实例详解
Nov 19 Python
Python3开发环境搭建详细教程
Jun 18 Python
Python如何快速找到多个字典中的公共键(key)
Apr 29 Python
Python实现bilibili时间长度查询的示例代码
Jan 14 #Python
基于python监控程序是否关闭
Jan 14 #Python
关于pytorch中全连接神经网络搭建两种模式详解
Jan 14 #Python
使用Pytorch来拟合函数方式
Jan 14 #Python
pytorch 模拟关系拟合——回归实例
Jan 14 #Python
PyTorch实现AlexNet示例
Jan 14 #Python
Pytorch 实现focal_loss 多类别和二分类示例
Jan 14 #Python
You might like
php cookie的操作实现代码(登录)
2010/12/29 PHP
PHP模板引擎Smarty中变量的使用方法示例
2016/04/11 PHP
php中访问修饰符的知识点总结
2019/01/27 PHP
Laravel5.0+框架邮件发送功能实现方法图文与实例详解
2019/04/23 PHP
AutoSave/自动存储功能实现
2007/03/24 Javascript
js 判断浏览器类型 去全角、半角空格 自动关闭当前窗口
2009/04/10 Javascript
js防止表单重复提交实现代码
2012/09/05 Javascript
如何让浏览器支持jquery ajax load 前进、后退功能
2014/06/12 Javascript
setTimeout()递归调用不加引号出错的解决方法
2014/09/05 Javascript
跟我学习javascript的this关键字
2020/05/28 Javascript
JavaScript提高网站性能优化的建议(二)
2016/07/24 Javascript
jQuery中$.ajax()方法参数解析
2016/10/22 Javascript
微信小程序 form组件详解及简单实例
2017/01/10 Javascript
bootstrap table表格使用方法详解
2017/04/26 Javascript
Angular4学习笔记之根模块与Ng模块
2017/09/09 Javascript
深入浅析Vue.js中 computed和methods不同机制
2018/03/22 Javascript
基于vue,vue-router, vuex及addRoutes进行权限控制问题
2018/05/02 Javascript
详解javascript中的变量提升和函数提升
2018/05/24 Javascript
解决Layui数据表格中checkbox位置不居中的方法
2018/08/15 Javascript
JS如何实现在弹出窗口中加载页面
2020/12/03 Javascript
Python+Wordpress制作小说站
2017/04/14 Python
Numpy数据类型转换astype,dtype的方法
2018/06/09 Python
centos 安装Python3 及对应的pip教程详解
2019/06/28 Python
Python 实现Numpy中找出array中最大值所对应的行和列
2019/11/26 Python
Django高并发负载均衡实现原理详解
2020/04/04 Python
HTML5播放实现rtmp流直播
2020/06/16 HTML / CSS
HTML5实现移动端点击翻牌功能
2020/10/23 HTML / CSS
印尼购物网站:iLOTTE
2019/10/16 全球购物
迟到检讨书大全
2014/01/25 职场文书
廉洁自律演讲稿
2014/05/22 职场文书
2015年采购员工作总结
2015/04/27 职场文书
党支部培养考察意见
2015/06/02 职场文书
春节慰问简报
2015/07/21 职场文书
公安忠诚教育心得体会
2016/01/23 职场文书
2019年感恩励志演讲稿(收藏备用)
2019/09/11 职场文书
导游词之天津古文化街
2019/11/09 职场文书