pytorch使用Variable实现线性回归


Posted in Python onMay 21, 2019

本文实例为大家分享了pytorch使用Variable实现线性回归的具体代码,供大家参考,具体内容如下

一、手动计算梯度实现线性回归

#导入相关包
import torch as t
import matplotlib.pyplot as plt
 
#构造数据
def get_fake_data(batch_size = 8):
 #设置随机种子数,这样每次生成的随机数都是一样的
 t.manual_seed(10)
 #产生随机数据:y = 2*x+3,加上了一些噪声
 x = t.rand(batch_size,1) * 20
 #randn生成期望为0方差为1的正态分布随机数
 y = x * 2 + (1 + t.randn(batch_size,1)) * 3 
 return x,y
 
#查看生成数据的分布
x,y = get_fake_data()
plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
 
#线性回归
 
#随机初始化参数
w = t.rand(1,1)
b = t.zeros(1,1)
#学习率
lr = 0.001 
 
for i in range(10000):
 x,y = get_fake_data()
 
 #forward:计算loss
 y_pred = x.mm(w) + b.expand_as(y)
 
 #均方误差作为损失函数
 loss = 0.5 * (y_pred - y)**2 
 loss = loss.sum()
 
 #backward:手动计算梯度
 dloss = 1
 dy_pred = dloss * (y_pred - y)
 dw = x.t().mm(dy_pred)
 db = dy_pred.sum()
 
 #更新参数
 w.sub_(lr * dw)
 b.sub_(lr * db)
 
 if i%1000 == 0:
 #画图
 plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
 
 x1 = t.arange(0,20).float().view(-1,1)
 y1 = x1.mm(w) + b.expand_as(x1)
 plt.plot(x1.numpy(),y1.numpy()) #predicted
 plt.show()
 #plt.pause(0.5)
 print(w.squeeze(),b.squeeze())

pytorch使用Variable实现线性回归

显示的最后一张图如下所示:

pytorch使用Variable实现线性回归

二、自动梯度 计算梯度实现线性回归

#导入相关包
import torch as t
from torch.autograd import Variable as V
import matplotlib.pyplot as plt
 
#构造数据
def get_fake_data(batch_size=8):
 t.manual_seed(10) #设置随机数种子
 x = t.rand(batch_size,1) * 20
 y = 2 * x +(1 + t.randn(batch_size,1)) * 3
 return x,y
 
#查看产生的x,y的分布是什么样的
x,y = get_fake_data()
plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
 
#线性回归
 
#初始化随机参数
w = V(t.rand(1,1),requires_grad=True)
b = V(t.rand(1,1),requires_grad=True)
lr = 0.001
for i in range(8000):
 x,y = get_fake_data()
 x,y = V(x),V(y)
 y_pred = x * w + b
 loss = 0.5 * (y_pred-y)**2
 loss = loss.sum()
 
 #自动计算梯度
 loss.backward()
 #更新参数
 w.data.sub_(lr * w.grad.data)
 b.data.sub_(lr * b.grad.data)
 
 #梯度清零,不清零梯度会累加的
 w.grad.data.zero_()
 b.grad.data.zero_()
 
 if i%1000==0:
 #predicted
 x = t.arange(0,20).float().view(-1,1)
 y = x.mm(w.data) + b.data.expand_as(x)
 plt.plot(x.numpy(),y.numpy())
 
 #true data
 x2,y2 = get_fake_data()
 plt.scatter(x2.numpy(),y2.numpy())
 plt.show()
print(w.data[0],b.data[0])

pytorch使用Variable实现线性回归

显示的最后一张图如下所示:

pytorch使用Variable实现线性回归

用autograd实现的线性回归最大的不同点就在于利用autograd不需要手动计算梯度,可以自动微分。这一点不单是在深度在学习中,在许多机器学习的问题中都很有用。另外,需要注意的是每次反向传播之前要记得先把梯度清零,因为autograd求得的梯度是自动累加的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python访问类中docstring注释的实现方法
May 04 Python
python实现时间o(1)的最小栈的实例代码
Jul 23 Python
使用python将时间转换为指定的格式方法
Nov 12 Python
快速排序的四种python实现(推荐)
Apr 03 Python
Apache部署Django项目图文详解
Jul 30 Python
利用Python绘制Jazz网络图的例子
Nov 21 Python
简单了解python字符串前面加r,u的含义
Dec 26 Python
python基于三阶贝塞尔曲线的数据平滑算法
Dec 27 Python
将tf.batch_matmul替换成tf.matmul的实现
Jun 18 Python
python中编写函数并调用的知识点总结
Jan 13 Python
python控制台打印log输出重复的解决方法
May 14 Python
教你怎么用python selenium实现自动化测试
May 27 Python
Python面向对象进阶学习
May 21 #Python
谈一谈基于python的面向对象编程基础
May 21 #Python
python字符串和常用数据结构知识总结
May 21 #Python
Opencv实现抠图背景图替换功能
May 21 #Python
python多进程读图提取特征存npy
May 21 #Python
Python中使用pypdf2合并、分割、加密pdf文件的代码详解
May 21 #Python
python+selenium实现简历自动刷新的示例代码
May 20 #Python
You might like
用PHP和ACCESS写聊天室(三)
2006/10/09 PHP
php下几个常用的去空、分组、调试数组函数
2009/02/22 PHP
php session_start()关于Cannot send session cache limiter - headers already sent错误解决方法
2009/11/27 PHP
常见php数据文件缓存类汇总
2014/12/05 PHP
PHP四种排序算法实现及效率分析【冒泡排序,插入排序,选择排序和快速排序】
2018/04/27 PHP
PHP创建对象的六种方式实例总结
2019/06/27 PHP
js鼠标左右键 键盘值小结
2010/06/11 Javascript
Jquery判断IE6等浏览器的代码
2011/04/05 Javascript
Package.js  现代化的JavaScript项目make工具
2012/05/23 Javascript
JavaScript数据类型详解
2015/04/01 Javascript
JS深度拷贝Object Array实例分析
2016/03/31 Javascript
JS实现的RGB网页颜色在线取色器完整实例
2016/12/21 Javascript
用js制作淘宝放大镜效果
2020/10/28 Javascript
Vue.js开发环境快速搭建教程
2017/03/17 Javascript
vue事件修饰符和按键修饰符用法总结
2017/07/25 Javascript
JS执行控制之节流模式实例分析
2018/12/21 Javascript
JS module的导出和导入的实现代码
2019/02/25 Javascript
详解elementui之el-image-viewer(图片查看器)
2019/08/30 Javascript
Vue 实现CLI 3.0 + momentjs + lodash打包时优化
2019/11/13 Javascript
Python中使用PyQt把网页转换成PDF操作代码实例
2015/04/23 Python
python利用datetime模块计算时间差
2015/08/04 Python
深入理解python中函数传递参数是值传递还是引用传递
2017/11/07 Python
关于Python数据结构中字典的心得
2017/12/04 Python
python实现俄罗斯方块游戏
2020/03/25 Python
Python实现获取当前目录下文件名代码详解
2020/03/10 Python
python 中 .py文件 转 .pyd文件的操作
2021/03/04 Python
让IE支持HTML5的方法
2012/12/11 HTML / CSS
浅谈HTML5 Web Worker的使用
2018/01/05 HTML / CSS
美国在线家居装饰店:Belle&June
2018/10/24 全球购物
名人演讲稿范文
2013/12/28 职场文书
开业庆典邀请函
2014/01/08 职场文书
《桃林那间小木屋》教学反思
2014/05/01 职场文书
党的群众路线教育实践活动领导班子整改方案
2014/10/25 职场文书
2014年个人业务工作总结
2014/11/17 职场文书
超级详细实用的pycharm常用快捷键
2021/05/12 Python
Python必备技巧之函数的使用详解
2022/04/04 Python