利用Pytorch实现简单的线性回归算法


Posted in Python onJanuary 15, 2020

最近听了张江老师的深度学习课程,用Pytorch实现神经网络预测,之前做Titanic生存率预测的时候稍微了解过Tensorflow,听说Tensorflow能做的Pyorch都可以做,而且更方便快捷,自己尝试了一下代码的逻辑确实比较简单。

Pytorch涉及的基本数据类型是tensor(张量)和Autograd(自动微分变量),对于这些概念我也是一知半解,tensor和向量,矩阵等概念都有交叉的部分,下次有时间好好补一下数学的基础知识,不过现阶段的任务主要是应用,学习掌握思维和方法即可,就不再深究了。tensor和ndarray可以相互转换,python的numpy库中的命令也都基本适用。

一些基本的代码:

import torch #导入torch包
x = torch.rand(5, 3) #产生一个5*3的tensor,在 [0,1) 之间随机取值
y = torch.ones(5, 3) #产生一个5*3的Tensor,元素都是1 
#和numpy的命令一致

#tensor的运算
z = x + y #两个tensor可以直接相加
q = x.mm(y.transpose(0, 1)) #x乘以y的转置 mm为矩阵的乘法,矩阵相乘必须某一个矩阵的行与另一个矩阵的列相等

##Tensor与numpy.ndarray之间的转换
import numpy as np  #导入numpy包
a = np.ones([5, 3])  #建立一个5*3全是1的二维数组(矩阵)
b = torch.from_numpy(a)  #利用from_numpy将其转换为tensor
c = torch.FloatTensor(a) #另外一种转换为tensor的方法,类型为FloatTensor,还可以使LongTensor,整型数据类型
b.numpy() #从一个tensor转化为numpy的多维数组

from torch.autograd import Variable #导入自动梯度的运算包,主要用Variable这个类
x = Variable(torch.ones(2, 2), requires_grad=True)  #创建一个Variable,包裹了一个2*2张量,将需要计算梯度属性置为True

下面用pytorch做一个简单的线性关系预测

线性关系是一种非常简单的变量之间的关系,因变量和自变量在线性关系的情况下,可以使用线性回归算法对一个或多个因变量和自变量间的线性关系进行建模,该模型的系数可以用最小二乘法进行求解。生活中的场景往往会比较复杂,需要考虑多元线性关系和非线性关系,用其他的回归分析方法求解。

这里po上代码:

#人为生成一些样本点作为原始数据
x = Variable(torch.linspace(0, 100).type(torch.FloatTensor)) 
rand = Variable(torch.randn(100)) * 10 #随机生成100个满足标准正态分布的随机数,均值为0,方差为1.将这个数字乘以10,标准方差变为10
y = x + rand #将x和rand相加,得到伪造的标签数据y。所以(x,y)应能近似地落在y=x这条直线上

import matplotlib.pyplot as plt #导入画图的程序包
plt.figure(figsize=(10,8)) #设定绘制窗口大小为10*8 inch
plt.plot(x.data.numpy(), y.data.numpy(), 'o') #绘制数据,考虑到x和y都是Variable,需要用data获取它们包裹的Tensor,并专成numpy
plt.xlabel('X') 
plt.ylabel('Y') 
plt.show()

图示:

利用Pytorch实现简单的线性回归算法

训练模型:

#a,b就是要构建的线性函数的系数
a = Variable(torch.rand(1), requires_grad = True) #创建a变量,并随机赋值初始化
b = Variable(torch.rand(1), requires_grad = True) #创建b变量,并随机赋值初始化
print('Initial parameters:', [a, b])

learning_rate = 0.0001 #设置学习率
for i in range(1000):
  ### 增加了这部分代码,清空存储在变量a,b中的梯度信息,以免在backward的过程中会反复不停地累加
  if (a.grad is not None) and (b.grad is not None): 
    a.grad.data.zero_() 
    b.grad.data.zero_() 
  predictions = a.expand_as(x) * x+ b.expand_as(x) #计算在当前a、b条件下的模型预测数值
  loss = torch.mean((predictions - y) ** 2) #通过与标签数据y比较,计算误差
  print('loss:', loss)

  loss.backward() #对损失函数进行梯度反传,backward的方向传播算法
  a.data.add_(- learning_rate * a.grad.data) #利用上一步计算中得到的a的梯度信息更新a中的data数值
  b.data.add_(- learning_rate * b.grad.data) #利用上一步计算中得到的b的梯度信息更新b中的data数值

##拟合
x_data = x.data.numpy() 
plt.figure(figsize = (10, 7))
xplot = plt.plot(x_data, y.data.numpy(), 'o') # 绘制原始数据
yplot = plt.plot(x_data, a.data.numpy() * x_data + b.data.numpy()) #绘制拟合数据
plt.xlabel('X') 
plt.ylabel('Y') 
str1 = str(a.data.numpy()[0]) + 'x +' + str(b.data.numpy()[0]) #图例信息
plt.legend([xplot, yplot],['Data', str1]) #绘制图例
plt.show()

图示:

利用Pytorch实现简单的线性回归算法

测试:

x_test = Variable(torch.FloatTensor([1, 2, 10, 100, 1000])) #随便选择一些点1,2,……,1000
predictions = a.expand_as(x_test) * x_test + b.expand_as(x_test) #计算模型的预测结果
predictions #输出

ok,大功告成,可以看到用pytorch做机器学习确实无论是准确度还是方便性都有优势,继续探索学习。

以上这篇利用Pytorch实现简单的线性回归算法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
videocapture库制作python视频高速传输程序
Dec 23 Python
python的urllib模块显示下载进度示例
Jan 17 Python
python计数排序和基数排序算法实例
Apr 25 Python
Python中AND、OR的一个使用小技巧
Feb 18 Python
浅谈python为什么不需要三目运算符和switch
Jun 17 Python
TF-IDF算法解析与Python实现方法详解
Nov 16 Python
详解python上传文件和字符到PHP服务器
Nov 24 Python
python Django编写接口并用Jmeter测试的方法
Jul 31 Python
pytorch查看torch.Tensor和model是否在CUDA上的实例
Jan 03 Python
浅谈Python访问MySQL的正确姿势
Jan 07 Python
Python 字典中的所有方法及用法
Jun 10 Python
scrapy-splash简单使用详解
Feb 21 Python
pytorch实现线性拟合方式
Jan 15 #Python
Python 支持向量机分类器的实现
Jan 15 #Python
pytorch-神经网络拟合曲线实例
Jan 15 #Python
Pytorch中的VGG实现修改最后一层FC
Jan 15 #Python
详解Python3 中的字符串格式化语法
Jan 15 #Python
用pytorch的nn.Module构造简单全链接层实例
Jan 14 #Python
pytorch三层全连接层实现手写字母识别方式
Jan 14 #Python
You might like
PHP 面向对象实现代码
2009/11/11 PHP
深入PHP autoload机制的详解
2013/06/09 PHP
解析thinkphp中的M()与D()方法的区别
2013/06/22 PHP
解析PHPExcel使用的常用说明以及把PHPExcel整合进CI框架的介绍
2013/06/24 PHP
php实现在线通讯录功能(附源码)
2016/05/13 PHP
Yii2中cookie用法示例分析
2016/07/18 PHP
Thinkphp实现短信验证注册功能
2016/10/18 PHP
php实现微信公众号企业转账功能
2018/10/01 PHP
PHP实现字符串的全排列详解
2019/04/24 PHP
javascript算法学习(直接插入排序)
2011/04/12 Javascript
js中判断文本框是否为空的两种方法
2011/07/31 Javascript
Javascript 鼠标移动上去 滑块跟随效果代码分享
2013/11/23 Javascript
JS,Jquery获取select,dropdownlist,checkbox下拉列表框的值(示例代码)
2014/01/11 Javascript
jQuery操作元素css样式的三种方法
2014/06/04 Javascript
只需五句话搞定JavaScript作用域(经典)
2016/07/26 Javascript
用jquery获取select标签中选中的option值及文本的示例
2018/01/25 jQuery
Javasript设计模式之链式调用详解
2018/04/26 Javascript
js 数组详细操作方法及解析合集
2018/06/01 Javascript
微信小程序全局变量功能与用法详解
2019/01/22 Javascript
jQuery控制input只能输入数字和两位小数的方法
2019/05/16 jQuery
深入浅出了解Node.js Streams
2019/05/27 Javascript
javascript数组元素删除方法delete和splice解析
2019/12/09 Javascript
Vue-axios-post数据后端接不到问题解决
2020/01/09 Javascript
微信小程序实现上传多张图片、删除图片
2020/07/29 Javascript
Python中列表和元组的相关语句和方法讲解
2015/08/20 Python
python中星号变量的几种特殊用法
2016/09/07 Python
tensorflow创建变量以及根据名称查找变量
2018/03/10 Python
在python中pandas的series合并方法
2018/11/12 Python
TensorFlow2.0矩阵与向量的加减乘实例
2020/02/07 Python
Pycharm Available Package无法显示/安装包的问题Error Loading Package List解决
2020/09/18 Python
Python3爬虫RedisDump的安装步骤
2021/02/20 Python
读书活动总结
2014/04/28 职场文书
活动总结报告格式
2014/05/09 职场文书
美容院合作经营协议书
2014/10/10 职场文书
第二批党的群众路线教育实践活动个人整改方案
2014/10/31 职场文书
SQL实现LeetCode(176.第二高薪水)
2021/08/04 MySQL