pytorch实现线性回归以及多元回归


Posted in Python onApril 11, 2021

本文实例为大家分享了pytorch实现线性回归以及多元回归的具体代码,供大家参考,具体内容如下

最近在学习pytorch,现在把学习的代码放在这里,下面是github链接

直接附上github代码

# 实现一个线性回归
# 所有的层结构和损失函数都来自于 torch.nn
# torch.optim 是一个实现各种优化算法的包,调用的时候必须是需要优化的参数传入,这些参数都必须是Variable
 
x_train = np.array([[3.3],[4.4],[5.5],[6.71],[6.93],[4.168],[9.779],[6.182],[7.59],[2.167],[7.042],[10.791],[5.313],[7.997],[3.1]],dtype=np.float32)
y_train = np.array([[1.7],[2.76],[2.09],[3.19],[1.694],[1.573],[3.366],[2.596],[2.53],[1.221],[2.827],[3.465],[1.65],[2.904],[1.3]],dtype=np.float32)
 
# 首先我们需要将array转化成tensor,因为pytorch处理的单元是Tensor
 
x_train = torch.from_numpy(x_train)
y_train = torch.from_numpy(y_train)
 
 
# def a simple network
 
class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression,self).__init__()
        self.linear = nn.Linear(1, 1)  # input and output is 2_dimension
    def forward(self, x):
        out = self.linear(x)
        return out
 
 
if torch.cuda.is_available():
    model = LinearRegression().cuda()
    #model = model.cuda()
else:
    model = LinearRegression()
    #model = model.cuda()
 
# 定义loss function 和 optimize func
criterion = nn.MSELoss()   # 均方误差作为优化函数
optimizer = torch.optim.SGD(model.parameters(),lr=1e-3)
num_epochs = 30000
for epoch in range(num_epochs):
    if torch.cuda.is_available():
        inputs = Variable(x_train).cuda()
        outputs = Variable(y_train).cuda()
    else:
        inputs = Variable(x_train)
        outputs = Variable(y_train)
 
    # forward
    out = model(inputs)
    loss = criterion(out,outputs)
 
    # backword
    optimizer.zero_grad()  # 每次做反向传播之前都要进行归零梯度。不然梯度会累加在一起,造成不收敛的结果
    loss.backward()
    optimizer.step()
 
    if (epoch +1)%20==0:
        print('Epoch[{}/{}], loss: {:.6f}'.format(epoch+1,num_epochs,loss.data))
 
 
model.eval()  # 将模型变成测试模式
predict = model(Variable(x_train).cuda())
predict = predict.data.cpu().numpy()
plt.plot(x_train.numpy(),y_train.numpy(),'ro',label = 'original data')
plt.plot(x_train.numpy(),predict,label = 'Fitting line')
plt.show()

结果如图所示:

pytorch实现线性回归以及多元回归

多元回归:

# _*_encoding=utf-8_*_
# pytorch 里面最基本的操作对象是Tensor,pytorch 的tensor可以和numpy的ndarray相互转化。
# 实现一个线性回归
# 所有的层结构和损失函数都来自于 torch.nn
# torch.optim 是一个实现各种优化算法的包,调用的时候必须是需要优化的参数传入,这些参数都必须是Variable
 
 
# 实现 y = b + w1 *x + w2 *x**2 +w3*x**3
import os
os.environ['CUDA_DEVICE_ORDER']="PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES']='0'
import torch
import numpy as np
from torch.autograd import Variable
import matplotlib.pyplot as plt
from torch import nn
 
 
# pre_processing
def make_feature(x):
    x = x.unsqueeze(1)   # unsquenze 是为了添加维度1的,0表示第一维度,1表示第二维度,将tensor大小由3变为(3,1)
    return torch.cat([x ** i for i in range(1, 4)], 1)
 
# 定义好真实的数据
 
 
def f(x):
    W_output = torch.Tensor([0.5, 3, 2.4]).unsqueeze(1)
    b_output = torch.Tensor([0.9])
    return x.mm(W_output)+b_output[0]  # 外积,矩阵乘法
 
 
# 批量处理数据
def get_batch(batch_size =32):
 
    random = torch.randn(batch_size)
    x = make_feature(random)
    y = f(x)
    if torch.cuda.is_available():
 
        return Variable(x).cuda(),Variable(y).cuda()
    else:
        return Variable(x),Variable(y)
 
 
 
# def model
class poly_model(nn.Module):
    def __init__(self):
        super(poly_model,self).__init__()
        self.poly = nn.Linear(3,1)
    def forward(self,input):
        output = self.poly(input)
        return output
 
if torch.cuda.is_available():
    print("sdf")
    model = poly_model().cuda()
else:
    model = poly_model()
 
 
# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
 
epoch = 0
while True:
    batch_x, batch_y = get_batch()
    #print(batch_x)
    output = model(batch_x)
    loss = criterion(output,batch_y)
    print_loss = loss.data
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    epoch = epoch +1
    if print_loss < 1e-3:
        print(print_loss)
        break
 
model.eval()
print("Epoch = {}".format(epoch))
 
batch_x, batch_y = get_batch()
predict = model(batch_x)
a = predict - batch_y
y = torch.sum(a)
print('y = ',y)
predict = predict.data.cpu().numpy()
plt.plot(batch_x.cpu().numpy(),batch_y.cpu().numpy(),'ro',label = 'Original data')
plt.plot(batch_x.cpu().numpy(),predict,'b', ls='--',label = 'Fitting line')
plt.show()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python批量修改文件后缀示例代码分享
Dec 24 Python
跟老齐学Python之有容乃大的list(3)
Sep 15 Python
浅谈使用Python变量时要避免的3个错误
Oct 30 Python
python实现下载pop3邮件保存到本地
Jun 19 Python
替换python字典中的key值方法
Jul 06 Python
Python爬虫实现获取动态gif格式搞笑图片的方法示例
Dec 24 Python
Python字符串内置函数功能与用法总结
Apr 16 Python
python opencv实现证件照换底功能
Aug 19 Python
关于Python3 类方法、静态方法新解
Aug 30 Python
Django REST 异常处理详解
Jul 15 Python
python 利用opencv实现图像网络传输
Nov 12 Python
python爬虫beautifulsoup解析html方法
Dec 07 Python
python如何获取网络数据
Apr 11 #Python
Pytorch 使用tensor特定条件判断索引
selenium.webdriver中add_argument方法常用参数表
Apr 08 #Python
python3使用diagrams绘制架构图的步骤
python实现求纯色彩图像的边框
python爬取企查查企业信息之selenium自动模拟登录企查查
Python3 使用pip安装git并获取Yahoo金融数据的操作
Apr 08 #Python
You might like
精美漂亮的php分页类代码
2013/04/02 PHP
PHP Class&amp;Object -- PHP 自排序二叉树的深入解析
2013/06/25 PHP
php提供实现反射的方法和实例代码
2019/09/17 PHP
PHP全局使用Laravel辅助函数dd
2019/12/26 PHP
javascript判断ie浏览器6/7版本加载不同样式表的实现代码
2011/12/26 Javascript
推荐 21 款优秀的高性能 Node.js 开发框架
2014/08/18 Javascript
JQuery中DOM事件合成用法实例分析
2015/06/13 Javascript
JavaScript节点及列表操作实例小结
2015/08/05 Javascript
Vue.js每天必学之指令系统与自定义指令
2016/09/07 Javascript
WEB前端实现裁剪上传图片功能
2016/10/17 Javascript
EasyUI在Panel上动态添加LinkButton按钮
2017/08/11 Javascript
vue中的$emit 与$on父子组件与兄弟组件的之间通信方式
2018/05/13 Javascript
微信小程序下拉框组件使用方法详解
2018/12/28 Javascript
基于JavaScript实现猜数字游戏代码实例
2020/07/30 Javascript
vant-ui组件调用Dialog弹窗异步关闭操作
2020/11/04 Javascript
Python实现提取文章摘要的方法
2015/04/21 Python
python过滤字符串中不属于指定集合中字符的类实例
2015/06/30 Python
Linux 下 Python 实现按任意键退出的实现方法
2016/09/25 Python
理解Python中的绝对路径和相对路径
2017/08/30 Python
利用Python如何生成hash值示例详解
2017/12/20 Python
python pandas库中DataFrame对行和列的操作实例讲解
2018/06/09 Python
Python文件常见操作实例分析【读写、遍历】
2018/12/10 Python
Python实现DDos攻击实例详解
2019/02/02 Python
浅谈Python小波分析库Pywavelets的一点使用心得
2019/07/09 Python
关于python 的legend图例,参数使用说明
2020/04/17 Python
python3 kubernetes api的使用示例
2021/01/12 Python
微软澳洲官方网站:Microsoft Australia
2017/01/10 全球购物
Discard Protocol抛弃协议的作用是什么
2015/10/10 面试题
linux面试题参考答案(10)
2013/11/04 面试题
施惠特软件测试面试题以及笔试题
2015/05/13 面试题
好军嫂事迹材料
2014/01/15 职场文书
大学生未来职业生涯规划书
2014/02/15 职场文书
联谊会主持词
2014/03/26 职场文书
乒乓球比赛通知
2015/04/27 职场文书
初三英语教学反思
2016/02/15 职场文书
JavaWeb Servlet实现网页登录功能
2021/07/04 Java/Android