使用pytorch实现线性回归


Posted in Python onApril 11, 2021

线性回归都是包括以下几个步骤:定义模型、选择损失函数、选择优化函数、 训练数据、测试

import torch
import matplotlib.pyplot as plt
# 构建数据集
x_data= torch.Tensor([[1.0],[2.0],[3.0],[4.0],[5.0],[6.0]])
y_data= torch.Tensor([[2.0],[4.0],[6.0],[8.0],[10.0],[12.0]])
#定义模型
class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.linear= torch.nn.Linear(1,1) #表示输入输出都只有一层,相当于前向传播中的函数模型,因为我们一般都不知道函数是什么形式的
 
    def forward(self, x):
        y_pred= self.linear(x)
        return y_pred
model= LinearModel()
# 使用均方误差作为损失函数
criterion= torch.nn.MSELoss(size_average= False)
#使用梯度下降作为优化SGD
# 从下面几种优化器的生成结果图像可以看出,SGD和ASGD效果最好,因为他们的图像收敛速度最快
optimizer= torch.optim.SGD(model.parameters(),lr=0.01)
# ASGD
# optimizer= torch.optim.ASGD(model.parameters(),lr=0.01)
# optimizer= torch.optim.Adagrad(model.parameters(), lr= 0.01)
# optimizer= torch.optim.RMSprop(model.parameters(), lr= 0.01)
# optimizer= torch.optim.Adamax(model.parameters(),lr= 0.01)
# 训练
epoch_list=[]
loss_list=[]
for epoch in range(100):
    y_pred= model(x_data)
    loss= criterion(y_pred, y_data)
    epoch_list.append(epoch)
    loss_list.append(loss.item())
    print(epoch, loss.item())
 
    optimizer.zero_grad() #梯度归零
    loss.backward()  #反向传播
    optimizer.step() #更新参数
 
print("w= ", model.linear.weight.item())
print("b= ",model.linear.bias.item())
 
x_test= torch.Tensor([[7.0]])
y_test= model(x_test)
print("y_pred= ",y_test.data)
 
plt.plot(epoch_list, loss_list)
plt.xlabel("epoch")
plt.ylabel("loss_val")
plt.show()

使用SGD优化器图像:                                                      

使用pytorch实现线性回归

使用ASGD优化器图像:

使用pytorch实现线性回归

使用Adagrad优化器图像:                                                 

使用pytorch实现线性回归

使用Adamax优化器图像:

使用pytorch实现线性回归

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之集合的关系
Sep 24 Python
python根据出生日期返回年龄的方法
Mar 26 Python
Python排序搜索基本算法之选择排序实例分析
Dec 09 Python
python3.6连接MySQL和表的创建与删除实例代码
Dec 28 Python
Centos 升级到python3后pip 无法使用的解决方法
Jun 12 Python
PyQt5创建一个新窗口的实例
Jun 20 Python
Python读取xlsx文件的实现方法
Jul 04 Python
np.newaxis 实现为 numpy.ndarray(多维数组)增加一个轴
Nov 30 Python
Python基本类型的连接组合和互相转换方式(13种)
Dec 16 Python
Python如何生成xml文件
Jun 04 Python
搭建pypi私有仓库实现过程详解
Nov 25 Python
python爬不同图片分别保存在不同文件夹中的实现
Apr 02 Python
pytorch实现线性回归以及多元回归
python如何获取网络数据
Apr 11 #Python
Pytorch 使用tensor特定条件判断索引
selenium.webdriver中add_argument方法常用参数表
Apr 08 #Python
python3使用diagrams绘制架构图的步骤
python实现求纯色彩图像的边框
python爬取企查查企业信息之selenium自动模拟登录企查查
You might like
深入php socket的讲解与实例分析
2013/06/13 PHP
解析PHP SPL标准库的用法(遍历目录,查找固定条件的文件)
2013/06/18 PHP
php使用ffmpeg获取视频信息并截图的实现方法
2016/05/03 PHP
基于CI(CodeIgniter)框架实现购物车功能的方法
2018/04/09 PHP
JS调用页面表格导出excel示例代码
2014/03/18 Javascript
javascript中的this详解
2014/12/08 Javascript
node.js中的emitter.emit方法使用说明
2014/12/10 Javascript
JS+CSS实现可拖拽的漂亮圆角特效弹出层完整实例
2015/02/13 Javascript
详解JavaScript编程中的数组结构
2015/10/24 Javascript
JavaScript+html5 canvas制作的百花齐放效果完整实例
2016/01/26 Javascript
Angular实现form自动布局
2016/01/28 Javascript
浅析Javascript匿名函数与自执行函数
2016/02/06 Javascript
js检测离开或刷新页面时表单数据是否更改的方法
2016/08/02 Javascript
jQuery Ajax 实现在html页面实时显示用户登录状态
2016/12/30 Javascript
nodejs中解决异步嵌套循环和循环嵌套异步的问题
2017/07/12 NodeJs
JavaScript 值类型和引用类型的初次研究(推荐)
2017/07/19 Javascript
原生JS实现移动端web轮播图详解(结合Tween算法造轮子)
2017/09/10 Javascript
jQuery实现定时隐藏对话框的方法分析
2018/02/12 jQuery
微信小程序实现的图片保存功能示例
2019/04/24 Javascript
js实现GIF动图分解成多帧图片上传
2019/10/24 Javascript
Bootstrap实现前端登录页面带验证码功能完整示例
2020/03/26 Javascript
Python的ORM框架SQLObject入门实例
2014/04/28 Python
Python编码爬坑指南(必看)
2016/06/10 Python
深入理解Django自定义信号(signals)
2018/10/15 Python
Pycharm 文件更改目录后,执行路径未更新的解决方法
2019/07/19 Python
Python3 解决读取中文文件txt编码的问题
2019/12/20 Python
Python如何使用paramiko模块连接linux
2020/03/18 Python
python实现TCP文件传输
2020/03/20 Python
如何学习Python time模块
2020/06/03 Python
39美元购买一副眼镜或太阳镜:39DollarGlasses.com
2018/06/17 全球购物
女装和独特珠宝:Sundance Catalog
2018/09/19 全球购物
小学生检讨书大全
2014/02/06 职场文书
幼儿园小班个人工作总结
2015/02/12 职场文书
2015年大学社团工作总结
2015/04/09 职场文书
社区禁毒宣传活动总结
2015/05/07 职场文书
鸿蒙3.0体验感怎么样? 鸿蒙3.0系统评测向
2022/08/14 数码科技