利用Pytorch实现简单的线性回归算法


Posted in Python onJanuary 15, 2020

最近听了张江老师的深度学习课程,用Pytorch实现神经网络预测,之前做Titanic生存率预测的时候稍微了解过Tensorflow,听说Tensorflow能做的Pyorch都可以做,而且更方便快捷,自己尝试了一下代码的逻辑确实比较简单。

Pytorch涉及的基本数据类型是tensor(张量)和Autograd(自动微分变量),对于这些概念我也是一知半解,tensor和向量,矩阵等概念都有交叉的部分,下次有时间好好补一下数学的基础知识,不过现阶段的任务主要是应用,学习掌握思维和方法即可,就不再深究了。tensor和ndarray可以相互转换,python的numpy库中的命令也都基本适用。

一些基本的代码:

import torch #导入torch包
x = torch.rand(5, 3) #产生一个5*3的tensor,在 [0,1) 之间随机取值
y = torch.ones(5, 3) #产生一个5*3的Tensor,元素都是1 
#和numpy的命令一致

#tensor的运算
z = x + y #两个tensor可以直接相加
q = x.mm(y.transpose(0, 1)) #x乘以y的转置 mm为矩阵的乘法,矩阵相乘必须某一个矩阵的行与另一个矩阵的列相等

##Tensor与numpy.ndarray之间的转换
import numpy as np  #导入numpy包
a = np.ones([5, 3])  #建立一个5*3全是1的二维数组(矩阵)
b = torch.from_numpy(a)  #利用from_numpy将其转换为tensor
c = torch.FloatTensor(a) #另外一种转换为tensor的方法,类型为FloatTensor,还可以使LongTensor,整型数据类型
b.numpy() #从一个tensor转化为numpy的多维数组

from torch.autograd import Variable #导入自动梯度的运算包,主要用Variable这个类
x = Variable(torch.ones(2, 2), requires_grad=True)  #创建一个Variable,包裹了一个2*2张量,将需要计算梯度属性置为True

下面用pytorch做一个简单的线性关系预测

线性关系是一种非常简单的变量之间的关系,因变量和自变量在线性关系的情况下,可以使用线性回归算法对一个或多个因变量和自变量间的线性关系进行建模,该模型的系数可以用最小二乘法进行求解。生活中的场景往往会比较复杂,需要考虑多元线性关系和非线性关系,用其他的回归分析方法求解。

这里po上代码:

#人为生成一些样本点作为原始数据
x = Variable(torch.linspace(0, 100).type(torch.FloatTensor)) 
rand = Variable(torch.randn(100)) * 10 #随机生成100个满足标准正态分布的随机数,均值为0,方差为1.将这个数字乘以10,标准方差变为10
y = x + rand #将x和rand相加,得到伪造的标签数据y。所以(x,y)应能近似地落在y=x这条直线上

import matplotlib.pyplot as plt #导入画图的程序包
plt.figure(figsize=(10,8)) #设定绘制窗口大小为10*8 inch
plt.plot(x.data.numpy(), y.data.numpy(), 'o') #绘制数据,考虑到x和y都是Variable,需要用data获取它们包裹的Tensor,并专成numpy
plt.xlabel('X') 
plt.ylabel('Y') 
plt.show()

图示:

利用Pytorch实现简单的线性回归算法

训练模型:

#a,b就是要构建的线性函数的系数
a = Variable(torch.rand(1), requires_grad = True) #创建a变量,并随机赋值初始化
b = Variable(torch.rand(1), requires_grad = True) #创建b变量,并随机赋值初始化
print('Initial parameters:', [a, b])

learning_rate = 0.0001 #设置学习率
for i in range(1000):
  ### 增加了这部分代码,清空存储在变量a,b中的梯度信息,以免在backward的过程中会反复不停地累加
  if (a.grad is not None) and (b.grad is not None): 
    a.grad.data.zero_() 
    b.grad.data.zero_() 
  predictions = a.expand_as(x) * x+ b.expand_as(x) #计算在当前a、b条件下的模型预测数值
  loss = torch.mean((predictions - y) ** 2) #通过与标签数据y比较,计算误差
  print('loss:', loss)

  loss.backward() #对损失函数进行梯度反传,backward的方向传播算法
  a.data.add_(- learning_rate * a.grad.data) #利用上一步计算中得到的a的梯度信息更新a中的data数值
  b.data.add_(- learning_rate * b.grad.data) #利用上一步计算中得到的b的梯度信息更新b中的data数值

##拟合
x_data = x.data.numpy() 
plt.figure(figsize = (10, 7))
xplot = plt.plot(x_data, y.data.numpy(), 'o') # 绘制原始数据
yplot = plt.plot(x_data, a.data.numpy() * x_data + b.data.numpy()) #绘制拟合数据
plt.xlabel('X') 
plt.ylabel('Y') 
str1 = str(a.data.numpy()[0]) + 'x +' + str(b.data.numpy()[0]) #图例信息
plt.legend([xplot, yplot],['Data', str1]) #绘制图例
plt.show()

图示:

利用Pytorch实现简单的线性回归算法

测试:

x_test = Variable(torch.FloatTensor([1, 2, 10, 100, 1000])) #随便选择一些点1,2,……,1000
predictions = a.expand_as(x_test) * x_test + b.expand_as(x_test) #计算模型的预测结果
predictions #输出

ok,大功告成,可以看到用pytorch做机器学习确实无论是准确度还是方便性都有优势,继续探索学习。

以上这篇利用Pytorch实现简单的线性回归算法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python数据结构之二叉树的遍历实例
Apr 29 Python
使用python实现生成用户信息
Mar 20 Python
Python微信企业号开发之回调模式接收微信端客户端发送消息及被动返回消息示例
Aug 21 Python
解决python使用open打开文件中文乱码的问题
Dec 29 Python
python3实现磁盘空间监控
Jun 21 Python
python调用百度语音识别api
Aug 30 Python
python实现Zabbix-API监控
Sep 17 Python
Python 微信之获取好友昵称并制作wordcloud的实例
Feb 21 Python
Python Pickle 实现在同一个文件中序列化多个对象
Dec 30 Python
基于python检查矩阵计算结果
May 21 Python
关于tf.matmul() 和tf.multiply() 的区别说明
Jun 18 Python
python基础之//、/与%的区别详解
Jun 10 Python
pytorch实现线性拟合方式
Jan 15 #Python
Python 支持向量机分类器的实现
Jan 15 #Python
pytorch-神经网络拟合曲线实例
Jan 15 #Python
Pytorch中的VGG实现修改最后一层FC
Jan 15 #Python
详解Python3 中的字符串格式化语法
Jan 15 #Python
用pytorch的nn.Module构造简单全链接层实例
Jan 14 #Python
pytorch三层全连接层实现手写字母识别方式
Jan 14 #Python
You might like
php park、unpark、ord 函数使用方法(二进制流接口应用实例)
2010/10/19 PHP
php实现window平台的checkdnsrr函数
2015/05/27 PHP
php倒计时出现-0情况的解决方法
2016/07/28 PHP
PHP用FTP类上传文件视频等的简单实现方法
2016/09/23 PHP
PHP随机数函数rand()与mt_rand()的讲解
2019/03/25 PHP
js正确获取元素样式详解
2009/08/07 Javascript
用 Javascript 验证表单(form)中的单选(radio)值
2009/09/08 Javascript
基于jQuery实现的图片切换焦点图整理
2014/12/07 Javascript
利用jQuery和CSS将背景图片拉伸
2015/10/16 Javascript
javascript实现很浪漫的气泡冒出特效
2020/09/05 Javascript
Google 地图类型详解及示例代码
2016/08/06 Javascript
深入理解jQuery layui分页控件的使用
2016/08/17 Javascript
基于JS代码实现简单易用的倒计时 x 天 x 时 x 分 x 秒效果
2017/07/13 Javascript
微信小程序实现页面跳转传值的方法
2017/10/12 Javascript
详解使用Next.js构建服务端渲染应用
2018/07/10 Javascript
微信小程序中使用wxss加载图片并实现动画效果
2018/08/13 Javascript
Vue2.x Todo之自定义指令实现自动聚焦的方法
2019/01/08 Javascript
WebGL学习教程之Three.js学习笔记(第一篇)
2019/04/25 Javascript
javascript实现简易数码时钟
2020/03/30 Javascript
微信小程序报错: thirdScriptError的错误问题
2020/06/19 Javascript
Vue+Bootstrap实现简易学生管理系统
2021/02/09 Vue.js
[03:11]不朽宝藏三外观展示
2020/09/18 DOTA
python爬虫实战之爬取京东商城实例教程
2017/04/24 Python
Windows下PyCharm安装图文教程
2018/08/27 Python
python时间序列按频率生成日期的方法
2019/05/14 Python
tensorflow -gpu安装方法(不用自己装cuda,cdnn)
2020/01/20 Python
在django中使用post方法时,需要增加csrftoken的例子
2020/03/13 Python
HTML5和以前HTML4的区别整理
2013/10/20 HTML / CSS
AmazeUI 输入框组的示例代码
2020/08/14 HTML / CSS
FOREO斐珞尔官方旗舰店:LUNA露娜洁面仪
2018/03/11 全球购物
吉尔德利巧克力公司:Ghirardelli Chocolate Company
2019/03/27 全球购物
中专生的个人自我评价
2013/12/11 职场文书
2015年禁毒宣传活动总结
2015/03/25 职场文书
2015年党务工作者个人工作总结
2015/10/22 职场文书
python 标准库原理与用法详解之os.path篇
2021/10/24 Python
漫画「你在春天醒来」第10卷封面公开
2022/03/21 日漫