pytorch使用Variable实现线性回归


Posted in Python onMay 21, 2019

本文实例为大家分享了pytorch使用Variable实现线性回归的具体代码,供大家参考,具体内容如下

一、手动计算梯度实现线性回归

#导入相关包
import torch as t
import matplotlib.pyplot as plt
 
#构造数据
def get_fake_data(batch_size = 8):
 #设置随机种子数,这样每次生成的随机数都是一样的
 t.manual_seed(10)
 #产生随机数据:y = 2*x+3,加上了一些噪声
 x = t.rand(batch_size,1) * 20
 #randn生成期望为0方差为1的正态分布随机数
 y = x * 2 + (1 + t.randn(batch_size,1)) * 3 
 return x,y
 
#查看生成数据的分布
x,y = get_fake_data()
plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
 
#线性回归
 
#随机初始化参数
w = t.rand(1,1)
b = t.zeros(1,1)
#学习率
lr = 0.001 
 
for i in range(10000):
 x,y = get_fake_data()
 
 #forward:计算loss
 y_pred = x.mm(w) + b.expand_as(y)
 
 #均方误差作为损失函数
 loss = 0.5 * (y_pred - y)**2 
 loss = loss.sum()
 
 #backward:手动计算梯度
 dloss = 1
 dy_pred = dloss * (y_pred - y)
 dw = x.t().mm(dy_pred)
 db = dy_pred.sum()
 
 #更新参数
 w.sub_(lr * dw)
 b.sub_(lr * db)
 
 if i%1000 == 0:
 #画图
 plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
 
 x1 = t.arange(0,20).float().view(-1,1)
 y1 = x1.mm(w) + b.expand_as(x1)
 plt.plot(x1.numpy(),y1.numpy()) #predicted
 plt.show()
 #plt.pause(0.5)
 print(w.squeeze(),b.squeeze())

pytorch使用Variable实现线性回归

显示的最后一张图如下所示:

pytorch使用Variable实现线性回归

二、自动梯度 计算梯度实现线性回归

#导入相关包
import torch as t
from torch.autograd import Variable as V
import matplotlib.pyplot as plt
 
#构造数据
def get_fake_data(batch_size=8):
 t.manual_seed(10) #设置随机数种子
 x = t.rand(batch_size,1) * 20
 y = 2 * x +(1 + t.randn(batch_size,1)) * 3
 return x,y
 
#查看产生的x,y的分布是什么样的
x,y = get_fake_data()
plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
 
#线性回归
 
#初始化随机参数
w = V(t.rand(1,1),requires_grad=True)
b = V(t.rand(1,1),requires_grad=True)
lr = 0.001
for i in range(8000):
 x,y = get_fake_data()
 x,y = V(x),V(y)
 y_pred = x * w + b
 loss = 0.5 * (y_pred-y)**2
 loss = loss.sum()
 
 #自动计算梯度
 loss.backward()
 #更新参数
 w.data.sub_(lr * w.grad.data)
 b.data.sub_(lr * b.grad.data)
 
 #梯度清零,不清零梯度会累加的
 w.grad.data.zero_()
 b.grad.data.zero_()
 
 if i%1000==0:
 #predicted
 x = t.arange(0,20).float().view(-1,1)
 y = x.mm(w.data) + b.data.expand_as(x)
 plt.plot(x.numpy(),y.numpy())
 
 #true data
 x2,y2 = get_fake_data()
 plt.scatter(x2.numpy(),y2.numpy())
 plt.show()
print(w.data[0],b.data[0])

pytorch使用Variable实现线性回归

显示的最后一张图如下所示:

pytorch使用Variable实现线性回归

用autograd实现的线性回归最大的不同点就在于利用autograd不需要手动计算梯度,可以自动微分。这一点不单是在深度在学习中,在许多机器学习的问题中都很有用。另外,需要注意的是每次反向传播之前要记得先把梯度清零,因为autograd求得的梯度是自动累加的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python全局变量操作详解
Apr 14 Python
Django安装配置mysql的方法步骤
Oct 15 Python
Django页面数据的缓存与使用的具体方法
Apr 23 Python
python实现文件助手中查看微信撤回消息
Apr 29 Python
Python 用matplotlib画以时间日期为x轴的图像
Aug 06 Python
django中使用事务及接入支付宝支付功能
Sep 15 Python
python进程间通信Queue工作过程详解
Nov 01 Python
python实现在一个画布上画多个子图
Jan 19 Python
python GUI库图形界面开发之PyQt5输入对话框QInputDialog详细使用方法与实例
Feb 27 Python
Python包资源下载路径报404解决方案
Nov 05 Python
Pyqt助手安装PyQt5帮助文档过程图解
Nov 20 Python
pytorch中Schedule与warmup_steps的用法说明
May 24 Python
Python面向对象进阶学习
May 21 #Python
谈一谈基于python的面向对象编程基础
May 21 #Python
python字符串和常用数据结构知识总结
May 21 #Python
Opencv实现抠图背景图替换功能
May 21 #Python
python多进程读图提取特征存npy
May 21 #Python
Python中使用pypdf2合并、分割、加密pdf文件的代码详解
May 21 #Python
python+selenium实现简历自动刷新的示例代码
May 20 #Python
You might like
php的一个登录的类 [推荐]
2007/03/16 PHP
快速配置PHPMyAdmin方法
2008/06/05 PHP
浅谈Eclipse PDT调试PHP程序
2014/06/09 PHP
php实现修改新闻时删除图片的方法
2015/05/12 PHP
php cookie用户登录的详解及实例代码
2017/01/03 PHP
PHP常量及变量区别原理详解
2020/08/14 PHP
Mootools 1.2教程 事件处理
2009/09/15 Javascript
Javascript new关键字的玄机 以及其它
2010/08/25 Javascript
从面试题学习Javascript 面向对象(创建对象)
2012/03/30 Javascript
javascript数组去重方法汇总
2015/04/23 Javascript
用自定义图片代替原生checkbox实现全选,删除以及提交的方法
2016/10/18 Javascript
半个小时学json(json传递示例)
2016/12/25 Javascript
Bootstrap 模态对话框只加载一次 remote 数据的完美解决办法
2017/07/09 Javascript
json字符串传到前台input的方法
2018/08/06 Javascript
NodeJS搭建HTTP服务器的实现步骤
2018/10/12 NodeJs
微信小程序canvas分享海报功能
2019/10/31 Javascript
[04:42]5分钟带你了解什么是DOTA2(第一期)
2017/02/07 DOTA
[01:03:22]LGD vs OG 2018国际邀请赛淘汰赛BO3 第一场 8.25
2018/08/29 DOTA
python新手经常遇到的17个错误分析
2014/07/30 Python
Pytorch修改ResNet模型全连接层进行直接训练实例
2019/09/10 Python
Python容器使用的5个技巧和2个误区总结
2019/09/26 Python
Python进程间通信multiprocess代码实例
2020/03/18 Python
Python调用接口合并Excel表代码实例
2020/03/31 Python
浅谈多卡服务器下隐藏部分 GPU 和 TensorFlow 的显存使用设置
2020/06/30 Python
CSS3 边框效果
2019/11/04 HTML / CSS
html5 视频播放解决方案
2016/11/06 HTML / CSS
班干部竞选演讲稿
2014/04/24 职场文书
精神文明单位申报材料
2014/05/02 职场文书
关于安全的演讲稿
2014/05/09 职场文书
演讲稿的格式及范文
2014/08/22 职场文书
学校党的群众路线教育实践活动整改措施
2014/10/25 职场文书
2015年办公室主任工作总结
2015/04/09 职场文书
2016年区委书记抓基层党建工作公开承诺书
2016/03/25 职场文书
nginx优化的六点方法
2021/03/31 Servers
Nginx配置https原理及实现过程详解
2021/03/31 Servers
Java中的Kafka为什么性能这么快及4大核心详析
2022/09/23 Java/Android