pytorch实现线性回归以及多元回归


Posted in Python onApril 11, 2021

本文实例为大家分享了pytorch实现线性回归以及多元回归的具体代码,供大家参考,具体内容如下

最近在学习pytorch,现在把学习的代码放在这里,下面是github链接

直接附上github代码

# 实现一个线性回归
# 所有的层结构和损失函数都来自于 torch.nn
# torch.optim 是一个实现各种优化算法的包,调用的时候必须是需要优化的参数传入,这些参数都必须是Variable
 
x_train = np.array([[3.3],[4.4],[5.5],[6.71],[6.93],[4.168],[9.779],[6.182],[7.59],[2.167],[7.042],[10.791],[5.313],[7.997],[3.1]],dtype=np.float32)
y_train = np.array([[1.7],[2.76],[2.09],[3.19],[1.694],[1.573],[3.366],[2.596],[2.53],[1.221],[2.827],[3.465],[1.65],[2.904],[1.3]],dtype=np.float32)
 
# 首先我们需要将array转化成tensor,因为pytorch处理的单元是Tensor
 
x_train = torch.from_numpy(x_train)
y_train = torch.from_numpy(y_train)
 
 
# def a simple network
 
class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression,self).__init__()
        self.linear = nn.Linear(1, 1)  # input and output is 2_dimension
    def forward(self, x):
        out = self.linear(x)
        return out
 
 
if torch.cuda.is_available():
    model = LinearRegression().cuda()
    #model = model.cuda()
else:
    model = LinearRegression()
    #model = model.cuda()
 
# 定义loss function 和 optimize func
criterion = nn.MSELoss()   # 均方误差作为优化函数
optimizer = torch.optim.SGD(model.parameters(),lr=1e-3)
num_epochs = 30000
for epoch in range(num_epochs):
    if torch.cuda.is_available():
        inputs = Variable(x_train).cuda()
        outputs = Variable(y_train).cuda()
    else:
        inputs = Variable(x_train)
        outputs = Variable(y_train)
 
    # forward
    out = model(inputs)
    loss = criterion(out,outputs)
 
    # backword
    optimizer.zero_grad()  # 每次做反向传播之前都要进行归零梯度。不然梯度会累加在一起,造成不收敛的结果
    loss.backward()
    optimizer.step()
 
    if (epoch +1)%20==0:
        print('Epoch[{}/{}], loss: {:.6f}'.format(epoch+1,num_epochs,loss.data))
 
 
model.eval()  # 将模型变成测试模式
predict = model(Variable(x_train).cuda())
predict = predict.data.cpu().numpy()
plt.plot(x_train.numpy(),y_train.numpy(),'ro',label = 'original data')
plt.plot(x_train.numpy(),predict,label = 'Fitting line')
plt.show()

结果如图所示:

pytorch实现线性回归以及多元回归

多元回归:

# _*_encoding=utf-8_*_
# pytorch 里面最基本的操作对象是Tensor,pytorch 的tensor可以和numpy的ndarray相互转化。
# 实现一个线性回归
# 所有的层结构和损失函数都来自于 torch.nn
# torch.optim 是一个实现各种优化算法的包,调用的时候必须是需要优化的参数传入,这些参数都必须是Variable
 
 
# 实现 y = b + w1 *x + w2 *x**2 +w3*x**3
import os
os.environ['CUDA_DEVICE_ORDER']="PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES']='0'
import torch
import numpy as np
from torch.autograd import Variable
import matplotlib.pyplot as plt
from torch import nn
 
 
# pre_processing
def make_feature(x):
    x = x.unsqueeze(1)   # unsquenze 是为了添加维度1的,0表示第一维度,1表示第二维度,将tensor大小由3变为(3,1)
    return torch.cat([x ** i for i in range(1, 4)], 1)
 
# 定义好真实的数据
 
 
def f(x):
    W_output = torch.Tensor([0.5, 3, 2.4]).unsqueeze(1)
    b_output = torch.Tensor([0.9])
    return x.mm(W_output)+b_output[0]  # 外积,矩阵乘法
 
 
# 批量处理数据
def get_batch(batch_size =32):
 
    random = torch.randn(batch_size)
    x = make_feature(random)
    y = f(x)
    if torch.cuda.is_available():
 
        return Variable(x).cuda(),Variable(y).cuda()
    else:
        return Variable(x),Variable(y)
 
 
 
# def model
class poly_model(nn.Module):
    def __init__(self):
        super(poly_model,self).__init__()
        self.poly = nn.Linear(3,1)
    def forward(self,input):
        output = self.poly(input)
        return output
 
if torch.cuda.is_available():
    print("sdf")
    model = poly_model().cuda()
else:
    model = poly_model()
 
 
# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
 
epoch = 0
while True:
    batch_x, batch_y = get_batch()
    #print(batch_x)
    output = model(batch_x)
    loss = criterion(output,batch_y)
    print_loss = loss.data
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    epoch = epoch +1
    if print_loss < 1e-3:
        print(print_loss)
        break
 
model.eval()
print("Epoch = {}".format(epoch))
 
batch_x, batch_y = get_batch()
predict = model(batch_x)
a = predict - batch_y
y = torch.sum(a)
print('y = ',y)
predict = predict.data.cpu().numpy()
plt.plot(batch_x.cpu().numpy(),batch_y.cpu().numpy(),'ro',label = 'Original data')
plt.plot(batch_x.cpu().numpy(),predict,'b', ls='--',label = 'Fitting line')
plt.show()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python获取当前日期和时间的方法
Apr 30 Python
python编写爬虫小程序
May 14 Python
听歌识曲--用python实现一个音乐检索器的功能
Nov 15 Python
django框架如何集成celery进行开发
May 24 Python
用python实现百度翻译的示例代码
Mar 09 Python
python 为什么说eval要慎用
Mar 26 Python
详解如何从TensorFlow的mnist数据集导出手写体数字图片
Aug 05 Python
Python FFT合成波形的实例
Dec 04 Python
浅谈tensorflow中Dataset图片的批量读取及维度的操作详解
Jan 20 Python
Django数据库操作之save与update的使用
Apr 01 Python
150行Python代码实现带界面的数独游戏
Apr 04 Python
Python学习笔记之装饰器
Aug 06 Python
python如何获取网络数据
Apr 11 #Python
Pytorch 使用tensor特定条件判断索引
selenium.webdriver中add_argument方法常用参数表
Apr 08 #Python
python3使用diagrams绘制架构图的步骤
python实现求纯色彩图像的边框
python爬取企查查企业信息之selenium自动模拟登录企查查
Python3 使用pip安装git并获取Yahoo金融数据的操作
Apr 08 #Python
You might like
使用php来实现网络服务
2009/09/15 PHP
Ajax PHP 边学边练 之三 数据库
2009/11/26 PHP
PHP创建桌面快捷方式的实例代码
2014/02/17 PHP
Thinkphp框架开发移动端接口(2)
2016/08/18 PHP
Js+XML 操作
2006/09/20 Javascript
数组Array进行原型prototype扩展后带来的for in遍历问题
2010/02/07 Javascript
Javascript 实现的数独解题算法网页实例
2013/10/15 Javascript
解决用jquery load加载页面到div时,不执行页面js的问题
2014/02/22 Javascript
node.js中的socket.io入门实例
2014/04/26 Javascript
流量统计器如何鉴别C#:WebBrowser中伪造referer
2015/01/07 Javascript
jQuery实现表格行上移下移和置顶的方法
2015/05/22 Javascript
JavaScript实现把rgb颜色转换成16进制颜色的方法
2015/06/01 Javascript
jQuery根据name属性进行查找的用法分析
2016/06/23 Javascript
Angularjs中的ui-bootstrap的使用教程
2017/02/19 Javascript
让webpack+vue-cil项目不再自动打开浏览器的方法
2018/09/27 Javascript
JavaScript剩余操作符Rest Operator详解
2019/07/20 Javascript
python求pi的方法
2014/10/08 Python
Python中的下划线详解
2015/06/24 Python
Python2.7环境Flask框架安装简明教程【已测试】
2018/07/13 Python
Python 实现域名解析为ip的方法
2019/02/14 Python
Python hexstring-list-str之间的转换方法
2019/06/12 Python
使用python制作一个为hex文件增加版本号的脚本实例
2019/06/12 Python
django 2.2和mysql使用的常见问题
2019/07/18 Python
Django框架获取form表单数据方式总结
2020/04/22 Python
Python如何根据时间序列数据作图
2020/05/12 Python
Python中logging日志记录到文件及自动分割的操作代码
2020/08/05 Python
如何快速一次性卸载所有python包(第三方库)呢
2020/10/20 Python
纯css3实现照片墙效果
2014/12/26 HTML / CSS
htmlentities() 和 htmlspecialchars()有什么区别
2015/07/01 面试题
日语专业推荐信
2013/11/12 职场文书
上班离岗检讨书
2014/01/27 职场文书
九一八事变演讲稿范文
2014/09/14 职场文书
公民授权委托书
2014/10/15 职场文书
小英雄雨来观后感
2015/06/09 职场文书
MySQL千万级数据表的优化实战记录
2021/08/04 MySQL
javascript对象3个属性特征
2021/11/17 Javascript