pytorch实现线性回归以及多元回归


Posted in Python onApril 11, 2021

本文实例为大家分享了pytorch实现线性回归以及多元回归的具体代码,供大家参考,具体内容如下

最近在学习pytorch,现在把学习的代码放在这里,下面是github链接

直接附上github代码

# 实现一个线性回归
# 所有的层结构和损失函数都来自于 torch.nn
# torch.optim 是一个实现各种优化算法的包,调用的时候必须是需要优化的参数传入,这些参数都必须是Variable
 
x_train = np.array([[3.3],[4.4],[5.5],[6.71],[6.93],[4.168],[9.779],[6.182],[7.59],[2.167],[7.042],[10.791],[5.313],[7.997],[3.1]],dtype=np.float32)
y_train = np.array([[1.7],[2.76],[2.09],[3.19],[1.694],[1.573],[3.366],[2.596],[2.53],[1.221],[2.827],[3.465],[1.65],[2.904],[1.3]],dtype=np.float32)
 
# 首先我们需要将array转化成tensor,因为pytorch处理的单元是Tensor
 
x_train = torch.from_numpy(x_train)
y_train = torch.from_numpy(y_train)
 
 
# def a simple network
 
class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression,self).__init__()
        self.linear = nn.Linear(1, 1)  # input and output is 2_dimension
    def forward(self, x):
        out = self.linear(x)
        return out
 
 
if torch.cuda.is_available():
    model = LinearRegression().cuda()
    #model = model.cuda()
else:
    model = LinearRegression()
    #model = model.cuda()
 
# 定义loss function 和 optimize func
criterion = nn.MSELoss()   # 均方误差作为优化函数
optimizer = torch.optim.SGD(model.parameters(),lr=1e-3)
num_epochs = 30000
for epoch in range(num_epochs):
    if torch.cuda.is_available():
        inputs = Variable(x_train).cuda()
        outputs = Variable(y_train).cuda()
    else:
        inputs = Variable(x_train)
        outputs = Variable(y_train)
 
    # forward
    out = model(inputs)
    loss = criterion(out,outputs)
 
    # backword
    optimizer.zero_grad()  # 每次做反向传播之前都要进行归零梯度。不然梯度会累加在一起,造成不收敛的结果
    loss.backward()
    optimizer.step()
 
    if (epoch +1)%20==0:
        print('Epoch[{}/{}], loss: {:.6f}'.format(epoch+1,num_epochs,loss.data))
 
 
model.eval()  # 将模型变成测试模式
predict = model(Variable(x_train).cuda())
predict = predict.data.cpu().numpy()
plt.plot(x_train.numpy(),y_train.numpy(),'ro',label = 'original data')
plt.plot(x_train.numpy(),predict,label = 'Fitting line')
plt.show()

结果如图所示:

pytorch实现线性回归以及多元回归

多元回归:

# _*_encoding=utf-8_*_
# pytorch 里面最基本的操作对象是Tensor,pytorch 的tensor可以和numpy的ndarray相互转化。
# 实现一个线性回归
# 所有的层结构和损失函数都来自于 torch.nn
# torch.optim 是一个实现各种优化算法的包,调用的时候必须是需要优化的参数传入,这些参数都必须是Variable
 
 
# 实现 y = b + w1 *x + w2 *x**2 +w3*x**3
import os
os.environ['CUDA_DEVICE_ORDER']="PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES']='0'
import torch
import numpy as np
from torch.autograd import Variable
import matplotlib.pyplot as plt
from torch import nn
 
 
# pre_processing
def make_feature(x):
    x = x.unsqueeze(1)   # unsquenze 是为了添加维度1的,0表示第一维度,1表示第二维度,将tensor大小由3变为(3,1)
    return torch.cat([x ** i for i in range(1, 4)], 1)
 
# 定义好真实的数据
 
 
def f(x):
    W_output = torch.Tensor([0.5, 3, 2.4]).unsqueeze(1)
    b_output = torch.Tensor([0.9])
    return x.mm(W_output)+b_output[0]  # 外积,矩阵乘法
 
 
# 批量处理数据
def get_batch(batch_size =32):
 
    random = torch.randn(batch_size)
    x = make_feature(random)
    y = f(x)
    if torch.cuda.is_available():
 
        return Variable(x).cuda(),Variable(y).cuda()
    else:
        return Variable(x),Variable(y)
 
 
 
# def model
class poly_model(nn.Module):
    def __init__(self):
        super(poly_model,self).__init__()
        self.poly = nn.Linear(3,1)
    def forward(self,input):
        output = self.poly(input)
        return output
 
if torch.cuda.is_available():
    print("sdf")
    model = poly_model().cuda()
else:
    model = poly_model()
 
 
# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
 
epoch = 0
while True:
    batch_x, batch_y = get_batch()
    #print(batch_x)
    output = model(batch_x)
    loss = criterion(output,batch_y)
    print_loss = loss.data
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    epoch = epoch +1
    if print_loss < 1e-3:
        print(print_loss)
        break
 
model.eval()
print("Epoch = {}".format(epoch))
 
batch_x, batch_y = get_batch()
predict = model(batch_x)
a = predict - batch_y
y = torch.sum(a)
print('y = ',y)
predict = predict.data.cpu().numpy()
plt.plot(batch_x.cpu().numpy(),batch_y.cpu().numpy(),'ro',label = 'Original data')
plt.plot(batch_x.cpu().numpy(),predict,'b', ls='--',label = 'Fitting line')
plt.show()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python自动化测试之从命令行运行测试用例with verbosity
Sep 28 Python
在Python3中初学者应会的一些基本的提升效率的小技巧
Mar 31 Python
Python的装饰器使用详解
Jun 26 Python
Python字符串和字典相关操作的实例详解
Sep 23 Python
Python数据类型之List列表实例详解
May 08 Python
python用for循环求和的方法总结
Jul 08 Python
python选取特定列 pandas iloc,loc,icol的使用详解(列切片及行切片)
Aug 06 Python
关于PyTorch 自动求导机制详解
Aug 18 Python
python爬虫 猫眼电影和电影天堂数据csv和mysql存储过程解析
Sep 05 Python
python计算导数并绘图的实例
Feb 29 Python
Python调用OpenCV实现图像平滑代码实例
Jun 19 Python
pip已经安装好第三方库但pycharm中import时还是标红的解决方案
Oct 09 Python
python如何获取网络数据
Apr 11 #Python
Pytorch 使用tensor特定条件判断索引
selenium.webdriver中add_argument方法常用参数表
Apr 08 #Python
python3使用diagrams绘制架构图的步骤
python实现求纯色彩图像的边框
python爬取企查查企业信息之selenium自动模拟登录企查查
Python3 使用pip安装git并获取Yahoo金融数据的操作
Apr 08 #Python
You might like
重置版游戏视频
2020/04/09 魔兽争霸
PHP 最大运行时间 max_execution_time修改方法
2010/03/08 PHP
如何用PHP实现插入排序?
2013/04/10 PHP
thinkphp3查询mssql数据库乱码解决方法分享
2014/02/11 PHP
PHP魔术引号所带来的安全问题分析
2014/07/15 PHP
Yii框架上传图片用法总结
2016/03/28 PHP
JS图片无缝滚动(简单利于使用)
2013/06/17 Javascript
原生javascript和jquery判断浏览器版本等信息
2013/07/04 Javascript
基于jQuery实现仿淘宝套餐选择插件
2015/03/04 Javascript
jQuery实现简单的抽奖游戏
2017/05/05 jQuery
解决vue里碰到 $refs 的问题的方法
2017/07/13 Javascript
Bootstrap table使用方法记录
2017/08/23 Javascript
vue.js分页中单击页码更换页面内容的方法(配合spring springmvc)
2018/02/10 Javascript
JS去除字符串最后的逗号实例分析【四种方法】
2019/06/20 Javascript
浅谈layui 数据表格前后台传值的问题
2019/09/12 Javascript
vue用BMap百度地图实现即时搜索功能
2019/09/26 Javascript
Vue+axios封装请求实现前后端分离
2020/10/23 Javascript
Vue使用CDN引用项目组件,减少项目体积的步骤
2020/10/30 Javascript
详解如何在vue+element-ui的项目中封装dialog组件
2020/12/11 Vue.js
python matplotlib库绘制散点图例题解析
2019/08/10 Python
Python3查找列表中重复元素的个数的3种方法详解
2020/02/13 Python
利用Python如何实时检测自身内存占用
2020/05/09 Python
KIKO MILANO西班牙官网:意大利领先的化妆品和护肤品品牌
2019/05/03 全球购物
Vuori官网:运动服装的终级表现
2021/01/27 全球购物
营业员个人总结的自我评价
2013/10/25 职场文书
美德少年事迹材料
2014/01/23 职场文书
银行服务明星推荐材料
2014/05/29 职场文书
2014第二批党员干部对照“四风”找差距检查材料思想汇报
2014/09/18 职场文书
财务工作检讨书
2014/10/29 职场文书
整改报告格式
2014/11/06 职场文书
门店店长岗位职责
2015/04/14 职场文书
工程项目合作意向书
2015/05/08 职场文书
《夸父追日》教学反思
2016/02/20 职场文书
使用python如何删除同一文件夹下相似的图片
2021/05/07 Python
详解Python常用的魔法方法
2021/06/03 Python
解决 Redis 秒杀超卖场景的高并发
2022/04/12 Redis