使用pytorch实现线性回归


Posted in Python onApril 11, 2021

线性回归都是包括以下几个步骤:定义模型、选择损失函数、选择优化函数、 训练数据、测试

import torch
import matplotlib.pyplot as plt
# 构建数据集
x_data= torch.Tensor([[1.0],[2.0],[3.0],[4.0],[5.0],[6.0]])
y_data= torch.Tensor([[2.0],[4.0],[6.0],[8.0],[10.0],[12.0]])
#定义模型
class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.linear= torch.nn.Linear(1,1) #表示输入输出都只有一层,相当于前向传播中的函数模型,因为我们一般都不知道函数是什么形式的
 
    def forward(self, x):
        y_pred= self.linear(x)
        return y_pred
model= LinearModel()
# 使用均方误差作为损失函数
criterion= torch.nn.MSELoss(size_average= False)
#使用梯度下降作为优化SGD
# 从下面几种优化器的生成结果图像可以看出,SGD和ASGD效果最好,因为他们的图像收敛速度最快
optimizer= torch.optim.SGD(model.parameters(),lr=0.01)
# ASGD
# optimizer= torch.optim.ASGD(model.parameters(),lr=0.01)
# optimizer= torch.optim.Adagrad(model.parameters(), lr= 0.01)
# optimizer= torch.optim.RMSprop(model.parameters(), lr= 0.01)
# optimizer= torch.optim.Adamax(model.parameters(),lr= 0.01)
# 训练
epoch_list=[]
loss_list=[]
for epoch in range(100):
    y_pred= model(x_data)
    loss= criterion(y_pred, y_data)
    epoch_list.append(epoch)
    loss_list.append(loss.item())
    print(epoch, loss.item())
 
    optimizer.zero_grad() #梯度归零
    loss.backward()  #反向传播
    optimizer.step() #更新参数
 
print("w= ", model.linear.weight.item())
print("b= ",model.linear.bias.item())
 
x_test= torch.Tensor([[7.0]])
y_test= model(x_test)
print("y_pred= ",y_test.data)
 
plt.plot(epoch_list, loss_list)
plt.xlabel("epoch")
plt.ylabel("loss_val")
plt.show()

使用SGD优化器图像:                                                      

使用pytorch实现线性回归

使用ASGD优化器图像:

使用pytorch实现线性回归

使用Adagrad优化器图像:                                                 

使用pytorch实现线性回归

使用Adamax优化器图像:

使用pytorch实现线性回归

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用fileinput模块实现逐行读取文件的方法
Apr 29 Python
Python2.x版本中maketrans()方法的使用介绍
May 19 Python
django 2.0更新的10条注意事项总结
Jan 05 Python
详解用python生成随机数的几种方法
Aug 04 Python
有关Tensorflow梯度下降常用的优化方法分享
Feb 04 Python
python numpy库linspace相同间隔采样的实现
Feb 25 Python
Python selenium抓取虎牙短视频代码实例
Mar 02 Python
python 实现任务管理清单案例
Apr 25 Python
python语言的优势是什么
Jun 17 Python
python属于软件吗
Jun 18 Python
python压包的概念及实例详解
Feb 17 Python
Python采集爬取京东商品信息和评论并存入MySQL
Apr 12 Python
pytorch实现线性回归以及多元回归
python如何获取网络数据
Apr 11 #Python
Pytorch 使用tensor特定条件判断索引
selenium.webdriver中add_argument方法常用参数表
Apr 08 #Python
python3使用diagrams绘制架构图的步骤
python实现求纯色彩图像的边框
python爬取企查查企业信息之selenium自动模拟登录企查查
You might like
攻克CakePHP系列三 表单数据增删改
2008/10/22 PHP
PHP命名空间namespace用法实例分析
2016/09/27 PHP
JAVASCRIPT 对象的创建与使用
2021/03/09 Javascript
offsetParent 算法分析
2010/04/05 Javascript
Fastest way to build an HTML string(拼装html字符串的最快方法)
2011/08/20 Javascript
JS冒泡事件的快速解决方法
2013/12/16 Javascript
解决JS中乘法的浮点错误的方法
2014/01/03 Javascript
JavaScript中getUTCSeconds()方法的使用详解
2015/06/11 Javascript
canvas 弹幕效果(实例分享)
2017/01/11 Javascript
Angular.js中ng-if、ng-show和ng-hide的区别介绍
2017/01/20 Javascript
javascript 面向对象function详解及实例代码
2017/02/28 Javascript
Angular 4.x 路由快速入门学习
2017/05/03 Javascript
jQuery中ajax获取数据赋值给页面的实例
2017/12/31 jQuery
Vue.js 中 axios 跨域访问错误问题及解决方法
2018/11/21 Javascript
JS Ajax请求会话过期处理问题解决方法分析
2019/11/16 Javascript
[03:07]DOTA2英雄基础教程 冰霜诅咒极寒幽魂
2013/12/06 DOTA
python模拟登录百度代码分享(获取百度贴吧等级)
2013/12/27 Python
Python中endswith()函数的基本使用
2015/04/07 Python
Python编程中的异常处理教程
2015/08/21 Python
python3.6+django2.0+mysql搭建网站过程详解
2019/07/24 Python
Python装饰器使用你可能不知道的几种姿势
2019/10/25 Python
深入剖析HTML5 内联框架iFrame
2016/05/04 HTML / CSS
应届毕业生个人自我评价
2013/09/20 职场文书
甜品店的创业计划书范文
2014/01/02 职场文书
学校岗位设置方案
2014/01/16 职场文书
交通事故私了协议书
2014/04/16 职场文书
给校长的建议书200字
2014/05/16 职场文书
俞敏洪北大演讲稿
2014/05/22 职场文书
酒店开业策划方案
2014/06/02 职场文书
老兵退伍标语
2014/10/07 职场文书
2015年挂职锻炼工作总结
2014/12/12 职场文书
语文复习计划
2015/01/19 职场文书
大学生毕业个人总结
2015/02/15 职场文书
党小组考察意见
2015/06/02 职场文书
企业宣传稿范文
2015/07/23 职场文书
教师反邪教心得体会
2016/01/15 职场文书