使用pytorch实现线性回归


Posted in Python onApril 11, 2021

线性回归都是包括以下几个步骤:定义模型、选择损失函数、选择优化函数、 训练数据、测试

import torch
import matplotlib.pyplot as plt
# 构建数据集
x_data= torch.Tensor([[1.0],[2.0],[3.0],[4.0],[5.0],[6.0]])
y_data= torch.Tensor([[2.0],[4.0],[6.0],[8.0],[10.0],[12.0]])
#定义模型
class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.linear= torch.nn.Linear(1,1) #表示输入输出都只有一层,相当于前向传播中的函数模型,因为我们一般都不知道函数是什么形式的
 
    def forward(self, x):
        y_pred= self.linear(x)
        return y_pred
model= LinearModel()
# 使用均方误差作为损失函数
criterion= torch.nn.MSELoss(size_average= False)
#使用梯度下降作为优化SGD
# 从下面几种优化器的生成结果图像可以看出,SGD和ASGD效果最好,因为他们的图像收敛速度最快
optimizer= torch.optim.SGD(model.parameters(),lr=0.01)
# ASGD
# optimizer= torch.optim.ASGD(model.parameters(),lr=0.01)
# optimizer= torch.optim.Adagrad(model.parameters(), lr= 0.01)
# optimizer= torch.optim.RMSprop(model.parameters(), lr= 0.01)
# optimizer= torch.optim.Adamax(model.parameters(),lr= 0.01)
# 训练
epoch_list=[]
loss_list=[]
for epoch in range(100):
    y_pred= model(x_data)
    loss= criterion(y_pred, y_data)
    epoch_list.append(epoch)
    loss_list.append(loss.item())
    print(epoch, loss.item())
 
    optimizer.zero_grad() #梯度归零
    loss.backward()  #反向传播
    optimizer.step() #更新参数
 
print("w= ", model.linear.weight.item())
print("b= ",model.linear.bias.item())
 
x_test= torch.Tensor([[7.0]])
y_test= model(x_test)
print("y_pred= ",y_test.data)
 
plt.plot(epoch_list, loss_list)
plt.xlabel("epoch")
plt.ylabel("loss_val")
plt.show()

使用SGD优化器图像:                                                      

使用pytorch实现线性回归

使用ASGD优化器图像:

使用pytorch实现线性回归

使用Adagrad优化器图像:                                                 

使用pytorch实现线性回归

使用Adamax优化器图像:

使用pytorch实现线性回归

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解Python中DOM方法的动态性
Apr 11 Python
详解python如何调用C/C++底层库与互相传值
Aug 10 Python
Linux下安装python3.6和第三方库的教程详解
Nov 09 Python
Python 日志logging模块用法简单示例
Oct 18 Python
Python SSL证书验证问题解决方案
Jan 13 Python
Python使用monkey.patch_all()解决协程阻塞问题
Apr 15 Python
使用OpenCV对车道进行实时检测的实现示例代码
Jun 19 Python
Python docutils文档编译过程方法解析
Jun 23 Python
Python获取excel内容及相关操作代码实例
Aug 10 Python
浅析python中的del用法
Sep 02 Python
有关pycharm登录github时有的时候会报错connection reset的问题
Sep 15 Python
python中的被动信息搜集
Apr 29 Python
pytorch实现线性回归以及多元回归
python如何获取网络数据
Apr 11 #Python
Pytorch 使用tensor特定条件判断索引
selenium.webdriver中add_argument方法常用参数表
Apr 08 #Python
python3使用diagrams绘制架构图的步骤
python实现求纯色彩图像的边框
python爬取企查查企业信息之selenium自动模拟登录企查查
You might like
Zerg兵种介绍
2020/03/14 星际争霸
浅谈laravel 5.6 安装 windows上使用composer的安装过程
2019/10/18 PHP
javascript 学习笔记(八)javascript对象
2011/04/12 Javascript
JavaScript根据数据生成百分比图和柱状图的实例代码
2013/07/14 Javascript
js 剪切板的用法(clipboardData.setData)与js match函数介绍
2013/11/19 Javascript
js实现显示当前状态的导航效果代码
2015/08/28 Javascript
jquery+ajax实现注册实时验证实例详解
2015/12/08 Javascript
JS实现的倒计时效果实例(2则实例)
2015/12/23 Javascript
学习JavaScript事件流和事件处理程序
2016/01/25 Javascript
基于jquery实现智能表单验证操作
2016/05/09 Javascript
Javascript函数中的arguments.callee用法实例分析
2016/09/16 Javascript
Node.js的特点详解
2017/02/03 Javascript
javascript常用的设计模式
2017/02/09 Javascript
大白话讲解JavaScript的Promise
2017/04/06 Javascript
Angular 4.x+Ionic3踩坑之Ionic3.x pop反向传值详解
2018/03/13 Javascript
详解小程序开发经验:多页面数据同步
2019/05/18 Javascript
浅谈Vue.set实际上是什么
2019/10/17 Javascript
JS JQuery获取data-*属性值方法解析
2020/09/01 jQuery
python提取字典key列表的方法
2015/07/11 Python
详解Python迭代和迭代器
2016/03/28 Python
python爬虫之BeautifulSoup 使用select方法详解
2017/10/23 Python
基于DataFrame改变列类型的方法
2018/07/25 Python
python+pyqt5实现图片批量缩放工具
2019/03/18 Python
python实现大量图片重命名
2020/03/23 Python
使用Pyhton集合set()实现成果查漏的例子
2019/11/24 Python
python图形界面开发之wxPython树控件使用方法详解
2020/02/24 Python
python3通过udp实现组播数据的发送和接收操作
2020/05/05 Python
如何在Win10系统使用Python3连接Hive
2020/10/15 Python
canvas实现漂亮的下雨效果的示例
2018/04/18 HTML / CSS
你在项目中用到了xml技术的哪些方面?如何实现的?
2014/01/26 面试题
挂职学习心得体会
2014/09/09 职场文书
有限公司股东合作协议书
2014/10/29 职场文书
网络研修随笔感言
2015/11/18 职场文书
六五普法心得体会2016
2016/01/21 职场文书
如何用JS实现简单的数据监听
2021/05/06 Javascript
如何给HttpServletRequest增加消息头
2021/06/30 Java/Android