pytorch使用Variable实现线性回归


Posted in Python onMay 21, 2019

本文实例为大家分享了pytorch使用Variable实现线性回归的具体代码,供大家参考,具体内容如下

一、手动计算梯度实现线性回归

#导入相关包
import torch as t
import matplotlib.pyplot as plt
 
#构造数据
def get_fake_data(batch_size = 8):
 #设置随机种子数,这样每次生成的随机数都是一样的
 t.manual_seed(10)
 #产生随机数据:y = 2*x+3,加上了一些噪声
 x = t.rand(batch_size,1) * 20
 #randn生成期望为0方差为1的正态分布随机数
 y = x * 2 + (1 + t.randn(batch_size,1)) * 3 
 return x,y
 
#查看生成数据的分布
x,y = get_fake_data()
plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
 
#线性回归
 
#随机初始化参数
w = t.rand(1,1)
b = t.zeros(1,1)
#学习率
lr = 0.001 
 
for i in range(10000):
 x,y = get_fake_data()
 
 #forward:计算loss
 y_pred = x.mm(w) + b.expand_as(y)
 
 #均方误差作为损失函数
 loss = 0.5 * (y_pred - y)**2 
 loss = loss.sum()
 
 #backward:手动计算梯度
 dloss = 1
 dy_pred = dloss * (y_pred - y)
 dw = x.t().mm(dy_pred)
 db = dy_pred.sum()
 
 #更新参数
 w.sub_(lr * dw)
 b.sub_(lr * db)
 
 if i%1000 == 0:
 #画图
 plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
 
 x1 = t.arange(0,20).float().view(-1,1)
 y1 = x1.mm(w) + b.expand_as(x1)
 plt.plot(x1.numpy(),y1.numpy()) #predicted
 plt.show()
 #plt.pause(0.5)
 print(w.squeeze(),b.squeeze())

pytorch使用Variable实现线性回归

显示的最后一张图如下所示:

pytorch使用Variable实现线性回归

二、自动梯度 计算梯度实现线性回归

#导入相关包
import torch as t
from torch.autograd import Variable as V
import matplotlib.pyplot as plt
 
#构造数据
def get_fake_data(batch_size=8):
 t.manual_seed(10) #设置随机数种子
 x = t.rand(batch_size,1) * 20
 y = 2 * x +(1 + t.randn(batch_size,1)) * 3
 return x,y
 
#查看产生的x,y的分布是什么样的
x,y = get_fake_data()
plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
 
#线性回归
 
#初始化随机参数
w = V(t.rand(1,1),requires_grad=True)
b = V(t.rand(1,1),requires_grad=True)
lr = 0.001
for i in range(8000):
 x,y = get_fake_data()
 x,y = V(x),V(y)
 y_pred = x * w + b
 loss = 0.5 * (y_pred-y)**2
 loss = loss.sum()
 
 #自动计算梯度
 loss.backward()
 #更新参数
 w.data.sub_(lr * w.grad.data)
 b.data.sub_(lr * b.grad.data)
 
 #梯度清零,不清零梯度会累加的
 w.grad.data.zero_()
 b.grad.data.zero_()
 
 if i%1000==0:
 #predicted
 x = t.arange(0,20).float().view(-1,1)
 y = x.mm(w.data) + b.data.expand_as(x)
 plt.plot(x.numpy(),y.numpy())
 
 #true data
 x2,y2 = get_fake_data()
 plt.scatter(x2.numpy(),y2.numpy())
 plt.show()
print(w.data[0],b.data[0])

pytorch使用Variable实现线性回归

显示的最后一张图如下所示:

pytorch使用Variable实现线性回归

用autograd实现的线性回归最大的不同点就在于利用autograd不需要手动计算梯度,可以自动微分。这一点不单是在深度在学习中,在许多机器学习的问题中都很有用。另外,需要注意的是每次反向传播之前要记得先把梯度清零,因为autograd求得的梯度是自动累加的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python dict remove数组删除(del,pop)
Mar 24 Python
MySQL最常见的操作语句小结
May 07 Python
浅谈Python数据类型判断及列表脚本操作
Nov 04 Python
使用numba对Python运算加速的方法
Oct 15 Python
对python内置map和six.moves.map的区别详解
Dec 19 Python
python写入数据到csv或xlsx文件的3种方法
Aug 23 Python
python3使用GUI统计代码量
Sep 18 Python
python脚本监控logstash进程并邮件告警实例
Apr 28 Python
python3.7添加dlib模块的方法
Jul 01 Python
Windows下Sqlmap环境安装教程详解
Aug 04 Python
python语言实现贪吃蛇游戏
Nov 13 Python
Django多个app urls配置代码实例
Nov 26 Python
Python面向对象进阶学习
May 21 #Python
谈一谈基于python的面向对象编程基础
May 21 #Python
python字符串和常用数据结构知识总结
May 21 #Python
Opencv实现抠图背景图替换功能
May 21 #Python
python多进程读图提取特征存npy
May 21 #Python
Python中使用pypdf2合并、分割、加密pdf文件的代码详解
May 21 #Python
python+selenium实现简历自动刷新的示例代码
May 20 #Python
You might like
一个简单的MySQL数据浏览器
2006/10/09 PHP
基于PHP5魔术常量与魔术方法的详解
2013/06/13 PHP
php实现window平台的checkdnsrr函数
2015/05/27 PHP
PHP实现微信发红包程序
2015/08/24 PHP
Yii2数据库操作常用方法小结
2017/05/04 PHP
Laravel中任务调度console使用方法小结
2017/05/07 PHP
ArrayList类(增强版)
2007/04/04 Javascript
Javascript 中的 && 和 || 使用小结
2010/04/25 Javascript
基于Jquery的文字滚动跑马灯插件(一个页面多个滚动区)
2010/07/26 Javascript
jquery ajax请求实例深入解析
2012/11/26 Javascript
防止浏览器记住用户名及密码的简单实用方法
2013/04/22 Javascript
推荐10 个很棒的 jQuery 特效代码
2015/10/04 Javascript
Javascript基础_标记文字的实现方法
2016/06/14 Javascript
JavaScript实战之菜单特效
2016/08/16 Javascript
jQuery和JavaScript节点插入元素的方法对比
2016/11/18 Javascript
JS中input表单隐藏域及其使用方法
2017/02/13 Javascript
20行JS代码实现网页刮刮乐效果
2017/06/23 Javascript
vue 页面加载进度条组件实例
2018/02/05 Javascript
Vue如何获取数据列表展示
2019/12/11 Javascript
微信小程序按顺序同步执行的两种方式
2019/12/20 Javascript
[01:15:15]VG VS EG Supermajor小组赛B组胜者组第一轮 BO3第二场 6.2
2018/06/03 DOTA
详解Python的Django框架中inclusion_tag的使用
2015/07/21 Python
使用Django的模版来配合字符串翻译工作
2015/07/27 Python
Python和Perl绘制中国北京跑步地图的方法
2016/03/03 Python
python发送多人邮件没有展示收件人问题的解决方法
2019/06/21 Python
基于Tensorflow高阶读写教程
2020/02/10 Python
python数据爬下来保存的位置
2020/02/17 Python
利用HTML5画出一个坦克的形状具体实现代码
2013/06/20 HTML / CSS
优秀实习生感言
2014/03/01 职场文书
学校节能减排倡议书
2014/05/16 职场文书
媒体宣传策划方案
2014/05/25 职场文书
李强优秀员工观后感
2015/06/16 职场文书
Mysql效率优化定位较低sql的两种方式
2021/05/26 MySQL
python中tkinter复选框使用操作
2021/11/11 Python
使用Nginx的访问日志统计PV与UV
2022/05/06 Servers
解决spring.thymeleaf.cache=false不起作用的问题
2022/06/10 Java/Android