pytorch实现线性拟合方式


Posted in Python onJanuary 15, 2020

一维线性拟合

数据为y=4x+5加上噪音

结果:

pytorch实现线性拟合方式

import numpy as np
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import pyplot as plt
from torch.autograd import Variable
import torch
from torch import nn
 
X = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
Y = 4*X + 5 + torch.rand(X.size())
 
class LinearRegression(nn.Module):
 def __init__(self):
  super(LinearRegression, self).__init__()
  self.linear = nn.Linear(1, 1) # 输入和输出的维度都是1
 def forward(self, X):
  out = self.linear(X)
  return out
 
model = LinearRegression()
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
 
num_epochs = 1000
for epoch in range(num_epochs):
 inputs = Variable(X)
 target = Variable(Y)
 # 向前传播
 out = model(inputs)
 loss = criterion(out, target)
 
 # 向后传播
 optimizer.zero_grad() # 注意每次迭代都需要清零
 loss.backward()
 optimizer.step()
 
 if (epoch + 1) % 20 == 0:
  print('Epoch[{}/{}], loss:{:.6f}'.format(epoch + 1, num_epochs, loss.item()))
model.eval()
predict = model(Variable(X))
predict = predict.data.numpy()
plt.plot(X.numpy(), Y.numpy(), 'ro', label='Original Data')
plt.plot(X.numpy(), predict, label='Fitting Line')
plt.show()

多维:

from itertools import count
import torch
import torch.autograd
import torch.nn.functional as F
 
POLY_DEGREE = 3
def make_features(x):
 """Builds features i.e. a matrix with columns [x, x^2, x^3]."""
 x = x.unsqueeze(1)
 return torch.cat([x ** i for i in range(1, POLY_DEGREE+1)], 1)
 
 
W_target = torch.randn(POLY_DEGREE, 1)
b_target = torch.randn(1)
 
 
def f(x):
 return x.mm(W_target) + b_target.item()
def get_batch(batch_size=32):
 random = torch.randn(batch_size)
 x = make_features(random)
 y = f(x)
 return x, y
# Define model
fc = torch.nn.Linear(W_target.size(0), 1)
batch_x, batch_y = get_batch()
print(batch_x,batch_y)
for batch_idx in count(1):
 # Get data
 
 
 # Reset gradients
 fc.zero_grad()
 
 # Forward pass
 output = F.smooth_l1_loss(fc(batch_x), batch_y)
 loss = output.item()
 
 # Backward pass
 output.backward()
 
 # Apply gradients
 for param in fc.parameters():
  param.data.add_(-0.1 * param.grad.data)
 
 # Stop criterion
 if loss < 1e-3:
  break
 
 
def poly_desc(W, b):
 """Creates a string description of a polynomial."""
 result = 'y = '
 for i, w in enumerate(W):
  result += '{:+.2f} x^{} '.format(w, len(W) - i)
 result += '{:+.2f}'.format(b[0])
 return result
 
 
print('Loss: {:.6f} after {} batches'.format(loss, batch_idx))
print('==> Learned function:\t' + poly_desc(fc.weight.view(-1), fc.bias))
print('==> Actual function:\t' + poly_desc(W_target.view(-1), b_target))

以上这篇pytorch实现线性拟合方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python多线程http下载实现示例
Dec 30 Python
python计数排序和基数排序算法实例
Apr 25 Python
Python3中多线程编程的队列运作示例
Apr 16 Python
pymongo实现多结果进行多列排序的方法
May 16 Python
Django基于ORM操作数据库的方法详解
Mar 27 Python
python 实现一次性在文件中写入多行的方法
Jan 28 Python
基于Python实现用户管理系统
Feb 26 Python
python3实现钉钉消息推送的方法示例
Mar 14 Python
Python实现的统计文章单词次数功能示例
Jul 08 Python
Python split() 函数拆分字符串将字符串转化为列的方法
Jul 16 Python
Django生成PDF文档显示网页上以及PDF中文显示乱码的解决方法
Dec 17 Python
Python实现序列化及csv文件读取
Jan 19 Python
Python 支持向量机分类器的实现
Jan 15 #Python
pytorch-神经网络拟合曲线实例
Jan 15 #Python
Pytorch中的VGG实现修改最后一层FC
Jan 15 #Python
详解Python3 中的字符串格式化语法
Jan 15 #Python
用pytorch的nn.Module构造简单全链接层实例
Jan 14 #Python
pytorch三层全连接层实现手写字母识别方式
Jan 14 #Python
Python实现bilibili时间长度查询的示例代码
Jan 14 #Python
You might like
php 强制下载文件实现代码
2013/10/28 PHP
PHP动态编译出现Cannot find autoconf的解决方法
2014/11/05 PHP
php递归遍历删除文件的方法
2015/04/17 PHP
Zend Framework教程之MVC框架的Controller用法分析
2016/03/07 PHP
PHP之十六个魔术方法详细介绍
2016/11/01 PHP
javascript 必知必会之closure
2009/09/21 Javascript
JS 用6N±1法求素数 实例教程
2009/10/20 Javascript
js 点击按钮弹出另一页,选择值后,返回到当前页
2010/05/26 Javascript
jquery动态调整div大小使其宽度始终为浏览器宽度
2014/06/06 Javascript
freemarker判断对象是否为空的方法
2015/08/13 Javascript
JavaScript代码轻松实现网页内容禁止复制(代码简单)
2015/10/23 Javascript
js实现自动轮换选项卡
2017/01/13 Javascript
bootstrap中添加额外的图标实例代码
2017/02/15 Javascript
移动端网页开发调试神器Eruda的介绍与使用技巧
2017/10/30 Javascript
JS表单传值和URL编码转换
2018/03/03 Javascript
jQuery仿移动端支付宝键盘的实现代码
2018/08/15 jQuery
解决包含在label标签下的checkbox在ie8及以下版本点击事件无效果兼容的问题
2019/10/27 Javascript
js构造函数constructor和原型prototype原理与用法实例分析
2020/03/02 Javascript
[02:40]2018年度DOTA2最佳新人-完美盛典
2018/12/16 DOTA
python实现在字符串中查找子字符串的方法
2015/07/11 Python
Python编程中NotImplementedError的使用方法
2018/04/21 Python
对python制作自己的数据集实例讲解
2018/12/12 Python
10个Python面试常问的问题(小结)
2019/11/20 Python
教你使用Canvas处理图片的方法
2017/11/28 HTML / CSS
简洁自适应404页面HTML好看的404源码
2020/12/16 HTML / CSS
美国最大的香水连锁店官网:Perfumania
2016/08/15 全球购物
Vertbaudet西班牙网上商店:婴儿服装、童装、母婴用品和儿童家具
2019/10/16 全球购物
个人能力自我鉴赏
2014/01/25 职场文书
大学生党性分析材料
2014/12/19 职场文书
2016秋季幼儿园开学寄语
2015/12/03 职场文书
《浅水洼里的小鱼》教学反思
2016/02/16 职场文书
建房合同协议书
2016/03/21 职场文书
如何在CSS中绘制曲线图形及展示动画
2021/05/24 HTML / CSS
golang定时器
2022/04/14 Golang
python如何将mat文件转为png
2022/07/15 Python
鸿蒙3.0体验感怎么样? 鸿蒙3.0系统评测向
2022/08/14 数码科技