pytorch实现线性回归以及多元回归


Posted in Python onApril 11, 2021

本文实例为大家分享了pytorch实现线性回归以及多元回归的具体代码,供大家参考,具体内容如下

最近在学习pytorch,现在把学习的代码放在这里,下面是github链接

直接附上github代码

# 实现一个线性回归
# 所有的层结构和损失函数都来自于 torch.nn
# torch.optim 是一个实现各种优化算法的包,调用的时候必须是需要优化的参数传入,这些参数都必须是Variable
 
x_train = np.array([[3.3],[4.4],[5.5],[6.71],[6.93],[4.168],[9.779],[6.182],[7.59],[2.167],[7.042],[10.791],[5.313],[7.997],[3.1]],dtype=np.float32)
y_train = np.array([[1.7],[2.76],[2.09],[3.19],[1.694],[1.573],[3.366],[2.596],[2.53],[1.221],[2.827],[3.465],[1.65],[2.904],[1.3]],dtype=np.float32)
 
# 首先我们需要将array转化成tensor,因为pytorch处理的单元是Tensor
 
x_train = torch.from_numpy(x_train)
y_train = torch.from_numpy(y_train)
 
 
# def a simple network
 
class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression,self).__init__()
        self.linear = nn.Linear(1, 1)  # input and output is 2_dimension
    def forward(self, x):
        out = self.linear(x)
        return out
 
 
if torch.cuda.is_available():
    model = LinearRegression().cuda()
    #model = model.cuda()
else:
    model = LinearRegression()
    #model = model.cuda()
 
# 定义loss function 和 optimize func
criterion = nn.MSELoss()   # 均方误差作为优化函数
optimizer = torch.optim.SGD(model.parameters(),lr=1e-3)
num_epochs = 30000
for epoch in range(num_epochs):
    if torch.cuda.is_available():
        inputs = Variable(x_train).cuda()
        outputs = Variable(y_train).cuda()
    else:
        inputs = Variable(x_train)
        outputs = Variable(y_train)
 
    # forward
    out = model(inputs)
    loss = criterion(out,outputs)
 
    # backword
    optimizer.zero_grad()  # 每次做反向传播之前都要进行归零梯度。不然梯度会累加在一起,造成不收敛的结果
    loss.backward()
    optimizer.step()
 
    if (epoch +1)%20==0:
        print('Epoch[{}/{}], loss: {:.6f}'.format(epoch+1,num_epochs,loss.data))
 
 
model.eval()  # 将模型变成测试模式
predict = model(Variable(x_train).cuda())
predict = predict.data.cpu().numpy()
plt.plot(x_train.numpy(),y_train.numpy(),'ro',label = 'original data')
plt.plot(x_train.numpy(),predict,label = 'Fitting line')
plt.show()

结果如图所示:

pytorch实现线性回归以及多元回归

多元回归:

# _*_encoding=utf-8_*_
# pytorch 里面最基本的操作对象是Tensor,pytorch 的tensor可以和numpy的ndarray相互转化。
# 实现一个线性回归
# 所有的层结构和损失函数都来自于 torch.nn
# torch.optim 是一个实现各种优化算法的包,调用的时候必须是需要优化的参数传入,这些参数都必须是Variable
 
 
# 实现 y = b + w1 *x + w2 *x**2 +w3*x**3
import os
os.environ['CUDA_DEVICE_ORDER']="PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES']='0'
import torch
import numpy as np
from torch.autograd import Variable
import matplotlib.pyplot as plt
from torch import nn
 
 
# pre_processing
def make_feature(x):
    x = x.unsqueeze(1)   # unsquenze 是为了添加维度1的,0表示第一维度,1表示第二维度,将tensor大小由3变为(3,1)
    return torch.cat([x ** i for i in range(1, 4)], 1)
 
# 定义好真实的数据
 
 
def f(x):
    W_output = torch.Tensor([0.5, 3, 2.4]).unsqueeze(1)
    b_output = torch.Tensor([0.9])
    return x.mm(W_output)+b_output[0]  # 外积,矩阵乘法
 
 
# 批量处理数据
def get_batch(batch_size =32):
 
    random = torch.randn(batch_size)
    x = make_feature(random)
    y = f(x)
    if torch.cuda.is_available():
 
        return Variable(x).cuda(),Variable(y).cuda()
    else:
        return Variable(x),Variable(y)
 
 
 
# def model
class poly_model(nn.Module):
    def __init__(self):
        super(poly_model,self).__init__()
        self.poly = nn.Linear(3,1)
    def forward(self,input):
        output = self.poly(input)
        return output
 
if torch.cuda.is_available():
    print("sdf")
    model = poly_model().cuda()
else:
    model = poly_model()
 
 
# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
 
epoch = 0
while True:
    batch_x, batch_y = get_batch()
    #print(batch_x)
    output = model(batch_x)
    loss = criterion(output,batch_y)
    print_loss = loss.data
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    epoch = epoch +1
    if print_loss < 1e-3:
        print(print_loss)
        break
 
model.eval()
print("Epoch = {}".format(epoch))
 
batch_x, batch_y = get_batch()
predict = model(batch_x)
a = predict - batch_y
y = torch.sum(a)
print('y = ',y)
predict = predict.data.cpu().numpy()
plt.plot(batch_x.cpu().numpy(),batch_y.cpu().numpy(),'ro',label = 'Original data')
plt.plot(batch_x.cpu().numpy(),predict,'b', ls='--',label = 'Fitting line')
plt.show()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
浅析Python中MySQLdb的事务处理功能
Sep 21 Python
深入理解NumPy简明教程---数组1
Dec 17 Python
python虚拟环境virualenv的安装与使用
Dec 18 Python
python计算auc指标实例
Jul 13 Python
Windows系统下PhantomJS的安装和基本用法
Oct 21 Python
超简单使用Python换脸实例
Mar 27 Python
Python 点击指定位置验证码破解的实现代码
Sep 11 Python
Pytorch 实现sobel算子的卷积操作详解
Jan 10 Python
在django中form的label和verbose name的区别说明
May 20 Python
Python Mock模块原理及使用方法详解
Jul 07 Python
Python如何获取文件路径/目录
Sep 22 Python
Python 中数组和数字相乘时的注意事项说明
May 10 Python
python如何获取网络数据
Apr 11 #Python
Pytorch 使用tensor特定条件判断索引
selenium.webdriver中add_argument方法常用参数表
Apr 08 #Python
python3使用diagrams绘制架构图的步骤
python实现求纯色彩图像的边框
python爬取企查查企业信息之selenium自动模拟登录企查查
Python3 使用pip安装git并获取Yahoo金融数据的操作
Apr 08 #Python
You might like
全国FM电台频率大全 - 30 宁夏回族自治区
2020/03/11 无线电
用PHP实现ODBC数据分页显示一例
2006/10/09 PHP
php xml文件操作实现代码(二)
2009/03/20 PHP
PHP投票系统防刷票判断流程分析
2012/02/04 PHP
yii去掉必填项中星号的方法
2015/12/28 PHP
PHP正则+Snoopy抓取框架实现的抓取淘宝店信誉功能实例
2017/05/17 PHP
元素的内联事件处理函数的特殊作用域在各浏览器中存在差异
2011/01/12 Javascript
JS日期和时间选择控件升级版(自写)
2013/08/02 Javascript
代码触发js事件(click、change)示例应用
2013/12/13 Javascript
浅谈Javascript 数组与字典
2015/01/29 Javascript
JS模态窗口返回值兼容问题的完美解决方法
2016/05/28 Javascript
jQuery简单实现title提示效果示例
2016/08/01 Javascript
javascript简单进制转换实现方法
2016/11/24 Javascript
基于node.js express mvc轻量级框架实践
2017/09/14 Javascript
JavaScript实现表单注册、表单验证、运算符功能
2018/10/15 Javascript
React 实现拖拽功能的示例代码
2019/01/06 Javascript
javascript二维数组和对象的深拷贝与浅拷贝实例分析
2019/10/26 Javascript
JS自定义对象创建与简单使用方法示例
2020/01/15 Javascript
如何在微信小程序中使用骨架屏的步骤
2020/06/12 Javascript
[01:11:02]Secret vs Newbee 2019国际邀请赛小组赛 BO2 第一场 8.15
2019/08/17 DOTA
python通过apply使用元祖和列表调用函数实例
2015/05/26 Python
Python函数返回值实例分析
2015/06/08 Python
Python2/3中urllib库的一些常见用法
2017/12/19 Python
基于python神经卷积网络的人脸识别
2018/05/24 Python
Python matplotlib 画图窗口显示到gui或者控制台的实例
2018/05/24 Python
Flask模拟实现CSRF攻击的方法
2018/07/24 Python
在Python中增加和插入元素的示例
2018/11/01 Python
python如何制作英文字典
2019/06/25 Python
利用Python脚本批量生成SQL语句
2020/03/04 Python
软件设计的目标是什么
2016/12/04 面试题
工业设计专业个人求职信范文
2013/12/28 职场文书
工作表扬信范文
2015/01/17 职场文书
利用ajax+php实现商品价格计算
2021/03/31 PHP
python 爬取华为应用市场评论
2021/05/29 Python
RPM包方式安装Oracle21c的方法详解
2021/08/23 Oracle
vue3使用vuedraggable实现拖拽功能
2022/04/06 Vue.js