pytorch使用Variable实现线性回归


Posted in Python onMay 21, 2019

本文实例为大家分享了pytorch使用Variable实现线性回归的具体代码,供大家参考,具体内容如下

一、手动计算梯度实现线性回归

#导入相关包
import torch as t
import matplotlib.pyplot as plt
 
#构造数据
def get_fake_data(batch_size = 8):
 #设置随机种子数,这样每次生成的随机数都是一样的
 t.manual_seed(10)
 #产生随机数据:y = 2*x+3,加上了一些噪声
 x = t.rand(batch_size,1) * 20
 #randn生成期望为0方差为1的正态分布随机数
 y = x * 2 + (1 + t.randn(batch_size,1)) * 3 
 return x,y
 
#查看生成数据的分布
x,y = get_fake_data()
plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
 
#线性回归
 
#随机初始化参数
w = t.rand(1,1)
b = t.zeros(1,1)
#学习率
lr = 0.001 
 
for i in range(10000):
 x,y = get_fake_data()
 
 #forward:计算loss
 y_pred = x.mm(w) + b.expand_as(y)
 
 #均方误差作为损失函数
 loss = 0.5 * (y_pred - y)**2 
 loss = loss.sum()
 
 #backward:手动计算梯度
 dloss = 1
 dy_pred = dloss * (y_pred - y)
 dw = x.t().mm(dy_pred)
 db = dy_pred.sum()
 
 #更新参数
 w.sub_(lr * dw)
 b.sub_(lr * db)
 
 if i%1000 == 0:
 #画图
 plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
 
 x1 = t.arange(0,20).float().view(-1,1)
 y1 = x1.mm(w) + b.expand_as(x1)
 plt.plot(x1.numpy(),y1.numpy()) #predicted
 plt.show()
 #plt.pause(0.5)
 print(w.squeeze(),b.squeeze())

pytorch使用Variable实现线性回归

显示的最后一张图如下所示:

pytorch使用Variable实现线性回归

二、自动梯度 计算梯度实现线性回归

#导入相关包
import torch as t
from torch.autograd import Variable as V
import matplotlib.pyplot as plt
 
#构造数据
def get_fake_data(batch_size=8):
 t.manual_seed(10) #设置随机数种子
 x = t.rand(batch_size,1) * 20
 y = 2 * x +(1 + t.randn(batch_size,1)) * 3
 return x,y
 
#查看产生的x,y的分布是什么样的
x,y = get_fake_data()
plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
 
#线性回归
 
#初始化随机参数
w = V(t.rand(1,1),requires_grad=True)
b = V(t.rand(1,1),requires_grad=True)
lr = 0.001
for i in range(8000):
 x,y = get_fake_data()
 x,y = V(x),V(y)
 y_pred = x * w + b
 loss = 0.5 * (y_pred-y)**2
 loss = loss.sum()
 
 #自动计算梯度
 loss.backward()
 #更新参数
 w.data.sub_(lr * w.grad.data)
 b.data.sub_(lr * b.grad.data)
 
 #梯度清零,不清零梯度会累加的
 w.grad.data.zero_()
 b.grad.data.zero_()
 
 if i%1000==0:
 #predicted
 x = t.arange(0,20).float().view(-1,1)
 y = x.mm(w.data) + b.data.expand_as(x)
 plt.plot(x.numpy(),y.numpy())
 
 #true data
 x2,y2 = get_fake_data()
 plt.scatter(x2.numpy(),y2.numpy())
 plt.show()
print(w.data[0],b.data[0])

pytorch使用Variable实现线性回归

显示的最后一张图如下所示:

pytorch使用Variable实现线性回归

用autograd实现的线性回归最大的不同点就在于利用autograd不需要手动计算梯度,可以自动微分。这一点不单是在深度在学习中,在许多机器学习的问题中都很有用。另外,需要注意的是每次反向传播之前要记得先把梯度清零,因为autograd求得的梯度是自动累加的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用在线API查询IP对应的地理位置信息实例
Jun 01 Python
python比较两个列表大小的方法
Jul 11 Python
Python数组遍历的简单实现方法小结
Apr 27 Python
Django项目开发中cookies和session的常用操作分析
Jul 03 Python
Python元组知识点总结
Feb 18 Python
python环境路径配置以及命令行运行脚本
Apr 02 Python
深入了解Django View(视图系统)
Jul 23 Python
详细介绍Python进度条tqdm的使用
Jul 31 Python
Pycharm IDE的安装和使用教程详解
Apr 30 Python
序列化Python对象的方法
Aug 01 Python
基于python调用jenkins-cli实现快速发布
Aug 14 Python
Python学习之时间包使用教程详解
Mar 21 Python
Python面向对象进阶学习
May 21 #Python
谈一谈基于python的面向对象编程基础
May 21 #Python
python字符串和常用数据结构知识总结
May 21 #Python
Opencv实现抠图背景图替换功能
May 21 #Python
python多进程读图提取特征存npy
May 21 #Python
Python中使用pypdf2合并、分割、加密pdf文件的代码详解
May 21 #Python
python+selenium实现简历自动刷新的示例代码
May 20 #Python
You might like
php加水印的代码(支持半透明透明打水印,支持png透明背景)
2013/01/17 PHP
如何阻止网站被恶意反向代理访问(防网站镜像)
2014/03/18 PHP
Ubuntu VPS中wordpress网站打开时提示”建立数据库连接错误”的解决办法
2016/11/03 PHP
jQuery TextBox自动完成条
2009/07/22 Javascript
分享一个自定义的console类 让你不再纠结JS中的调试代码的兼容
2012/04/20 Javascript
js中cookie的添加、取值、删除示例代码
2013/10/21 Javascript
js使用正则实现ReplaceAll全部替换的方法
2014/08/22 Javascript
js实现tab切换效果实例
2015/09/16 Javascript
以Python代码实例展示kNN算法的实际运用
2015/10/26 Javascript
微信小程序使用第三方库Underscore.js步骤详解
2016/09/27 Javascript
JS两种类型的表单提交方法实例分析
2016/11/28 Javascript
神级程序员JavaScript300行代码搞定汉字转拼音
2017/05/20 Javascript
微信小程序日期时间选择器使用方法
2018/02/01 Javascript
微信小程序点击保存图片到本机功能
2019/12/13 Javascript
微信小程序实现简单的select下拉框
2020/11/23 Javascript
本地文件上传到七牛云服务器示例(七牛云存储)
2014/01/11 Python
python 示例分享---逻辑推理编程解决八皇后
2014/07/20 Python
Python查找函数f(x)=0根的解决方法
2015/05/07 Python
简单介绍Python的Django框架的dj-scaffold项目
2015/05/30 Python
Flask框架信号用法实例分析
2018/07/24 Python
python itchat实现调用微信接口的第三方模块方法
2019/06/11 Python
通过selenium抓取某东的TT购买记录并分析趋势过程解析
2019/08/15 Python
Python SELENIUM上传文件或图片实现过程
2019/10/28 Python
python类中super() 的使用解析
2019/12/19 Python
Python对称的二叉树多种思路实现方法
2020/02/28 Python
解决pycharm下pyuic工具使用的问题
2020/04/08 Python
Python小白学习爬虫常用请求报头
2020/06/03 Python
Django启动时找不到mysqlclient问题解决方案
2020/11/11 Python
python 使用openpyxl读取excel数据
2021/02/18 Python
html5的自定义data-*属性与jquery的data()方法的使用
2014/07/02 HTML / CSS
HTML最新标准HTML5总结(必看)
2016/06/13 HTML / CSS
data:image data url 文件转为Blob上传后端的方法
2019/07/16 HTML / CSS
Gap加拿大官网:Gap Canada
2017/08/24 全球购物
面向对象设计的原则是什么
2013/02/13 面试题
作风整顿剖析材料
2014/09/30 职场文书
日本十大血腥动漫,那些被禁播的动漫盘点
2022/03/21 日漫