pytorch实现线性回归以及多元回归


Posted in Python onApril 11, 2021

本文实例为大家分享了pytorch实现线性回归以及多元回归的具体代码,供大家参考,具体内容如下

最近在学习pytorch,现在把学习的代码放在这里,下面是github链接

直接附上github代码

# 实现一个线性回归
# 所有的层结构和损失函数都来自于 torch.nn
# torch.optim 是一个实现各种优化算法的包,调用的时候必须是需要优化的参数传入,这些参数都必须是Variable
 
x_train = np.array([[3.3],[4.4],[5.5],[6.71],[6.93],[4.168],[9.779],[6.182],[7.59],[2.167],[7.042],[10.791],[5.313],[7.997],[3.1]],dtype=np.float32)
y_train = np.array([[1.7],[2.76],[2.09],[3.19],[1.694],[1.573],[3.366],[2.596],[2.53],[1.221],[2.827],[3.465],[1.65],[2.904],[1.3]],dtype=np.float32)
 
# 首先我们需要将array转化成tensor,因为pytorch处理的单元是Tensor
 
x_train = torch.from_numpy(x_train)
y_train = torch.from_numpy(y_train)
 
 
# def a simple network
 
class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression,self).__init__()
        self.linear = nn.Linear(1, 1)  # input and output is 2_dimension
    def forward(self, x):
        out = self.linear(x)
        return out
 
 
if torch.cuda.is_available():
    model = LinearRegression().cuda()
    #model = model.cuda()
else:
    model = LinearRegression()
    #model = model.cuda()
 
# 定义loss function 和 optimize func
criterion = nn.MSELoss()   # 均方误差作为优化函数
optimizer = torch.optim.SGD(model.parameters(),lr=1e-3)
num_epochs = 30000
for epoch in range(num_epochs):
    if torch.cuda.is_available():
        inputs = Variable(x_train).cuda()
        outputs = Variable(y_train).cuda()
    else:
        inputs = Variable(x_train)
        outputs = Variable(y_train)
 
    # forward
    out = model(inputs)
    loss = criterion(out,outputs)
 
    # backword
    optimizer.zero_grad()  # 每次做反向传播之前都要进行归零梯度。不然梯度会累加在一起,造成不收敛的结果
    loss.backward()
    optimizer.step()
 
    if (epoch +1)%20==0:
        print('Epoch[{}/{}], loss: {:.6f}'.format(epoch+1,num_epochs,loss.data))
 
 
model.eval()  # 将模型变成测试模式
predict = model(Variable(x_train).cuda())
predict = predict.data.cpu().numpy()
plt.plot(x_train.numpy(),y_train.numpy(),'ro',label = 'original data')
plt.plot(x_train.numpy(),predict,label = 'Fitting line')
plt.show()

结果如图所示:

pytorch实现线性回归以及多元回归

多元回归:

# _*_encoding=utf-8_*_
# pytorch 里面最基本的操作对象是Tensor,pytorch 的tensor可以和numpy的ndarray相互转化。
# 实现一个线性回归
# 所有的层结构和损失函数都来自于 torch.nn
# torch.optim 是一个实现各种优化算法的包,调用的时候必须是需要优化的参数传入,这些参数都必须是Variable
 
 
# 实现 y = b + w1 *x + w2 *x**2 +w3*x**3
import os
os.environ['CUDA_DEVICE_ORDER']="PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES']='0'
import torch
import numpy as np
from torch.autograd import Variable
import matplotlib.pyplot as plt
from torch import nn
 
 
# pre_processing
def make_feature(x):
    x = x.unsqueeze(1)   # unsquenze 是为了添加维度1的,0表示第一维度,1表示第二维度,将tensor大小由3变为(3,1)
    return torch.cat([x ** i for i in range(1, 4)], 1)
 
# 定义好真实的数据
 
 
def f(x):
    W_output = torch.Tensor([0.5, 3, 2.4]).unsqueeze(1)
    b_output = torch.Tensor([0.9])
    return x.mm(W_output)+b_output[0]  # 外积,矩阵乘法
 
 
# 批量处理数据
def get_batch(batch_size =32):
 
    random = torch.randn(batch_size)
    x = make_feature(random)
    y = f(x)
    if torch.cuda.is_available():
 
        return Variable(x).cuda(),Variable(y).cuda()
    else:
        return Variable(x),Variable(y)
 
 
 
# def model
class poly_model(nn.Module):
    def __init__(self):
        super(poly_model,self).__init__()
        self.poly = nn.Linear(3,1)
    def forward(self,input):
        output = self.poly(input)
        return output
 
if torch.cuda.is_available():
    print("sdf")
    model = poly_model().cuda()
else:
    model = poly_model()
 
 
# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
 
epoch = 0
while True:
    batch_x, batch_y = get_batch()
    #print(batch_x)
    output = model(batch_x)
    loss = criterion(output,batch_y)
    print_loss = loss.data
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    epoch = epoch +1
    if print_loss < 1e-3:
        print(print_loss)
        break
 
model.eval()
print("Epoch = {}".format(epoch))
 
batch_x, batch_y = get_batch()
predict = model(batch_x)
a = predict - batch_y
y = torch.sum(a)
print('y = ',y)
predict = predict.data.cpu().numpy()
plt.plot(batch_x.cpu().numpy(),batch_y.cpu().numpy(),'ro',label = 'Original data')
plt.plot(batch_x.cpu().numpy(),predict,'b', ls='--',label = 'Fitting line')
plt.show()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用PYTHON创建XML文档
Mar 01 Python
python 网络爬虫初级实现代码
Feb 27 Python
使用Python实现博客上进行自动翻页
Aug 23 Python
django 常用orm操作详解
Sep 13 Python
python3写爬取B站视频弹幕功能
Dec 22 Python
python对于requests的封装方法详解
Jan 03 Python
Python+OpenCV图片局部区域像素值处理详解
Jan 23 Python
PyTorch搭建一维线性回归模型(二)
May 22 Python
Django框架 Pagination分页实现代码实例
Sep 04 Python
python自动脚本的pyautogui入门学习
Apr 01 Python
浅析Python打包时包含静态文件处理方法
Jan 15 Python
Python+OpenCV实现在图像上绘制矩形
Mar 21 Python
python如何获取网络数据
Apr 11 #Python
Pytorch 使用tensor特定条件判断索引
selenium.webdriver中add_argument方法常用参数表
Apr 08 #Python
python3使用diagrams绘制架构图的步骤
python实现求纯色彩图像的边框
python爬取企查查企业信息之selenium自动模拟登录企查查
Python3 使用pip安装git并获取Yahoo金融数据的操作
Apr 08 #Python
You might like
php轻松实现中英文混排字符串截取
2014/05/28 PHP
微信公众平台天气预报功能开发
2014/07/06 PHP
PHP封装的分页类与简单用法示例
2019/02/25 PHP
延时重复执行函数 lLoopRun.js
2007/05/08 Javascript
javascript css在IE和Firefox中区别分析
2009/02/18 Javascript
用js替换除数字与逗号以外的所有字符的代码
2014/06/07 Javascript
jquery用data方法获取某个元素上的事件
2014/06/23 Javascript
js中函数调用的两种常用方法使用介绍
2014/07/17 Javascript
JavaScript中使用Object.create()创建对象介绍
2014/12/30 Javascript
js验证框架实现代码分享
2016/05/18 Javascript
Bootstrap中的Dropdown下拉菜单更改为悬停(hover)触发
2016/08/31 Javascript
jQuery的extend方法【三种】
2016/12/14 Javascript
AngularJS指令与控制器之间的交互功能示例
2016/12/14 Javascript
JS完成画圆圈的小球
2017/03/07 Javascript
JQuery判断正整数整理小结
2017/08/21 jQuery
vue2.0页面前进刷新回退不刷新的实现方法
2018/07/31 Javascript
vue+element-ui实现表格编辑的三种实现方式
2018/10/31 Javascript
微信小程序自定义模态弹窗组件详解
2019/12/24 Javascript
在Python中通过threading模块定义和调用线程的方法
2016/07/12 Python
一个基于flask的web应用诞生 记录用户账户登录状态(6)
2017/04/11 Python
Python加密模块的hashlib,hmac模块使用解析
2020/01/02 Python
在 Linux/Mac 下为Python函数添加超时时间的方法
2020/02/20 Python
关于python 的legend图例,参数使用说明
2020/04/17 Python
python函数map()和partial()的知识点总结
2020/05/26 Python
Python建造者模式案例运行原理解析
2020/06/29 Python
python实现图片转换成素描和漫画格式
2020/08/19 Python
python用tkinter实现一个gui的翻译工具
2020/10/26 Python
怀俄明州飞钓:Platte River Fly Shop
2017/12/28 全球购物
澳大利亚领先的美容护肤品零售商之一:SkincareStore
2018/01/22 全球购物
RUIFIER官网:英国奢侈高级珠宝品牌
2020/06/12 全球购物
计算机实训报告范文
2014/11/05 职场文书
高中班主任寄语
2019/06/21 职场文书
准备去美国留学,那么大学申请文书应该怎么写?
2019/08/12 职场文书
试用1103暨1103、1101同门大比武 [ DAIWEI ]
2022/04/05 无线电
SpringCloud Function SpEL注入漏洞分析及环境搭建
2022/04/08 Java/Android
MySQL远程无法连接的一些常见原因总结
2022/09/23 MySQL