pytorch实现线性拟合方式


Posted in Python onJanuary 15, 2020

一维线性拟合

数据为y=4x+5加上噪音

结果:

pytorch实现线性拟合方式

import numpy as np
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import pyplot as plt
from torch.autograd import Variable
import torch
from torch import nn
 
X = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
Y = 4*X + 5 + torch.rand(X.size())
 
class LinearRegression(nn.Module):
 def __init__(self):
  super(LinearRegression, self).__init__()
  self.linear = nn.Linear(1, 1) # 输入和输出的维度都是1
 def forward(self, X):
  out = self.linear(X)
  return out
 
model = LinearRegression()
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
 
num_epochs = 1000
for epoch in range(num_epochs):
 inputs = Variable(X)
 target = Variable(Y)
 # 向前传播
 out = model(inputs)
 loss = criterion(out, target)
 
 # 向后传播
 optimizer.zero_grad() # 注意每次迭代都需要清零
 loss.backward()
 optimizer.step()
 
 if (epoch + 1) % 20 == 0:
  print('Epoch[{}/{}], loss:{:.6f}'.format(epoch + 1, num_epochs, loss.item()))
model.eval()
predict = model(Variable(X))
predict = predict.data.numpy()
plt.plot(X.numpy(), Y.numpy(), 'ro', label='Original Data')
plt.plot(X.numpy(), predict, label='Fitting Line')
plt.show()

多维:

from itertools import count
import torch
import torch.autograd
import torch.nn.functional as F
 
POLY_DEGREE = 3
def make_features(x):
 """Builds features i.e. a matrix with columns [x, x^2, x^3]."""
 x = x.unsqueeze(1)
 return torch.cat([x ** i for i in range(1, POLY_DEGREE+1)], 1)
 
 
W_target = torch.randn(POLY_DEGREE, 1)
b_target = torch.randn(1)
 
 
def f(x):
 return x.mm(W_target) + b_target.item()
def get_batch(batch_size=32):
 random = torch.randn(batch_size)
 x = make_features(random)
 y = f(x)
 return x, y
# Define model
fc = torch.nn.Linear(W_target.size(0), 1)
batch_x, batch_y = get_batch()
print(batch_x,batch_y)
for batch_idx in count(1):
 # Get data
 
 
 # Reset gradients
 fc.zero_grad()
 
 # Forward pass
 output = F.smooth_l1_loss(fc(batch_x), batch_y)
 loss = output.item()
 
 # Backward pass
 output.backward()
 
 # Apply gradients
 for param in fc.parameters():
  param.data.add_(-0.1 * param.grad.data)
 
 # Stop criterion
 if loss < 1e-3:
  break
 
 
def poly_desc(W, b):
 """Creates a string description of a polynomial."""
 result = 'y = '
 for i, w in enumerate(W):
  result += '{:+.2f} x^{} '.format(w, len(W) - i)
 result += '{:+.2f}'.format(b[0])
 return result
 
 
print('Loss: {:.6f} after {} batches'.format(loss, batch_idx))
print('==> Learned function:\t' + poly_desc(fc.weight.view(-1), fc.bias))
print('==> Actual function:\t' + poly_desc(W_target.view(-1), b_target))

以上这篇pytorch实现线性拟合方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 输出一个两行字符的变量
Feb 05 Python
Android应用开发中Action bar编写的入门教程
Feb 26 Python
Python基于回溯法子集树模板实现图的遍历功能示例
Sep 05 Python
Python 私有函数的实例详解
Sep 11 Python
对python3标准库httpclient的使用详解
Dec 18 Python
python+mysql实现教务管理系统
Feb 20 Python
python使用Plotly绘图工具绘制气泡图
Apr 01 Python
Python一键查找iOS项目中未使用的图片、音频、视频资源
Aug 12 Python
Python 50行爬虫抓取并处理图灵书目过程详解
Sep 20 Python
基于Python中isfile函数和isdir函数使用详解
Nov 29 Python
详解pycharm连接不上mysql数据库的解决办法
Jan 10 Python
只用40行Python代码就能写出pdf转word小工具
May 31 Python
Python 支持向量机分类器的实现
Jan 15 #Python
pytorch-神经网络拟合曲线实例
Jan 15 #Python
Pytorch中的VGG实现修改最后一层FC
Jan 15 #Python
详解Python3 中的字符串格式化语法
Jan 15 #Python
用pytorch的nn.Module构造简单全链接层实例
Jan 14 #Python
pytorch三层全连接层实现手写字母识别方式
Jan 14 #Python
Python实现bilibili时间长度查询的示例代码
Jan 14 #Python
You might like
在 PHP 中使用随机数的三个步骤
2006/10/09 PHP
不用数据库的多用户文件自由上传投票系统(2)
2006/10/09 PHP
使用php实现快钱支付功能(涉及到接口)
2013/07/01 PHP
PHP开发工具ZendStudio下Xdebug工具使用说明详解
2013/11/11 PHP
PHP代码优化之成员变量获取速度对比
2014/02/28 PHP
php数组中删除元素之重新索引的方法
2014/09/16 PHP
PHP基于简单递归函数求一个数阶乘的方法示例
2017/04/26 PHP
php apache开启跨域模式过程详解
2019/07/08 PHP
浅谈Laravel中的三种中间件的作用
2019/10/13 PHP
php 多个变量指向同一个引用($b = &amp;$a)用法分析
2019/11/13 PHP
Javascript select控件操作大全(新增、修改、删除、选中、清空、判断存在等)
2008/12/19 Javascript
javascript自定义的addClass()方法
2014/05/28 Javascript
初步认识JavaScript函数库jQuery
2015/06/18 Javascript
jquery实现点击展开列表同时隐藏其他列表
2015/08/10 Javascript
jQuery增加与删除table列的方法
2016/03/01 Javascript
jQuery仿京东商城楼梯式导航定位菜单
2016/07/25 Javascript
JS中split()用法(将字符串按指定符号分割成数组)
2016/10/24 Javascript
NodeJS测试框架mocha入门教程
2017/03/28 NodeJs
JS使用tween.js动画库实现轮播图并且有切换功能
2018/07/17 Javascript
JavaScript函数、闭包、原型、面向对象学习笔记
2018/09/06 Javascript
JQuery样式与属性设置方法分析
2019/12/07 jQuery
Vue实现简单购物车功能
2020/12/13 Vue.js
跨平台python异步回调机制实现和使用方法
2013/11/26 Python
python sys模块sys.path使用方法示例
2013/12/04 Python
Python中变量交换的例子
2014/08/25 Python
Fabric 应用案例
2016/08/28 Python
Python中的sort()方法使用基础教程
2017/01/08 Python
分享8个非常流行的 Python 可视化工具包
2019/06/05 Python
python pandas时序处理相关功能详解
2019/07/03 Python
基于python判断目录或者文件代码实例
2019/11/29 Python
Python函数的迭代器与生成器的示例代码
2020/06/18 Python
Python 使用生成器代替线程的方法
2020/08/04 Python
css3背景_动力节点Java学院整理
2017/07/11 HTML / CSS
柒牌官方商城:中国男装优秀品牌
2017/06/30 全球购物
毕业生自荐书模版
2014/01/04 职场文书
党员剖析材料范文
2014/09/30 职场文书