Pytorch实现神经网络的分类方式


Posted in Python onJanuary 08, 2020

本文用于利用Pytorch实现神经网络的分类!!!

1.训练神经网络分类模型

import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.utils.data as Data
torch.manual_seed(1)#设置随机种子,使得每次生成的随机数是确定的
BATCH_SIZE = 5#设置batch size
 
#1.制作两类数据
n_data = torch.ones( 1000,2 )
x0 = torch.normal( 1.5*n_data, 1 )#均值为2 标准差为1
y0 = torch.zeros( 1000 )
 
x1 = torch.normal( -1.5*n_data,1 )#均值为-2 标准差为1
y1 = torch.ones( 1000 )
print("数据集维度:",x0.size(),y0.size())
 
#合并训练数据集,并转化数据类型为浮点型或整型
x = torch.cat( (x0,x1),0 ).type( torch.FloatTensor )
y = torch.cat( (y0,y1) ).type( torch.LongTensor )
print( "合并后的数据集维度:",x.data.size(), y.data.size() )
 
#当不使用batch size训练数据时,将Tensor放入Variable中
# x,y = Variable(x), Variable(y)
#绘制训练数据
# plt.scatter( x.data.numpy()[:,0], x.data.numpy()[:,1], c=y.data.numpy())
# plt.show()
 
#当使用batch size训练数据时,首先将tensor转化为Dataset格式
torch_dataset = Data.TensorDataset(x, y)
 
#将dataset放入DataLoader中
loader = Data.DataLoader(
 dataset=torch_dataset,
 batch_size = BATCH_SIZE,#设置batch size
 shuffle=True,#打乱数据
 num_workers=2#多线程读取数据
)
 
#2.前向传播过程
class Net(torch.nn.Module):#继承基类Module的属性和方法
 def __init__(self, input, hidden, output):
  super(Net, self).__init__()#继承__init__功能
  self.hidden = torch.nn.Linear(input, hidden)#隐层的线性输出
  self.out = torch.nn.Linear(hidden, output)#输出层线性输出
 def forward(self, x):
  x = F.relu(self.hidden(x))
  x = self.out(x)
  return x
 
# 训练模型的同时保存网络模型参数
def save():
 #3.利用自定义的前向传播过程设计网络,设置各层神经元数量
 # net = Net(input=2, hidden=10, output=2)
 # print("神经网络结构:",net)
 
 #3.快速搭建神经网络模型
 net = torch.nn.Sequential(
  torch.nn.Linear(2,10),#指定输入层和隐层结点,获得隐层线性输出
  torch.nn.ReLU(),#隐层非线性化
  torch.nn.Linear(10,2)#指定隐层和输出层结点,获得输出层线性输出
 )
 
 #4.设置优化算法、学习率
 # optimizer = torch.optim.SGD( net.parameters(), lr=0.2 )
 # optimizer = torch.optim.SGD( net.parameters(), lr=0.2, momentum=0.8 )
 # optimizer = torch.optim.RMSprop( net.parameters(), lr=0.2, alpha=0.9 )
 optimizer = torch.optim.Adam( net.parameters(), lr=0.2, betas=(0.9,0.99) )
 
 #5.设置损失函数
 loss_func = torch.nn.CrossEntropyLoss()
 
 plt.ion()#打开画布,可视化更新过程
 #6.迭代训练
 for epoch in range(2):
  for step, (batch_x, batch_y) in enumerate(loader):
   out = net(batch_x)#输入训练集,获得当前迭代输出值
   loss = loss_func(out, batch_y)#获得当前迭代的损失
 
   optimizer.zero_grad()#清除上次迭代的更新梯度
   loss.backward()#反向传播
   optimizer.step()#更新权重
 
   if step%200==0:
    plt.cla()#清空之前画布上的内容
    entire_out = net(x)#测试整个训练集
    #获得当前softmax层最大概率对应的索引值
    pred = torch.max(F.softmax(entire_out), 1)[1]
    #将二维压缩为一维
    pred_y = pred.data.numpy().squeeze()
    label_y = y.data.numpy()
    plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, cmap='RdYlGn')
    accuracy = sum(pred_y == label_y)/y.size()
    print("第 %d 个epoch,第 %d 次迭代,准确率为 %.2f"%(epoch+1, step/200+1, accuracy))
    #在指定位置添加文本
    plt.text(1.5, -4, 'Accuracy=%.2f' % accuracy, fontdict={'size': 15, 'color': 'red'})
    plt.pause(2)#图像显示时间
 
 #7.保存模型结构和参数
 torch.save(net, 'net.pkl')
 #7.只保存模型参数
 # torch.save(net.state_dict(), 'net_param.pkl')
 
 plt.ioff()#关闭画布
 plt.show()
 
if __name__ == '__main__':
 save()

2. 读取已训练好的模型测试数据

import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
import torch.nn.functional as F
 
#制作数据
n_data = torch.ones( 100,2 )
x0 = torch.normal( 1.5*n_data, 1 )#均值为2 标准差为1
y0 = torch.zeros( 100 )
 
x1 = torch.normal( -1.5*n_data,1 )#均值为-2 标准差为1
y1 = torch.ones( 100 )
print("数据集维度:",x0.size(),y0.size())
 
#合并训练数据集,并转化数据类型为浮点型或整型
x = torch.cat( (x0,x1),0 ).type( torch.FloatTensor )
y = torch.cat( (y0,y1) ).type( torch.LongTensor )
print( "合并后的数据集维度:",x.data.size(), y.data.size() )
 
#将Tensor放入Variable中
x,y = Variable(x), Variable(y)
 
#载入模型和参数
def restore_net():
 net = torch.load('net.pkl')
 #获得载入模型的预测输出
 pred = net(x)
 # 获得当前softmax层最大概率对应的索引值
 pred = torch.max(F.softmax(pred), 1)[1]
 # 将二维压缩为一维
 pred_y = pred.data.numpy().squeeze()
 label_y = y.data.numpy()
 accuracy = sum(pred_y == label_y) / y.size()
 print("准确率为:",accuracy)
 plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, cmap='RdYlGn')
 plt.show()
#仅载入模型参数,需要先创建网络模型
def restore_param():
 net = torch.nn.Sequential(
  torch.nn.Linear(2,10),#指定输入层和隐层结点,获得隐层线性输出
  torch.nn.ReLU(),#隐层非线性化
  torch.nn.Linear(10,2)#指定隐层和输出层结点,获得输出层线性输出
 )
 
 net.load_state_dict( torch.load('net_param.pkl') )
 #获得载入模型的预测输出
 pred = net(x)
 # 获得当前softmax层最大概率对应的索引值
 pred = torch.max(F.softmax(pred), 1)[1]
 # 将二维压缩为一维
 pred_y = pred.data.numpy().squeeze()
 label_y = y.data.numpy()
 accuracy = sum(pred_y == label_y) / y.size()
 print("准确率为:",accuracy)
 plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, cmap='RdYlGn')
 plt.show()
 
if __name__ =='__main__':
 # restore_net()
 restore_param()

以上这篇Pytorch实现神经网络的分类方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python3 模拟登录v2ex实例讲解
Jul 13 Python
代码分析Python地图坐标转换
Feb 08 Python
python实现Adapter模式实例代码
Feb 09 Python
基于numpy.random.randn()与rand()的区别详解
Apr 17 Python
python实现人工智能Ai抠图功能
Sep 05 Python
python中必要的名词解释
Nov 20 Python
Python 类的魔法属性用法实例分析
Nov 21 Python
TensorFlow2.X结合OpenCV 实现手势识别功能
Apr 08 Python
Python基于Tkinter编写crc校验工具
May 06 Python
Python 如何对文件目录操作
Jul 10 Python
python爬虫用mongodb的理由
Jul 28 Python
解决python和pycharm安装gmpy2 出现ERROR的问题
Aug 28 Python
python 爬取古诗文存入mysql数据库的方法
Jan 08 #Python
基于python3抓取pinpoint应用信息入库
Jan 08 #Python
Python PyInstaller安装和使用教程详解
Jan 08 #Python
关于Pytorch的MLP模块实现方式
Jan 07 #Python
PyTorch 普通卷积和空洞卷积实例
Jan 07 #Python
Pytorch中膨胀卷积的用法详解
Jan 07 #Python
Python urlopen()和urlretrieve()用法解析
Jan 07 #Python
You might like
php过滤敏感词的示例
2014/03/31 PHP
php设计模式之简单工厂模式详解
2014/09/04 PHP
在TP5数据库中四个字段实现无限分类的示例
2019/10/18 PHP
laravel excel 上传文件保存到本地服务器功能
2019/11/14 PHP
Javascript的IE和Firefox兼容性汇编(zz)
2007/02/02 Javascript
Ajax局部更新导致JS事件重复触发问题的解决方法
2014/10/14 Javascript
Jquery修改页面标题title其它JS失效的解决方法
2014/10/31 Javascript
jQuery中[attribute*=value]选择器用法实例
2014/12/31 Javascript
JavaScript删除数组元素的方法
2015/03/20 Javascript
javascript中sort排序实例详解
2016/07/24 Javascript
window.open打开窗口被拦截的快速解决方法
2016/08/04 Javascript
AngularJS开发教程之控制器之间的通信方法分析
2016/12/25 Javascript
vue组件实例解析
2017/01/10 Javascript
js实现鼠标左右移动,图片也跟着移动效果
2017/01/25 Javascript
vue地区选择组件教程详解
2018/05/04 Javascript
详解vue中组件参数
2018/07/09 Javascript
Node.js中Koa2在控制台输出请求日志的方法示例
2019/05/02 Javascript
vue 中使用 watch 出现了如下的报错的原因分析
2019/05/21 Javascript
[03:14]DOTA2斧王 英雄基础教程
2013/11/26 DOTA
python实现的二叉树算法和kmp算法实例
2014/04/25 Python
Python返回真假值(True or False)小技巧
2015/04/10 Python
python简单实现旋转图片的方法
2015/05/30 Python
Python的Flask框架应用调用Redis队列数据的方法
2016/06/06 Python
结合Python的SimpleHTTPServer源码来解析socket通信
2016/06/27 Python
python设置环境变量的作用和实例
2019/07/09 Python
Python利用FFT进行简单滤波的实现
2020/02/26 Python
运动鞋中的劳斯莱斯:索康尼(SAUCONY)
2017/08/09 全球购物
台湾屈臣氏网路商店:Watsons台湾
2020/12/29 全球购物
党员教师四风自我剖析材料
2014/09/30 职场文书
政风行风建设整改方案
2014/10/27 职场文书
党员反腐倡廉学习心得体会
2015/08/15 职场文书
求职自我评价参考范文
2019/05/16 职场文书
家长必看:义务教育,不得以面试 评测等名义选拔学生
2019/07/09 职场文书
Python实现byte转integer
2021/06/03 Python
SQL SERVER中的流程控制语句
2022/05/25 SQL Server
Python如何利用pandas读取csv数据并绘图
2022/07/07 Python