pytorch使用Variable实现线性回归


Posted in Python onMay 21, 2019

本文实例为大家分享了pytorch使用Variable实现线性回归的具体代码,供大家参考,具体内容如下

一、手动计算梯度实现线性回归

#导入相关包
import torch as t
import matplotlib.pyplot as plt
 
#构造数据
def get_fake_data(batch_size = 8):
 #设置随机种子数,这样每次生成的随机数都是一样的
 t.manual_seed(10)
 #产生随机数据:y = 2*x+3,加上了一些噪声
 x = t.rand(batch_size,1) * 20
 #randn生成期望为0方差为1的正态分布随机数
 y = x * 2 + (1 + t.randn(batch_size,1)) * 3 
 return x,y
 
#查看生成数据的分布
x,y = get_fake_data()
plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
 
#线性回归
 
#随机初始化参数
w = t.rand(1,1)
b = t.zeros(1,1)
#学习率
lr = 0.001 
 
for i in range(10000):
 x,y = get_fake_data()
 
 #forward:计算loss
 y_pred = x.mm(w) + b.expand_as(y)
 
 #均方误差作为损失函数
 loss = 0.5 * (y_pred - y)**2 
 loss = loss.sum()
 
 #backward:手动计算梯度
 dloss = 1
 dy_pred = dloss * (y_pred - y)
 dw = x.t().mm(dy_pred)
 db = dy_pred.sum()
 
 #更新参数
 w.sub_(lr * dw)
 b.sub_(lr * db)
 
 if i%1000 == 0:
 #画图
 plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
 
 x1 = t.arange(0,20).float().view(-1,1)
 y1 = x1.mm(w) + b.expand_as(x1)
 plt.plot(x1.numpy(),y1.numpy()) #predicted
 plt.show()
 #plt.pause(0.5)
 print(w.squeeze(),b.squeeze())

pytorch使用Variable实现线性回归

显示的最后一张图如下所示:

pytorch使用Variable实现线性回归

二、自动梯度 计算梯度实现线性回归

#导入相关包
import torch as t
from torch.autograd import Variable as V
import matplotlib.pyplot as plt
 
#构造数据
def get_fake_data(batch_size=8):
 t.manual_seed(10) #设置随机数种子
 x = t.rand(batch_size,1) * 20
 y = 2 * x +(1 + t.randn(batch_size,1)) * 3
 return x,y
 
#查看产生的x,y的分布是什么样的
x,y = get_fake_data()
plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
 
#线性回归
 
#初始化随机参数
w = V(t.rand(1,1),requires_grad=True)
b = V(t.rand(1,1),requires_grad=True)
lr = 0.001
for i in range(8000):
 x,y = get_fake_data()
 x,y = V(x),V(y)
 y_pred = x * w + b
 loss = 0.5 * (y_pred-y)**2
 loss = loss.sum()
 
 #自动计算梯度
 loss.backward()
 #更新参数
 w.data.sub_(lr * w.grad.data)
 b.data.sub_(lr * b.grad.data)
 
 #梯度清零,不清零梯度会累加的
 w.grad.data.zero_()
 b.grad.data.zero_()
 
 if i%1000==0:
 #predicted
 x = t.arange(0,20).float().view(-1,1)
 y = x.mm(w.data) + b.data.expand_as(x)
 plt.plot(x.numpy(),y.numpy())
 
 #true data
 x2,y2 = get_fake_data()
 plt.scatter(x2.numpy(),y2.numpy())
 plt.show()
print(w.data[0],b.data[0])

pytorch使用Variable实现线性回归

显示的最后一张图如下所示:

pytorch使用Variable实现线性回归

用autograd实现的线性回归最大的不同点就在于利用autograd不需要手动计算梯度,可以自动微分。这一点不单是在深度在学习中,在许多机器学习的问题中都很有用。另外,需要注意的是每次反向传播之前要记得先把梯度清零,因为autograd求得的梯度是自动累加的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python字典序问题实例
Sep 26 Python
Python基于PycURL自动处理cookie的方法
Jul 25 Python
Python实现二分查找与bisect模块详解
Jan 13 Python
Python实现pdf文档转txt的方法示例
Jan 19 Python
django admin 后台实现三级联动的示例代码
Jun 22 Python
python查找重复图片并删除(图片去重)
Jul 16 Python
Python测试模块doctest使用解析
Aug 10 Python
python 字段拆分详解
Dec 17 Python
tensorflow 模型权重导出实例
Jan 24 Python
Python绘图之二维图与三维图详解
Aug 04 Python
selenium.webdriver中add_argument方法常用参数表
Apr 08 Python
如何利用pygame实现打飞机小游戏
May 30 Python
Python面向对象进阶学习
May 21 #Python
谈一谈基于python的面向对象编程基础
May 21 #Python
python字符串和常用数据结构知识总结
May 21 #Python
Opencv实现抠图背景图替换功能
May 21 #Python
python多进程读图提取特征存npy
May 21 #Python
Python中使用pypdf2合并、分割、加密pdf文件的代码详解
May 21 #Python
python+selenium实现简历自动刷新的示例代码
May 20 #Python
You might like
使用PHP维护文件系统
2006/10/09 PHP
snoopy 强大的PHP采集类使用实例代码
2010/12/09 PHP
php下将多个数组合并成一个数组的方法与实例代码
2011/02/03 PHP
探讨Smarty中如何获取数组的长度以及smarty调用php函数的详解
2013/06/20 PHP
一个PHP的远程图片抓取函数分享
2013/09/25 PHP
CodeIgniter框架中_remap()使用方法2例
2014/03/10 PHP
php版微信公众平台接口参数调试实现判断用户行为的方法
2016/09/23 PHP
如何实现iframe(嵌入式帧)的自适应高度
2006/07/26 Javascript
javascript中巧用“闭包”实现程序的暂停执行功能
2007/04/04 Javascript
Flash+XML滚动新闻代码 无图片 附源码下载
2007/11/22 Javascript
JavaScript自动设置IFrame高度的小例子
2013/06/08 Javascript
js 自动播放的实例代码
2013/11/19 Javascript
PhotoShop给图片自动添加边框及EXIF信息的JS脚本
2015/02/15 Javascript
jQuery删除一个元素后淡出效果展示删除过程的方法
2015/03/18 Javascript
微信小程序 教程之条件渲染
2016/10/18 Javascript
JS正则表达式修饰符global(/g)用法分析
2016/12/27 Javascript
JS实现的点击表头排序功能示例
2017/03/27 Javascript
webpack进阶——缓存与独立打包的用法
2017/08/02 Javascript
import与export在node.js中的使用详解
2017/09/28 Javascript
AngularJS实现注册表单验证功能
2017/10/16 Javascript
Vue 莹石摄像头直播视频实例代码
2018/08/31 Javascript
webpack DllPlugin xxx is not defined解决办法
2019/12/13 Javascript
[01:09:50]VP vs Pain 2018国际邀请赛小组赛BO2 第二场
2018/08/20 DOTA
python正则表达式去除两个特殊字符间的内容方法
2018/12/24 Python
Django 实现admin后台显示图片缩略图的例子
2019/07/28 Python
基于python+selenium的二次封装的实现
2020/01/06 Python
python如何查看网页代码
2020/06/07 Python
HTML5 新表单类型示例代码
2018/03/20 HTML / CSS
福克斯租车:Fox Rent A Car
2017/04/13 全球购物
仓库主管的岗位职责
2013/12/04 职场文书
一年级家长会邀请函
2014/01/25 职场文书
接受捐赠答谢词
2014/01/27 职场文书
结婚通知短信大全
2015/04/17 职场文书
城南旧事观后感
2015/06/11 职场文书
详解Python小数据池和代码块缓存机制
2021/04/07 Python
如何理解Vue简单状态管理之store模式
2021/05/15 Vue.js