pytorch使用Variable实现线性回归


Posted in Python onMay 21, 2019

本文实例为大家分享了pytorch使用Variable实现线性回归的具体代码,供大家参考,具体内容如下

一、手动计算梯度实现线性回归

#导入相关包
import torch as t
import matplotlib.pyplot as plt
 
#构造数据
def get_fake_data(batch_size = 8):
 #设置随机种子数,这样每次生成的随机数都是一样的
 t.manual_seed(10)
 #产生随机数据:y = 2*x+3,加上了一些噪声
 x = t.rand(batch_size,1) * 20
 #randn生成期望为0方差为1的正态分布随机数
 y = x * 2 + (1 + t.randn(batch_size,1)) * 3 
 return x,y
 
#查看生成数据的分布
x,y = get_fake_data()
plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
 
#线性回归
 
#随机初始化参数
w = t.rand(1,1)
b = t.zeros(1,1)
#学习率
lr = 0.001 
 
for i in range(10000):
 x,y = get_fake_data()
 
 #forward:计算loss
 y_pred = x.mm(w) + b.expand_as(y)
 
 #均方误差作为损失函数
 loss = 0.5 * (y_pred - y)**2 
 loss = loss.sum()
 
 #backward:手动计算梯度
 dloss = 1
 dy_pred = dloss * (y_pred - y)
 dw = x.t().mm(dy_pred)
 db = dy_pred.sum()
 
 #更新参数
 w.sub_(lr * dw)
 b.sub_(lr * db)
 
 if i%1000 == 0:
 #画图
 plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
 
 x1 = t.arange(0,20).float().view(-1,1)
 y1 = x1.mm(w) + b.expand_as(x1)
 plt.plot(x1.numpy(),y1.numpy()) #predicted
 plt.show()
 #plt.pause(0.5)
 print(w.squeeze(),b.squeeze())

pytorch使用Variable实现线性回归

显示的最后一张图如下所示:

pytorch使用Variable实现线性回归

二、自动梯度 计算梯度实现线性回归

#导入相关包
import torch as t
from torch.autograd import Variable as V
import matplotlib.pyplot as plt
 
#构造数据
def get_fake_data(batch_size=8):
 t.manual_seed(10) #设置随机数种子
 x = t.rand(batch_size,1) * 20
 y = 2 * x +(1 + t.randn(batch_size,1)) * 3
 return x,y
 
#查看产生的x,y的分布是什么样的
x,y = get_fake_data()
plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
 
#线性回归
 
#初始化随机参数
w = V(t.rand(1,1),requires_grad=True)
b = V(t.rand(1,1),requires_grad=True)
lr = 0.001
for i in range(8000):
 x,y = get_fake_data()
 x,y = V(x),V(y)
 y_pred = x * w + b
 loss = 0.5 * (y_pred-y)**2
 loss = loss.sum()
 
 #自动计算梯度
 loss.backward()
 #更新参数
 w.data.sub_(lr * w.grad.data)
 b.data.sub_(lr * b.grad.data)
 
 #梯度清零,不清零梯度会累加的
 w.grad.data.zero_()
 b.grad.data.zero_()
 
 if i%1000==0:
 #predicted
 x = t.arange(0,20).float().view(-1,1)
 y = x.mm(w.data) + b.data.expand_as(x)
 plt.plot(x.numpy(),y.numpy())
 
 #true data
 x2,y2 = get_fake_data()
 plt.scatter(x2.numpy(),y2.numpy())
 plt.show()
print(w.data[0],b.data[0])

pytorch使用Variable实现线性回归

显示的最后一张图如下所示:

pytorch使用Variable实现线性回归

用autograd实现的线性回归最大的不同点就在于利用autograd不需要手动计算梯度,可以自动微分。这一点不单是在深度在学习中,在许多机器学习的问题中都很有用。另外,需要注意的是每次反向传播之前要记得先把梯度清零,因为autograd求得的梯度是自动累加的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python解析xml文件实例分析
May 27 Python
在Django的通用视图中处理Context的方法
Jul 21 Python
python指定写入文件时的编码格式方法
Jun 07 Python
Python爬虫之网页图片抓取的方法
Jul 16 Python
Python解析、提取url关键字的实例详解
Dec 17 Python
python游戏地图最短路径求解
Jan 16 Python
Python3将jpg转为pdf文件的方法示例
Dec 13 Python
Python批量将图片灰度化的实现代码
Apr 11 Python
Django中的模型类设计及展示示例详解
May 29 Python
详解python with 上下文管理器
Sep 02 Python
Python如何使用logging为Flask增加logid
Mar 30 Python
python如何读取和存储dict()与.json格式文件
Jun 25 Python
Python面向对象进阶学习
May 21 #Python
谈一谈基于python的面向对象编程基础
May 21 #Python
python字符串和常用数据结构知识总结
May 21 #Python
Opencv实现抠图背景图替换功能
May 21 #Python
python多进程读图提取特征存npy
May 21 #Python
Python中使用pypdf2合并、分割、加密pdf文件的代码详解
May 21 #Python
python+selenium实现简历自动刷新的示例代码
May 20 #Python
You might like
PHP中的日期及时间
2006/11/23 PHP
PHP error_log()将错误信息写入一个文件(定义和用法)
2013/10/25 PHP
php汉字转拼音的示例
2014/02/27 PHP
PHP中抽象类、接口的区别与选择分析
2016/03/29 PHP
php微信公众号开发之答题连闯三关
2018/10/20 PHP
PHP实现微信公众号验证Token的示例代码
2019/12/16 PHP
CSS心形加载的动画源码的实现
2021/03/09 HTML / CSS
filemanage功能中用到的lib.js
2007/04/08 Javascript
jQuery的slideToggle方法实例
2013/05/07 Javascript
javaScript 动态访问JSon元素示例代码
2013/08/30 Javascript
JS实现响应鼠标点击动画渐变弹出层效果代码
2016/03/25 Javascript
Angular 通过注入 $location 获取与修改当前页面URL的实例
2017/05/31 Javascript
Vue Transition实现类原生组件跳转过渡动画的示例
2017/08/19 Javascript
JavaScript中数组常见操作技巧
2017/09/01 Javascript
VueJS组件之间通过props交互及验证的方式
2017/09/04 Javascript
vue 父组件调用子组件方法及事件
2018/03/29 Javascript
vue+axios 前端实现登录拦截的两种方式(路由拦截、http拦截)
2018/10/24 Javascript
利用JavaScript缓存远程窃取Wi-Fi密码的思路详解
2018/11/05 Javascript
node.js使用express框架进行文件上传详解
2019/03/03 Javascript
layui使用数据表格实现购物车功能
2019/07/26 Javascript
简述vue-cli中chainWebpack的使用方法
2019/07/30 Javascript
详解element上传组件before-remove钩子问题解决
2020/04/08 Javascript
js cavans实现静态滚动弹幕
2020/05/21 Javascript
理解JavaScript中的Proxy 与 Reflection API
2020/09/21 Javascript
python实现IOU计算案例
2020/04/12 Python
django使用graphql的实例
2020/09/02 Python
python性能测试工具locust的使用
2020/12/28 Python
详解移动端HTML5音频与视频问题及解决方案
2018/08/22 HTML / CSS
爱尔兰最大的体育零售商:Life Style Sports
2019/06/12 全球购物
材料成型专业个人求职信范文
2013/09/25 职场文书
聘任书模板
2014/03/29 职场文书
员工试用期自我鉴定范文
2014/09/15 职场文书
员工2014年度工作总结
2014/12/09 职场文书
承诺函范文
2015/01/21 职场文书
2019广播稿怎么写
2019/04/17 职场文书
微信小程序和php的登录实现
2021/04/01 PHP