pytorch实现线性回归以及多元回归


Posted in Python onApril 11, 2021

本文实例为大家分享了pytorch实现线性回归以及多元回归的具体代码,供大家参考,具体内容如下

最近在学习pytorch,现在把学习的代码放在这里,下面是github链接

直接附上github代码

# 实现一个线性回归
# 所有的层结构和损失函数都来自于 torch.nn
# torch.optim 是一个实现各种优化算法的包,调用的时候必须是需要优化的参数传入,这些参数都必须是Variable
 
x_train = np.array([[3.3],[4.4],[5.5],[6.71],[6.93],[4.168],[9.779],[6.182],[7.59],[2.167],[7.042],[10.791],[5.313],[7.997],[3.1]],dtype=np.float32)
y_train = np.array([[1.7],[2.76],[2.09],[3.19],[1.694],[1.573],[3.366],[2.596],[2.53],[1.221],[2.827],[3.465],[1.65],[2.904],[1.3]],dtype=np.float32)
 
# 首先我们需要将array转化成tensor,因为pytorch处理的单元是Tensor
 
x_train = torch.from_numpy(x_train)
y_train = torch.from_numpy(y_train)
 
 
# def a simple network
 
class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression,self).__init__()
        self.linear = nn.Linear(1, 1)  # input and output is 2_dimension
    def forward(self, x):
        out = self.linear(x)
        return out
 
 
if torch.cuda.is_available():
    model = LinearRegression().cuda()
    #model = model.cuda()
else:
    model = LinearRegression()
    #model = model.cuda()
 
# 定义loss function 和 optimize func
criterion = nn.MSELoss()   # 均方误差作为优化函数
optimizer = torch.optim.SGD(model.parameters(),lr=1e-3)
num_epochs = 30000
for epoch in range(num_epochs):
    if torch.cuda.is_available():
        inputs = Variable(x_train).cuda()
        outputs = Variable(y_train).cuda()
    else:
        inputs = Variable(x_train)
        outputs = Variable(y_train)
 
    # forward
    out = model(inputs)
    loss = criterion(out,outputs)
 
    # backword
    optimizer.zero_grad()  # 每次做反向传播之前都要进行归零梯度。不然梯度会累加在一起,造成不收敛的结果
    loss.backward()
    optimizer.step()
 
    if (epoch +1)%20==0:
        print('Epoch[{}/{}], loss: {:.6f}'.format(epoch+1,num_epochs,loss.data))
 
 
model.eval()  # 将模型变成测试模式
predict = model(Variable(x_train).cuda())
predict = predict.data.cpu().numpy()
plt.plot(x_train.numpy(),y_train.numpy(),'ro',label = 'original data')
plt.plot(x_train.numpy(),predict,label = 'Fitting line')
plt.show()

结果如图所示:

pytorch实现线性回归以及多元回归

多元回归:

# _*_encoding=utf-8_*_
# pytorch 里面最基本的操作对象是Tensor,pytorch 的tensor可以和numpy的ndarray相互转化。
# 实现一个线性回归
# 所有的层结构和损失函数都来自于 torch.nn
# torch.optim 是一个实现各种优化算法的包,调用的时候必须是需要优化的参数传入,这些参数都必须是Variable
 
 
# 实现 y = b + w1 *x + w2 *x**2 +w3*x**3
import os
os.environ['CUDA_DEVICE_ORDER']="PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES']='0'
import torch
import numpy as np
from torch.autograd import Variable
import matplotlib.pyplot as plt
from torch import nn
 
 
# pre_processing
def make_feature(x):
    x = x.unsqueeze(1)   # unsquenze 是为了添加维度1的,0表示第一维度,1表示第二维度,将tensor大小由3变为(3,1)
    return torch.cat([x ** i for i in range(1, 4)], 1)
 
# 定义好真实的数据
 
 
def f(x):
    W_output = torch.Tensor([0.5, 3, 2.4]).unsqueeze(1)
    b_output = torch.Tensor([0.9])
    return x.mm(W_output)+b_output[0]  # 外积,矩阵乘法
 
 
# 批量处理数据
def get_batch(batch_size =32):
 
    random = torch.randn(batch_size)
    x = make_feature(random)
    y = f(x)
    if torch.cuda.is_available():
 
        return Variable(x).cuda(),Variable(y).cuda()
    else:
        return Variable(x),Variable(y)
 
 
 
# def model
class poly_model(nn.Module):
    def __init__(self):
        super(poly_model,self).__init__()
        self.poly = nn.Linear(3,1)
    def forward(self,input):
        output = self.poly(input)
        return output
 
if torch.cuda.is_available():
    print("sdf")
    model = poly_model().cuda()
else:
    model = poly_model()
 
 
# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
 
epoch = 0
while True:
    batch_x, batch_y = get_batch()
    #print(batch_x)
    output = model(batch_x)
    loss = criterion(output,batch_y)
    print_loss = loss.data
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    epoch = epoch +1
    if print_loss < 1e-3:
        print(print_loss)
        break
 
model.eval()
print("Epoch = {}".format(epoch))
 
batch_x, batch_y = get_batch()
predict = model(batch_x)
a = predict - batch_y
y = torch.sum(a)
print('y = ',y)
predict = predict.data.cpu().numpy()
plt.plot(batch_x.cpu().numpy(),batch_y.cpu().numpy(),'ro',label = 'Original data')
plt.plot(batch_x.cpu().numpy(),predict,'b', ls='--',label = 'Fitting line')
plt.show()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python在linux中输出带颜色的文字的方法
Jun 19 Python
11个并不被常用但对开发非常有帮助的Python库
Mar 31 Python
Python实现将数据写入netCDF4中的方法示例
Aug 30 Python
Python随机函数库random的使用方法详解
Aug 21 Python
python super的使用方法及实例详解
Sep 25 Python
Python collections中的双向队列deque简单介绍详解
Nov 04 Python
关于tensorflow的几种参数初始化方法小结
Jan 04 Python
Python Scrapy框架第一个入门程序示例
Feb 05 Python
Python3 shutil(高级文件操作模块)实例用法总结
Feb 19 Python
详解Flask前后端分离项目案例
Jul 24 Python
Django返回HTML文件的实现方法
Sep 17 Python
python套接字socket通信
Apr 01 Python
python如何获取网络数据
Apr 11 #Python
Pytorch 使用tensor特定条件判断索引
selenium.webdriver中add_argument方法常用参数表
Apr 08 #Python
python3使用diagrams绘制架构图的步骤
python实现求纯色彩图像的边框
python爬取企查查企业信息之selenium自动模拟登录企查查
Python3 使用pip安装git并获取Yahoo金融数据的操作
Apr 08 #Python
You might like
yii框架builder、update、delete使用方法
2014/04/30 PHP
深入分析PHP引用(&amp;)
2014/09/04 PHP
php中smarty区域循环的方法
2015/06/11 PHP
教大家制作简单的php日历
2015/11/17 PHP
轻松掌握php设计模式之访问者模式
2016/09/23 PHP
THinkPHP获取客户端IP与IP地址查询的方法
2016/11/14 PHP
PHP实现根据数组某个键值大小进行排序的方法
2018/03/13 PHP
php ajax confirm 删除实例详解
2019/03/06 PHP
HR vs ForZe BO3 第二场 2.13
2021/03/10 DOTA
javascript中最常用的继承模式 组合继承
2010/08/12 Javascript
javascript利用初始化数据装配模版的实现代码
2010/11/17 Javascript
js 利用className得到对象的实现代码
2011/11/15 Javascript
Javascript冒泡排序算法详解
2014/12/03 Javascript
JavaScript使用Range调色及透明度实例
2016/09/25 Javascript
扩展jquery easyui tree的搜索树节点方法(推荐)
2016/10/28 Javascript
微信小程序 wx:key详细介绍
2016/10/28 Javascript
Vue源码解读之Component组件注册的实现
2018/08/24 Javascript
Vue项目中使用better-scroll实现一个轮播图自动播放功能
2018/12/03 Javascript
JS实现的自定义map方法示例
2019/05/17 Javascript
简单了解小程序+node梳理登陆流程
2019/06/24 Javascript
layui问题之模拟table表格中的选中按钮选中事件的方法
2019/09/20 Javascript
windows下create-react-app 升级至3.3.1版本踩坑记
2020/02/17 Javascript
vue 动态生成拓扑图的示例
2021/01/03 Vue.js
[14:00]DOTA2国际邀请赛史上最长大战 赛后专访B神
2013/08/10 DOTA
Python3匿名函数lambda介绍与使用示例
2019/05/18 Python
基于keras输出中间层结果的2种实现方式
2020/01/24 Python
对python中各个response的使用说明
2020/03/28 Python
HTML5中Canvas与SVG的画图原理比较
2013/01/16 HTML / CSS
Html5基于canvas实现电子签名并生成PDF文档
2020/12/07 HTML / CSS
青年创业培训欢迎词
2014/01/08 职场文书
运动会邀请函范文
2014/01/31 职场文书
2014年化工厂工作总结
2014/11/25 职场文书
2014年新农村建设工作总结
2014/12/01 职场文书
撤诉书怎么写
2015/05/19 职场文书
详解python的内存分配机制
2021/05/10 Python
MySQL中一条SQL查询语句是如何执行的
2022/04/08 MySQL