pytorch实现线性回归以及多元回归


Posted in Python onApril 11, 2021

本文实例为大家分享了pytorch实现线性回归以及多元回归的具体代码,供大家参考,具体内容如下

最近在学习pytorch,现在把学习的代码放在这里,下面是github链接

直接附上github代码

# 实现一个线性回归
# 所有的层结构和损失函数都来自于 torch.nn
# torch.optim 是一个实现各种优化算法的包,调用的时候必须是需要优化的参数传入,这些参数都必须是Variable
 
x_train = np.array([[3.3],[4.4],[5.5],[6.71],[6.93],[4.168],[9.779],[6.182],[7.59],[2.167],[7.042],[10.791],[5.313],[7.997],[3.1]],dtype=np.float32)
y_train = np.array([[1.7],[2.76],[2.09],[3.19],[1.694],[1.573],[3.366],[2.596],[2.53],[1.221],[2.827],[3.465],[1.65],[2.904],[1.3]],dtype=np.float32)
 
# 首先我们需要将array转化成tensor,因为pytorch处理的单元是Tensor
 
x_train = torch.from_numpy(x_train)
y_train = torch.from_numpy(y_train)
 
 
# def a simple network
 
class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression,self).__init__()
        self.linear = nn.Linear(1, 1)  # input and output is 2_dimension
    def forward(self, x):
        out = self.linear(x)
        return out
 
 
if torch.cuda.is_available():
    model = LinearRegression().cuda()
    #model = model.cuda()
else:
    model = LinearRegression()
    #model = model.cuda()
 
# 定义loss function 和 optimize func
criterion = nn.MSELoss()   # 均方误差作为优化函数
optimizer = torch.optim.SGD(model.parameters(),lr=1e-3)
num_epochs = 30000
for epoch in range(num_epochs):
    if torch.cuda.is_available():
        inputs = Variable(x_train).cuda()
        outputs = Variable(y_train).cuda()
    else:
        inputs = Variable(x_train)
        outputs = Variable(y_train)
 
    # forward
    out = model(inputs)
    loss = criterion(out,outputs)
 
    # backword
    optimizer.zero_grad()  # 每次做反向传播之前都要进行归零梯度。不然梯度会累加在一起,造成不收敛的结果
    loss.backward()
    optimizer.step()
 
    if (epoch +1)%20==0:
        print('Epoch[{}/{}], loss: {:.6f}'.format(epoch+1,num_epochs,loss.data))
 
 
model.eval()  # 将模型变成测试模式
predict = model(Variable(x_train).cuda())
predict = predict.data.cpu().numpy()
plt.plot(x_train.numpy(),y_train.numpy(),'ro',label = 'original data')
plt.plot(x_train.numpy(),predict,label = 'Fitting line')
plt.show()

结果如图所示:

pytorch实现线性回归以及多元回归

多元回归:

# _*_encoding=utf-8_*_
# pytorch 里面最基本的操作对象是Tensor,pytorch 的tensor可以和numpy的ndarray相互转化。
# 实现一个线性回归
# 所有的层结构和损失函数都来自于 torch.nn
# torch.optim 是一个实现各种优化算法的包,调用的时候必须是需要优化的参数传入,这些参数都必须是Variable
 
 
# 实现 y = b + w1 *x + w2 *x**2 +w3*x**3
import os
os.environ['CUDA_DEVICE_ORDER']="PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES']='0'
import torch
import numpy as np
from torch.autograd import Variable
import matplotlib.pyplot as plt
from torch import nn
 
 
# pre_processing
def make_feature(x):
    x = x.unsqueeze(1)   # unsquenze 是为了添加维度1的,0表示第一维度,1表示第二维度,将tensor大小由3变为(3,1)
    return torch.cat([x ** i for i in range(1, 4)], 1)
 
# 定义好真实的数据
 
 
def f(x):
    W_output = torch.Tensor([0.5, 3, 2.4]).unsqueeze(1)
    b_output = torch.Tensor([0.9])
    return x.mm(W_output)+b_output[0]  # 外积,矩阵乘法
 
 
# 批量处理数据
def get_batch(batch_size =32):
 
    random = torch.randn(batch_size)
    x = make_feature(random)
    y = f(x)
    if torch.cuda.is_available():
 
        return Variable(x).cuda(),Variable(y).cuda()
    else:
        return Variable(x),Variable(y)
 
 
 
# def model
class poly_model(nn.Module):
    def __init__(self):
        super(poly_model,self).__init__()
        self.poly = nn.Linear(3,1)
    def forward(self,input):
        output = self.poly(input)
        return output
 
if torch.cuda.is_available():
    print("sdf")
    model = poly_model().cuda()
else:
    model = poly_model()
 
 
# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
 
epoch = 0
while True:
    batch_x, batch_y = get_batch()
    #print(batch_x)
    output = model(batch_x)
    loss = criterion(output,batch_y)
    print_loss = loss.data
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    epoch = epoch +1
    if print_loss < 1e-3:
        print(print_loss)
        break
 
model.eval()
print("Epoch = {}".format(epoch))
 
batch_x, batch_y = get_batch()
predict = model(batch_x)
a = predict - batch_y
y = torch.sum(a)
print('y = ',y)
predict = predict.data.cpu().numpy()
plt.plot(batch_x.cpu().numpy(),batch_y.cpu().numpy(),'ro',label = 'Original data')
plt.plot(batch_x.cpu().numpy(),predict,'b', ls='--',label = 'Fitting line')
plt.show()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解Python中的__new__()方法的使用
Apr 09 Python
Python中方法链的使用方法
Feb 23 Python
利用python生成一个导出数据库的bat脚本文件的方法
Dec 30 Python
pyqt5实现俄罗斯方块游戏
Jan 11 Python
在macOS上搭建python环境的实现方法
Aug 13 Python
将python依赖包打包成window下可执行文件bat方式
Dec 26 Python
OpenCV中VideoCapture类的使用详解
Feb 14 Python
Python socket处理client连接过程解析
Mar 18 Python
Python函数基本使用原理详解
Mar 19 Python
解决django FileFIELD的编码问题
Mar 30 Python
jupyter notebook中美观显示矩阵实例
Apr 17 Python
python字符串拼接.join()和拆分.split()详解
Nov 23 Python
python如何获取网络数据
Apr 11 #Python
Pytorch 使用tensor特定条件判断索引
selenium.webdriver中add_argument方法常用参数表
Apr 08 #Python
python3使用diagrams绘制架构图的步骤
python实现求纯色彩图像的边框
python爬取企查查企业信息之selenium自动模拟登录企查查
Python3 使用pip安装git并获取Yahoo金融数据的操作
Apr 08 #Python
You might like
thinkphp模板继承实例简述
2014/11/26 PHP
php中让人头疼的浮点数运算分析
2016/10/10 PHP
THINKPHP-Apache服务器中使用Alias虚拟目录URL重写 隐藏index.php
2021/03/09 PHP
JavaScript 空位补零实现代码
2010/02/26 Javascript
jquery DOM操作 基于命令改变页面
2010/05/06 Javascript
js支持键盘控制的左右切换立体式图片轮播效果代码分享
2015/08/26 Javascript
JS实现仿新浪黄色经典滑动门效果代码
2015/09/27 Javascript
AngularJS入门教程之迭代器过滤详解
2016/08/18 Javascript
用file标签实现多图文件上传预览
2017/02/14 Javascript
jq checkbox 的全选并ajax传参的实例
2017/04/01 Javascript
jquery 禁止鼠标右键并监听右键事件
2017/04/27 jQuery
node.js + socket.io 实现点对点随机匹配聊天
2017/06/30 Javascript
swiper 自动图片无限轮播实现代码
2018/05/21 Javascript
Angular5集成eventbus的示例代码
2018/07/19 Javascript
Vue中android4.4不兼容问题的解决方法
2018/09/04 Javascript
Vue.extend实现挂载到实例上的方法
2019/05/01 Javascript
JavaScript实现图片上传并预览并提交ajax
2019/09/30 Javascript
vue data引入本地图片的两种方式小结
2019/11/13 Javascript
详解Vue.js 可拖放文本框组件的使用
2021/03/03 Vue.js
跟老齐学Python之模块的加载
2014/10/24 Python
Python中设置变量访问权限的方法
2015/04/27 Python
初学python的操作难点总结(新手必看篇)
2017/08/03 Python
对Python生成汉字字库文字,以及转换为文字图片的实例详解
2019/01/29 Python
python 已知一个字符,在一个list中找出近似值或相似值实现模糊匹配
2020/02/29 Python
Django 删除upload_to文件的步骤
2020/03/30 Python
莫斯科隐形眼镜网上商店:Linzi
2019/07/22 全球购物
美国在线购买空气净化器、除湿器、加湿器网站:AllergyBuyersClub
2021/03/16 全球购物
2015年元旦晚会活动总结(学生会)
2014/11/28 职场文书
升职感谢信
2015/01/22 职场文书
大学推普周活动总结
2015/05/07 职场文书
2015年幼儿园中班下学期工作总结
2015/05/22 职场文书
地道战观后感2000字
2015/06/04 职场文书
现实表现证明材料
2015/06/19 职场文书
《水浒传》读后感3篇(范文)
2019/09/19 职场文书
关于Nginx中虚拟主机的一些冷门知识小结
2022/03/03 Servers
SpringBoot详解自定义Stater的应用
2022/07/15 Java/Android