使用pytorch实现线性回归


Posted in Python onApril 11, 2021

线性回归都是包括以下几个步骤:定义模型、选择损失函数、选择优化函数、 训练数据、测试

import torch
import matplotlib.pyplot as plt
# 构建数据集
x_data= torch.Tensor([[1.0],[2.0],[3.0],[4.0],[5.0],[6.0]])
y_data= torch.Tensor([[2.0],[4.0],[6.0],[8.0],[10.0],[12.0]])
#定义模型
class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.linear= torch.nn.Linear(1,1) #表示输入输出都只有一层,相当于前向传播中的函数模型,因为我们一般都不知道函数是什么形式的
 
    def forward(self, x):
        y_pred= self.linear(x)
        return y_pred
model= LinearModel()
# 使用均方误差作为损失函数
criterion= torch.nn.MSELoss(size_average= False)
#使用梯度下降作为优化SGD
# 从下面几种优化器的生成结果图像可以看出,SGD和ASGD效果最好,因为他们的图像收敛速度最快
optimizer= torch.optim.SGD(model.parameters(),lr=0.01)
# ASGD
# optimizer= torch.optim.ASGD(model.parameters(),lr=0.01)
# optimizer= torch.optim.Adagrad(model.parameters(), lr= 0.01)
# optimizer= torch.optim.RMSprop(model.parameters(), lr= 0.01)
# optimizer= torch.optim.Adamax(model.parameters(),lr= 0.01)
# 训练
epoch_list=[]
loss_list=[]
for epoch in range(100):
    y_pred= model(x_data)
    loss= criterion(y_pred, y_data)
    epoch_list.append(epoch)
    loss_list.append(loss.item())
    print(epoch, loss.item())
 
    optimizer.zero_grad() #梯度归零
    loss.backward()  #反向传播
    optimizer.step() #更新参数
 
print("w= ", model.linear.weight.item())
print("b= ",model.linear.bias.item())
 
x_test= torch.Tensor([[7.0]])
y_test= model(x_test)
print("y_pred= ",y_test.data)
 
plt.plot(epoch_list, loss_list)
plt.xlabel("epoch")
plt.ylabel("loss_val")
plt.show()

使用SGD优化器图像:                                                      

使用pytorch实现线性回归

使用ASGD优化器图像:

使用pytorch实现线性回归

使用Adagrad优化器图像:                                                 

使用pytorch实现线性回归

使用Adamax优化器图像:

使用pytorch实现线性回归

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python3之微信文章爬虫实例讲解
Jul 12 Python
python使用opencv按一定间隔截取视频帧
Mar 06 Python
python实现跨excel的工作表sheet之间的复制方法
May 03 Python
python实现flappy bird小游戏
Dec 24 Python
python爬虫 urllib模块发起post请求过程解析
Aug 20 Python
python实现代码统计程序
Sep 19 Python
Pytorch释放显存占用方式
Jan 13 Python
python3实现网页版raspberry pi(树莓派)小车控制
Feb 12 Python
Python 去除字符串中指定字符串
Mar 05 Python
Python通过递归函数输出嵌套列表元素
Oct 15 Python
python Tornado框架的使用示例
Oct 19 Python
tensorflow中的梯度求解及梯度裁剪操作
May 26 Python
pytorch实现线性回归以及多元回归
python如何获取网络数据
Apr 11 #Python
Pytorch 使用tensor特定条件判断索引
selenium.webdriver中add_argument方法常用参数表
Apr 08 #Python
python3使用diagrams绘制架构图的步骤
python实现求纯色彩图像的边框
python爬取企查查企业信息之selenium自动模拟登录企查查
You might like
php 过滤危险html代码
2009/06/29 PHP
php实现每天自动变换随机问候语的方法
2015/05/12 PHP
EarthLiveSharp中cloudinary的CDN图片缓存自动清理python脚本
2017/04/04 PHP
JavaScript中的连字符详解
2013/11/28 Javascript
js取得html iframe中的元素和变量值
2014/06/30 Javascript
JS简单获取及显示当前时间的方法
2016/08/03 Javascript
浅谈js中startsWith 函数不能在任何浏览器兼容的问题
2017/03/01 Javascript
JS基于正则表达式的替换操作(replace)用法示例
2017/04/28 Javascript
AngularJS  ng-repeat遍历输出的用法
2017/06/19 Javascript
详解vue-cli快速构建vue应用并实现webpack打包
2017/12/13 Javascript
分析JavaScript数组操作难点
2017/12/18 Javascript
javascript连接mysql与php通过odbc连接任意数据库的实例
2017/12/27 Javascript
解决Vue-cli npm run build生产环境打包,本地不能打开的问题
2018/09/20 Javascript
JavaScript中如何对多维数组(矩阵)去重的实现
2019/12/04 Javascript
python回调函数用法实例分析
2015/05/09 Python
Django实现图片文字同时提交的方法
2015/05/26 Python
pycharm远程调试openstack的图文教程
2017/11/21 Python
python 爬虫 批量获取代理ip的实例代码
2018/05/22 Python
python 搜索大文件的实例代码
2019/07/08 Python
selenium中get_cookies()和add_cookie()的用法详解
2020/01/06 Python
python调用有道智云API实现文件批量翻译
2020/10/10 Python
python多线程爬取西刺代理的示例代码
2021/01/30 Python
用CSS3的box-reflect设置文字倒影效果的方法讲解
2016/03/07 HTML / CSS
美国办公用品购物网站:Quill.com
2016/09/01 全球购物
美国领先的男士和女士内衣购物网站:Freshpair
2019/02/25 全球购物
英国办公家具网站:Furniture At Work
2019/10/07 全球购物
是什么让J2EE适合用来开发多层的分布式的应用
2015/01/16 面试题
2013年保送生自荐信格式
2013/11/20 职场文书
毕业生的自我评价分享
2013/12/18 职场文书
区域销售经理职责
2013/12/22 职场文书
2014年团员学习十八大思想汇报
2014/09/13 职场文书
党的群众路线教育实践活动心得体会(医院)
2014/11/03 职场文书
编写python程序的90条建议
2021/04/14 Python
详细分析PHP7与PHP5区别
2021/06/26 PHP
详解nginx location指令
2022/01/18 Servers
Python中的 enumerate和zip详情
2022/05/30 Python