pytorch实现线性拟合方式


Posted in Python onJanuary 15, 2020

一维线性拟合

数据为y=4x+5加上噪音

结果:

pytorch实现线性拟合方式

import numpy as np
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import pyplot as plt
from torch.autograd import Variable
import torch
from torch import nn
 
X = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
Y = 4*X + 5 + torch.rand(X.size())
 
class LinearRegression(nn.Module):
 def __init__(self):
  super(LinearRegression, self).__init__()
  self.linear = nn.Linear(1, 1) # 输入和输出的维度都是1
 def forward(self, X):
  out = self.linear(X)
  return out
 
model = LinearRegression()
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
 
num_epochs = 1000
for epoch in range(num_epochs):
 inputs = Variable(X)
 target = Variable(Y)
 # 向前传播
 out = model(inputs)
 loss = criterion(out, target)
 
 # 向后传播
 optimizer.zero_grad() # 注意每次迭代都需要清零
 loss.backward()
 optimizer.step()
 
 if (epoch + 1) % 20 == 0:
  print('Epoch[{}/{}], loss:{:.6f}'.format(epoch + 1, num_epochs, loss.item()))
model.eval()
predict = model(Variable(X))
predict = predict.data.numpy()
plt.plot(X.numpy(), Y.numpy(), 'ro', label='Original Data')
plt.plot(X.numpy(), predict, label='Fitting Line')
plt.show()

多维:

from itertools import count
import torch
import torch.autograd
import torch.nn.functional as F
 
POLY_DEGREE = 3
def make_features(x):
 """Builds features i.e. a matrix with columns [x, x^2, x^3]."""
 x = x.unsqueeze(1)
 return torch.cat([x ** i for i in range(1, POLY_DEGREE+1)], 1)
 
 
W_target = torch.randn(POLY_DEGREE, 1)
b_target = torch.randn(1)
 
 
def f(x):
 return x.mm(W_target) + b_target.item()
def get_batch(batch_size=32):
 random = torch.randn(batch_size)
 x = make_features(random)
 y = f(x)
 return x, y
# Define model
fc = torch.nn.Linear(W_target.size(0), 1)
batch_x, batch_y = get_batch()
print(batch_x,batch_y)
for batch_idx in count(1):
 # Get data
 
 
 # Reset gradients
 fc.zero_grad()
 
 # Forward pass
 output = F.smooth_l1_loss(fc(batch_x), batch_y)
 loss = output.item()
 
 # Backward pass
 output.backward()
 
 # Apply gradients
 for param in fc.parameters():
  param.data.add_(-0.1 * param.grad.data)
 
 # Stop criterion
 if loss < 1e-3:
  break
 
 
def poly_desc(W, b):
 """Creates a string description of a polynomial."""
 result = 'y = '
 for i, w in enumerate(W):
  result += '{:+.2f} x^{} '.format(w, len(W) - i)
 result += '{:+.2f}'.format(b[0])
 return result
 
 
print('Loss: {:.6f} after {} batches'.format(loss, batch_idx))
print('==> Learned function:\t' + poly_desc(fc.weight.view(-1), fc.bias))
print('==> Actual function:\t' + poly_desc(W_target.view(-1), b_target))

以上这篇pytorch实现线性拟合方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
django1.8使用表单上传文件的实现方法
Nov 04 Python
Python之reload流程实例代码解析
Jan 29 Python
python构建基础的爬虫教学
Dec 23 Python
Python元组常见操作示例
Feb 19 Python
Python猴子补丁知识点总结
Jan 05 Python
Pytorch使用MNIST数据集实现基础GAN和DCGAN详解
Jan 10 Python
python如何使用代码运行助手
Jul 03 Python
Python如何给函数库增加日志功能
Aug 04 Python
Python如何解除一个装饰器
Aug 07 Python
python的launcher用法知识点总结
Aug 07 Python
浅析python 通⽤爬⾍和聚焦爬⾍
Sep 28 Python
python 利用百度API识别图片文字(多线程版)
Dec 14 Python
Python 支持向量机分类器的实现
Jan 15 #Python
pytorch-神经网络拟合曲线实例
Jan 15 #Python
Pytorch中的VGG实现修改最后一层FC
Jan 15 #Python
详解Python3 中的字符串格式化语法
Jan 15 #Python
用pytorch的nn.Module构造简单全链接层实例
Jan 14 #Python
pytorch三层全连接层实现手写字母识别方式
Jan 14 #Python
Python实现bilibili时间长度查询的示例代码
Jan 14 #Python
You might like
PHP 批量删除 sql语句
2009/06/05 PHP
php制作基于xml的RSS订阅源功能示例
2017/02/08 PHP
Laravel接收前端ajax传来的数据的实例代码
2017/07/20 PHP
Javascript在IE下设置innerHTML时出现未知的运行时错误的解决方法
2011/01/12 Javascript
firefox下input type=&quot;file&quot;的size是多大
2011/10/24 Javascript
jquery ajax属性async(同步异步)示例
2013/11/05 Javascript
javascript中expression的用法整理
2014/05/13 Javascript
Bootstrap每天必学之标签与徽章
2015/11/27 Javascript
Javascript实现的SHA-256加密算法完整实例
2016/02/02 Javascript
Node+Express+MongoDB实现登录注册功能实例
2017/04/23 Javascript
Angular 容器部署的方法
2018/04/17 Javascript
VueX模块的具体使用(小白教程)
2020/06/05 Javascript
进一步探究Python中的正则表达式
2015/04/28 Python
python获取当前日期和时间的方法
2015/04/30 Python
python获得一个月有多少天的方法
2015/06/04 Python
Python随机生成带特殊字符的密码
2016/03/02 Python
Python利用IPython提高开发效率
2016/08/10 Python
Python将list中的string批量转化成int/float的方法
2018/06/26 Python
详解python里的命名规范
2018/07/16 Python
对Python中内置异常层次结构详解
2018/10/18 Python
Python进阶之@property动态属性的实现
2019/04/01 Python
python七夕浪漫表白源码
2019/04/05 Python
pyqt5 获取显示器的分辨率的方法
2019/06/18 Python
python 实现多维数组转向量
2019/11/30 Python
详解如何使用Pytest进行自动化测试
2021/01/14 Python
实例教程 利用html5和css3打造一款创意404页面
2014/10/20 HTML / CSS
Manduka官网:瑜伽垫、瑜伽毛巾和服装
2018/07/02 全球购物
销售找工作求职信
2013/12/20 职场文书
团日活动总结范文
2014/04/25 职场文书
合唱兴趣小组活动总结
2014/07/10 职场文书
2014公司党员自我评价范文
2014/09/11 职场文书
中职毕业生自我鉴定
2014/09/13 职场文书
2015年前台文员工作总结
2015/05/18 职场文书
国情备忘录观后感
2015/06/04 职场文书
小学教育见习总结
2015/06/23 职场文书
CSS作用域(样式分割)的使用汇总
2021/11/07 HTML / CSS