pytorch使用Variable实现线性回归


Posted in Python onMay 21, 2019

本文实例为大家分享了pytorch使用Variable实现线性回归的具体代码,供大家参考,具体内容如下

一、手动计算梯度实现线性回归

#导入相关包
import torch as t
import matplotlib.pyplot as plt
 
#构造数据
def get_fake_data(batch_size = 8):
 #设置随机种子数,这样每次生成的随机数都是一样的
 t.manual_seed(10)
 #产生随机数据:y = 2*x+3,加上了一些噪声
 x = t.rand(batch_size,1) * 20
 #randn生成期望为0方差为1的正态分布随机数
 y = x * 2 + (1 + t.randn(batch_size,1)) * 3 
 return x,y
 
#查看生成数据的分布
x,y = get_fake_data()
plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
 
#线性回归
 
#随机初始化参数
w = t.rand(1,1)
b = t.zeros(1,1)
#学习率
lr = 0.001 
 
for i in range(10000):
 x,y = get_fake_data()
 
 #forward:计算loss
 y_pred = x.mm(w) + b.expand_as(y)
 
 #均方误差作为损失函数
 loss = 0.5 * (y_pred - y)**2 
 loss = loss.sum()
 
 #backward:手动计算梯度
 dloss = 1
 dy_pred = dloss * (y_pred - y)
 dw = x.t().mm(dy_pred)
 db = dy_pred.sum()
 
 #更新参数
 w.sub_(lr * dw)
 b.sub_(lr * db)
 
 if i%1000 == 0:
 #画图
 plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
 
 x1 = t.arange(0,20).float().view(-1,1)
 y1 = x1.mm(w) + b.expand_as(x1)
 plt.plot(x1.numpy(),y1.numpy()) #predicted
 plt.show()
 #plt.pause(0.5)
 print(w.squeeze(),b.squeeze())

pytorch使用Variable实现线性回归

显示的最后一张图如下所示:

pytorch使用Variable实现线性回归

二、自动梯度 计算梯度实现线性回归

#导入相关包
import torch as t
from torch.autograd import Variable as V
import matplotlib.pyplot as plt
 
#构造数据
def get_fake_data(batch_size=8):
 t.manual_seed(10) #设置随机数种子
 x = t.rand(batch_size,1) * 20
 y = 2 * x +(1 + t.randn(batch_size,1)) * 3
 return x,y
 
#查看产生的x,y的分布是什么样的
x,y = get_fake_data()
plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
 
#线性回归
 
#初始化随机参数
w = V(t.rand(1,1),requires_grad=True)
b = V(t.rand(1,1),requires_grad=True)
lr = 0.001
for i in range(8000):
 x,y = get_fake_data()
 x,y = V(x),V(y)
 y_pred = x * w + b
 loss = 0.5 * (y_pred-y)**2
 loss = loss.sum()
 
 #自动计算梯度
 loss.backward()
 #更新参数
 w.data.sub_(lr * w.grad.data)
 b.data.sub_(lr * b.grad.data)
 
 #梯度清零,不清零梯度会累加的
 w.grad.data.zero_()
 b.grad.data.zero_()
 
 if i%1000==0:
 #predicted
 x = t.arange(0,20).float().view(-1,1)
 y = x.mm(w.data) + b.data.expand_as(x)
 plt.plot(x.numpy(),y.numpy())
 
 #true data
 x2,y2 = get_fake_data()
 plt.scatter(x2.numpy(),y2.numpy())
 plt.show()
print(w.data[0],b.data[0])

pytorch使用Variable实现线性回归

显示的最后一张图如下所示:

pytorch使用Variable实现线性回归

用autograd实现的线性回归最大的不同点就在于利用autograd不需要手动计算梯度,可以自动微分。这一点不单是在深度在学习中,在许多机器学习的问题中都很有用。另外,需要注意的是每次反向传播之前要记得先把梯度清零,因为autograd求得的梯度是自动累加的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python操作摄像头截图实现远程监控的例子
Mar 25 Python
Python自定义scrapy中间模块避免重复采集的方法
Apr 07 Python
K-means聚类算法介绍与利用python实现的代码示例
Nov 13 Python
python+pillow绘制矩阵盖尔圆简单实例
Jan 16 Python
python 递归深度优先搜索与广度优先搜索算法模拟实现
Oct 22 Python
python datetime中strptime用法详解
Aug 29 Python
matplotlib.pyplot画图并导出保存的实例
Dec 07 Python
python和c语言哪个更适合初学者
Jun 22 Python
Python模块常用四种安装方式
Oct 20 Python
Python读取图像并显示灰度图的实现
Dec 01 Python
详解Python模块化编程与装饰器
Jan 16 Python
用Python将库打包发布到pypi
Apr 13 Python
Python面向对象进阶学习
May 21 #Python
谈一谈基于python的面向对象编程基础
May 21 #Python
python字符串和常用数据结构知识总结
May 21 #Python
Opencv实现抠图背景图替换功能
May 21 #Python
python多进程读图提取特征存npy
May 21 #Python
Python中使用pypdf2合并、分割、加密pdf文件的代码详解
May 21 #Python
python+selenium实现简历自动刷新的示例代码
May 20 #Python
You might like
php中3种方法删除字符串中间的空格
2014/03/10 PHP
php实现的zip文件内容比较类
2014/09/24 PHP
ThinkPHP静态缓存简单配置和使用方法详解
2016/03/23 PHP
PHP yield关键字功能与用法分析
2019/01/03 PHP
浅谈JavaScript之事件绑定
2013/07/08 Javascript
JavaScript定时器详解及实例
2013/08/01 Javascript
JavaScript数据类型之基本类型和引用类型的值
2015/04/01 Javascript
jquery仿百度经验滑动切换浏览效果
2015/04/14 Javascript
JavaScript的removeChild()函数用法详解
2015/12/27 Javascript
jquery正则表达式验证(手机号、身份证号、中文名称)
2015/12/31 Javascript
JavaScript电子时钟倒计时第二款
2016/01/10 Javascript
js获取时间精确到秒(年月日)
2016/03/16 Javascript
jQuery解析XML 详解及方法总结
2016/09/28 Javascript
vue实现留言板todolist功能
2017/08/16 Javascript
微信小程序自定义音乐进度条的实例代码
2018/08/28 Javascript
vue.js仿hover效果的实现方法示例
2019/01/28 Javascript
[04:02]2014DOTA2国际邀请赛 BBC每日综述中国战队将再度登顶
2014/07/21 DOTA
[02:04]2014DOTA2国际邀请赛 DK一个时代的落幕
2014/07/21 DOTA
Python实现并行抓取整站40万条房价数据(可更换抓取城市)
2016/12/14 Python
Python设计模式之门面模式简单示例
2018/01/09 Python
python文本数据相似度的度量
2018/03/12 Python
PHP实现发送和接收JSON请求
2018/06/07 Python
Python企业编码生成系统之系统主要函数设计详解
2019/07/26 Python
python爬虫 线程池创建并获取文件代码实例
2019/09/28 Python
CSS3过渡transition效果实例介绍
2016/05/03 HTML / CSS
印尼在线旅游门户网站:NusaTrip
2019/11/01 全球购物
专业毕业生个性的自我评价
2013/10/03 职场文书
面包店的创业计划书范文
2014/01/16 职场文书
《桃花心木》教学反思
2014/02/17 职场文书
电台实习生求职信
2014/02/25 职场文书
小学生演讲稿大全
2014/04/25 职场文书
大学拉赞助协议书范文
2014/09/26 职场文书
家长通知书家长意见
2014/12/30 职场文书
党支部综合考察意见
2015/06/01 职场文书
法制教育主题班会
2015/08/13 职场文书
七夕情人节问候语
2015/11/11 职场文书