使用pytorch实现线性回归


Posted in Python onApril 11, 2021

线性回归都是包括以下几个步骤:定义模型、选择损失函数、选择优化函数、 训练数据、测试

import torch
import matplotlib.pyplot as plt
# 构建数据集
x_data= torch.Tensor([[1.0],[2.0],[3.0],[4.0],[5.0],[6.0]])
y_data= torch.Tensor([[2.0],[4.0],[6.0],[8.0],[10.0],[12.0]])
#定义模型
class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.linear= torch.nn.Linear(1,1) #表示输入输出都只有一层,相当于前向传播中的函数模型,因为我们一般都不知道函数是什么形式的
 
    def forward(self, x):
        y_pred= self.linear(x)
        return y_pred
model= LinearModel()
# 使用均方误差作为损失函数
criterion= torch.nn.MSELoss(size_average= False)
#使用梯度下降作为优化SGD
# 从下面几种优化器的生成结果图像可以看出,SGD和ASGD效果最好,因为他们的图像收敛速度最快
optimizer= torch.optim.SGD(model.parameters(),lr=0.01)
# ASGD
# optimizer= torch.optim.ASGD(model.parameters(),lr=0.01)
# optimizer= torch.optim.Adagrad(model.parameters(), lr= 0.01)
# optimizer= torch.optim.RMSprop(model.parameters(), lr= 0.01)
# optimizer= torch.optim.Adamax(model.parameters(),lr= 0.01)
# 训练
epoch_list=[]
loss_list=[]
for epoch in range(100):
    y_pred= model(x_data)
    loss= criterion(y_pred, y_data)
    epoch_list.append(epoch)
    loss_list.append(loss.item())
    print(epoch, loss.item())
 
    optimizer.zero_grad() #梯度归零
    loss.backward()  #反向传播
    optimizer.step() #更新参数
 
print("w= ", model.linear.weight.item())
print("b= ",model.linear.bias.item())
 
x_test= torch.Tensor([[7.0]])
y_test= model(x_test)
print("y_pred= ",y_test.data)
 
plt.plot(epoch_list, loss_list)
plt.xlabel("epoch")
plt.ylabel("loss_val")
plt.show()

使用SGD优化器图像:                                                      

使用pytorch实现线性回归

使用ASGD优化器图像:

使用pytorch实现线性回归

使用Adagrad优化器图像:                                                 

使用pytorch实现线性回归

使用Adamax优化器图像:

使用pytorch实现线性回归

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之网站的结构
Oct 24 Python
python验证码识别的实例详解
Sep 09 Python
python 循环读取txt文档 并转换成csv的方法
Oct 26 Python
python面试题小结附答案实例代码
Apr 11 Python
Pyinstaller 打包exe教程及问题解决
Aug 16 Python
Django项目之Elasticsearch搜索引擎的实例
Aug 21 Python
python字符串下标与切片及使用方法
Feb 13 Python
Python print不能立即打印的解决方式
Feb 19 Python
python使用PIL剪切和拼接图片
Mar 23 Python
Python并发请求下限制QPS(每秒查询率)的实现代码
Jun 05 Python
pytorch交叉熵损失函数的weight参数的使用
May 24 Python
Python加密与解密模块hashlib与hmac
Jun 05 Python
pytorch实现线性回归以及多元回归
python如何获取网络数据
Apr 11 #Python
Pytorch 使用tensor特定条件判断索引
selenium.webdriver中add_argument方法常用参数表
Apr 08 #Python
python3使用diagrams绘制架构图的步骤
python实现求纯色彩图像的边框
python爬取企查查企业信息之selenium自动模拟登录企查查
You might like
php从右向左/从左向右截取字符串的实现方法
2011/11/28 PHP
thinkphp中多表查询中防止数据重复的sql语句(必看)
2016/09/22 PHP
CakePHP框架Session设置方法分析
2017/02/23 PHP
window.location.hash 使用说明
2010/11/08 Javascript
基于jQuery的Tab选项框效果代码(插件)
2011/03/01 Javascript
使用JavaScript获取电池状态的方法
2014/05/03 Javascript
浅谈javascript运算符——条件,逗号,赋值,()和void运算符
2016/07/15 Javascript
ES6新数据结构Map功能与用法示例
2017/03/31 Javascript
解决jQuery ajax动态新增节点无法触发点击事件的问题
2017/05/24 jQuery
vue-cli整合vuex的时候,修改actions和mutations,实现热部署的方法
2018/09/19 Javascript
3分钟读懂移动端rem使用方法(推荐)
2019/05/06 Javascript
Vue中通过vue-router实现命名视图的问题
2020/04/23 Javascript
vue + el-form 实现的多层循环表单验证
2020/11/25 Vue.js
JavaScript 获取滚动条位置并将页面滑动到锚点
2021/02/08 Javascript
Django实现图片文字同时提交的方法
2015/05/26 Python
Python实现LRU算法的2种方法
2015/06/24 Python
Python制作数据导入导出工具
2015/07/31 Python
Python爬取网页中的图片(搜狗图片)详解
2017/03/23 Python
Python使用Selenium+BeautifulSoup爬取淘宝搜索页
2018/02/24 Python
Python简单获取网卡名称及其IP地址的方法【基于psutil模块】
2018/05/24 Python
Python3单行定义多个变量或赋值方法
2018/07/12 Python
基于sklearn实现Bagging算法(python)
2019/07/11 Python
使用Keras中的ImageDataGenerator进行批次读图方式
2020/06/17 Python
python怎么判断模块安装完成
2020/06/19 Python
Django Admin 上传文件到七牛云的示例代码
2020/06/20 Python
怎么解决pycharm license Acti的方法
2020/10/28 Python
html5 更新图片颜色示例代码
2014/07/29 HTML / CSS
HTML5 canvas基本绘图之图形变换
2016/06/27 HTML / CSS
ECCO爱步官方旗舰店:丹麦鞋履品牌
2018/01/02 全球购物
超市端午节活动方案
2014/01/23 职场文书
初中班主任评语
2014/04/24 职场文书
一份没有按时交货失信于客户的检讨书
2014/09/19 职场文书
房产协议书范本
2014/10/18 职场文书
学校食品安全责任书
2015/01/29 职场文书
2015年学校图书室工作总结
2015/05/19 职场文书
Python装饰器的练习题
2021/11/23 Python