pytorch使用Variable实现线性回归


Posted in Python onMay 21, 2019

本文实例为大家分享了pytorch使用Variable实现线性回归的具体代码,供大家参考,具体内容如下

一、手动计算梯度实现线性回归

#导入相关包
import torch as t
import matplotlib.pyplot as plt
 
#构造数据
def get_fake_data(batch_size = 8):
 #设置随机种子数,这样每次生成的随机数都是一样的
 t.manual_seed(10)
 #产生随机数据:y = 2*x+3,加上了一些噪声
 x = t.rand(batch_size,1) * 20
 #randn生成期望为0方差为1的正态分布随机数
 y = x * 2 + (1 + t.randn(batch_size,1)) * 3 
 return x,y
 
#查看生成数据的分布
x,y = get_fake_data()
plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
 
#线性回归
 
#随机初始化参数
w = t.rand(1,1)
b = t.zeros(1,1)
#学习率
lr = 0.001 
 
for i in range(10000):
 x,y = get_fake_data()
 
 #forward:计算loss
 y_pred = x.mm(w) + b.expand_as(y)
 
 #均方误差作为损失函数
 loss = 0.5 * (y_pred - y)**2 
 loss = loss.sum()
 
 #backward:手动计算梯度
 dloss = 1
 dy_pred = dloss * (y_pred - y)
 dw = x.t().mm(dy_pred)
 db = dy_pred.sum()
 
 #更新参数
 w.sub_(lr * dw)
 b.sub_(lr * db)
 
 if i%1000 == 0:
 #画图
 plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
 
 x1 = t.arange(0,20).float().view(-1,1)
 y1 = x1.mm(w) + b.expand_as(x1)
 plt.plot(x1.numpy(),y1.numpy()) #predicted
 plt.show()
 #plt.pause(0.5)
 print(w.squeeze(),b.squeeze())

pytorch使用Variable实现线性回归

显示的最后一张图如下所示:

pytorch使用Variable实现线性回归

二、自动梯度 计算梯度实现线性回归

#导入相关包
import torch as t
from torch.autograd import Variable as V
import matplotlib.pyplot as plt
 
#构造数据
def get_fake_data(batch_size=8):
 t.manual_seed(10) #设置随机数种子
 x = t.rand(batch_size,1) * 20
 y = 2 * x +(1 + t.randn(batch_size,1)) * 3
 return x,y
 
#查看产生的x,y的分布是什么样的
x,y = get_fake_data()
plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
 
#线性回归
 
#初始化随机参数
w = V(t.rand(1,1),requires_grad=True)
b = V(t.rand(1,1),requires_grad=True)
lr = 0.001
for i in range(8000):
 x,y = get_fake_data()
 x,y = V(x),V(y)
 y_pred = x * w + b
 loss = 0.5 * (y_pred-y)**2
 loss = loss.sum()
 
 #自动计算梯度
 loss.backward()
 #更新参数
 w.data.sub_(lr * w.grad.data)
 b.data.sub_(lr * b.grad.data)
 
 #梯度清零,不清零梯度会累加的
 w.grad.data.zero_()
 b.grad.data.zero_()
 
 if i%1000==0:
 #predicted
 x = t.arange(0,20).float().view(-1,1)
 y = x.mm(w.data) + b.data.expand_as(x)
 plt.plot(x.numpy(),y.numpy())
 
 #true data
 x2,y2 = get_fake_data()
 plt.scatter(x2.numpy(),y2.numpy())
 plt.show()
print(w.data[0],b.data[0])

pytorch使用Variable实现线性回归

显示的最后一张图如下所示:

pytorch使用Variable实现线性回归

用autograd实现的线性回归最大的不同点就在于利用autograd不需要手动计算梯度,可以自动微分。这一点不单是在深度在学习中,在许多机器学习的问题中都很有用。另外,需要注意的是每次反向传播之前要记得先把梯度清零,因为autograd求得的梯度是自动累加的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python遍历类中所有成员的方法
Mar 18 Python
简洁的十分钟Python入门教程
Apr 03 Python
Python的Django框架中TEMPLATES项的设置教程
May 29 Python
python实现线程池的方法
Jun 30 Python
Python与Java间Socket通信实例代码
Mar 06 Python
Python爬虫实例扒取2345天气预报
Mar 04 Python
下载python中Crypto库报错:ModuleNotFoundError: No module named ‘Crypto’的解决
Apr 23 Python
Python爬虫获取图片并下载保存至本地的实例
Jun 01 Python
python实现将一个数组逆序输出的方法
Jun 25 Python
python接口自动化(十六)--参数关联接口后传(详解)
Apr 16 Python
python使用列表的最佳方案
Aug 12 Python
Django操作cookie的实现
May 26 Python
Python面向对象进阶学习
May 21 #Python
谈一谈基于python的面向对象编程基础
May 21 #Python
python字符串和常用数据结构知识总结
May 21 #Python
Opencv实现抠图背景图替换功能
May 21 #Python
python多进程读图提取特征存npy
May 21 #Python
Python中使用pypdf2合并、分割、加密pdf文件的代码详解
May 21 #Python
python+selenium实现简历自动刷新的示例代码
May 20 #Python
You might like
深入array multisort排序原理的详解
2013/06/18 PHP
制作安全性高的PHP网站的几个实用要点
2014/12/30 PHP
php获取远程文件内容的函数
2015/11/02 PHP
PHP中abstract(抽象)、final(最终)和static(静态)原理与用法详解
2020/06/05 PHP
Aster vs KG BO3 第三场2.19
2021/03/10 DOTA
Jquery 基础学习笔记
2009/05/29 Javascript
Javascript Cookie读写删除操作的函数
2010/03/02 Javascript
一个js的tab切换效果代码[代码分离]
2010/04/11 Javascript
克隆javascript对象的三个方法小结
2011/01/12 Javascript
数组方法解决JS字符串连接性能问题有争议
2011/01/12 Javascript
js单向链表的具体实现实例
2013/06/21 Javascript
JavaScript判断变量是否为数组的方法(Array)
2016/02/24 Javascript
jQuery Easyui 下拉树组件combotree
2016/12/16 Javascript
微信小程序 获取javascript 里的数据
2017/08/17 Javascript
JavaScript设计模式之策略模式实现原理详解
2020/05/29 Javascript
vue如何在项目中调用腾讯云的滑动验证码
2020/07/15 Javascript
Python的Django框架可适配的各种数据库介绍
2015/07/15 Python
Python基于pygame模块播放MP3的方法示例
2017/09/30 Python
致Python初学者 Anaconda入门使用指南完整版
2018/04/05 Python
python队列queue模块详解
2018/04/27 Python
python os.path模块常用方法实例详解
2018/09/16 Python
解决pyinstaller打包exe文件出现命令窗口一闪而过的问题
2018/10/31 Python
Python之循环结构
2019/01/15 Python
Python 调用有道翻译接口实现翻译
2020/03/02 Python
Numpy 多维数据数组的实现
2020/06/18 Python
美国医疗用品、医疗设备和家庭保健用品商店:Medical Supply Depot
2018/07/08 全球购物
麦当劳辞职信范文
2014/01/18 职场文书
2014年小班元旦活动方案
2014/02/16 职场文书
社会实践活动总结范文
2014/07/03 职场文书
2014年行政助理工作总结
2014/11/19 职场文书
2014年政协工作总结
2014/12/09 职场文书
优秀共产党员推荐材料
2014/12/18 职场文书
一个独生女的故事观后感
2015/06/04 职场文书
2016孝老爱亲模范事迹材料
2016/02/26 职场文书
html form表单基础入门案例讲解
2021/07/15 HTML / CSS
Python用any()函数检查字符串中的字母以及如何使用all()函数
2022/04/14 Python