使用pytorch实现线性回归


Posted in Python onApril 11, 2021

线性回归都是包括以下几个步骤:定义模型、选择损失函数、选择优化函数、 训练数据、测试

import torch
import matplotlib.pyplot as plt
# 构建数据集
x_data= torch.Tensor([[1.0],[2.0],[3.0],[4.0],[5.0],[6.0]])
y_data= torch.Tensor([[2.0],[4.0],[6.0],[8.0],[10.0],[12.0]])
#定义模型
class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.linear= torch.nn.Linear(1,1) #表示输入输出都只有一层,相当于前向传播中的函数模型,因为我们一般都不知道函数是什么形式的
 
    def forward(self, x):
        y_pred= self.linear(x)
        return y_pred
model= LinearModel()
# 使用均方误差作为损失函数
criterion= torch.nn.MSELoss(size_average= False)
#使用梯度下降作为优化SGD
# 从下面几种优化器的生成结果图像可以看出,SGD和ASGD效果最好,因为他们的图像收敛速度最快
optimizer= torch.optim.SGD(model.parameters(),lr=0.01)
# ASGD
# optimizer= torch.optim.ASGD(model.parameters(),lr=0.01)
# optimizer= torch.optim.Adagrad(model.parameters(), lr= 0.01)
# optimizer= torch.optim.RMSprop(model.parameters(), lr= 0.01)
# optimizer= torch.optim.Adamax(model.parameters(),lr= 0.01)
# 训练
epoch_list=[]
loss_list=[]
for epoch in range(100):
    y_pred= model(x_data)
    loss= criterion(y_pred, y_data)
    epoch_list.append(epoch)
    loss_list.append(loss.item())
    print(epoch, loss.item())
 
    optimizer.zero_grad() #梯度归零
    loss.backward()  #反向传播
    optimizer.step() #更新参数
 
print("w= ", model.linear.weight.item())
print("b= ",model.linear.bias.item())
 
x_test= torch.Tensor([[7.0]])
y_test= model(x_test)
print("y_pred= ",y_test.data)
 
plt.plot(epoch_list, loss_list)
plt.xlabel("epoch")
plt.ylabel("loss_val")
plt.show()

使用SGD优化器图像:                                                      

使用pytorch实现线性回归

使用ASGD优化器图像:

使用pytorch实现线性回归

使用Adagrad优化器图像:                                                 

使用pytorch实现线性回归

使用Adamax优化器图像:

使用pytorch实现线性回归

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python打开url并按指定块读取网页内容的方法
Apr 29 Python
Python多层嵌套list的递归处理方法(推荐)
Jun 08 Python
Python的Flask框架标配模板引擎Jinja2的使用教程
Jul 12 Python
CentOS 6.5中安装Python 3.6.2的方法步骤
Dec 03 Python
python版百度语音识别功能
Jul 09 Python
Python解决pip install时出现的Could not fetch URL问题
Aug 01 Python
python实现简单图书管理系统
Nov 22 Python
Python for i in range ()用法详解
Sep 18 Python
浅谈Python中的异常和JSON读写数据的实现
Feb 27 Python
python实现猜拳游戏
Mar 04 Python
python cv2.resize函数high和width注意事项说明
Jul 05 Python
Restful_framework视图组件代码实例解析
Nov 17 Python
pytorch实现线性回归以及多元回归
python如何获取网络数据
Apr 11 #Python
Pytorch 使用tensor特定条件判断索引
selenium.webdriver中add_argument方法常用参数表
Apr 08 #Python
python3使用diagrams绘制架构图的步骤
python实现求纯色彩图像的边框
python爬取企查查企业信息之selenium自动模拟登录企查查
You might like
php操作memcache缓存方法分享
2015/06/03 PHP
PHP消息队列用法实例分析
2016/02/12 PHP
php实现图片上传并利用ImageMagick生成缩略图
2016/03/14 PHP
laravel 判断查询数据库返回值的例子
2019/10/11 PHP
javascript下利用arguments实现string.format函数
2010/08/24 Javascript
jQuery中bind,live,delegate与one方法的用法及区别解析
2013/12/30 Javascript
JavaScript实现动态删除列表框值的方法
2015/08/12 Javascript
浅谈javascript的Array.prototype.slice.call
2015/08/31 Javascript
JSONObject使用方法详解
2015/12/17 Javascript
javascript+HTML5自定义元素播放焦点图动画
2016/02/21 Javascript
基于jQuery实现弹出可关闭遮罩提示框实例代码
2016/07/18 Javascript
浅谈js中test()函数在正则中的使用
2016/08/19 Javascript
jQuery与js实现颜色渐变的方法
2016/12/30 Javascript
微信小程序收藏功能的实现代码
2018/06/12 Javascript
vue子传父关于.sync与$emit的实现
2019/11/05 Javascript
vue中 v-for循环的用法详解
2020/02/19 Javascript
JavaScript实现瀑布流布局的3种方式
2020/12/27 Javascript
[01:58]2018DOTA2亚洲邀请赛趣味视频——交流
2018/04/03 DOTA
[01:03:41]完美世界DOTA2联赛PWL S3 DLG vs Phoenix 第一场 12.17
2020/12/19 DOTA
django输出html内容的实例
2018/05/27 Python
Python实现基于POS算法的区块链
2018/08/07 Python
Python实现Mysql数据统计及numpy统计函数
2019/07/15 Python
python海龟绘图之画国旗实例代码
2020/11/11 Python
暇步士官网:Hush Puppies
2016/09/22 全球购物
Stefania Mode美国:奢华设计师和时尚服装
2018/01/07 全球购物
Paul’s Boutique官网:英国时尚手袋品牌
2018/03/31 全球购物
C,C++的几个面试题小集
2013/07/13 面试题
技术人员面试提纲
2013/11/28 职场文书
大学生职业生涯规划书汇总
2014/03/20 职场文书
英文求职信范文
2014/05/23 职场文书
国际贸易系求职信
2014/08/09 职场文书
工作经常出错的检讨书
2014/09/13 职场文书
毕业生就业推荐表自我评价
2015/03/02 职场文书
2015年全国科普日活动总结
2015/03/23 职场文书
药品销售员2015年终工作总结
2015/10/22 职场文书
解决python绘图使用subplots出现标题重叠的问题
2021/04/30 Python