Pytorch实现神经网络的分类方式


Posted in Python onJanuary 08, 2020

本文用于利用Pytorch实现神经网络的分类!!!

1.训练神经网络分类模型

import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.utils.data as Data
torch.manual_seed(1)#设置随机种子,使得每次生成的随机数是确定的
BATCH_SIZE = 5#设置batch size
 
#1.制作两类数据
n_data = torch.ones( 1000,2 )
x0 = torch.normal( 1.5*n_data, 1 )#均值为2 标准差为1
y0 = torch.zeros( 1000 )
 
x1 = torch.normal( -1.5*n_data,1 )#均值为-2 标准差为1
y1 = torch.ones( 1000 )
print("数据集维度:",x0.size(),y0.size())
 
#合并训练数据集,并转化数据类型为浮点型或整型
x = torch.cat( (x0,x1),0 ).type( torch.FloatTensor )
y = torch.cat( (y0,y1) ).type( torch.LongTensor )
print( "合并后的数据集维度:",x.data.size(), y.data.size() )
 
#当不使用batch size训练数据时,将Tensor放入Variable中
# x,y = Variable(x), Variable(y)
#绘制训练数据
# plt.scatter( x.data.numpy()[:,0], x.data.numpy()[:,1], c=y.data.numpy())
# plt.show()
 
#当使用batch size训练数据时,首先将tensor转化为Dataset格式
torch_dataset = Data.TensorDataset(x, y)
 
#将dataset放入DataLoader中
loader = Data.DataLoader(
 dataset=torch_dataset,
 batch_size = BATCH_SIZE,#设置batch size
 shuffle=True,#打乱数据
 num_workers=2#多线程读取数据
)
 
#2.前向传播过程
class Net(torch.nn.Module):#继承基类Module的属性和方法
 def __init__(self, input, hidden, output):
  super(Net, self).__init__()#继承__init__功能
  self.hidden = torch.nn.Linear(input, hidden)#隐层的线性输出
  self.out = torch.nn.Linear(hidden, output)#输出层线性输出
 def forward(self, x):
  x = F.relu(self.hidden(x))
  x = self.out(x)
  return x
 
# 训练模型的同时保存网络模型参数
def save():
 #3.利用自定义的前向传播过程设计网络,设置各层神经元数量
 # net = Net(input=2, hidden=10, output=2)
 # print("神经网络结构:",net)
 
 #3.快速搭建神经网络模型
 net = torch.nn.Sequential(
  torch.nn.Linear(2,10),#指定输入层和隐层结点,获得隐层线性输出
  torch.nn.ReLU(),#隐层非线性化
  torch.nn.Linear(10,2)#指定隐层和输出层结点,获得输出层线性输出
 )
 
 #4.设置优化算法、学习率
 # optimizer = torch.optim.SGD( net.parameters(), lr=0.2 )
 # optimizer = torch.optim.SGD( net.parameters(), lr=0.2, momentum=0.8 )
 # optimizer = torch.optim.RMSprop( net.parameters(), lr=0.2, alpha=0.9 )
 optimizer = torch.optim.Adam( net.parameters(), lr=0.2, betas=(0.9,0.99) )
 
 #5.设置损失函数
 loss_func = torch.nn.CrossEntropyLoss()
 
 plt.ion()#打开画布,可视化更新过程
 #6.迭代训练
 for epoch in range(2):
  for step, (batch_x, batch_y) in enumerate(loader):
   out = net(batch_x)#输入训练集,获得当前迭代输出值
   loss = loss_func(out, batch_y)#获得当前迭代的损失
 
   optimizer.zero_grad()#清除上次迭代的更新梯度
   loss.backward()#反向传播
   optimizer.step()#更新权重
 
   if step%200==0:
    plt.cla()#清空之前画布上的内容
    entire_out = net(x)#测试整个训练集
    #获得当前softmax层最大概率对应的索引值
    pred = torch.max(F.softmax(entire_out), 1)[1]
    #将二维压缩为一维
    pred_y = pred.data.numpy().squeeze()
    label_y = y.data.numpy()
    plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, cmap='RdYlGn')
    accuracy = sum(pred_y == label_y)/y.size()
    print("第 %d 个epoch,第 %d 次迭代,准确率为 %.2f"%(epoch+1, step/200+1, accuracy))
    #在指定位置添加文本
    plt.text(1.5, -4, 'Accuracy=%.2f' % accuracy, fontdict={'size': 15, 'color': 'red'})
    plt.pause(2)#图像显示时间
 
 #7.保存模型结构和参数
 torch.save(net, 'net.pkl')
 #7.只保存模型参数
 # torch.save(net.state_dict(), 'net_param.pkl')
 
 plt.ioff()#关闭画布
 plt.show()
 
if __name__ == '__main__':
 save()

2. 读取已训练好的模型测试数据

import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
import torch.nn.functional as F
 
#制作数据
n_data = torch.ones( 100,2 )
x0 = torch.normal( 1.5*n_data, 1 )#均值为2 标准差为1
y0 = torch.zeros( 100 )
 
x1 = torch.normal( -1.5*n_data,1 )#均值为-2 标准差为1
y1 = torch.ones( 100 )
print("数据集维度:",x0.size(),y0.size())
 
#合并训练数据集,并转化数据类型为浮点型或整型
x = torch.cat( (x0,x1),0 ).type( torch.FloatTensor )
y = torch.cat( (y0,y1) ).type( torch.LongTensor )
print( "合并后的数据集维度:",x.data.size(), y.data.size() )
 
#将Tensor放入Variable中
x,y = Variable(x), Variable(y)
 
#载入模型和参数
def restore_net():
 net = torch.load('net.pkl')
 #获得载入模型的预测输出
 pred = net(x)
 # 获得当前softmax层最大概率对应的索引值
 pred = torch.max(F.softmax(pred), 1)[1]
 # 将二维压缩为一维
 pred_y = pred.data.numpy().squeeze()
 label_y = y.data.numpy()
 accuracy = sum(pred_y == label_y) / y.size()
 print("准确率为:",accuracy)
 plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, cmap='RdYlGn')
 plt.show()
#仅载入模型参数,需要先创建网络模型
def restore_param():
 net = torch.nn.Sequential(
  torch.nn.Linear(2,10),#指定输入层和隐层结点,获得隐层线性输出
  torch.nn.ReLU(),#隐层非线性化
  torch.nn.Linear(10,2)#指定隐层和输出层结点,获得输出层线性输出
 )
 
 net.load_state_dict( torch.load('net_param.pkl') )
 #获得载入模型的预测输出
 pred = net(x)
 # 获得当前softmax层最大概率对应的索引值
 pred = torch.max(F.softmax(pred), 1)[1]
 # 将二维压缩为一维
 pred_y = pred.data.numpy().squeeze()
 label_y = y.data.numpy()
 accuracy = sum(pred_y == label_y) / y.size()
 print("准确率为:",accuracy)
 plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, cmap='RdYlGn')
 plt.show()
 
if __name__ =='__main__':
 # restore_net()
 restore_param()

以上这篇Pytorch实现神经网络的分类方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python3.x和Python2.x的区别介绍
Feb 12 Python
将图片文件嵌入到wxpython代码中的实现方法
Aug 11 Python
python 全文检索引擎详解
Apr 25 Python
Python学习小技巧之列表项的推导式与过滤操作
May 20 Python
python实现ID3决策树算法
Dec 20 Python
pycharm+django创建一个搜索网页实例代码
Jan 24 Python
Python RabbitMQ消息队列实现rpc
May 30 Python
Python3随机漫步生成数据并绘制
Aug 27 Python
Numpy 中的矩阵求逆实例
Aug 26 Python
Python requests获取网页常用方法解析
Feb 20 Python
Python Scrapy多页数据爬取实现过程解析
Jun 12 Python
python析构函数用法及注意事项
Jun 22 Python
python 爬取古诗文存入mysql数据库的方法
Jan 08 #Python
基于python3抓取pinpoint应用信息入库
Jan 08 #Python
Python PyInstaller安装和使用教程详解
Jan 08 #Python
关于Pytorch的MLP模块实现方式
Jan 07 #Python
PyTorch 普通卷积和空洞卷积实例
Jan 07 #Python
Pytorch中膨胀卷积的用法详解
Jan 07 #Python
Python urlopen()和urlretrieve()用法解析
Jan 07 #Python
You might like
php自定义session示例分享
2014/04/22 PHP
php把数据表导出为Excel表的最简单、最快的方法(不用插件)
2014/05/10 PHP
ThinkPHP使用心得分享-上传类UploadFile的使用
2014/05/15 PHP
使用配置类定义Codeigniter全局变量
2014/06/12 PHP
php简单实现发送带附件的邮件
2015/06/10 PHP
PHP中模拟链表和链表的基本操作示例
2016/02/27 PHP
PHPCMS忘记后台密码的解决办法
2016/10/30 PHP
PHP对象相关知识总结
2017/04/09 PHP
PHP在同一域名下两个不同的项目做独立登录机制详解
2017/09/22 PHP
php中file_get_contents()函数用法实例
2019/02/21 PHP
javascript中利用数组实现的循环队列代码
2010/01/24 Javascript
JS、DOM和JQuery之间的关系示例分析
2014/04/09 Javascript
js中数组排序sort方法的原理分析
2014/11/20 Javascript
浅析JavaScript Array和string的转换(推荐)
2016/05/20 Javascript
JavaScript函数节流概念与用法实例详解
2016/06/20 Javascript
简单三步实现报表页面集成天气
2016/12/15 Javascript
JS基于正则截取替换特定字符之间字符串操作示例
2017/02/03 Javascript
如何使用bootstrap框架 bootstrap入门必看!
2017/04/13 Javascript
Vue 兄弟组件通信的方法(不使用Vuex)
2017/10/26 Javascript
jquery实现回车键触发事件(实例讲解)
2017/11/21 jQuery
详解如何理解vue的key属性
2019/04/14 Javascript
微信域名检测接口调用演示步骤(含PHP、Python)
2019/12/08 Javascript
解决Vue使用bus总线时,第一次路由跳转时数据没成功传递问题
2020/07/28 Javascript
Vue的Options用法说明
2020/08/14 Javascript
基于react项目打包css引用路径错误解决方案
2020/10/28 Javascript
java直接调用python脚本的例子
2014/02/16 Python
pytorch中tensor的合并与截取方法
2018/07/26 Python
Python使用装饰器模拟用户登陆验证功能示例
2018/08/24 Python
Python: glob匹配文件的操作
2020/12/11 Python
有关HTML5中背景音乐的自动播放功能
2017/10/16 HTML / CSS
英国领先的葡萄酒专家:Majestic Wine
2017/05/30 全球购物
市场营销专科应届生求职信
2013/11/24 职场文书
办公室主任竞聘演讲稿
2014/05/15 职场文书
幼儿园大班开学寄语
2014/08/02 职场文书
秋季运动会开幕词
2015/01/28 职场文书
优秀乡村医生事迹材料(2016精选版)
2016/02/29 职场文书