pytorch实现线性回归以及多元回归


Posted in Python onApril 11, 2021

本文实例为大家分享了pytorch实现线性回归以及多元回归的具体代码,供大家参考,具体内容如下

最近在学习pytorch,现在把学习的代码放在这里,下面是github链接

直接附上github代码

# 实现一个线性回归
# 所有的层结构和损失函数都来自于 torch.nn
# torch.optim 是一个实现各种优化算法的包,调用的时候必须是需要优化的参数传入,这些参数都必须是Variable
 
x_train = np.array([[3.3],[4.4],[5.5],[6.71],[6.93],[4.168],[9.779],[6.182],[7.59],[2.167],[7.042],[10.791],[5.313],[7.997],[3.1]],dtype=np.float32)
y_train = np.array([[1.7],[2.76],[2.09],[3.19],[1.694],[1.573],[3.366],[2.596],[2.53],[1.221],[2.827],[3.465],[1.65],[2.904],[1.3]],dtype=np.float32)
 
# 首先我们需要将array转化成tensor,因为pytorch处理的单元是Tensor
 
x_train = torch.from_numpy(x_train)
y_train = torch.from_numpy(y_train)
 
 
# def a simple network
 
class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression,self).__init__()
        self.linear = nn.Linear(1, 1)  # input and output is 2_dimension
    def forward(self, x):
        out = self.linear(x)
        return out
 
 
if torch.cuda.is_available():
    model = LinearRegression().cuda()
    #model = model.cuda()
else:
    model = LinearRegression()
    #model = model.cuda()
 
# 定义loss function 和 optimize func
criterion = nn.MSELoss()   # 均方误差作为优化函数
optimizer = torch.optim.SGD(model.parameters(),lr=1e-3)
num_epochs = 30000
for epoch in range(num_epochs):
    if torch.cuda.is_available():
        inputs = Variable(x_train).cuda()
        outputs = Variable(y_train).cuda()
    else:
        inputs = Variable(x_train)
        outputs = Variable(y_train)
 
    # forward
    out = model(inputs)
    loss = criterion(out,outputs)
 
    # backword
    optimizer.zero_grad()  # 每次做反向传播之前都要进行归零梯度。不然梯度会累加在一起,造成不收敛的结果
    loss.backward()
    optimizer.step()
 
    if (epoch +1)%20==0:
        print('Epoch[{}/{}], loss: {:.6f}'.format(epoch+1,num_epochs,loss.data))
 
 
model.eval()  # 将模型变成测试模式
predict = model(Variable(x_train).cuda())
predict = predict.data.cpu().numpy()
plt.plot(x_train.numpy(),y_train.numpy(),'ro',label = 'original data')
plt.plot(x_train.numpy(),predict,label = 'Fitting line')
plt.show()

结果如图所示:

pytorch实现线性回归以及多元回归

多元回归:

# _*_encoding=utf-8_*_
# pytorch 里面最基本的操作对象是Tensor,pytorch 的tensor可以和numpy的ndarray相互转化。
# 实现一个线性回归
# 所有的层结构和损失函数都来自于 torch.nn
# torch.optim 是一个实现各种优化算法的包,调用的时候必须是需要优化的参数传入,这些参数都必须是Variable
 
 
# 实现 y = b + w1 *x + w2 *x**2 +w3*x**3
import os
os.environ['CUDA_DEVICE_ORDER']="PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES']='0'
import torch
import numpy as np
from torch.autograd import Variable
import matplotlib.pyplot as plt
from torch import nn
 
 
# pre_processing
def make_feature(x):
    x = x.unsqueeze(1)   # unsquenze 是为了添加维度1的,0表示第一维度,1表示第二维度,将tensor大小由3变为(3,1)
    return torch.cat([x ** i for i in range(1, 4)], 1)
 
# 定义好真实的数据
 
 
def f(x):
    W_output = torch.Tensor([0.5, 3, 2.4]).unsqueeze(1)
    b_output = torch.Tensor([0.9])
    return x.mm(W_output)+b_output[0]  # 外积,矩阵乘法
 
 
# 批量处理数据
def get_batch(batch_size =32):
 
    random = torch.randn(batch_size)
    x = make_feature(random)
    y = f(x)
    if torch.cuda.is_available():
 
        return Variable(x).cuda(),Variable(y).cuda()
    else:
        return Variable(x),Variable(y)
 
 
 
# def model
class poly_model(nn.Module):
    def __init__(self):
        super(poly_model,self).__init__()
        self.poly = nn.Linear(3,1)
    def forward(self,input):
        output = self.poly(input)
        return output
 
if torch.cuda.is_available():
    print("sdf")
    model = poly_model().cuda()
else:
    model = poly_model()
 
 
# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
 
epoch = 0
while True:
    batch_x, batch_y = get_batch()
    #print(batch_x)
    output = model(batch_x)
    loss = criterion(output,batch_y)
    print_loss = loss.data
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    epoch = epoch +1
    if print_loss < 1e-3:
        print(print_loss)
        break
 
model.eval()
print("Epoch = {}".format(epoch))
 
batch_x, batch_y = get_batch()
predict = model(batch_x)
a = predict - batch_y
y = torch.sum(a)
print('y = ',y)
predict = predict.data.cpu().numpy()
plt.plot(batch_x.cpu().numpy(),batch_y.cpu().numpy(),'ro',label = 'Original data')
plt.plot(batch_x.cpu().numpy(),predict,'b', ls='--',label = 'Fitting line')
plt.show()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python编写一个简单的tic-tac-toe游戏的教程
Apr 16 Python
学习python之编写简单简单连接数据库并执行查询操作
Feb 27 Python
Python使用自带的ConfigParser模块读写ini配置文件
Jun 26 Python
通过源码分析Python中的切片赋值
May 08 Python
Python实现输出程序执行进度百分比的方法
Sep 16 Python
网红编程语言Python将纳入高考你怎么看?
Jun 07 Python
对python中的iter()函数与next()函数详解
Oct 18 Python
Python 利用pydub库操作音频文件的方法
Jan 09 Python
Python3实现取图片中特定的像素替换指定的颜色示例
Jan 24 Python
Python matplotlib绘制饼状图功能示例
Sep 10 Python
Selenium webdriver添加cookie实现过程详解
Aug 12 Python
利用Python判断整数是否是回文数的3种方法总结
Jul 07 Python
python如何获取网络数据
Apr 11 #Python
Pytorch 使用tensor特定条件判断索引
selenium.webdriver中add_argument方法常用参数表
Apr 08 #Python
python3使用diagrams绘制架构图的步骤
python实现求纯色彩图像的边框
python爬取企查查企业信息之selenium自动模拟登录企查查
Python3 使用pip安装git并获取Yahoo金融数据的操作
Apr 08 #Python
You might like
非常精妙的PHP递归调用与静态变量使用
2012/12/16 PHP
php IP转换整形(ip2long)的详解
2013/06/06 PHP
php中将数组转成字符串并保存到数据库中的函数代码
2013/09/29 PHP
PHP中addslashes与mysql_escape_string的区别分析
2016/04/25 PHP
thinkPHP多域名情况下使用memcache方式共享session数据的实现方法
2016/07/21 PHP
php实现数组重复数字统计实例
2018/09/30 PHP
PHP进阶学习之类的自动加载机制原理分析
2019/06/18 PHP
web性能优化之javascript性能调优
2012/12/28 Javascript
jquery滚动组件(vticker.js)实现页面动态数据的滚动效果
2013/07/03 Javascript
javascript读写XML实现广告轮换(兼容IE、FF)
2013/08/09 Javascript
jQuery实现为图片添加镜头放大效果的方法
2015/06/25 Javascript
js实现文字在按钮上滚动的方法
2015/08/20 Javascript
分享使用AngularJS创建应用的5个框架
2015/12/05 Javascript
js倒计时抢购实例
2015/12/20 Javascript
javascript 数组的定义和数组的长度
2016/06/07 Javascript
nodejs加密Crypto的实例代码
2016/07/07 NodeJs
jquery中ajax请求后台数据成功后既不执行success也不执行error的完美解决方法
2017/12/24 jQuery
JS实现进度条动态加载特效
2020/03/25 Javascript
[05:31]DOTA2英雄梦之声_第08期_莉娜
2014/06/23 DOTA
布同自制Python函数帮助查询小工具
2011/03/13 Python
Python中的urllib模块使用详解
2015/07/07 Python
探究python中open函数的使用
2016/03/01 Python
python编程实现希尔排序
2017/04/13 Python
python使用super()出现错误解决办法
2017/08/14 Python
在python 中split()使用多符号分割的例子
2019/07/15 Python
python 爬取疫情数据的源码
2020/02/09 Python
python 将视频 通过视频帧转换成时间实例
2020/04/23 Python
如何使用Python调整图像大小
2020/09/26 Python
Python 找出英文单词列表(list)中最长单词链
2020/12/14 Python
Speedo美国:澳大利亚顶尖泳衣制造商
2016/08/03 全球购物
英国奢华护肤、美容和Spa品牌:Temple Spa
2019/11/02 全球购物
个人求职简历中英文自我评价
2013/12/16 职场文书
会计系中文个人求职信
2013/12/24 职场文书
大学生优秀自荐信范文
2014/02/25 职场文书
小学校本培训方案
2014/06/06 职场文书
详解java如何集成swagger组件
2021/06/21 Java/Android