pytorch实现线性拟合方式


Posted in Python onJanuary 15, 2020

一维线性拟合

数据为y=4x+5加上噪音

结果:

pytorch实现线性拟合方式

import numpy as np
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import pyplot as plt
from torch.autograd import Variable
import torch
from torch import nn
 
X = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
Y = 4*X + 5 + torch.rand(X.size())
 
class LinearRegression(nn.Module):
 def __init__(self):
  super(LinearRegression, self).__init__()
  self.linear = nn.Linear(1, 1) # 输入和输出的维度都是1
 def forward(self, X):
  out = self.linear(X)
  return out
 
model = LinearRegression()
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
 
num_epochs = 1000
for epoch in range(num_epochs):
 inputs = Variable(X)
 target = Variable(Y)
 # 向前传播
 out = model(inputs)
 loss = criterion(out, target)
 
 # 向后传播
 optimizer.zero_grad() # 注意每次迭代都需要清零
 loss.backward()
 optimizer.step()
 
 if (epoch + 1) % 20 == 0:
  print('Epoch[{}/{}], loss:{:.6f}'.format(epoch + 1, num_epochs, loss.item()))
model.eval()
predict = model(Variable(X))
predict = predict.data.numpy()
plt.plot(X.numpy(), Y.numpy(), 'ro', label='Original Data')
plt.plot(X.numpy(), predict, label='Fitting Line')
plt.show()

多维:

from itertools import count
import torch
import torch.autograd
import torch.nn.functional as F
 
POLY_DEGREE = 3
def make_features(x):
 """Builds features i.e. a matrix with columns [x, x^2, x^3]."""
 x = x.unsqueeze(1)
 return torch.cat([x ** i for i in range(1, POLY_DEGREE+1)], 1)
 
 
W_target = torch.randn(POLY_DEGREE, 1)
b_target = torch.randn(1)
 
 
def f(x):
 return x.mm(W_target) + b_target.item()
def get_batch(batch_size=32):
 random = torch.randn(batch_size)
 x = make_features(random)
 y = f(x)
 return x, y
# Define model
fc = torch.nn.Linear(W_target.size(0), 1)
batch_x, batch_y = get_batch()
print(batch_x,batch_y)
for batch_idx in count(1):
 # Get data
 
 
 # Reset gradients
 fc.zero_grad()
 
 # Forward pass
 output = F.smooth_l1_loss(fc(batch_x), batch_y)
 loss = output.item()
 
 # Backward pass
 output.backward()
 
 # Apply gradients
 for param in fc.parameters():
  param.data.add_(-0.1 * param.grad.data)
 
 # Stop criterion
 if loss < 1e-3:
  break
 
 
def poly_desc(W, b):
 """Creates a string description of a polynomial."""
 result = 'y = '
 for i, w in enumerate(W):
  result += '{:+.2f} x^{} '.format(w, len(W) - i)
 result += '{:+.2f}'.format(b[0])
 return result
 
 
print('Loss: {:.6f} after {} batches'.format(loss, batch_idx))
print('==> Learned function:\t' + poly_desc(fc.weight.view(-1), fc.bias))
print('==> Actual function:\t' + poly_desc(W_target.view(-1), b_target))

以上这篇pytorch实现线性拟合方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 运算符 供重载参考
Jun 11 Python
Python循环语句中else的用法总结
Sep 11 Python
Python中装饰器兼容加括号和不加括号的写法详解
Jul 05 Python
python中sys.argv函数精简概括
Jul 08 Python
python代码过长的换行方法
Jul 19 Python
python对视频画框标记后保存的方法
Dec 07 Python
Python通过4种方式实现进程数据通信
Mar 12 Python
浅谈Django中的QueryDict元素为数组的坑
Mar 31 Python
Python selenium 加载并保存QQ群成员,去除其群主、管理员信息的示例代码
May 28 Python
Python如何避免文件同名产生覆盖
Jun 09 Python
解决Keras使用GPU资源耗尽的问题
Jun 22 Python
Python 实现微信自动回复的方法
Sep 11 Python
Python 支持向量机分类器的实现
Jan 15 #Python
pytorch-神经网络拟合曲线实例
Jan 15 #Python
Pytorch中的VGG实现修改最后一层FC
Jan 15 #Python
详解Python3 中的字符串格式化语法
Jan 15 #Python
用pytorch的nn.Module构造简单全链接层实例
Jan 14 #Python
pytorch三层全连接层实现手写字母识别方式
Jan 14 #Python
Python实现bilibili时间长度查询的示例代码
Jan 14 #Python
You might like
PHP函数utf8转gb2312编码
2006/12/21 PHP
PHP定时执行计划任务的多种方法小结
2011/12/19 PHP
Windows下部署Apache+PHP+MySQL运行环境实战
2012/08/31 PHP
利用php递归实现无限分类 格式化数组的详解
2013/06/08 PHP
php广告加载类用法实例
2014/09/23 PHP
Linux(CentOS)下PHP扩展PDO编译安装的方法
2016/04/07 PHP
在页面上点击任一链接时触发一个事件的代码
2007/04/07 Javascript
jQuery 源码分析笔记(2) 变量列表
2011/05/28 Javascript
Jsonp 跨域的原理以及Jquery的解决方案
2011/06/27 Javascript
JsDom 编程小结
2011/08/09 Javascript
jquery.Jwin.js 基于jquery的弹出层插件代码
2012/05/23 Javascript
ExtJS实现文件下载的方法实例
2013/11/09 Javascript
Extjs4中的分页应用结合前后台
2013/12/13 Javascript
JS动态添加与删除select中的Option对象(示例代码)
2013/12/20 Javascript
css如何让浮动元素水平居中
2015/08/07 Javascript
JSON遍历方式实例总结
2015/12/07 Javascript
纯JavaScript实现实时反馈系统时间
2017/10/26 Javascript
vue实现双向绑定和依赖收集遇到的坑
2018/11/29 Javascript
[02:44]2014DOTA2 国际邀请赛中国区预选赛 大神红毯秀
2014/05/25 DOTA
python中的对象拷贝示例 python引用传递
2014/01/23 Python
python3.6 实现AES加密的示例(pyCryptodome)
2018/01/10 Python
Django框架使用mysql视图操作示例
2019/05/15 Python
python3在同一行内输入n个数并用列表保存的例子
2019/07/20 Python
Python实现中值滤波去噪方式
2019/12/18 Python
Python中logging日志记录到文件及自动分割的操作代码
2020/08/05 Python
python进行OpenCV实战之画图(直线、矩形、圆形)
2020/08/27 Python
Python pip install之SSL异常处理操作
2020/09/03 Python
Sunglasses Shop荷兰站:英国最大的太阳镜独立在线零售商和供应商
2017/01/08 全球购物
服装创业计划书范文
2014/02/05 职场文书
机工车间主任岗位职责
2014/03/05 职场文书
颁奖晚会主持词
2014/03/25 职场文书
《郑和远航》教学反思
2014/04/16 职场文书
物价局领导班子四风问题整改措施
2014/10/26 职场文书
2014年骨干教师工作总结
2014/12/19 职场文书
美丽的大脚观后感
2015/06/03 职场文书
ubuntu端向日葵键盘输入卡顿问题及解决
2022/12/24 Servers