Pytorch实现神经网络的分类方式


Posted in Python onJanuary 08, 2020

本文用于利用Pytorch实现神经网络的分类!!!

1.训练神经网络分类模型

import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.utils.data as Data
torch.manual_seed(1)#设置随机种子,使得每次生成的随机数是确定的
BATCH_SIZE = 5#设置batch size
 
#1.制作两类数据
n_data = torch.ones( 1000,2 )
x0 = torch.normal( 1.5*n_data, 1 )#均值为2 标准差为1
y0 = torch.zeros( 1000 )
 
x1 = torch.normal( -1.5*n_data,1 )#均值为-2 标准差为1
y1 = torch.ones( 1000 )
print("数据集维度:",x0.size(),y0.size())
 
#合并训练数据集,并转化数据类型为浮点型或整型
x = torch.cat( (x0,x1),0 ).type( torch.FloatTensor )
y = torch.cat( (y0,y1) ).type( torch.LongTensor )
print( "合并后的数据集维度:",x.data.size(), y.data.size() )
 
#当不使用batch size训练数据时,将Tensor放入Variable中
# x,y = Variable(x), Variable(y)
#绘制训练数据
# plt.scatter( x.data.numpy()[:,0], x.data.numpy()[:,1], c=y.data.numpy())
# plt.show()
 
#当使用batch size训练数据时,首先将tensor转化为Dataset格式
torch_dataset = Data.TensorDataset(x, y)
 
#将dataset放入DataLoader中
loader = Data.DataLoader(
 dataset=torch_dataset,
 batch_size = BATCH_SIZE,#设置batch size
 shuffle=True,#打乱数据
 num_workers=2#多线程读取数据
)
 
#2.前向传播过程
class Net(torch.nn.Module):#继承基类Module的属性和方法
 def __init__(self, input, hidden, output):
  super(Net, self).__init__()#继承__init__功能
  self.hidden = torch.nn.Linear(input, hidden)#隐层的线性输出
  self.out = torch.nn.Linear(hidden, output)#输出层线性输出
 def forward(self, x):
  x = F.relu(self.hidden(x))
  x = self.out(x)
  return x
 
# 训练模型的同时保存网络模型参数
def save():
 #3.利用自定义的前向传播过程设计网络,设置各层神经元数量
 # net = Net(input=2, hidden=10, output=2)
 # print("神经网络结构:",net)
 
 #3.快速搭建神经网络模型
 net = torch.nn.Sequential(
  torch.nn.Linear(2,10),#指定输入层和隐层结点,获得隐层线性输出
  torch.nn.ReLU(),#隐层非线性化
  torch.nn.Linear(10,2)#指定隐层和输出层结点,获得输出层线性输出
 )
 
 #4.设置优化算法、学习率
 # optimizer = torch.optim.SGD( net.parameters(), lr=0.2 )
 # optimizer = torch.optim.SGD( net.parameters(), lr=0.2, momentum=0.8 )
 # optimizer = torch.optim.RMSprop( net.parameters(), lr=0.2, alpha=0.9 )
 optimizer = torch.optim.Adam( net.parameters(), lr=0.2, betas=(0.9,0.99) )
 
 #5.设置损失函数
 loss_func = torch.nn.CrossEntropyLoss()
 
 plt.ion()#打开画布,可视化更新过程
 #6.迭代训练
 for epoch in range(2):
  for step, (batch_x, batch_y) in enumerate(loader):
   out = net(batch_x)#输入训练集,获得当前迭代输出值
   loss = loss_func(out, batch_y)#获得当前迭代的损失
 
   optimizer.zero_grad()#清除上次迭代的更新梯度
   loss.backward()#反向传播
   optimizer.step()#更新权重
 
   if step%200==0:
    plt.cla()#清空之前画布上的内容
    entire_out = net(x)#测试整个训练集
    #获得当前softmax层最大概率对应的索引值
    pred = torch.max(F.softmax(entire_out), 1)[1]
    #将二维压缩为一维
    pred_y = pred.data.numpy().squeeze()
    label_y = y.data.numpy()
    plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, cmap='RdYlGn')
    accuracy = sum(pred_y == label_y)/y.size()
    print("第 %d 个epoch,第 %d 次迭代,准确率为 %.2f"%(epoch+1, step/200+1, accuracy))
    #在指定位置添加文本
    plt.text(1.5, -4, 'Accuracy=%.2f' % accuracy, fontdict={'size': 15, 'color': 'red'})
    plt.pause(2)#图像显示时间
 
 #7.保存模型结构和参数
 torch.save(net, 'net.pkl')
 #7.只保存模型参数
 # torch.save(net.state_dict(), 'net_param.pkl')
 
 plt.ioff()#关闭画布
 plt.show()
 
if __name__ == '__main__':
 save()

2. 读取已训练好的模型测试数据

import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
import torch.nn.functional as F
 
#制作数据
n_data = torch.ones( 100,2 )
x0 = torch.normal( 1.5*n_data, 1 )#均值为2 标准差为1
y0 = torch.zeros( 100 )
 
x1 = torch.normal( -1.5*n_data,1 )#均值为-2 标准差为1
y1 = torch.ones( 100 )
print("数据集维度:",x0.size(),y0.size())
 
#合并训练数据集,并转化数据类型为浮点型或整型
x = torch.cat( (x0,x1),0 ).type( torch.FloatTensor )
y = torch.cat( (y0,y1) ).type( torch.LongTensor )
print( "合并后的数据集维度:",x.data.size(), y.data.size() )
 
#将Tensor放入Variable中
x,y = Variable(x), Variable(y)
 
#载入模型和参数
def restore_net():
 net = torch.load('net.pkl')
 #获得载入模型的预测输出
 pred = net(x)
 # 获得当前softmax层最大概率对应的索引值
 pred = torch.max(F.softmax(pred), 1)[1]
 # 将二维压缩为一维
 pred_y = pred.data.numpy().squeeze()
 label_y = y.data.numpy()
 accuracy = sum(pred_y == label_y) / y.size()
 print("准确率为:",accuracy)
 plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, cmap='RdYlGn')
 plt.show()
#仅载入模型参数,需要先创建网络模型
def restore_param():
 net = torch.nn.Sequential(
  torch.nn.Linear(2,10),#指定输入层和隐层结点,获得隐层线性输出
  torch.nn.ReLU(),#隐层非线性化
  torch.nn.Linear(10,2)#指定隐层和输出层结点,获得输出层线性输出
 )
 
 net.load_state_dict( torch.load('net_param.pkl') )
 #获得载入模型的预测输出
 pred = net(x)
 # 获得当前softmax层最大概率对应的索引值
 pred = torch.max(F.softmax(pred), 1)[1]
 # 将二维压缩为一维
 pred_y = pred.data.numpy().squeeze()
 label_y = y.data.numpy()
 accuracy = sum(pred_y == label_y) / y.size()
 print("准确率为:",accuracy)
 plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, cmap='RdYlGn')
 plt.show()
 
if __name__ =='__main__':
 # restore_net()
 restore_param()

以上这篇Pytorch实现神经网络的分类方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python抓取网页中的图片示例
Feb 28 Python
利用python实现简单的循环购物车功能示例代码
Jul 05 Python
Python基于hashlib模块的文件MD5一致性加密验证示例
Feb 10 Python
python 实现一次性在文件中写入多行的方法
Jan 28 Python
python Django编写接口并用Jmeter测试的方法
Jul 31 Python
Python使用itchat模块实现简单的微信控制电脑功能示例
Aug 26 Python
利用Python产生加密表和解密表的实现方法
Oct 15 Python
tensorflow的计算图总结
Jan 12 Python
python使用hdfs3模块对hdfs进行操作详解
Jun 06 Python
Python数据可视化实现多种图例代码详解
Jul 14 Python
属性与 @property 方法让你的python更高效
Sep 21 Python
pycharm中选中一个单词替换所有重复单词的实现方法
Nov 17 Python
python 爬取古诗文存入mysql数据库的方法
Jan 08 #Python
基于python3抓取pinpoint应用信息入库
Jan 08 #Python
Python PyInstaller安装和使用教程详解
Jan 08 #Python
关于Pytorch的MLP模块实现方式
Jan 07 #Python
PyTorch 普通卷积和空洞卷积实例
Jan 07 #Python
Pytorch中膨胀卷积的用法详解
Jan 07 #Python
Python urlopen()和urlretrieve()用法解析
Jan 07 #Python
You might like
PHP获取163、gmail、126等邮箱联系人地址【已测试2009.10.10】
2009/10/11 PHP
php中强制下载文件的代码(解决了IE下中文文件名乱码问题)
2011/05/09 PHP
使用配置类定义Codeigniter全局变量
2014/06/12 PHP
javascript跨域刷新实现代码
2011/01/01 Javascript
EasyUI中的tree用法介绍
2011/11/01 Javascript
5个javascript的数字格式化函数分享
2011/12/07 Javascript
Js实现双击鼠标自动滚动屏幕的示例代码
2013/12/14 Javascript
javascript替换已有元素replaceChild()使用介绍
2014/04/03 Javascript
jquery实现select选中行、列合计示例
2014/04/25 Javascript
IE中JS跳转丢失referrer问题的2个解决方法
2014/07/18 Javascript
详解Windows下安装Nodejs步骤
2017/05/18 NodeJs
基于ES6 Array.of的用法(实例讲解)
2017/09/05 Javascript
使用Dropzone.js上传的示例代码
2017/10/10 Javascript
vue.js element-ui validate中代码不执行问题解决方法
2017/12/18 Javascript
微信小程序当前时间时段选择器插件使用方法详解
2018/12/28 Javascript
JavaScript从原型到原型链深入理解
2019/06/03 Javascript
JS前端知识点总结之页面加载事件,数组操作,DOM节点操作,循环和分支
2019/07/04 Javascript
Javascript节流函数throttle和防抖函数debounce
2020/12/03 Javascript
python基础教程之popen函数操作其它程序的输入和输出示例
2014/02/10 Python
python中pygame模块用法实例
2014/10/09 Python
理解python正则表达式
2016/01/15 Python
以一个投票程序的实例来讲解Python的Django框架使用
2016/02/18 Python
Python中模块pymysql查询结果后如何获取字段列表
2017/06/05 Python
关于Kotlin中SAM转换的那些事
2020/09/15 Python
简单介绍CSS3中Media Query的使用
2015/07/07 HTML / CSS
怎样比较两个类型为String的字符串
2016/08/17 面试题
医学专业本科毕业生自我鉴定
2013/12/28 职场文书
活动邀请函范文
2014/01/19 职场文书
自荐信格式简述
2014/01/25 职场文书
大学校运会广播稿
2014/02/03 职场文书
最经典的大学生职业生涯规划范文
2014/03/05 职场文书
调解协议书
2014/04/16 职场文书
2014年酒店年度工作总结
2014/12/10 职场文书
2014年学生资助工作总结
2014/12/18 职场文书
小学校本教研总结
2015/08/13 职场文书
关于python爬虫应用urllib库作用分析
2021/09/04 Python