pytorch实现线性回归以及多元回归


Posted in Python onApril 11, 2021

本文实例为大家分享了pytorch实现线性回归以及多元回归的具体代码,供大家参考,具体内容如下

最近在学习pytorch,现在把学习的代码放在这里,下面是github链接

直接附上github代码

# 实现一个线性回归
# 所有的层结构和损失函数都来自于 torch.nn
# torch.optim 是一个实现各种优化算法的包,调用的时候必须是需要优化的参数传入,这些参数都必须是Variable
 
x_train = np.array([[3.3],[4.4],[5.5],[6.71],[6.93],[4.168],[9.779],[6.182],[7.59],[2.167],[7.042],[10.791],[5.313],[7.997],[3.1]],dtype=np.float32)
y_train = np.array([[1.7],[2.76],[2.09],[3.19],[1.694],[1.573],[3.366],[2.596],[2.53],[1.221],[2.827],[3.465],[1.65],[2.904],[1.3]],dtype=np.float32)
 
# 首先我们需要将array转化成tensor,因为pytorch处理的单元是Tensor
 
x_train = torch.from_numpy(x_train)
y_train = torch.from_numpy(y_train)
 
 
# def a simple network
 
class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression,self).__init__()
        self.linear = nn.Linear(1, 1)  # input and output is 2_dimension
    def forward(self, x):
        out = self.linear(x)
        return out
 
 
if torch.cuda.is_available():
    model = LinearRegression().cuda()
    #model = model.cuda()
else:
    model = LinearRegression()
    #model = model.cuda()
 
# 定义loss function 和 optimize func
criterion = nn.MSELoss()   # 均方误差作为优化函数
optimizer = torch.optim.SGD(model.parameters(),lr=1e-3)
num_epochs = 30000
for epoch in range(num_epochs):
    if torch.cuda.is_available():
        inputs = Variable(x_train).cuda()
        outputs = Variable(y_train).cuda()
    else:
        inputs = Variable(x_train)
        outputs = Variable(y_train)
 
    # forward
    out = model(inputs)
    loss = criterion(out,outputs)
 
    # backword
    optimizer.zero_grad()  # 每次做反向传播之前都要进行归零梯度。不然梯度会累加在一起,造成不收敛的结果
    loss.backward()
    optimizer.step()
 
    if (epoch +1)%20==0:
        print('Epoch[{}/{}], loss: {:.6f}'.format(epoch+1,num_epochs,loss.data))
 
 
model.eval()  # 将模型变成测试模式
predict = model(Variable(x_train).cuda())
predict = predict.data.cpu().numpy()
plt.plot(x_train.numpy(),y_train.numpy(),'ro',label = 'original data')
plt.plot(x_train.numpy(),predict,label = 'Fitting line')
plt.show()

结果如图所示:

pytorch实现线性回归以及多元回归

多元回归:

# _*_encoding=utf-8_*_
# pytorch 里面最基本的操作对象是Tensor,pytorch 的tensor可以和numpy的ndarray相互转化。
# 实现一个线性回归
# 所有的层结构和损失函数都来自于 torch.nn
# torch.optim 是一个实现各种优化算法的包,调用的时候必须是需要优化的参数传入,这些参数都必须是Variable
 
 
# 实现 y = b + w1 *x + w2 *x**2 +w3*x**3
import os
os.environ['CUDA_DEVICE_ORDER']="PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES']='0'
import torch
import numpy as np
from torch.autograd import Variable
import matplotlib.pyplot as plt
from torch import nn
 
 
# pre_processing
def make_feature(x):
    x = x.unsqueeze(1)   # unsquenze 是为了添加维度1的,0表示第一维度,1表示第二维度,将tensor大小由3变为(3,1)
    return torch.cat([x ** i for i in range(1, 4)], 1)
 
# 定义好真实的数据
 
 
def f(x):
    W_output = torch.Tensor([0.5, 3, 2.4]).unsqueeze(1)
    b_output = torch.Tensor([0.9])
    return x.mm(W_output)+b_output[0]  # 外积,矩阵乘法
 
 
# 批量处理数据
def get_batch(batch_size =32):
 
    random = torch.randn(batch_size)
    x = make_feature(random)
    y = f(x)
    if torch.cuda.is_available():
 
        return Variable(x).cuda(),Variable(y).cuda()
    else:
        return Variable(x),Variable(y)
 
 
 
# def model
class poly_model(nn.Module):
    def __init__(self):
        super(poly_model,self).__init__()
        self.poly = nn.Linear(3,1)
    def forward(self,input):
        output = self.poly(input)
        return output
 
if torch.cuda.is_available():
    print("sdf")
    model = poly_model().cuda()
else:
    model = poly_model()
 
 
# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
 
epoch = 0
while True:
    batch_x, batch_y = get_batch()
    #print(batch_x)
    output = model(batch_x)
    loss = criterion(output,batch_y)
    print_loss = loss.data
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    epoch = epoch +1
    if print_loss < 1e-3:
        print(print_loss)
        break
 
model.eval()
print("Epoch = {}".format(epoch))
 
batch_x, batch_y = get_batch()
predict = model(batch_x)
a = predict - batch_y
y = torch.sum(a)
print('y = ',y)
predict = predict.data.cpu().numpy()
plt.plot(batch_x.cpu().numpy(),batch_y.cpu().numpy(),'ro',label = 'Original data')
plt.plot(batch_x.cpu().numpy(),predict,'b', ls='--',label = 'Fitting line')
plt.show()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
本地文件上传到七牛云服务器示例(七牛云存储)
Jan 11 Python
wxPython中文教程入门实例
Jun 09 Python
Python中的fileinput模块的简单实用示例
Jul 09 Python
全面理解Python中self的用法
Jun 04 Python
Python配置mysql的教程(推荐)
Oct 13 Python
Python实现PS滤镜的万花筒效果示例
Jan 23 Python
对pandas进行数据预处理的实例讲解
Apr 20 Python
django允许外部访问的实例讲解
May 14 Python
Python利用ORM控制MongoDB(MongoEngine)的步骤全纪录
Sep 13 Python
python 在某.py文件中调用其他.py内的函数的方法
Jun 25 Python
python函数调用,循环,列表复制实例
May 03 Python
python blinker 信号库
May 04 Python
python如何获取网络数据
Apr 11 #Python
Pytorch 使用tensor特定条件判断索引
selenium.webdriver中add_argument方法常用参数表
Apr 08 #Python
python3使用diagrams绘制架构图的步骤
python实现求纯色彩图像的边框
python爬取企查查企业信息之selenium自动模拟登录企查查
Python3 使用pip安装git并获取Yahoo金融数据的操作
Apr 08 #Python
You might like
Mysql中limit的用法方法详解与注意事项
2008/04/19 PHP
PHP文件缓存内容保存格式实例分析
2014/08/20 PHP
php生成图片验证码-附五种验证码
2015/08/19 PHP
php冒泡排序与快速排序实例详解
2015/12/07 PHP
yii2 resetful 授权验证详解
2017/05/18 PHP
PHP实现的解汉诺塔问题算法示例
2018/08/06 PHP
PHP耦合设计模式实例分析
2018/08/08 PHP
Yii框架的redis命令使用方法简单示例
2019/10/15 PHP
js substr、substring和slice使用说明小记
2011/09/15 Javascript
浅谈JavaScript数据类型
2015/03/03 Javascript
详解JavaScript数组的操作大全
2015/10/19 Javascript
jQuery实现获取绑定自定义事件元素的方法
2015/12/02 Javascript
JS JSOP跨域请求实例详解
2016/07/04 Javascript
前端弹出对话框 js实现ajax交互
2016/09/09 Javascript
Vue如何从1.0迁移到2.0
2017/10/19 Javascript
原生JS实现 MUI导航栏透明渐变效果
2017/11/07 Javascript
加载 vue 远程代码的组件实例详解
2017/11/20 Javascript
JS匿名函数内部this指向问题详析
2019/05/10 Javascript
vue如何自动化打包测试环境和正式环境的dist/test文件
2019/06/06 Javascript
[04:32]DOTA2著名解说配音敌法师 现场专访海涛怒切假腿
2013/12/20 DOTA
解决pycharm无法识别本地site-packages的问题
2018/10/13 Python
python对于requests的封装方法详解
2019/01/03 Python
对Python 多线程统计所有csv文件的行数方法详解
2019/02/12 Python
PyQt5 在label显示的图片中绘制矩形的方法
2019/06/17 Python
详解程序意外中断自动重启shell脚本(以Python为例)
2019/07/26 Python
Python中注释(多行注释和单行注释)的用法实例
2019/08/28 Python
python+Django+pycharm+mysql 搭建首个web项目详解
2019/11/29 Python
python 实现目录复制的三种小结
2019/12/04 Python
TensorFlow tf.nn.conv2d实现卷积的方式
2020/01/03 Python
使用TFRecord存取多个数据案例
2020/02/17 Python
浅谈tensorflow 中的图片读取和裁剪方式
2020/06/30 Python
Smashbox英国官网:美国知名彩妆品牌
2017/11/13 全球购物
机电一体化大学生求职信
2013/11/08 职场文书
购房意向书范本
2014/04/01 职场文书
python某漫画app逆向
2021/03/31 Python
详解python中[-1]、[:-1]、[::-1]、[n::-1]使用方法
2021/04/25 Python