Pytorch实现神经网络的分类方式


Posted in Python onJanuary 08, 2020

本文用于利用Pytorch实现神经网络的分类!!!

1.训练神经网络分类模型

import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.utils.data as Data
torch.manual_seed(1)#设置随机种子,使得每次生成的随机数是确定的
BATCH_SIZE = 5#设置batch size
 
#1.制作两类数据
n_data = torch.ones( 1000,2 )
x0 = torch.normal( 1.5*n_data, 1 )#均值为2 标准差为1
y0 = torch.zeros( 1000 )
 
x1 = torch.normal( -1.5*n_data,1 )#均值为-2 标准差为1
y1 = torch.ones( 1000 )
print("数据集维度:",x0.size(),y0.size())
 
#合并训练数据集,并转化数据类型为浮点型或整型
x = torch.cat( (x0,x1),0 ).type( torch.FloatTensor )
y = torch.cat( (y0,y1) ).type( torch.LongTensor )
print( "合并后的数据集维度:",x.data.size(), y.data.size() )
 
#当不使用batch size训练数据时,将Tensor放入Variable中
# x,y = Variable(x), Variable(y)
#绘制训练数据
# plt.scatter( x.data.numpy()[:,0], x.data.numpy()[:,1], c=y.data.numpy())
# plt.show()
 
#当使用batch size训练数据时,首先将tensor转化为Dataset格式
torch_dataset = Data.TensorDataset(x, y)
 
#将dataset放入DataLoader中
loader = Data.DataLoader(
 dataset=torch_dataset,
 batch_size = BATCH_SIZE,#设置batch size
 shuffle=True,#打乱数据
 num_workers=2#多线程读取数据
)
 
#2.前向传播过程
class Net(torch.nn.Module):#继承基类Module的属性和方法
 def __init__(self, input, hidden, output):
  super(Net, self).__init__()#继承__init__功能
  self.hidden = torch.nn.Linear(input, hidden)#隐层的线性输出
  self.out = torch.nn.Linear(hidden, output)#输出层线性输出
 def forward(self, x):
  x = F.relu(self.hidden(x))
  x = self.out(x)
  return x
 
# 训练模型的同时保存网络模型参数
def save():
 #3.利用自定义的前向传播过程设计网络,设置各层神经元数量
 # net = Net(input=2, hidden=10, output=2)
 # print("神经网络结构:",net)
 
 #3.快速搭建神经网络模型
 net = torch.nn.Sequential(
  torch.nn.Linear(2,10),#指定输入层和隐层结点,获得隐层线性输出
  torch.nn.ReLU(),#隐层非线性化
  torch.nn.Linear(10,2)#指定隐层和输出层结点,获得输出层线性输出
 )
 
 #4.设置优化算法、学习率
 # optimizer = torch.optim.SGD( net.parameters(), lr=0.2 )
 # optimizer = torch.optim.SGD( net.parameters(), lr=0.2, momentum=0.8 )
 # optimizer = torch.optim.RMSprop( net.parameters(), lr=0.2, alpha=0.9 )
 optimizer = torch.optim.Adam( net.parameters(), lr=0.2, betas=(0.9,0.99) )
 
 #5.设置损失函数
 loss_func = torch.nn.CrossEntropyLoss()
 
 plt.ion()#打开画布,可视化更新过程
 #6.迭代训练
 for epoch in range(2):
  for step, (batch_x, batch_y) in enumerate(loader):
   out = net(batch_x)#输入训练集,获得当前迭代输出值
   loss = loss_func(out, batch_y)#获得当前迭代的损失
 
   optimizer.zero_grad()#清除上次迭代的更新梯度
   loss.backward()#反向传播
   optimizer.step()#更新权重
 
   if step%200==0:
    plt.cla()#清空之前画布上的内容
    entire_out = net(x)#测试整个训练集
    #获得当前softmax层最大概率对应的索引值
    pred = torch.max(F.softmax(entire_out), 1)[1]
    #将二维压缩为一维
    pred_y = pred.data.numpy().squeeze()
    label_y = y.data.numpy()
    plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, cmap='RdYlGn')
    accuracy = sum(pred_y == label_y)/y.size()
    print("第 %d 个epoch,第 %d 次迭代,准确率为 %.2f"%(epoch+1, step/200+1, accuracy))
    #在指定位置添加文本
    plt.text(1.5, -4, 'Accuracy=%.2f' % accuracy, fontdict={'size': 15, 'color': 'red'})
    plt.pause(2)#图像显示时间
 
 #7.保存模型结构和参数
 torch.save(net, 'net.pkl')
 #7.只保存模型参数
 # torch.save(net.state_dict(), 'net_param.pkl')
 
 plt.ioff()#关闭画布
 plt.show()
 
if __name__ == '__main__':
 save()

2. 读取已训练好的模型测试数据

import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
import torch.nn.functional as F
 
#制作数据
n_data = torch.ones( 100,2 )
x0 = torch.normal( 1.5*n_data, 1 )#均值为2 标准差为1
y0 = torch.zeros( 100 )
 
x1 = torch.normal( -1.5*n_data,1 )#均值为-2 标准差为1
y1 = torch.ones( 100 )
print("数据集维度:",x0.size(),y0.size())
 
#合并训练数据集,并转化数据类型为浮点型或整型
x = torch.cat( (x0,x1),0 ).type( torch.FloatTensor )
y = torch.cat( (y0,y1) ).type( torch.LongTensor )
print( "合并后的数据集维度:",x.data.size(), y.data.size() )
 
#将Tensor放入Variable中
x,y = Variable(x), Variable(y)
 
#载入模型和参数
def restore_net():
 net = torch.load('net.pkl')
 #获得载入模型的预测输出
 pred = net(x)
 # 获得当前softmax层最大概率对应的索引值
 pred = torch.max(F.softmax(pred), 1)[1]
 # 将二维压缩为一维
 pred_y = pred.data.numpy().squeeze()
 label_y = y.data.numpy()
 accuracy = sum(pred_y == label_y) / y.size()
 print("准确率为:",accuracy)
 plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, cmap='RdYlGn')
 plt.show()
#仅载入模型参数,需要先创建网络模型
def restore_param():
 net = torch.nn.Sequential(
  torch.nn.Linear(2,10),#指定输入层和隐层结点,获得隐层线性输出
  torch.nn.ReLU(),#隐层非线性化
  torch.nn.Linear(10,2)#指定隐层和输出层结点,获得输出层线性输出
 )
 
 net.load_state_dict( torch.load('net_param.pkl') )
 #获得载入模型的预测输出
 pred = net(x)
 # 获得当前softmax层最大概率对应的索引值
 pred = torch.max(F.softmax(pred), 1)[1]
 # 将二维压缩为一维
 pred_y = pred.data.numpy().squeeze()
 label_y = y.data.numpy()
 accuracy = sum(pred_y == label_y) / y.size()
 print("准确率为:",accuracy)
 plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, cmap='RdYlGn')
 plt.show()
 
if __name__ =='__main__':
 # restore_net()
 restore_param()

以上这篇Pytorch实现神经网络的分类方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Windows和Linux下使用Python访问SqlServer的方法介绍
Mar 10 Python
python执行等待程序直到第二天零点的方法
Apr 23 Python
python2.7 mayavi 安装图文教程(推荐)
Jun 22 Python
windows下安装python的C扩展编译环境(解决Unable to find vcvarsall.bat)
Feb 21 Python
使用PyInstaller将python转成可执行文件exe笔记
May 26 Python
使用python模拟命令行终端的示例
Aug 13 Python
Python csv模块使用方法代码实例
Aug 29 Python
python利用openpyxl拆分多个工作表的工作簿的方法
Sep 27 Python
Python 内置函数globals()和locals()对比详解
Dec 23 Python
Python基于Hypothesis测试库生成测试数据
Apr 29 Python
python 对一幅灰度图像进行直方图均衡化
Oct 27 Python
OpenCV灰度化之后图片为绿色的解决
Dec 01 Python
python 爬取古诗文存入mysql数据库的方法
Jan 08 #Python
基于python3抓取pinpoint应用信息入库
Jan 08 #Python
Python PyInstaller安装和使用教程详解
Jan 08 #Python
关于Pytorch的MLP模块实现方式
Jan 07 #Python
PyTorch 普通卷积和空洞卷积实例
Jan 07 #Python
Pytorch中膨胀卷积的用法详解
Jan 07 #Python
Python urlopen()和urlretrieve()用法解析
Jan 07 #Python
You might like
php 连接mssql数据库 初学php笔记
2010/03/01 PHP
PHP过滤★等特殊符号的正则
2014/01/27 PHP
php中随机函数mt_rand()与rand()性能对比分析
2014/12/01 PHP
PHP常用处理静态操作类
2015/04/03 PHP
javascript tips提示框组件实现代码
2010/11/19 Javascript
基于jquery的动态创建表格的插件
2011/04/05 Javascript
jQuery.query.js 取参数的两点问题分析
2012/08/06 Javascript
顶部缓冲下拉菜单导航特效的JS代码
2013/08/27 Javascript
javascript获取元素离文档各边距离的方法
2015/02/13 Javascript
jQuery实现数秒后自动提交form的方法
2015/03/05 Javascript
实现无刷新联动例子汇总
2015/05/20 Javascript
Javascript中Array用法实例分析
2015/06/13 Javascript
jquery插件jquery.LightBox.js实现点击放大图片并左右点击切换效果(附demo源码下载)
2016/02/25 Javascript
js中通过getElementsByName访问name集合对象的方法
2016/10/31 Javascript
js移动焦点到最后位置的简单方法
2016/11/25 Javascript
JS正则替换去空格的方法
2017/03/24 Javascript
实例解析ES6 Proxy使用场景介绍
2018/01/08 Javascript
详解使用angular框架离线你的应用(pwa指南)
2019/01/31 Javascript
JavaScript 俄罗斯方块游戏实现方法与代码解释
2020/04/08 Javascript
小程序实现图片移动缩放效果
2020/05/26 Javascript
Vue组件为什么data必须是一个函数
2020/06/11 Javascript
python静态方法实例
2015/01/14 Python
使用pyecharts无法import Bar的解决方案
2020/04/23 Python
Python中使用haystack实现django全文检索搜索引擎功能
2017/08/26 Python
调用其他python脚本文件里面的类和方法过程解析
2019/11/15 Python
Python实现检测文件的MD5值来查找重复文件案例
2020/03/12 Python
如何基于python实现不邻接植花
2020/05/01 Python
读取nii或nii.gz文件中的信息即输出图像操作
2020/07/01 Python
基于python图书馆管理系统设计实例详解
2020/08/05 Python
俄罗斯最大的隐形眼镜销售网站:Ochkov.Net
2021/02/07 全球购物
儿子婚宴答谢词
2014/01/09 职场文书
应届本科毕业生求职信
2014/07/23 职场文书
离退休人员聘用协议书
2014/11/24 职场文书
表扬稿格式范文
2015/01/16 职场文书
机关干部作风整顿心得体会
2016/01/22 职场文书
SpringBoot接入钉钉自定义机器人预警通知
2022/07/15 Java/Android