pytorch使用Variable实现线性回归


Posted in Python onMay 21, 2019

本文实例为大家分享了pytorch使用Variable实现线性回归的具体代码,供大家参考,具体内容如下

一、手动计算梯度实现线性回归

#导入相关包
import torch as t
import matplotlib.pyplot as plt
 
#构造数据
def get_fake_data(batch_size = 8):
 #设置随机种子数,这样每次生成的随机数都是一样的
 t.manual_seed(10)
 #产生随机数据:y = 2*x+3,加上了一些噪声
 x = t.rand(batch_size,1) * 20
 #randn生成期望为0方差为1的正态分布随机数
 y = x * 2 + (1 + t.randn(batch_size,1)) * 3 
 return x,y
 
#查看生成数据的分布
x,y = get_fake_data()
plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
 
#线性回归
 
#随机初始化参数
w = t.rand(1,1)
b = t.zeros(1,1)
#学习率
lr = 0.001 
 
for i in range(10000):
 x,y = get_fake_data()
 
 #forward:计算loss
 y_pred = x.mm(w) + b.expand_as(y)
 
 #均方误差作为损失函数
 loss = 0.5 * (y_pred - y)**2 
 loss = loss.sum()
 
 #backward:手动计算梯度
 dloss = 1
 dy_pred = dloss * (y_pred - y)
 dw = x.t().mm(dy_pred)
 db = dy_pred.sum()
 
 #更新参数
 w.sub_(lr * dw)
 b.sub_(lr * db)
 
 if i%1000 == 0:
 #画图
 plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
 
 x1 = t.arange(0,20).float().view(-1,1)
 y1 = x1.mm(w) + b.expand_as(x1)
 plt.plot(x1.numpy(),y1.numpy()) #predicted
 plt.show()
 #plt.pause(0.5)
 print(w.squeeze(),b.squeeze())

pytorch使用Variable实现线性回归

显示的最后一张图如下所示:

pytorch使用Variable实现线性回归

二、自动梯度 计算梯度实现线性回归

#导入相关包
import torch as t
from torch.autograd import Variable as V
import matplotlib.pyplot as plt
 
#构造数据
def get_fake_data(batch_size=8):
 t.manual_seed(10) #设置随机数种子
 x = t.rand(batch_size,1) * 20
 y = 2 * x +(1 + t.randn(batch_size,1)) * 3
 return x,y
 
#查看产生的x,y的分布是什么样的
x,y = get_fake_data()
plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
 
#线性回归
 
#初始化随机参数
w = V(t.rand(1,1),requires_grad=True)
b = V(t.rand(1,1),requires_grad=True)
lr = 0.001
for i in range(8000):
 x,y = get_fake_data()
 x,y = V(x),V(y)
 y_pred = x * w + b
 loss = 0.5 * (y_pred-y)**2
 loss = loss.sum()
 
 #自动计算梯度
 loss.backward()
 #更新参数
 w.data.sub_(lr * w.grad.data)
 b.data.sub_(lr * b.grad.data)
 
 #梯度清零,不清零梯度会累加的
 w.grad.data.zero_()
 b.grad.data.zero_()
 
 if i%1000==0:
 #predicted
 x = t.arange(0,20).float().view(-1,1)
 y = x.mm(w.data) + b.data.expand_as(x)
 plt.plot(x.numpy(),y.numpy())
 
 #true data
 x2,y2 = get_fake_data()
 plt.scatter(x2.numpy(),y2.numpy())
 plt.show()
print(w.data[0],b.data[0])

pytorch使用Variable实现线性回归

显示的最后一张图如下所示:

pytorch使用Variable实现线性回归

用autograd实现的线性回归最大的不同点就在于利用autograd不需要手动计算梯度,可以自动微分。这一点不单是在深度在学习中,在许多机器学习的问题中都很有用。另外,需要注意的是每次反向传播之前要记得先把梯度清零,因为autograd求得的梯度是自动累加的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python操作Mysql实例代码教程在线版(查询手册)
Feb 18 Python
Python的设计模式编程入门指南
Apr 02 Python
Python验证码识别的方法
Jul 10 Python
python合并已经存在的sheet数据到新sheet的方法
Dec 11 Python
Django框架模板注入操作示例【变量传递到模板】
Dec 19 Python
详解Python 调用C# dll库最简方法
Jun 20 Python
Python搭建代理IP池实现检测IP的方法
Oct 27 Python
python 经典数字滤波实例
Dec 16 Python
Python @property原理解析和用法实例
Feb 11 Python
关于Python字符编码与二进制不得不说的一些事
Oct 04 Python
PyCharm配置KBEngine快速处理代码提示冲突、配置命令问题
Apr 03 Python
Django显示可视化图表的实践
May 10 Python
Python面向对象进阶学习
May 21 #Python
谈一谈基于python的面向对象编程基础
May 21 #Python
python字符串和常用数据结构知识总结
May 21 #Python
Opencv实现抠图背景图替换功能
May 21 #Python
python多进程读图提取特征存npy
May 21 #Python
Python中使用pypdf2合并、分割、加密pdf文件的代码详解
May 21 #Python
python+selenium实现简历自动刷新的示例代码
May 20 #Python
You might like
PHP的FTP学习(一)
2006/10/09 PHP
PHP 巧用数组降低程序的时间复杂度
2010/01/01 PHP
PHP+SQL 注入攻击的技术实现以及预防办法
2011/01/27 PHP
PHP解决中文乱码
2017/04/28 PHP
PHP生成随机数的方法总结
2018/03/01 PHP
JavaScript入门教程(1) 什么是JS
2009/01/31 Javascript
JavaScript实现页面滚动图片加载(仿lazyload效果)
2011/07/22 Javascript
基于JQuery 的消息提示框效果代码
2011/07/31 Javascript
Javascript开发之三数组对象实例介绍
2012/11/12 Javascript
jquery获得页面元素的坐标值实现思路及代码
2013/04/15 Javascript
利用javaScript实现点击输入框弹出窗体选择信息
2013/12/11 Javascript
使用forever管理nodejs应用教程
2014/06/03 NodeJs
angularjs的一些优化小技巧
2014/12/06 Javascript
zepto.js中tap事件阻止冒泡的实现方法
2015/02/12 Javascript
javascript中createElement的两种创建方式
2015/05/14 Javascript
Java中int与integer的区别(基本数据类型与引用数据类型)
2017/02/19 Javascript
详解Angular.js指令中scope类型的几种特殊情况
2017/02/21 Javascript
Kotlin学习第一步 kotlin语法特性
2017/05/25 Javascript
Bootstrap一款超好用的前端框架
2017/09/25 Javascript
在vue项目中优雅的使用SVG的方法实例详解
2018/12/03 Javascript
jQuery属性选择器用法实例分析
2019/06/28 jQuery
python爬虫入门教程之点点美女图片爬虫代码分享
2014/09/02 Python
bat和python批量重命名文件的实现代码
2016/05/19 Python
python使用minimax算法实现五子棋
2019/07/29 Python
jupyter 实现notebook中显示完整的行和列
2020/04/09 Python
python 中的命名空间,你真的了解吗?
2020/08/19 Python
python用tkinter实现一个简易能进行随机点名的界面
2020/09/27 Python
HTML5 贪吃蛇游戏实现思路及源代码
2013/09/03 HTML / CSS
网络维护中文求职信
2014/01/03 职场文书
社会实践先进工作者事迹材料
2014/05/06 职场文书
2014年国庆标语
2014/06/30 职场文书
安全生产知识竞赛活动总结
2014/07/07 职场文书
法人授权委托书公证范本
2014/09/14 职场文书
成人成长感言如何写?
2019/08/16 职场文书
Nginx访问日志及错误日志参数说明
2021/03/31 Servers
Win11右下角图标点了没反应怎么办?Win11点击右下角图标无反应解决方法汇总
2022/07/07 数码科技