pytorch实现线性拟合方式


Posted in Python onJanuary 15, 2020

一维线性拟合

数据为y=4x+5加上噪音

结果:

pytorch实现线性拟合方式

import numpy as np
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import pyplot as plt
from torch.autograd import Variable
import torch
from torch import nn
 
X = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
Y = 4*X + 5 + torch.rand(X.size())
 
class LinearRegression(nn.Module):
 def __init__(self):
  super(LinearRegression, self).__init__()
  self.linear = nn.Linear(1, 1) # 输入和输出的维度都是1
 def forward(self, X):
  out = self.linear(X)
  return out
 
model = LinearRegression()
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
 
num_epochs = 1000
for epoch in range(num_epochs):
 inputs = Variable(X)
 target = Variable(Y)
 # 向前传播
 out = model(inputs)
 loss = criterion(out, target)
 
 # 向后传播
 optimizer.zero_grad() # 注意每次迭代都需要清零
 loss.backward()
 optimizer.step()
 
 if (epoch + 1) % 20 == 0:
  print('Epoch[{}/{}], loss:{:.6f}'.format(epoch + 1, num_epochs, loss.item()))
model.eval()
predict = model(Variable(X))
predict = predict.data.numpy()
plt.plot(X.numpy(), Y.numpy(), 'ro', label='Original Data')
plt.plot(X.numpy(), predict, label='Fitting Line')
plt.show()

多维:

from itertools import count
import torch
import torch.autograd
import torch.nn.functional as F
 
POLY_DEGREE = 3
def make_features(x):
 """Builds features i.e. a matrix with columns [x, x^2, x^3]."""
 x = x.unsqueeze(1)
 return torch.cat([x ** i for i in range(1, POLY_DEGREE+1)], 1)
 
 
W_target = torch.randn(POLY_DEGREE, 1)
b_target = torch.randn(1)
 
 
def f(x):
 return x.mm(W_target) + b_target.item()
def get_batch(batch_size=32):
 random = torch.randn(batch_size)
 x = make_features(random)
 y = f(x)
 return x, y
# Define model
fc = torch.nn.Linear(W_target.size(0), 1)
batch_x, batch_y = get_batch()
print(batch_x,batch_y)
for batch_idx in count(1):
 # Get data
 
 
 # Reset gradients
 fc.zero_grad()
 
 # Forward pass
 output = F.smooth_l1_loss(fc(batch_x), batch_y)
 loss = output.item()
 
 # Backward pass
 output.backward()
 
 # Apply gradients
 for param in fc.parameters():
  param.data.add_(-0.1 * param.grad.data)
 
 # Stop criterion
 if loss < 1e-3:
  break
 
 
def poly_desc(W, b):
 """Creates a string description of a polynomial."""
 result = 'y = '
 for i, w in enumerate(W):
  result += '{:+.2f} x^{} '.format(w, len(W) - i)
 result += '{:+.2f}'.format(b[0])
 return result
 
 
print('Loss: {:.6f} after {} batches'.format(loss, batch_idx))
print('==> Learned function:\t' + poly_desc(fc.weight.view(-1), fc.bias))
print('==> Actual function:\t' + poly_desc(W_target.view(-1), b_target))

以上这篇pytorch实现线性拟合方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用python调用浏览器并打开一个网址的例子
Jun 05 Python
Python中文编码那些事
Jun 25 Python
Python实现的递归神经网络简单示例
Aug 11 Python
Python给你的头像加上圣诞帽
Jan 04 Python
pandas 条件搜索返回列表的方法
Oct 30 Python
Python快速转换numpy数组中Nan和Inf的方法实例说明
Feb 21 Python
Python爬取数据保存为Json格式的代码示例
Apr 09 Python
python实现用类读取文件数据并计算矩形面积
Jan 18 Python
python matplotlib中的subplot函数使用详解
Jan 19 Python
Python开发企业微信机器人每天定时发消息实例
Mar 17 Python
keras多显卡训练方式
Jun 10 Python
python中upper是做什么用的
Jul 20 Python
Python 支持向量机分类器的实现
Jan 15 #Python
pytorch-神经网络拟合曲线实例
Jan 15 #Python
Pytorch中的VGG实现修改最后一层FC
Jan 15 #Python
详解Python3 中的字符串格式化语法
Jan 15 #Python
用pytorch的nn.Module构造简单全链接层实例
Jan 14 #Python
pytorch三层全连接层实现手写字母识别方式
Jan 14 #Python
Python实现bilibili时间长度查询的示例代码
Jan 14 #Python
You might like
php获取后台Job管理的实现代码
2011/06/10 PHP
PHP_Cooikes不同页面无法传递的解决方法
2014/03/07 PHP
destoon设置自定义搜索的方法
2014/06/21 PHP
ThinkPHP实现支付宝接口功能实例
2014/12/02 PHP
PHP中Trait及其应用详解
2017/02/14 PHP
yii2使用GridView实现数据全选及批量删除按钮示例
2017/03/01 PHP
JavaScript 给汉字排序实例代码
2008/06/28 Javascript
js 单引号 传递方法
2009/06/22 Javascript
用JavaScript修改CSS属性的代码
2013/05/06 Javascript
NodeJS学习笔记之网络编程
2014/08/03 NodeJs
jQuery中hasClass()方法用法实例
2015/01/06 Javascript
jQuery平滑旋转幻灯片特效代码分享
2015/09/07 Javascript
js判断日期时间有效性的方法
2015/10/24 Javascript
jQuery实现页面滚动时智能浮动定位
2017/01/08 Javascript
JavaScript函数参数的传递方式详解
2017/03/06 Javascript
Angular.JS去掉访问路径URL中的#号详解
2017/03/30 Javascript
基于vue打包后字体和图片资源失效问题的解决方法
2018/03/06 Javascript
微信小程序实现即时通信聊天功能的实例代码
2018/08/17 Javascript
JavaScript学习笔记之图片库案例分析
2019/01/08 Javascript
vue里的data要用return返回的原因浅析
2019/05/28 Javascript
[03:32]2014DOTA2西雅图邀请赛 CIS外卡赛赛前black专访
2014/07/09 DOTA
Python操作sqlite3快速、安全插入数据(防注入)的实例
2014/04/26 Python
python登录豆瓣并发帖的方法
2015/07/08 Python
Django自定义分页效果
2017/06/27 Python
Python发送http请求解析返回json的实例
2018/03/26 Python
完美解决在oj中Python的循环输入问题
2018/06/25 Python
python得到电脑的开机时间方法
2018/10/15 Python
python判断文件夹内是否存在指定后缀文件的实例
2019/06/10 Python
英国灯具和灯泡网上商店:Lights.co.uk
2018/02/02 全球购物
牵手50新加坡:专为黄金岁月的单身人士而设的交友网站
2020/08/16 全球购物
STP协议的主要用途是什么?为什么要用STP
2012/12/20 面试题
介绍一下JNDI的基本概念
2013/07/26 面试题
外贸员简历中的自我评价
2014/03/04 职场文书
私人会所最新创业计划书范文
2014/03/24 职场文书
幼儿园家长寄语
2014/04/02 职场文书
入股协议书范本
2014/11/01 职场文书