使用pytorch实现线性回归


Posted in Python onApril 11, 2021

线性回归都是包括以下几个步骤:定义模型、选择损失函数、选择优化函数、 训练数据、测试

import torch
import matplotlib.pyplot as plt
# 构建数据集
x_data= torch.Tensor([[1.0],[2.0],[3.0],[4.0],[5.0],[6.0]])
y_data= torch.Tensor([[2.0],[4.0],[6.0],[8.0],[10.0],[12.0]])
#定义模型
class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.linear= torch.nn.Linear(1,1) #表示输入输出都只有一层,相当于前向传播中的函数模型,因为我们一般都不知道函数是什么形式的
 
    def forward(self, x):
        y_pred= self.linear(x)
        return y_pred
model= LinearModel()
# 使用均方误差作为损失函数
criterion= torch.nn.MSELoss(size_average= False)
#使用梯度下降作为优化SGD
# 从下面几种优化器的生成结果图像可以看出,SGD和ASGD效果最好,因为他们的图像收敛速度最快
optimizer= torch.optim.SGD(model.parameters(),lr=0.01)
# ASGD
# optimizer= torch.optim.ASGD(model.parameters(),lr=0.01)
# optimizer= torch.optim.Adagrad(model.parameters(), lr= 0.01)
# optimizer= torch.optim.RMSprop(model.parameters(), lr= 0.01)
# optimizer= torch.optim.Adamax(model.parameters(),lr= 0.01)
# 训练
epoch_list=[]
loss_list=[]
for epoch in range(100):
    y_pred= model(x_data)
    loss= criterion(y_pred, y_data)
    epoch_list.append(epoch)
    loss_list.append(loss.item())
    print(epoch, loss.item())
 
    optimizer.zero_grad() #梯度归零
    loss.backward()  #反向传播
    optimizer.step() #更新参数
 
print("w= ", model.linear.weight.item())
print("b= ",model.linear.bias.item())
 
x_test= torch.Tensor([[7.0]])
y_test= model(x_test)
print("y_pred= ",y_test.data)
 
plt.plot(epoch_list, loss_list)
plt.xlabel("epoch")
plt.ylabel("loss_val")
plt.show()

使用SGD优化器图像:                                                      

使用pytorch实现线性回归

使用ASGD优化器图像:

使用pytorch实现线性回归

使用Adagrad优化器图像:                                                 

使用pytorch实现线性回归

使用Adamax优化器图像:

使用pytorch实现线性回归

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python编写基于DHT协议的BT资源爬虫
Mar 19 Python
python xml.etree.ElementTree遍历xml所有节点实例详解
Dec 04 Python
python 中的divmod数字处理函数浅析
Oct 17 Python
Python常见工厂函数用法示例
Mar 21 Python
numpy向空的二维数组中添加元素的方法
Nov 01 Python
Numpy截取指定范围内的数据方法
Nov 14 Python
Python3爬虫爬取英雄联盟高清桌面壁纸功能示例【基于Scrapy框架】
Dec 05 Python
Pytorch释放显存占用方式
Jan 13 Python
基于python监控程序是否关闭
Jan 14 Python
Python random模块制作简易的四位数验证码
Feb 01 Python
keras 回调函数Callbacks 断点ModelCheckpoint教程
Jun 18 Python
python爬取新闻门户网站的示例
Apr 25 Python
pytorch实现线性回归以及多元回归
python如何获取网络数据
Apr 11 #Python
Pytorch 使用tensor特定条件判断索引
selenium.webdriver中add_argument方法常用参数表
Apr 08 #Python
python3使用diagrams绘制架构图的步骤
python实现求纯色彩图像的边框
python爬取企查查企业信息之selenium自动模拟登录企查查
You might like
PHP错误Warning: Cannot modify header information - headers already sent by解决方法
2014/09/27 PHP
ThinkPHP内置jsonRPC的缺陷分析
2014/12/18 PHP
php生成验证码函数
2015/10/20 PHP
Joomla简单判断用户是否登录的方法
2016/05/04 PHP
PHP中功能强大却很少使用的函数实例小结
2016/11/10 PHP
thinkPHP5框架设置404、403等http状态页面的方法
2018/06/05 PHP
PHP+jQuery实现即点即改功能示例
2019/02/21 PHP
PHP 数组操作详解【遍历、指针、函数等】
2020/05/13 PHP
了解jQuery技巧来提高你的代码
2010/01/08 Javascript
js中匿名函数的N种写法
2010/09/08 Javascript
js jquery分别实现动态的文件上传操作按钮的添加和删除
2014/01/13 Javascript
JS烟花背景效果实现方法
2015/03/03 Javascript
详解JavaScript中的构造器Constructor模式
2016/01/14 Javascript
JavaScript和jQuery制作光棒效果
2017/02/24 Javascript
jQuery结合jQuery.cookie.js插件实现换肤功能示例
2017/10/14 jQuery
vue路由跳转时判断用户是否登录功能的实现
2017/10/26 Javascript
基于Vuejs的搜索匹配功能实现方法
2018/03/03 Javascript
解决Vue使用swiper动态加载数据,动态轮播数据显示白屏的问题
2018/09/27 Javascript
微信小程序云开发实现增删改查功能
2019/05/17 Javascript
20个必会的JavaScript面试题(小结)
2019/07/02 Javascript
vue中的过滤器及其时间格式化问题
2020/04/09 Javascript
详解 javascript对象创建模式
2020/10/30 Javascript
Python os.rename() 重命名目录和文件的示例
2018/10/25 Python
Python使用Pandas对csv文件进行数据处理的方法
2019/08/01 Python
python实现多进程通信实例分析
2019/09/01 Python
python多进程并行代码实例
2019/09/30 Python
唤醒头发毛囊的秘密武器:Grow Gorgeous
2016/08/28 全球购物
详细的大学生创业计划书模板
2014/01/27 职场文书
联谊会主持词
2014/03/26 职场文书
负责人任命书范本
2014/06/04 职场文书
学校纪律作风整改措施思想汇报
2014/10/11 职场文书
民主评议党员工作总结
2014/10/20 职场文书
先进个人材料怎么写
2014/12/30 职场文书
导游词之广东佛山(南风古灶)
2019/09/24 职场文书
Python Django模型详解
2021/10/05 Python
「地球外少年少女」BD发售宣传CM公开
2022/03/21 日漫