pytorch实现线性拟合方式


Posted in Python onJanuary 15, 2020

一维线性拟合

数据为y=4x+5加上噪音

结果:

pytorch实现线性拟合方式

import numpy as np
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import pyplot as plt
from torch.autograd import Variable
import torch
from torch import nn
 
X = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
Y = 4*X + 5 + torch.rand(X.size())
 
class LinearRegression(nn.Module):
 def __init__(self):
  super(LinearRegression, self).__init__()
  self.linear = nn.Linear(1, 1) # 输入和输出的维度都是1
 def forward(self, X):
  out = self.linear(X)
  return out
 
model = LinearRegression()
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
 
num_epochs = 1000
for epoch in range(num_epochs):
 inputs = Variable(X)
 target = Variable(Y)
 # 向前传播
 out = model(inputs)
 loss = criterion(out, target)
 
 # 向后传播
 optimizer.zero_grad() # 注意每次迭代都需要清零
 loss.backward()
 optimizer.step()
 
 if (epoch + 1) % 20 == 0:
  print('Epoch[{}/{}], loss:{:.6f}'.format(epoch + 1, num_epochs, loss.item()))
model.eval()
predict = model(Variable(X))
predict = predict.data.numpy()
plt.plot(X.numpy(), Y.numpy(), 'ro', label='Original Data')
plt.plot(X.numpy(), predict, label='Fitting Line')
plt.show()

多维:

from itertools import count
import torch
import torch.autograd
import torch.nn.functional as F
 
POLY_DEGREE = 3
def make_features(x):
 """Builds features i.e. a matrix with columns [x, x^2, x^3]."""
 x = x.unsqueeze(1)
 return torch.cat([x ** i for i in range(1, POLY_DEGREE+1)], 1)
 
 
W_target = torch.randn(POLY_DEGREE, 1)
b_target = torch.randn(1)
 
 
def f(x):
 return x.mm(W_target) + b_target.item()
def get_batch(batch_size=32):
 random = torch.randn(batch_size)
 x = make_features(random)
 y = f(x)
 return x, y
# Define model
fc = torch.nn.Linear(W_target.size(0), 1)
batch_x, batch_y = get_batch()
print(batch_x,batch_y)
for batch_idx in count(1):
 # Get data
 
 
 # Reset gradients
 fc.zero_grad()
 
 # Forward pass
 output = F.smooth_l1_loss(fc(batch_x), batch_y)
 loss = output.item()
 
 # Backward pass
 output.backward()
 
 # Apply gradients
 for param in fc.parameters():
  param.data.add_(-0.1 * param.grad.data)
 
 # Stop criterion
 if loss < 1e-3:
  break
 
 
def poly_desc(W, b):
 """Creates a string description of a polynomial."""
 result = 'y = '
 for i, w in enumerate(W):
  result += '{:+.2f} x^{} '.format(w, len(W) - i)
 result += '{:+.2f}'.format(b[0])
 return result
 
 
print('Loss: {:.6f} after {} batches'.format(loss, batch_idx))
print('==> Learned function:\t' + poly_desc(fc.weight.view(-1), fc.bias))
print('==> Actual function:\t' + poly_desc(W_target.view(-1), b_target))

以上这篇pytorch实现线性拟合方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Django1.7+python 2.78+pycharm配置mysql数据库
Oct 09 Python
python3实现ftp服务功能(客户端)
Mar 24 Python
利用Python爬取微博数据生成词云图片实例代码
Aug 31 Python
彻底理解Python list切片原理
Oct 27 Python
Flask模拟实现CSRF攻击的方法
Jul 24 Python
Python多线程应用于自动化测试操作示例
Dec 06 Python
ActiveMQ:使用Python访问ActiveMQ的方法
Jan 30 Python
python3实现mysql导出excel的方法
Jul 31 Python
python操作openpyxl导出Excel 设置单元格格式及合并处理代码实例
Aug 27 Python
基于Python3.6中的OpenCV实现图片色彩空间的转换
Feb 03 Python
Pytorch如何切换 cpu和gpu的使用详解
Mar 01 Python
Python pyecharts绘制条形图详解
Apr 02 Python
Python 支持向量机分类器的实现
Jan 15 #Python
pytorch-神经网络拟合曲线实例
Jan 15 #Python
Pytorch中的VGG实现修改最后一层FC
Jan 15 #Python
详解Python3 中的字符串格式化语法
Jan 15 #Python
用pytorch的nn.Module构造简单全链接层实例
Jan 14 #Python
pytorch三层全连接层实现手写字母识别方式
Jan 14 #Python
Python实现bilibili时间长度查询的示例代码
Jan 14 #Python
You might like
收藏的PHP常用函数 推荐收藏保存
2010/02/21 PHP
CodeIgniter输出中文乱码的两种解决办法
2014/06/12 PHP
浅谈ThinkPHP的URL重写
2014/11/25 PHP
CodeIgniter删除和设置Cookie的方法
2015/04/07 PHP
js类中获取外部函数名的方法
2007/08/19 Javascript
javascript 常用方法总结
2009/06/03 Javascript
JS简单的图片放大缩小的两种方法
2013/11/11 Javascript
javascript打印html内容功能的方法示例
2013/11/28 Javascript
Node.js安装教程和NPM包管理器使用详解
2014/08/16 Javascript
JavaScript中使用Math.floor()方法对数字取整
2015/06/15 Javascript
js调用百度地图及调用百度地图的搜索功能
2015/09/07 Javascript
JavaScript实现的简单烟花特效代码
2015/10/20 Javascript
利用node.js搭建简单web服务器的方法教程
2017/02/20 Javascript
Vue.js实现移动端短信验证码功能
2017/03/29 Javascript
JS排序算法之希尔排序与快速排序实现方法
2017/12/12 Javascript
基于mpvue的小程序项目搭建的步骤
2018/05/22 Javascript
Vue前后端不同端口的实现方法
2018/09/19 Javascript
vue根据值给予不同class的实例
2018/09/29 Javascript
如何用Node写页面爬虫的工具集
2018/10/26 Javascript
修改Vue打包后的默认文件名操作
2020/08/12 Javascript
vue实现标签云效果的示例
2020/11/09 Javascript
详解在Python中处理异常的教程
2015/05/24 Python
利用Python+Java调用Shell脚本时的死锁陷阱详解
2018/01/24 Python
Pandas 按索引合并数据集的方法
2018/11/15 Python
Python根据指定文件生成XML的方法
2020/06/29 Python
python Protobuf定义消息类型知识点讲解
2021/03/02 Python
纽约服装和生活方式品牌:Saturdays NYC
2017/08/13 全球购物
欧洲最大的品牌水上运动服装和设备在线零售商:Wuituit Outlet
2018/05/05 全球购物
Kipling意大利官网:世界著名的时尚休闲包袋品牌
2019/06/05 全球购物
热能动力工程毕业生自荐信
2013/11/07 职场文书
酒店总经理助理岗位职责
2014/02/01 职场文书
一位农村小子的自荐信
2014/04/07 职场文书
工伤事故赔偿协议书(标准)
2014/09/29 职场文书
2015初中团委工作总结
2015/07/28 职场文书
初中政教处工作总结
2015/08/12 职场文书
DE1103使用报告
2022/04/05 无线电