pytorch实现线性拟合方式


Posted in Python onJanuary 15, 2020

一维线性拟合

数据为y=4x+5加上噪音

结果:

pytorch实现线性拟合方式

import numpy as np
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import pyplot as plt
from torch.autograd import Variable
import torch
from torch import nn
 
X = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
Y = 4*X + 5 + torch.rand(X.size())
 
class LinearRegression(nn.Module):
 def __init__(self):
  super(LinearRegression, self).__init__()
  self.linear = nn.Linear(1, 1) # 输入和输出的维度都是1
 def forward(self, X):
  out = self.linear(X)
  return out
 
model = LinearRegression()
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
 
num_epochs = 1000
for epoch in range(num_epochs):
 inputs = Variable(X)
 target = Variable(Y)
 # 向前传播
 out = model(inputs)
 loss = criterion(out, target)
 
 # 向后传播
 optimizer.zero_grad() # 注意每次迭代都需要清零
 loss.backward()
 optimizer.step()
 
 if (epoch + 1) % 20 == 0:
  print('Epoch[{}/{}], loss:{:.6f}'.format(epoch + 1, num_epochs, loss.item()))
model.eval()
predict = model(Variable(X))
predict = predict.data.numpy()
plt.plot(X.numpy(), Y.numpy(), 'ro', label='Original Data')
plt.plot(X.numpy(), predict, label='Fitting Line')
plt.show()

多维:

from itertools import count
import torch
import torch.autograd
import torch.nn.functional as F
 
POLY_DEGREE = 3
def make_features(x):
 """Builds features i.e. a matrix with columns [x, x^2, x^3]."""
 x = x.unsqueeze(1)
 return torch.cat([x ** i for i in range(1, POLY_DEGREE+1)], 1)
 
 
W_target = torch.randn(POLY_DEGREE, 1)
b_target = torch.randn(1)
 
 
def f(x):
 return x.mm(W_target) + b_target.item()
def get_batch(batch_size=32):
 random = torch.randn(batch_size)
 x = make_features(random)
 y = f(x)
 return x, y
# Define model
fc = torch.nn.Linear(W_target.size(0), 1)
batch_x, batch_y = get_batch()
print(batch_x,batch_y)
for batch_idx in count(1):
 # Get data
 
 
 # Reset gradients
 fc.zero_grad()
 
 # Forward pass
 output = F.smooth_l1_loss(fc(batch_x), batch_y)
 loss = output.item()
 
 # Backward pass
 output.backward()
 
 # Apply gradients
 for param in fc.parameters():
  param.data.add_(-0.1 * param.grad.data)
 
 # Stop criterion
 if loss < 1e-3:
  break
 
 
def poly_desc(W, b):
 """Creates a string description of a polynomial."""
 result = 'y = '
 for i, w in enumerate(W):
  result += '{:+.2f} x^{} '.format(w, len(W) - i)
 result += '{:+.2f}'.format(b[0])
 return result
 
 
print('Loss: {:.6f} after {} batches'.format(loss, batch_idx))
print('==> Learned function:\t' + poly_desc(fc.weight.view(-1), fc.bias))
print('==> Actual function:\t' + poly_desc(W_target.view(-1), b_target))

以上这篇pytorch实现线性拟合方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python3.6 +tkinter GUI编程 实现界面化的文本处理工具(推荐)
Dec 20 Python
java中两个byte数组实现合并的示例
May 09 Python
python 读取摄像头数据并保存的实例
Aug 03 Python
Django框架搭建的简易图书信息网站案例
May 25 Python
pandas.cut具体使用总结
Jun 24 Python
解决Django中调用keras的模型出现的问题
Aug 07 Python
浅析Python数字类型和字符串类型的内置方法
Dec 22 Python
Pandas缺失值2种处理方式代码实例
Jun 13 Python
Python环境管理virtualenv&amp;virtualenvwrapper的配置详解
Jul 01 Python
基于Python pyecharts实现多种图例代码解析
Aug 10 Python
python 动态绘制爱心的示例
Sep 27 Python
python基础之//、/与%的区别详解
Jun 10 Python
Python 支持向量机分类器的实现
Jan 15 #Python
pytorch-神经网络拟合曲线实例
Jan 15 #Python
Pytorch中的VGG实现修改最后一层FC
Jan 15 #Python
详解Python3 中的字符串格式化语法
Jan 15 #Python
用pytorch的nn.Module构造简单全链接层实例
Jan 14 #Python
pytorch三层全连接层实现手写字母识别方式
Jan 14 #Python
Python实现bilibili时间长度查询的示例代码
Jan 14 #Python
You might like
全国FM电台频率大全 - 11 浙江省
2020/03/11 无线电
php 检查电子邮件函数(自写)
2014/01/16 PHP
yii2.0实现验证用户名与邮箱功能
2015/12/22 PHP
使用php实现从身份证中提取生日
2016/05/09 PHP
Laravel5.5 数据库迁移:创建表与修改表示例
2019/10/23 PHP
XAMPP升级PHP版本实现步骤解析
2020/09/04 PHP
Javascript客户端脚本的设计和应用
2006/08/21 Javascript
flexigrid 类似ext grid的JS表格代码
2010/07/17 Javascript
javascript委托(Delegate)blur和focus用法实例分析
2015/05/26 Javascript
JS实现iframe编辑器光标位置插入内容的方法(兼容IE和Firefox)
2016/06/24 Javascript
JavaScript自定义分页样式
2017/01/17 Javascript
Bootstrap DateTime Picker日历控件简单应用
2017/03/25 Javascript
一个因@click.stop引发的bug的解决
2019/01/08 Javascript
详解基于electron制作一个node压缩图片的桌面应用
2019/01/29 Javascript
小程序实现订单倒计时功能
2019/04/23 Javascript
JavaScript的console命令使用实例
2019/12/03 Javascript
javascript实现数字时钟效果
2021/02/06 Javascript
[00:09]DOTA2全国高校联赛 精彩活动引爆全场
2018/05/30 DOTA
[47:03]完美世界DOTA2联赛PWL S3 Galaxy Racer vs Phoenix 第二场 12.10
2020/12/13 DOTA
kNN算法python实现和简单数字识别的方法
2014/11/18 Python
django3.02模板中的超链接配置实例代码
2020/02/04 Python
基于tensorflow for循环 while循环案例
2020/06/30 Python
Python如何设置指定窗口为前台活动窗口
2020/08/12 Python
python实现发送邮件
2021/03/02 Python
澳大利亚领先的美容护肤品零售商之一:SkincareStore
2018/01/22 全球购物
介绍一下linux文件系统分配策略
2013/02/25 面试题
信息专业大学生自我评价分享
2014/01/17 职场文书
群众路线剖析材料
2014/02/02 职场文书
宣传保护环境的公益广告词
2014/03/13 职场文书
保密工作责任书
2014/04/16 职场文书
临床护理求职信
2014/04/26 职场文书
六一儿童节开幕词
2015/01/29 职场文书
会议通知
2015/04/15 职场文书
辛德勒的名单观后感
2015/06/03 职场文书
导游词之太原天龙山
2020/01/02 职场文书
Nginx性能优化之Gzip压缩设置详解(最大程度提高页面打开速度)
2022/02/12 Servers