pytorch实现线性拟合方式


Posted in Python onJanuary 15, 2020

一维线性拟合

数据为y=4x+5加上噪音

结果:

pytorch实现线性拟合方式

import numpy as np
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import pyplot as plt
from torch.autograd import Variable
import torch
from torch import nn
 
X = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
Y = 4*X + 5 + torch.rand(X.size())
 
class LinearRegression(nn.Module):
 def __init__(self):
  super(LinearRegression, self).__init__()
  self.linear = nn.Linear(1, 1) # 输入和输出的维度都是1
 def forward(self, X):
  out = self.linear(X)
  return out
 
model = LinearRegression()
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
 
num_epochs = 1000
for epoch in range(num_epochs):
 inputs = Variable(X)
 target = Variable(Y)
 # 向前传播
 out = model(inputs)
 loss = criterion(out, target)
 
 # 向后传播
 optimizer.zero_grad() # 注意每次迭代都需要清零
 loss.backward()
 optimizer.step()
 
 if (epoch + 1) % 20 == 0:
  print('Epoch[{}/{}], loss:{:.6f}'.format(epoch + 1, num_epochs, loss.item()))
model.eval()
predict = model(Variable(X))
predict = predict.data.numpy()
plt.plot(X.numpy(), Y.numpy(), 'ro', label='Original Data')
plt.plot(X.numpy(), predict, label='Fitting Line')
plt.show()

多维:

from itertools import count
import torch
import torch.autograd
import torch.nn.functional as F
 
POLY_DEGREE = 3
def make_features(x):
 """Builds features i.e. a matrix with columns [x, x^2, x^3]."""
 x = x.unsqueeze(1)
 return torch.cat([x ** i for i in range(1, POLY_DEGREE+1)], 1)
 
 
W_target = torch.randn(POLY_DEGREE, 1)
b_target = torch.randn(1)
 
 
def f(x):
 return x.mm(W_target) + b_target.item()
def get_batch(batch_size=32):
 random = torch.randn(batch_size)
 x = make_features(random)
 y = f(x)
 return x, y
# Define model
fc = torch.nn.Linear(W_target.size(0), 1)
batch_x, batch_y = get_batch()
print(batch_x,batch_y)
for batch_idx in count(1):
 # Get data
 
 
 # Reset gradients
 fc.zero_grad()
 
 # Forward pass
 output = F.smooth_l1_loss(fc(batch_x), batch_y)
 loss = output.item()
 
 # Backward pass
 output.backward()
 
 # Apply gradients
 for param in fc.parameters():
  param.data.add_(-0.1 * param.grad.data)
 
 # Stop criterion
 if loss < 1e-3:
  break
 
 
def poly_desc(W, b):
 """Creates a string description of a polynomial."""
 result = 'y = '
 for i, w in enumerate(W):
  result += '{:+.2f} x^{} '.format(w, len(W) - i)
 result += '{:+.2f}'.format(b[0])
 return result
 
 
print('Loss: {:.6f} after {} batches'.format(loss, batch_idx))
print('==> Learned function:\t' + poly_desc(fc.weight.view(-1), fc.bias))
print('==> Actual function:\t' + poly_desc(W_target.view(-1), b_target))

以上这篇pytorch实现线性拟合方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python脚本实现xls(xlsx)转成csv
Apr 10 Python
Python的Django中将文件上传至七牛云存储的代码分享
Jun 03 Python
Windows中使用wxPython和py2exe开发Python的GUI程序的实例教程
Jul 11 Python
python 类对象和实例对象动态添加方法(分享)
Dec 31 Python
Python实现的径向基(RBF)神经网络示例
Feb 06 Python
详解如何为eclipse安装合适版本的python插件pydev
Nov 04 Python
python根据url地址下载小文件的实例
Dec 18 Python
如何使用Python自动控制windows桌面
Jul 11 Python
Django 使用easy_thumbnails压缩上传的图片方法
Jul 26 Python
python通过SSH登陆linux并操作的实现
Oct 10 Python
Python求凸包及多边形面积教程
Apr 12 Python
使用python绘制分组对比柱状图
Apr 21 Python
Python 支持向量机分类器的实现
Jan 15 #Python
pytorch-神经网络拟合曲线实例
Jan 15 #Python
Pytorch中的VGG实现修改最后一层FC
Jan 15 #Python
详解Python3 中的字符串格式化语法
Jan 15 #Python
用pytorch的nn.Module构造简单全链接层实例
Jan 14 #Python
pytorch三层全连接层实现手写字母识别方式
Jan 14 #Python
Python实现bilibili时间长度查询的示例代码
Jan 14 #Python
You might like
php下HTTP Response中的Chunked编码实现方法
2008/11/19 PHP
PHP 解决session死锁的方法
2013/06/20 PHP
PHP列出MySQL中所有数据库的方法
2015/03/12 PHP
PHP调试的强悍利器之PHPDBG
2016/02/22 PHP
PHP实现的简单在线计算器功能示例
2017/08/02 PHP
在模板页面的js使用办法
2010/04/01 Javascript
JQuery浮动DIV提示信息并自动隐藏的代码
2010/08/29 Javascript
javascript 二进制运算技巧解析
2012/11/27 Javascript
javascript根据时间生成m位随机数最大13位
2014/10/30 Javascript
Angularjs中UI Router的使用方法
2016/05/14 Javascript
BootStrap的alert提示框的关闭后再显示怎么解决
2016/05/17 Javascript
javascript汉字拼音互转的简单实例
2016/10/09 Javascript
微信小程序  网络请求API详解
2016/10/25 Javascript
nodejs个人博客开发第四步 数据模型
2017/04/12 NodeJs
vue中倒计时组件的实例代码
2018/07/06 Javascript
jQuery判断自定义属性data-val用法示例
2019/01/07 jQuery
最简单的vue消息提示全局组件的方法
2019/06/16 Javascript
vue移动端使用canvas签名的实现
2020/01/15 Javascript
[01:02:09]Liquid vs TNC 2019国际邀请赛淘汰赛 胜者组 BO3 第二场 8.21
2020/07/19 DOTA
一则python3的简单爬虫代码
2014/05/26 Python
Python实现备份文件实例
2014/09/16 Python
Python字典操作简明总结
2015/04/13 Python
Python设置在shell脚本中自动补全功能的方法
2018/06/25 Python
Python简单读写Xls格式文档的方法示例
2018/08/17 Python
Python正则匹配判断手机号是否合法的方法
2020/12/09 Python
使用python批量化音乐文件格式转换的实例
2019/01/09 Python
原生python实现knn分类算法
2019/10/24 Python
windows环境中利用celery实现简单任务队列过程解析
2019/11/29 Python
python 实现两个npy档案合并
2020/07/01 Python
使用python对excel表格处理的一些小功能
2021/01/25 Python
String这个类型的class为何定义成final?
2012/11/13 面试题
大学校园生活自我鉴定
2014/01/13 职场文书
军训自我鉴定100字
2014/02/13 职场文书
清明节网上祭英烈活动总结
2014/04/30 职场文书
战马观后感
2015/06/08 职场文书
学校中层领导培训心得体会
2016/01/11 职场文书