pytorch实现线性拟合方式


Posted in Python onJanuary 15, 2020

一维线性拟合

数据为y=4x+5加上噪音

结果:

pytorch实现线性拟合方式

import numpy as np
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import pyplot as plt
from torch.autograd import Variable
import torch
from torch import nn
 
X = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
Y = 4*X + 5 + torch.rand(X.size())
 
class LinearRegression(nn.Module):
 def __init__(self):
  super(LinearRegression, self).__init__()
  self.linear = nn.Linear(1, 1) # 输入和输出的维度都是1
 def forward(self, X):
  out = self.linear(X)
  return out
 
model = LinearRegression()
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
 
num_epochs = 1000
for epoch in range(num_epochs):
 inputs = Variable(X)
 target = Variable(Y)
 # 向前传播
 out = model(inputs)
 loss = criterion(out, target)
 
 # 向后传播
 optimizer.zero_grad() # 注意每次迭代都需要清零
 loss.backward()
 optimizer.step()
 
 if (epoch + 1) % 20 == 0:
  print('Epoch[{}/{}], loss:{:.6f}'.format(epoch + 1, num_epochs, loss.item()))
model.eval()
predict = model(Variable(X))
predict = predict.data.numpy()
plt.plot(X.numpy(), Y.numpy(), 'ro', label='Original Data')
plt.plot(X.numpy(), predict, label='Fitting Line')
plt.show()

多维:

from itertools import count
import torch
import torch.autograd
import torch.nn.functional as F
 
POLY_DEGREE = 3
def make_features(x):
 """Builds features i.e. a matrix with columns [x, x^2, x^3]."""
 x = x.unsqueeze(1)
 return torch.cat([x ** i for i in range(1, POLY_DEGREE+1)], 1)
 
 
W_target = torch.randn(POLY_DEGREE, 1)
b_target = torch.randn(1)
 
 
def f(x):
 return x.mm(W_target) + b_target.item()
def get_batch(batch_size=32):
 random = torch.randn(batch_size)
 x = make_features(random)
 y = f(x)
 return x, y
# Define model
fc = torch.nn.Linear(W_target.size(0), 1)
batch_x, batch_y = get_batch()
print(batch_x,batch_y)
for batch_idx in count(1):
 # Get data
 
 
 # Reset gradients
 fc.zero_grad()
 
 # Forward pass
 output = F.smooth_l1_loss(fc(batch_x), batch_y)
 loss = output.item()
 
 # Backward pass
 output.backward()
 
 # Apply gradients
 for param in fc.parameters():
  param.data.add_(-0.1 * param.grad.data)
 
 # Stop criterion
 if loss < 1e-3:
  break
 
 
def poly_desc(W, b):
 """Creates a string description of a polynomial."""
 result = 'y = '
 for i, w in enumerate(W):
  result += '{:+.2f} x^{} '.format(w, len(W) - i)
 result += '{:+.2f}'.format(b[0])
 return result
 
 
print('Loss: {:.6f} after {} batches'.format(loss, batch_idx))
print('==> Learned function:\t' + poly_desc(fc.weight.view(-1), fc.bias))
print('==> Actual function:\t' + poly_desc(W_target.view(-1), b_target))

以上这篇pytorch实现线性拟合方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python3使用urllib模块制作网络爬虫
Apr 08 Python
Python常用算法学习基础教程
Apr 13 Python
python实现聚类算法原理
Feb 12 Python
Python对List中的元素排序的方法
Apr 01 Python
Python基本数据结构与用法详解【列表、元组、集合、字典】
Mar 23 Python
Python tkinter和exe打包的方法
Feb 05 Python
python爬虫学习笔记之pyquery模块基本用法详解
Apr 09 Python
python3访问字典里的值实例方法
Nov 18 Python
python try...finally...的实现方法
Nov 25 Python
Python调用高德API实现批量地址转经纬度并写入表格的功能
Jan 12 Python
python办公自动化之excel的操作
May 23 Python
python 进阶学习之python装饰器小结
Sep 04 Python
Python 支持向量机分类器的实现
Jan 15 #Python
pytorch-神经网络拟合曲线实例
Jan 15 #Python
Pytorch中的VGG实现修改最后一层FC
Jan 15 #Python
详解Python3 中的字符串格式化语法
Jan 15 #Python
用pytorch的nn.Module构造简单全链接层实例
Jan 14 #Python
pytorch三层全连接层实现手写字母识别方式
Jan 14 #Python
Python实现bilibili时间长度查询的示例代码
Jan 14 #Python
You might like
php完全过滤HTML,JS,CSS等标签
2009/01/16 PHP
调整优化您的LAMP应用程序的5种简单方法
2011/06/26 PHP
php后门URL的防范
2013/11/12 PHP
php中in_array函数用法探究
2014/11/25 PHP
jQuery ui 1.7更新小结
2009/08/15 Javascript
锋利的jQuery 要点归纳(一) jQuery选择器
2010/03/21 Javascript
加载 Javascript 最佳实践
2011/10/30 Javascript
利用webqq协议使用python登录qq发消息源码参考
2013/04/08 Javascript
PHP 数组current和next用法分享
2015/03/05 Javascript
jQuery选择器源码解读(一):Sizzle方法
2015/03/31 Javascript
js以及jquery实现手风琴效果
2020/04/17 Javascript
js操作二进制数据方法
2018/03/03 Javascript
详解ES6 Symbol 的用途
2018/10/14 Javascript
如何使用 vue + d3 画一棵树
2018/12/03 Javascript
Angular value与ngValue区别详解
2019/11/27 Javascript
Javascript ParentNode和ChildNode接口原理解析
2020/03/16 Javascript
Python实现的一个自动售饮料程序代码分享
2014/08/25 Python
详解Python的Django框架中的Cookie相关处理
2015/07/22 Python
详解python的数字类型变量与其方法
2016/11/20 Python
numpy中的高维数组转置实例
2018/04/17 Python
python-序列解包(对可迭代元素的快速取值方法)
2019/08/24 Python
tensorflow -gpu安装方法(不用自己装cuda,cdnn)
2020/01/20 Python
美国婴儿服装购物网站:Gerber Childrenswear
2020/05/06 全球购物
药学专业大专生的自我评价
2013/12/12 职场文书
大学生创业计划书的格式要求
2013/12/29 职场文书
员工拾金不昧表扬信
2014/01/09 职场文书
大学班级计划书
2014/04/29 职场文书
中班幼儿评语大全
2014/04/30 职场文书
公司收款委托书范本
2014/09/20 职场文书
新闻稿件写作范文
2015/07/18 职场文书
大队委员竞选演讲稿
2015/11/20 职场文书
2016元旦主持人开场白
2015/12/03 职场文书
那些美到让人窒息的诗句,值得你收藏!
2019/08/20 职场文书
python实现腾讯滑块验证码识别
2021/04/27 Python
在前女友婚礼上,用Python破解了现场的WIFI还把名称改成了
2021/05/28 Python
Python包argparse模块常用方法
2021/06/04 Python