pytorch使用Variable实现线性回归


Posted in Python onMay 21, 2019

本文实例为大家分享了pytorch使用Variable实现线性回归的具体代码,供大家参考,具体内容如下

一、手动计算梯度实现线性回归

#导入相关包
import torch as t
import matplotlib.pyplot as plt
 
#构造数据
def get_fake_data(batch_size = 8):
 #设置随机种子数,这样每次生成的随机数都是一样的
 t.manual_seed(10)
 #产生随机数据:y = 2*x+3,加上了一些噪声
 x = t.rand(batch_size,1) * 20
 #randn生成期望为0方差为1的正态分布随机数
 y = x * 2 + (1 + t.randn(batch_size,1)) * 3 
 return x,y
 
#查看生成数据的分布
x,y = get_fake_data()
plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
 
#线性回归
 
#随机初始化参数
w = t.rand(1,1)
b = t.zeros(1,1)
#学习率
lr = 0.001 
 
for i in range(10000):
 x,y = get_fake_data()
 
 #forward:计算loss
 y_pred = x.mm(w) + b.expand_as(y)
 
 #均方误差作为损失函数
 loss = 0.5 * (y_pred - y)**2 
 loss = loss.sum()
 
 #backward:手动计算梯度
 dloss = 1
 dy_pred = dloss * (y_pred - y)
 dw = x.t().mm(dy_pred)
 db = dy_pred.sum()
 
 #更新参数
 w.sub_(lr * dw)
 b.sub_(lr * db)
 
 if i%1000 == 0:
 #画图
 plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
 
 x1 = t.arange(0,20).float().view(-1,1)
 y1 = x1.mm(w) + b.expand_as(x1)
 plt.plot(x1.numpy(),y1.numpy()) #predicted
 plt.show()
 #plt.pause(0.5)
 print(w.squeeze(),b.squeeze())

pytorch使用Variable实现线性回归

显示的最后一张图如下所示:

pytorch使用Variable实现线性回归

二、自动梯度 计算梯度实现线性回归

#导入相关包
import torch as t
from torch.autograd import Variable as V
import matplotlib.pyplot as plt
 
#构造数据
def get_fake_data(batch_size=8):
 t.manual_seed(10) #设置随机数种子
 x = t.rand(batch_size,1) * 20
 y = 2 * x +(1 + t.randn(batch_size,1)) * 3
 return x,y
 
#查看产生的x,y的分布是什么样的
x,y = get_fake_data()
plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
 
#线性回归
 
#初始化随机参数
w = V(t.rand(1,1),requires_grad=True)
b = V(t.rand(1,1),requires_grad=True)
lr = 0.001
for i in range(8000):
 x,y = get_fake_data()
 x,y = V(x),V(y)
 y_pred = x * w + b
 loss = 0.5 * (y_pred-y)**2
 loss = loss.sum()
 
 #自动计算梯度
 loss.backward()
 #更新参数
 w.data.sub_(lr * w.grad.data)
 b.data.sub_(lr * b.grad.data)
 
 #梯度清零,不清零梯度会累加的
 w.grad.data.zero_()
 b.grad.data.zero_()
 
 if i%1000==0:
 #predicted
 x = t.arange(0,20).float().view(-1,1)
 y = x.mm(w.data) + b.data.expand_as(x)
 plt.plot(x.numpy(),y.numpy())
 
 #true data
 x2,y2 = get_fake_data()
 plt.scatter(x2.numpy(),y2.numpy())
 plt.show()
print(w.data[0],b.data[0])

pytorch使用Variable实现线性回归

显示的最后一张图如下所示:

pytorch使用Variable实现线性回归

用autograd实现的线性回归最大的不同点就在于利用autograd不需要手动计算梯度,可以自动微分。这一点不单是在深度在学习中,在许多机器学习的问题中都很有用。另外,需要注意的是每次反向传播之前要记得先把梯度清零,因为autograd求得的梯度是自动累加的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中使用摄像头实现简单的延时摄影技术
Mar 27 Python
python单例模式实例分析
Apr 08 Python
利用Python实现图书超期提醒
Aug 02 Python
Python简单遍历字典及删除元素的方法
Sep 18 Python
Python生成随机密码的方法
Jun 16 Python
Mac 上切换Python多版本
Jun 17 Python
Django框架组成结构、基本概念与文件功能分析
Jul 30 Python
Python2与Python3关于字符串编码处理的差别总结
Sep 07 Python
python产生模拟数据faker库的使用详解
Nov 04 Python
tensorflow2.0教程之Keras快速入门
Feb 20 Python
Python+Selenium自动化环境搭建与操作基础详解
Mar 13 Python
Pandas实现DataFrame的简单运算、统计与排序
Mar 31 Python
Python面向对象进阶学习
May 21 #Python
谈一谈基于python的面向对象编程基础
May 21 #Python
python字符串和常用数据结构知识总结
May 21 #Python
Opencv实现抠图背景图替换功能
May 21 #Python
python多进程读图提取特征存npy
May 21 #Python
Python中使用pypdf2合并、分割、加密pdf文件的代码详解
May 21 #Python
python+selenium实现简历自动刷新的示例代码
May 20 #Python
You might like
PHP验证码无法显示的原因及解决办法
2017/08/11 PHP
ThinkPHP实现转换数据库查询结果数据到对应类型的方法
2017/11/16 PHP
php中通用的excel导出方法实例
2017/12/30 PHP
Laravel框架中自定义模板指令总结
2017/12/17 PHP
PHP中将一个字符串部分字符用星号*替代隐藏的实现代码
2019/09/08 PHP
JavaScript库 开发规则
2009/01/31 Javascript
Chosen 基于jquery的选择框插件使用方法
2012/05/30 Javascript
js显示时间 js显示最后修改时间
2013/01/02 Javascript
jquery实现微博文字输入框 输入时显示输入字数 效果实现
2013/07/12 Javascript
javascript ready和load事件的区别示例介绍
2013/08/30 Javascript
IE下双击checkbox反应延迟问题的解决方法
2014/03/27 Javascript
BootStrap响应式导航条实例介绍
2016/05/06 Javascript
JavaScript数组方法大全(推荐)
2016/07/05 Javascript
jQuery 利用$.ajax 时获取原生XMLHttpRequest 对象的方法
2016/08/25 Javascript
Javascript 链式作用域详细介绍
2017/02/23 Javascript
Vue.js对象转换实例
2017/06/07 Javascript
Angularjs 事件指令详细整理
2017/07/27 Javascript
Vue之mixin全局的用法详解
2018/08/22 Javascript
解决vue项目打包上服务器显示404错误,本地没出错的问题
2020/11/03 Javascript
[05:08]2014DOTA2国际邀请赛 Hao专访复仇的胜利很爽
2014/07/15 DOTA
python字符类型的一些方法小结
2016/05/16 Python
Python实现发送与接收邮件的方法详解
2018/03/28 Python
Python利用公共键如何对字典列表进行排序详解
2018/05/19 Python
利用Python+阿里云实现DDNS动态域名解析的方法
2019/04/01 Python
对python中 math模块下 atan 和 atan2的区别详解
2020/01/17 Python
CSS3中currentColor关键字的妙用
2016/02/27 HTML / CSS
NIHAOMARKET官方海外旗舰店:意大利你好华人超市
2018/01/27 全球购物
国际领先的在线时尚服装和配饰店:DressLily
2019/03/03 全球购物
迎新晚会邀请函
2014/02/01 职场文书
岗位廉洁从政承诺书
2014/03/27 职场文书
数学高效课堂实施方案
2014/03/29 职场文书
期末评语大全
2014/05/04 职场文书
给校长的建议书600字
2014/05/15 职场文书
留学推荐信中文范文
2015/03/26 职场文书
离职证明格式样本
2015/06/12 职场文书
javascript的var与let,const之间的区别详解
2022/02/18 Javascript