pytorch实现线性拟合方式


Posted in Python onJanuary 15, 2020

一维线性拟合

数据为y=4x+5加上噪音

结果:

pytorch实现线性拟合方式

import numpy as np
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import pyplot as plt
from torch.autograd import Variable
import torch
from torch import nn
 
X = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
Y = 4*X + 5 + torch.rand(X.size())
 
class LinearRegression(nn.Module):
 def __init__(self):
  super(LinearRegression, self).__init__()
  self.linear = nn.Linear(1, 1) # 输入和输出的维度都是1
 def forward(self, X):
  out = self.linear(X)
  return out
 
model = LinearRegression()
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
 
num_epochs = 1000
for epoch in range(num_epochs):
 inputs = Variable(X)
 target = Variable(Y)
 # 向前传播
 out = model(inputs)
 loss = criterion(out, target)
 
 # 向后传播
 optimizer.zero_grad() # 注意每次迭代都需要清零
 loss.backward()
 optimizer.step()
 
 if (epoch + 1) % 20 == 0:
  print('Epoch[{}/{}], loss:{:.6f}'.format(epoch + 1, num_epochs, loss.item()))
model.eval()
predict = model(Variable(X))
predict = predict.data.numpy()
plt.plot(X.numpy(), Y.numpy(), 'ro', label='Original Data')
plt.plot(X.numpy(), predict, label='Fitting Line')
plt.show()

多维:

from itertools import count
import torch
import torch.autograd
import torch.nn.functional as F
 
POLY_DEGREE = 3
def make_features(x):
 """Builds features i.e. a matrix with columns [x, x^2, x^3]."""
 x = x.unsqueeze(1)
 return torch.cat([x ** i for i in range(1, POLY_DEGREE+1)], 1)
 
 
W_target = torch.randn(POLY_DEGREE, 1)
b_target = torch.randn(1)
 
 
def f(x):
 return x.mm(W_target) + b_target.item()
def get_batch(batch_size=32):
 random = torch.randn(batch_size)
 x = make_features(random)
 y = f(x)
 return x, y
# Define model
fc = torch.nn.Linear(W_target.size(0), 1)
batch_x, batch_y = get_batch()
print(batch_x,batch_y)
for batch_idx in count(1):
 # Get data
 
 
 # Reset gradients
 fc.zero_grad()
 
 # Forward pass
 output = F.smooth_l1_loss(fc(batch_x), batch_y)
 loss = output.item()
 
 # Backward pass
 output.backward()
 
 # Apply gradients
 for param in fc.parameters():
  param.data.add_(-0.1 * param.grad.data)
 
 # Stop criterion
 if loss < 1e-3:
  break
 
 
def poly_desc(W, b):
 """Creates a string description of a polynomial."""
 result = 'y = '
 for i, w in enumerate(W):
  result += '{:+.2f} x^{} '.format(w, len(W) - i)
 result += '{:+.2f}'.format(b[0])
 return result
 
 
print('Loss: {:.6f} after {} batches'.format(loss, batch_idx))
print('==> Learned function:\t' + poly_desc(fc.weight.view(-1), fc.bias))
print('==> Actual function:\t' + poly_desc(W_target.view(-1), b_target))

以上这篇pytorch实现线性拟合方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python有序查找算法之二分法实例分析
Dec 11 Python
ubuntu17.4下为python和python3装上pip的方法
Jun 12 Python
Pycharm配置远程调试的方法步骤
Dec 17 Python
Python面向对象程序设计之私有属性及私有方法示例
Apr 08 Python
Python实现将HTML转成PDF的方法分析
May 04 Python
python中seaborn包常用图形使用详解
Nov 25 Python
python实现回旋矩阵方式(旋转矩阵)
Dec 04 Python
Python 面向对象部分知识点小结
Mar 09 Python
python不同系统中打开方法
Jun 23 Python
python中如何写类
Jun 29 Python
Python extract及contains方法代码实例
Sep 11 Python
python数据分析之单因素分析线性拟合及地理编码
Jun 25 Python
Python 支持向量机分类器的实现
Jan 15 #Python
pytorch-神经网络拟合曲线实例
Jan 15 #Python
Pytorch中的VGG实现修改最后一层FC
Jan 15 #Python
详解Python3 中的字符串格式化语法
Jan 15 #Python
用pytorch的nn.Module构造简单全链接层实例
Jan 14 #Python
pytorch三层全连接层实现手写字母识别方式
Jan 14 #Python
Python实现bilibili时间长度查询的示例代码
Jan 14 #Python
You might like
Yii中srbac权限扩展模块工作原理与用法分析
2016/07/14 PHP
PHP 中常量的知识整理
2017/04/14 PHP
laravel 5.3中自定义加密服务的方案详解
2017/05/09 PHP
js调用浏览器打印模块实现点击按钮触发自定义函数
2014/03/21 Javascript
教你如何在 Javascript 文件里使用 .Net MVC Razor 语法
2014/07/23 Javascript
node.js开发中使用Node Supervisor实现监测文件修改并自动重启应用
2014/11/04 Javascript
JavaScript基于Dom操作实现查找、修改HTML元素的内容及属性的方法
2017/01/20 Javascript
从零开始做一个pagination分页组件
2017/03/15 Javascript
AngularJS  ng-repeat遍历输出的用法
2017/06/19 Javascript
JS 判断某变量是否为某数组中的一个值的3种方法(总结)
2017/07/10 Javascript
angular指令笔记ng-options的使用方法
2017/09/18 Javascript
浅谈vue项目4rs vue-router上线后history模式遇到的坑
2018/09/27 Javascript
微信小程序和百度的语音识别接口详解
2019/05/06 Javascript
Vue实现base64编码图片间的切换功能
2019/12/04 Javascript
Python和perl实现批量对目录下电子书文件重命名的代码分享
2014/11/21 Python
Python编程中运用闭包时所需要注意的一些地方
2015/05/02 Python
python读写ini配置文件方法实例分析
2015/06/30 Python
实例讲解Python爬取网页数据
2018/07/08 Python
解决tensorflow模型参数保存和加载的问题
2018/07/26 Python
如何安装多版本python python2和python3共存以及pip共存
2018/09/18 Python
Python GUI布局尺寸适配方法
2018/10/11 Python
Python Selenium 之关闭窗口close与quit的方法
2019/02/13 Python
人工神经网络算法知识点总结
2019/06/11 Python
Python中包的用法及安装
2020/02/11 Python
关于多元线性回归分析——Python&amp;SPSS
2020/02/24 Python
Myholidays美国:在线旅游网站
2019/08/16 全球购物
单位实习证明怎么写
2014/01/17 职场文书
庆元旦活动总结
2014/07/09 职场文书
现场活动策划方案
2014/08/22 职场文书
个人工作作风整改措施思想汇报
2014/10/13 职场文书
先进事迹材料范文
2014/12/29 职场文书
2015年敬老院工作总结
2015/05/18 职场文书
「偶像大师 MILLION LIVE!」七尾百合子手办开订
2022/03/21 日漫
分析MySQL优化 index merge 后引起的死锁
2022/04/19 MySQL
安装Windows Server 2012 R2企业版操作系统并设置好相关参数
2022/04/29 Servers
Python使用openpyxl模块处理Excel文件
2022/06/05 Python