利用Pytorch实现简单的线性回归算法


Posted in Python onJanuary 15, 2020

最近听了张江老师的深度学习课程,用Pytorch实现神经网络预测,之前做Titanic生存率预测的时候稍微了解过Tensorflow,听说Tensorflow能做的Pyorch都可以做,而且更方便快捷,自己尝试了一下代码的逻辑确实比较简单。

Pytorch涉及的基本数据类型是tensor(张量)和Autograd(自动微分变量),对于这些概念我也是一知半解,tensor和向量,矩阵等概念都有交叉的部分,下次有时间好好补一下数学的基础知识,不过现阶段的任务主要是应用,学习掌握思维和方法即可,就不再深究了。tensor和ndarray可以相互转换,python的numpy库中的命令也都基本适用。

一些基本的代码:

import torch #导入torch包
x = torch.rand(5, 3) #产生一个5*3的tensor,在 [0,1) 之间随机取值
y = torch.ones(5, 3) #产生一个5*3的Tensor,元素都是1 
#和numpy的命令一致

#tensor的运算
z = x + y #两个tensor可以直接相加
q = x.mm(y.transpose(0, 1)) #x乘以y的转置 mm为矩阵的乘法,矩阵相乘必须某一个矩阵的行与另一个矩阵的列相等

##Tensor与numpy.ndarray之间的转换
import numpy as np  #导入numpy包
a = np.ones([5, 3])  #建立一个5*3全是1的二维数组(矩阵)
b = torch.from_numpy(a)  #利用from_numpy将其转换为tensor
c = torch.FloatTensor(a) #另外一种转换为tensor的方法,类型为FloatTensor,还可以使LongTensor,整型数据类型
b.numpy() #从一个tensor转化为numpy的多维数组

from torch.autograd import Variable #导入自动梯度的运算包,主要用Variable这个类
x = Variable(torch.ones(2, 2), requires_grad=True)  #创建一个Variable,包裹了一个2*2张量,将需要计算梯度属性置为True

下面用pytorch做一个简单的线性关系预测

线性关系是一种非常简单的变量之间的关系,因变量和自变量在线性关系的情况下,可以使用线性回归算法对一个或多个因变量和自变量间的线性关系进行建模,该模型的系数可以用最小二乘法进行求解。生活中的场景往往会比较复杂,需要考虑多元线性关系和非线性关系,用其他的回归分析方法求解。

这里po上代码:

#人为生成一些样本点作为原始数据
x = Variable(torch.linspace(0, 100).type(torch.FloatTensor)) 
rand = Variable(torch.randn(100)) * 10 #随机生成100个满足标准正态分布的随机数,均值为0,方差为1.将这个数字乘以10,标准方差变为10
y = x + rand #将x和rand相加,得到伪造的标签数据y。所以(x,y)应能近似地落在y=x这条直线上

import matplotlib.pyplot as plt #导入画图的程序包
plt.figure(figsize=(10,8)) #设定绘制窗口大小为10*8 inch
plt.plot(x.data.numpy(), y.data.numpy(), 'o') #绘制数据,考虑到x和y都是Variable,需要用data获取它们包裹的Tensor,并专成numpy
plt.xlabel('X') 
plt.ylabel('Y') 
plt.show()

图示:

利用Pytorch实现简单的线性回归算法

训练模型:

#a,b就是要构建的线性函数的系数
a = Variable(torch.rand(1), requires_grad = True) #创建a变量,并随机赋值初始化
b = Variable(torch.rand(1), requires_grad = True) #创建b变量,并随机赋值初始化
print('Initial parameters:', [a, b])

learning_rate = 0.0001 #设置学习率
for i in range(1000):
  ### 增加了这部分代码,清空存储在变量a,b中的梯度信息,以免在backward的过程中会反复不停地累加
  if (a.grad is not None) and (b.grad is not None): 
    a.grad.data.zero_() 
    b.grad.data.zero_() 
  predictions = a.expand_as(x) * x+ b.expand_as(x) #计算在当前a、b条件下的模型预测数值
  loss = torch.mean((predictions - y) ** 2) #通过与标签数据y比较,计算误差
  print('loss:', loss)

  loss.backward() #对损失函数进行梯度反传,backward的方向传播算法
  a.data.add_(- learning_rate * a.grad.data) #利用上一步计算中得到的a的梯度信息更新a中的data数值
  b.data.add_(- learning_rate * b.grad.data) #利用上一步计算中得到的b的梯度信息更新b中的data数值

##拟合
x_data = x.data.numpy() 
plt.figure(figsize = (10, 7))
xplot = plt.plot(x_data, y.data.numpy(), 'o') # 绘制原始数据
yplot = plt.plot(x_data, a.data.numpy() * x_data + b.data.numpy()) #绘制拟合数据
plt.xlabel('X') 
plt.ylabel('Y') 
str1 = str(a.data.numpy()[0]) + 'x +' + str(b.data.numpy()[0]) #图例信息
plt.legend([xplot, yplot],['Data', str1]) #绘制图例
plt.show()

图示:

利用Pytorch实现简单的线性回归算法

测试:

x_test = Variable(torch.FloatTensor([1, 2, 10, 100, 1000])) #随便选择一些点1,2,……,1000
predictions = a.expand_as(x_test) * x_test + b.expand_as(x_test) #计算模型的预测结果
predictions #输出

ok,大功告成,可以看到用pytorch做机器学习确实无论是准确度还是方便性都有优势,继续探索学习。

以上这篇利用Pytorch实现简单的线性回归算法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
简单文件操作python 修改文件指定行的方法
May 15 Python
python生成指定长度的随机数密码
Jan 23 Python
python实现每次处理一个字符的三种方法
Oct 09 Python
Python中Django 后台自定义表单控件
Mar 28 Python
Python 12306抢火车票脚本
Feb 07 Python
python获取磁盘号下盘符步骤详解
Jun 19 Python
在linux系统下安装python librtmp包的实现方法
Jul 22 Python
python numpy--数组的组合和分割实例
Feb 24 Python
python如何安装下载后的模块
Jul 03 Python
用Python 爬取猫眼电影数据分析《无名之辈》
Jul 24 Python
Python urllib3软件包的使用说明
Nov 18 Python
Python如何配置环境变量详解
May 18 Python
pytorch实现线性拟合方式
Jan 15 #Python
Python 支持向量机分类器的实现
Jan 15 #Python
pytorch-神经网络拟合曲线实例
Jan 15 #Python
Pytorch中的VGG实现修改最后一层FC
Jan 15 #Python
详解Python3 中的字符串格式化语法
Jan 15 #Python
用pytorch的nn.Module构造简单全链接层实例
Jan 14 #Python
pytorch三层全连接层实现手写字母识别方式
Jan 14 #Python
You might like
了解咖啡雨林联盟认证 什么是雨林认证 雨林认证是什么意思
2021/03/05 新手入门
利用文件属性结合Session实现在线人数统计
2006/10/09 PHP
PHPMailer使用教程(PHPMailer发送邮件实例分析)
2012/12/06 PHP
php二维数组排序与默认自然排序的方法介绍
2013/04/27 PHP
php仿微信红包分配算法的实现方法
2016/05/13 PHP
php实现构建排除当前元素的乘积数组方法
2018/10/06 PHP
php判断/计算闰年的方法小结【三种方法】
2019/07/06 PHP
javascript的数组和常用函数详解
2014/05/09 Javascript
js实现文章文字大小字号功能完整实例
2014/11/01 Javascript
JavaScript计算两个日期时间段内日期的方法
2015/03/16 Javascript
Javascript中arguments对象的详解与使用方法
2016/10/04 Javascript
input file上传 图片预览功能实例代码
2016/10/25 Javascript
JS实现鼠标移上去显示图片或微信二维码
2016/12/14 Javascript
js实现常见的工具条效果
2017/03/02 Javascript
详解node.js平台下Express的session与cookie模块包的配置
2017/04/26 Javascript
Vue的Options用法说明
2020/08/14 Javascript
[41:37]DOTA2北京网鱼队选拔赛——冲击职业之路
2015/04/13 DOTA
[01:03:41]DOTA2-DPC中国联赛 正赛 Dynasty vs XG BO3 第三场 2月2日
2021/03/11 DOTA
Python中的字符串查找操作方法总结
2016/06/27 Python
详解python脚本自动生成需要文件实例代码
2017/02/04 Python
Python迭代器定义与简单用法分析
2018/04/30 Python
漂亮的Django Markdown富文本app插件的实现
2019/01/02 Python
Django项目主urls导入应用中views的红线问题解决
2019/08/10 Python
3行Python代码实现图像照片抠图和换底色的方法
2019/10/10 Python
Python函数参数类型及排序原理总结
2019/12/19 Python
彻底搞懂python 迭代器和生成器
2020/09/07 Python
利用Python发送邮件或发带附件的邮件
2020/11/12 Python
法国和欧洲海边和滑雪度假:Pierre & Vacances
2017/01/04 全球购物
Grow Gorgeous美国官网:只要八天,体验唤醒毛囊后新生的茂密秀发
2018/06/04 全球购物
系统管理员的职责包括那些?管理的对象是什么?
2016/09/20 面试题
社会公德演讲稿
2014/05/20 职场文书
死亡证明书样本说明
2014/10/18 职场文书
如何制定一份可行的计划!
2019/06/21 职场文书
SQL Server 数据库实验课第五周——常用查询条件
2021/04/05 SQL Server
JavaScript高级程序设计之基本引用类型
2021/11/17 Javascript
Nginx工作模式及代理配置的使用细节
2022/03/21 Servers