Pytorch实现神经网络的分类方式


Posted in Python onJanuary 08, 2020

本文用于利用Pytorch实现神经网络的分类!!!

1.训练神经网络分类模型

import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.utils.data as Data
torch.manual_seed(1)#设置随机种子,使得每次生成的随机数是确定的
BATCH_SIZE = 5#设置batch size
 
#1.制作两类数据
n_data = torch.ones( 1000,2 )
x0 = torch.normal( 1.5*n_data, 1 )#均值为2 标准差为1
y0 = torch.zeros( 1000 )
 
x1 = torch.normal( -1.5*n_data,1 )#均值为-2 标准差为1
y1 = torch.ones( 1000 )
print("数据集维度:",x0.size(),y0.size())
 
#合并训练数据集,并转化数据类型为浮点型或整型
x = torch.cat( (x0,x1),0 ).type( torch.FloatTensor )
y = torch.cat( (y0,y1) ).type( torch.LongTensor )
print( "合并后的数据集维度:",x.data.size(), y.data.size() )
 
#当不使用batch size训练数据时,将Tensor放入Variable中
# x,y = Variable(x), Variable(y)
#绘制训练数据
# plt.scatter( x.data.numpy()[:,0], x.data.numpy()[:,1], c=y.data.numpy())
# plt.show()
 
#当使用batch size训练数据时,首先将tensor转化为Dataset格式
torch_dataset = Data.TensorDataset(x, y)
 
#将dataset放入DataLoader中
loader = Data.DataLoader(
 dataset=torch_dataset,
 batch_size = BATCH_SIZE,#设置batch size
 shuffle=True,#打乱数据
 num_workers=2#多线程读取数据
)
 
#2.前向传播过程
class Net(torch.nn.Module):#继承基类Module的属性和方法
 def __init__(self, input, hidden, output):
  super(Net, self).__init__()#继承__init__功能
  self.hidden = torch.nn.Linear(input, hidden)#隐层的线性输出
  self.out = torch.nn.Linear(hidden, output)#输出层线性输出
 def forward(self, x):
  x = F.relu(self.hidden(x))
  x = self.out(x)
  return x
 
# 训练模型的同时保存网络模型参数
def save():
 #3.利用自定义的前向传播过程设计网络,设置各层神经元数量
 # net = Net(input=2, hidden=10, output=2)
 # print("神经网络结构:",net)
 
 #3.快速搭建神经网络模型
 net = torch.nn.Sequential(
  torch.nn.Linear(2,10),#指定输入层和隐层结点,获得隐层线性输出
  torch.nn.ReLU(),#隐层非线性化
  torch.nn.Linear(10,2)#指定隐层和输出层结点,获得输出层线性输出
 )
 
 #4.设置优化算法、学习率
 # optimizer = torch.optim.SGD( net.parameters(), lr=0.2 )
 # optimizer = torch.optim.SGD( net.parameters(), lr=0.2, momentum=0.8 )
 # optimizer = torch.optim.RMSprop( net.parameters(), lr=0.2, alpha=0.9 )
 optimizer = torch.optim.Adam( net.parameters(), lr=0.2, betas=(0.9,0.99) )
 
 #5.设置损失函数
 loss_func = torch.nn.CrossEntropyLoss()
 
 plt.ion()#打开画布,可视化更新过程
 #6.迭代训练
 for epoch in range(2):
  for step, (batch_x, batch_y) in enumerate(loader):
   out = net(batch_x)#输入训练集,获得当前迭代输出值
   loss = loss_func(out, batch_y)#获得当前迭代的损失
 
   optimizer.zero_grad()#清除上次迭代的更新梯度
   loss.backward()#反向传播
   optimizer.step()#更新权重
 
   if step%200==0:
    plt.cla()#清空之前画布上的内容
    entire_out = net(x)#测试整个训练集
    #获得当前softmax层最大概率对应的索引值
    pred = torch.max(F.softmax(entire_out), 1)[1]
    #将二维压缩为一维
    pred_y = pred.data.numpy().squeeze()
    label_y = y.data.numpy()
    plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, cmap='RdYlGn')
    accuracy = sum(pred_y == label_y)/y.size()
    print("第 %d 个epoch,第 %d 次迭代,准确率为 %.2f"%(epoch+1, step/200+1, accuracy))
    #在指定位置添加文本
    plt.text(1.5, -4, 'Accuracy=%.2f' % accuracy, fontdict={'size': 15, 'color': 'red'})
    plt.pause(2)#图像显示时间
 
 #7.保存模型结构和参数
 torch.save(net, 'net.pkl')
 #7.只保存模型参数
 # torch.save(net.state_dict(), 'net_param.pkl')
 
 plt.ioff()#关闭画布
 plt.show()
 
if __name__ == '__main__':
 save()

2. 读取已训练好的模型测试数据

import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
import torch.nn.functional as F
 
#制作数据
n_data = torch.ones( 100,2 )
x0 = torch.normal( 1.5*n_data, 1 )#均值为2 标准差为1
y0 = torch.zeros( 100 )
 
x1 = torch.normal( -1.5*n_data,1 )#均值为-2 标准差为1
y1 = torch.ones( 100 )
print("数据集维度:",x0.size(),y0.size())
 
#合并训练数据集,并转化数据类型为浮点型或整型
x = torch.cat( (x0,x1),0 ).type( torch.FloatTensor )
y = torch.cat( (y0,y1) ).type( torch.LongTensor )
print( "合并后的数据集维度:",x.data.size(), y.data.size() )
 
#将Tensor放入Variable中
x,y = Variable(x), Variable(y)
 
#载入模型和参数
def restore_net():
 net = torch.load('net.pkl')
 #获得载入模型的预测输出
 pred = net(x)
 # 获得当前softmax层最大概率对应的索引值
 pred = torch.max(F.softmax(pred), 1)[1]
 # 将二维压缩为一维
 pred_y = pred.data.numpy().squeeze()
 label_y = y.data.numpy()
 accuracy = sum(pred_y == label_y) / y.size()
 print("准确率为:",accuracy)
 plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, cmap='RdYlGn')
 plt.show()
#仅载入模型参数,需要先创建网络模型
def restore_param():
 net = torch.nn.Sequential(
  torch.nn.Linear(2,10),#指定输入层和隐层结点,获得隐层线性输出
  torch.nn.ReLU(),#隐层非线性化
  torch.nn.Linear(10,2)#指定隐层和输出层结点,获得输出层线性输出
 )
 
 net.load_state_dict( torch.load('net_param.pkl') )
 #获得载入模型的预测输出
 pred = net(x)
 # 获得当前softmax层最大概率对应的索引值
 pred = torch.max(F.softmax(pred), 1)[1]
 # 将二维压缩为一维
 pred_y = pred.data.numpy().squeeze()
 label_y = y.data.numpy()
 accuracy = sum(pred_y == label_y) / y.size()
 print("准确率为:",accuracy)
 plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, cmap='RdYlGn')
 plt.show()
 
if __name__ =='__main__':
 # restore_net()
 restore_param()

以上这篇Pytorch实现神经网络的分类方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中解析json格式文件的方法示例
May 03 Python
python贪婪匹配以及多行匹配的实例讲解
Apr 19 Python
python操作excel的包(openpyxl、xlsxwriter)
Jun 11 Python
python 自动重连wifi windows的方法
Dec 18 Python
python字符串和常用数据结构知识总结
May 21 Python
python挖矿算力测试程序详解
Jul 03 Python
Django调用百度AI接口实现人脸注册登录代码实例
Apr 23 Python
python logging.info在终端没输出的解决
May 12 Python
新手常见Python错误及异常解决处理方案
Jun 18 Python
利用Vscode进行Python开发环境配置的步骤
Jun 22 Python
python实现学生信息管理系统(精简版)
Nov 27 Python
Python 数据结构之十大经典排序算法一文通关
Oct 16 Python
python 爬取古诗文存入mysql数据库的方法
Jan 08 #Python
基于python3抓取pinpoint应用信息入库
Jan 08 #Python
Python PyInstaller安装和使用教程详解
Jan 08 #Python
关于Pytorch的MLP模块实现方式
Jan 07 #Python
PyTorch 普通卷积和空洞卷积实例
Jan 07 #Python
Pytorch中膨胀卷积的用法详解
Jan 07 #Python
Python urlopen()和urlretrieve()用法解析
Jan 07 #Python
You might like
几款免费开源的不用数据库的php的cms
2010/12/19 PHP
PHP实现的交通银行网银在线支付接口ECSHOP插件和使用例子
2014/05/10 PHP
网页常用特效代码整理
2006/06/23 Javascript
IE中直接运行显示当前网页中的图片 推荐
2006/08/31 Javascript
批量实现面向对象的实例代码
2013/07/01 Javascript
在Iframe中获取父窗口中表单的值(示例代码)
2013/11/22 Javascript
Javascript实现带关闭按钮的网页漂浮广告代码
2014/01/12 Javascript
基于NodeJS的前后端分离的思考与实践(四)安全问题解决方案
2014/09/26 NodeJs
jQuery中:focus选择器用法实例
2014/12/30 Javascript
JavaScript日期类型的一些用法介绍
2015/03/02 Javascript
JavaScript控制按钮可用或不可用的方法
2015/04/03 Javascript
jQuery实现选中弹出窗口选择框内容后赋值给文本框的方法
2015/11/23 Javascript
利用iscroll4实现轮播图效果实例代码
2017/01/11 Javascript
微信小程序实现动态获取元素宽高的方法分析
2018/12/10 Javascript
微信小程序开发问题之wx.previewImage
2018/12/25 Javascript
node.js微信小程序配置消息推送的实现
2019/02/13 Javascript
express框架下使用session的方法
2019/07/31 Javascript
react quill中图片上传由默认转成base64改成上传到服务器的方法
2019/10/30 Javascript
Nodejs + Websocket 指定发送及群聊的实现
2020/01/09 NodeJs
vue列表数据发生变化指令没有更新问题及解决方法
2020/01/16 Javascript
原生js实现瀑布流效果
2020/03/09 Javascript
微信小程序scroll-view隐藏滚动条的方法详解
2020/03/25 Javascript
vscode 插件开发 + vue的操作方法
2020/06/05 Javascript
ant design vue嵌套表格及表格内部编辑的用法说明
2020/10/28 Javascript
Python求离散序列导数的示例
2019/07/10 Python
详解python中index()、find()方法
2019/08/29 Python
python打包成so文件过程解析
2019/09/28 Python
python如何爬取动态网站
2020/09/09 Python
Sentry错误日志监控使用方法解析
2020/11/12 Python
Python3.8.2安装包及安装教程图文详解(附安装包)
2020/11/28 Python
香港中原电器网上商店:Chung Yuen
2019/06/26 全球购物
平面设计师工作职责范文
2013/12/03 职场文书
人力资源本科毕业生求职信
2014/06/04 职场文书
Redis5之后版本的高可用集群搭建的实现
2021/04/27 Redis
如何用Navicat操作MySQL
2021/05/12 MySQL
nginx搭建NFS网络文件系统
2022/04/14 Servers