使用pytorch实现线性回归


Posted in Python onApril 11, 2021

线性回归都是包括以下几个步骤:定义模型、选择损失函数、选择优化函数、 训练数据、测试

import torch
import matplotlib.pyplot as plt
# 构建数据集
x_data= torch.Tensor([[1.0],[2.0],[3.0],[4.0],[5.0],[6.0]])
y_data= torch.Tensor([[2.0],[4.0],[6.0],[8.0],[10.0],[12.0]])
#定义模型
class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.linear= torch.nn.Linear(1,1) #表示输入输出都只有一层,相当于前向传播中的函数模型,因为我们一般都不知道函数是什么形式的
 
    def forward(self, x):
        y_pred= self.linear(x)
        return y_pred
model= LinearModel()
# 使用均方误差作为损失函数
criterion= torch.nn.MSELoss(size_average= False)
#使用梯度下降作为优化SGD
# 从下面几种优化器的生成结果图像可以看出,SGD和ASGD效果最好,因为他们的图像收敛速度最快
optimizer= torch.optim.SGD(model.parameters(),lr=0.01)
# ASGD
# optimizer= torch.optim.ASGD(model.parameters(),lr=0.01)
# optimizer= torch.optim.Adagrad(model.parameters(), lr= 0.01)
# optimizer= torch.optim.RMSprop(model.parameters(), lr= 0.01)
# optimizer= torch.optim.Adamax(model.parameters(),lr= 0.01)
# 训练
epoch_list=[]
loss_list=[]
for epoch in range(100):
    y_pred= model(x_data)
    loss= criterion(y_pred, y_data)
    epoch_list.append(epoch)
    loss_list.append(loss.item())
    print(epoch, loss.item())
 
    optimizer.zero_grad() #梯度归零
    loss.backward()  #反向传播
    optimizer.step() #更新参数
 
print("w= ", model.linear.weight.item())
print("b= ",model.linear.bias.item())
 
x_test= torch.Tensor([[7.0]])
y_test= model(x_test)
print("y_pred= ",y_test.data)
 
plt.plot(epoch_list, loss_list)
plt.xlabel("epoch")
plt.ylabel("loss_val")
plt.show()

使用SGD优化器图像:                                                      

使用pytorch实现线性回归

使用ASGD优化器图像:

使用pytorch实现线性回归

使用Adagrad优化器图像:                                                 

使用pytorch实现线性回归

使用Adamax优化器图像:

使用pytorch实现线性回归

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的pprint折腾记
Jan 21 Python
python简单程序读取串口信息的方法
Mar 13 Python
python和bash统计CPU利用率的方法
Jul 10 Python
在Django的模型中执行原始SQL查询的方法
Jul 21 Python
django 常用orm操作详解
Sep 13 Python
PyQt5利用QPainter绘制各种图形的实例
Oct 19 Python
Python遍历pandas数据方法总结
Feb 09 Python
异步任务队列Celery在Django中的使用方法
Jun 07 Python
Python 从列表中取值和取索引的方法
Dec 25 Python
Python使用字典的嵌套功能详解
Feb 27 Python
Python生成器实现简单"生产者消费者"模型代码实例
Mar 27 Python
浅谈PyTorch中in-place operation的含义
Jun 27 Python
pytorch实现线性回归以及多元回归
python如何获取网络数据
Apr 11 #Python
Pytorch 使用tensor特定条件判断索引
selenium.webdriver中add_argument方法常用参数表
Apr 08 #Python
python3使用diagrams绘制架构图的步骤
python实现求纯色彩图像的边框
python爬取企查查企业信息之selenium自动模拟登录企查查
You might like
php实现MD5加密16位(不要默认的32位)
2013/08/12 PHP
PHP进行批量任务处理不超时的解决方法
2016/07/11 PHP
php使用curl获取header检测开启GZip压缩的方法
2018/08/15 PHP
PHP实现笛卡尔积算法的实例讲解
2019/12/22 PHP
统计jQuery中各字符串出现次数的工具
2012/05/03 Javascript
了解一点js的Eval函数
2012/07/26 Javascript
jQuery实现的数值范围range2dslider选取插件特效多款代码分享
2015/08/27 Javascript
原生js实现回复评论功能
2017/01/18 Javascript
Flask中获取小程序Request数据的两种方法
2017/05/12 Javascript
yarn的使用与升级Node.js的方法详解
2017/06/04 Javascript
vue.js评论发布信息可插入QQ表情功能
2017/08/08 Javascript
微信小程序使用request网络请求操作实例
2017/12/15 Javascript
vue2.0自定义指令示例代码详解
2019/04/25 Javascript
JS控制GIF图片的停止与显示
2019/10/24 Javascript
用Python进行一些简单的自然语言处理的教程
2015/03/31 Python
python字符串编码识别模块chardet简单应用
2015/06/15 Python
Python中的Descriptor描述符学习教程
2016/06/02 Python
SQLite3中文编码 Python的实现
2017/01/11 Python
详解python调度框架APScheduler使用
2017/03/28 Python
Python 中的Selenium异常处理实例代码
2018/05/03 Python
Python实现压缩文件夹与解压缩zip文件的方法
2018/09/01 Python
Python3.8中使用f-strings调试
2019/05/22 Python
python实现连续变量最优分箱详解--CART算法
2019/11/22 Python
Python基于stuck实现scoket文件传输
2020/04/02 Python
200行python代码实现贪吃蛇游戏
2020/04/24 Python
Python 爬取淘宝商品信息栏目的实现
2021/02/06 Python
校园文化建设方案
2014/02/03 职场文书
《鹬蚌相争》教学反思
2014/04/22 职场文书
高中毕业典礼演讲稿
2014/09/09 职场文书
国庆横幅标语
2014/10/08 职场文书
考试作弊检讨书范文
2015/01/27 职场文书
毕业典礼邀请函
2015/01/31 职场文书
公司禁烟通知
2015/04/23 职场文书
2019职场单身人才调研报告:互联网行业单身比例最高
2019/08/07 职场文书
Java9新特性对HTTP2协议支持与非阻塞HTTP API
2022/03/16 Java/Android
Docker安装MySql8并远程访问的实现
2022/07/07 Servers