使用pytorch实现线性回归


Posted in Python onApril 11, 2021

线性回归都是包括以下几个步骤:定义模型、选择损失函数、选择优化函数、 训练数据、测试

import torch
import matplotlib.pyplot as plt
# 构建数据集
x_data= torch.Tensor([[1.0],[2.0],[3.0],[4.0],[5.0],[6.0]])
y_data= torch.Tensor([[2.0],[4.0],[6.0],[8.0],[10.0],[12.0]])
#定义模型
class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.linear= torch.nn.Linear(1,1) #表示输入输出都只有一层,相当于前向传播中的函数模型,因为我们一般都不知道函数是什么形式的
 
    def forward(self, x):
        y_pred= self.linear(x)
        return y_pred
model= LinearModel()
# 使用均方误差作为损失函数
criterion= torch.nn.MSELoss(size_average= False)
#使用梯度下降作为优化SGD
# 从下面几种优化器的生成结果图像可以看出,SGD和ASGD效果最好,因为他们的图像收敛速度最快
optimizer= torch.optim.SGD(model.parameters(),lr=0.01)
# ASGD
# optimizer= torch.optim.ASGD(model.parameters(),lr=0.01)
# optimizer= torch.optim.Adagrad(model.parameters(), lr= 0.01)
# optimizer= torch.optim.RMSprop(model.parameters(), lr= 0.01)
# optimizer= torch.optim.Adamax(model.parameters(),lr= 0.01)
# 训练
epoch_list=[]
loss_list=[]
for epoch in range(100):
    y_pred= model(x_data)
    loss= criterion(y_pred, y_data)
    epoch_list.append(epoch)
    loss_list.append(loss.item())
    print(epoch, loss.item())
 
    optimizer.zero_grad() #梯度归零
    loss.backward()  #反向传播
    optimizer.step() #更新参数
 
print("w= ", model.linear.weight.item())
print("b= ",model.linear.bias.item())
 
x_test= torch.Tensor([[7.0]])
y_test= model(x_test)
print("y_pred= ",y_test.data)
 
plt.plot(epoch_list, loss_list)
plt.xlabel("epoch")
plt.ylabel("loss_val")
plt.show()

使用SGD优化器图像:                                                      

使用pytorch实现线性回归

使用ASGD优化器图像:

使用pytorch实现线性回归

使用Adagrad优化器图像:                                                 

使用pytorch实现线性回归

使用Adamax优化器图像:

使用pytorch实现线性回归

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python3.3实现乘法表示例
Feb 07 Python
Python Web框架Flask中使用新浪SAE云存储实例
Feb 08 Python
Python中不同进制的语法及转换方法分析
Jul 27 Python
Python实现翻转数组功能示例
Jan 12 Python
Python基于递归实现电话号码映射功能示例
Apr 13 Python
使用pandas批量处理矢量化字符串的实例讲解
Jul 10 Python
python超时重新请求解决方案
Oct 21 Python
Python实现猜年龄游戏代码实例
Mar 25 Python
Django nginx配置实现过程详解
Sep 10 Python
python中pyplot基础图标函数整理
Nov 10 Python
Python使用Opencv实现边缘检测以及轮廓检测的实现
Dec 31 Python
python数字图像处理:图像的绘制
Jun 28 Python
pytorch实现线性回归以及多元回归
python如何获取网络数据
Apr 11 #Python
Pytorch 使用tensor特定条件判断索引
selenium.webdriver中add_argument方法常用参数表
Apr 08 #Python
python3使用diagrams绘制架构图的步骤
python实现求纯色彩图像的边框
python爬取企查查企业信息之selenium自动模拟登录企查查
You might like
关于php程序报date()警告的处理(date_default_timezone_set)
2013/10/22 PHP
thinkPHP基于ajax实现的菜单与分页示例
2016/07/12 PHP
PHP 信号管理知识整理汇总
2017/02/19 PHP
PHP设计模式之单例模式原理与实现方法分析
2018/04/25 PHP
Extjs显示从数据库取出时间转换JSON后的出现问题
2012/11/20 Javascript
JavaScript设置首页和收藏页面的小例子
2013/11/11 Javascript
firefox下jquery ajax返回object XMLDocument处理方法
2014/01/26 Javascript
22点关于jquery性能优化的建议
2014/05/28 Javascript
js中直接声明一个对象的方法
2014/08/10 Javascript
jquery插件推荐浏览器嗅探userAgent
2014/11/09 Javascript
php常见的页面跳转方法汇总
2015/04/15 Javascript
JavaScript中的函数嵌套使用
2015/06/04 Javascript
jQuery插件animateSlide制作多点滑动幻灯片
2015/06/11 Javascript
jQuery EasyUI Dialog拖不下来如何解决
2015/09/28 Javascript
深入理解jquery跨域请求方法
2016/05/18 Javascript
node.js路径处理方法以及绝对路径详解
2021/03/04 Javascript
微信小程序 动态传参实例详解
2017/04/27 Javascript
详解vue2父组件传递props异步数据到子组件的问题
2017/06/29 Javascript
Angular 封装并发布组件的方法示例
2018/04/19 Javascript
用Node编写RESTful API接口的示例代码
2018/07/04 Javascript
在vue中使用echarts(折线图的demo,markline用法)
2020/07/20 Javascript
针对Vue路由history模式下Nginx后台配置操作
2020/10/22 Javascript
Python  __getattr__与__setattr__使用方法
2008/09/06 Python
pycharm 使用心得(七)一些实用功能介绍
2014/06/06 Python
详解 Python中LEGB和闭包及装饰器
2017/08/03 Python
Python设计模式之享元模式原理与用法实例分析
2019/01/11 Python
python绘制雪景图
2019/12/16 Python
用Python绘制漫步图实例讲解
2020/02/26 Python
Django模型中字段属性choice使用说明
2020/03/30 Python
html5 figure和figcaption的使用方法
2018/09/10 HTML / CSS
ProBikeKit英国:在线公路自行车之家
2017/02/10 全球购物
阿联酋航空丹麦官方网站:Emirates DK
2019/08/25 全球购物
社区文明创建工作总结2015
2015/04/21 职场文书
pandas提升计算效率的一些方法汇总
2021/05/30 Python
苹果可能正在打击不进行更新的 App
2022/04/24 数码科技
Springboot中如何自动转JSON输出
2022/06/16 Java/Android