pytorch实现线性拟合方式


Posted in Python onJanuary 15, 2020

一维线性拟合

数据为y=4x+5加上噪音

结果:

pytorch实现线性拟合方式

import numpy as np
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import pyplot as plt
from torch.autograd import Variable
import torch
from torch import nn
 
X = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
Y = 4*X + 5 + torch.rand(X.size())
 
class LinearRegression(nn.Module):
 def __init__(self):
  super(LinearRegression, self).__init__()
  self.linear = nn.Linear(1, 1) # 输入和输出的维度都是1
 def forward(self, X):
  out = self.linear(X)
  return out
 
model = LinearRegression()
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
 
num_epochs = 1000
for epoch in range(num_epochs):
 inputs = Variable(X)
 target = Variable(Y)
 # 向前传播
 out = model(inputs)
 loss = criterion(out, target)
 
 # 向后传播
 optimizer.zero_grad() # 注意每次迭代都需要清零
 loss.backward()
 optimizer.step()
 
 if (epoch + 1) % 20 == 0:
  print('Epoch[{}/{}], loss:{:.6f}'.format(epoch + 1, num_epochs, loss.item()))
model.eval()
predict = model(Variable(X))
predict = predict.data.numpy()
plt.plot(X.numpy(), Y.numpy(), 'ro', label='Original Data')
plt.plot(X.numpy(), predict, label='Fitting Line')
plt.show()

多维:

from itertools import count
import torch
import torch.autograd
import torch.nn.functional as F
 
POLY_DEGREE = 3
def make_features(x):
 """Builds features i.e. a matrix with columns [x, x^2, x^3]."""
 x = x.unsqueeze(1)
 return torch.cat([x ** i for i in range(1, POLY_DEGREE+1)], 1)
 
 
W_target = torch.randn(POLY_DEGREE, 1)
b_target = torch.randn(1)
 
 
def f(x):
 return x.mm(W_target) + b_target.item()
def get_batch(batch_size=32):
 random = torch.randn(batch_size)
 x = make_features(random)
 y = f(x)
 return x, y
# Define model
fc = torch.nn.Linear(W_target.size(0), 1)
batch_x, batch_y = get_batch()
print(batch_x,batch_y)
for batch_idx in count(1):
 # Get data
 
 
 # Reset gradients
 fc.zero_grad()
 
 # Forward pass
 output = F.smooth_l1_loss(fc(batch_x), batch_y)
 loss = output.item()
 
 # Backward pass
 output.backward()
 
 # Apply gradients
 for param in fc.parameters():
  param.data.add_(-0.1 * param.grad.data)
 
 # Stop criterion
 if loss < 1e-3:
  break
 
 
def poly_desc(W, b):
 """Creates a string description of a polynomial."""
 result = 'y = '
 for i, w in enumerate(W):
  result += '{:+.2f} x^{} '.format(w, len(W) - i)
 result += '{:+.2f}'.format(b[0])
 return result
 
 
print('Loss: {:.6f} after {} batches'.format(loss, batch_idx))
print('==> Learned function:\t' + poly_desc(fc.weight.view(-1), fc.bias))
print('==> Actual function:\t' + poly_desc(W_target.view(-1), b_target))

以上这篇pytorch实现线性拟合方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python中处理字符串之isdecimal()方法的使用
May 20 Python
用Python的Flask框架结合MySQL写一个内存监控程序
Nov 07 Python
Python 类与元类的深度挖掘 II【经验】
May 06 Python
Python面向对象class类属性及子类用法分析
Feb 02 Python
Python元组拆包和具名元组解析实例详解
Mar 26 Python
python事件驱动event实现详解
Nov 21 Python
numpy.random模块用法总结
May 27 Python
django基于存储在前端的token用户认证解析
Aug 06 Python
pymysql的简单封装代码实例
Jan 08 Python
tensorflow 获取checkpoint中的变量列表实例
Feb 11 Python
Python unittest工作原理和使用过程解析
Feb 24 Python
教你如何用python操作摄像头以及对视频流的处理
Oct 12 Python
Python 支持向量机分类器的实现
Jan 15 #Python
pytorch-神经网络拟合曲线实例
Jan 15 #Python
Pytorch中的VGG实现修改最后一层FC
Jan 15 #Python
详解Python3 中的字符串格式化语法
Jan 15 #Python
用pytorch的nn.Module构造简单全链接层实例
Jan 14 #Python
pytorch三层全连接层实现手写字母识别方式
Jan 14 #Python
Python实现bilibili时间长度查询的示例代码
Jan 14 #Python
You might like
使用gd库实现php服务端图片裁剪和生成缩略图功能分享
2013/12/25 PHP
js函数般调用正则
2008/04/08 Javascript
基于JQuery 的消息提示框效果代码
2011/07/31 Javascript
jquery多选项卡效果实例代码(附效果图)
2013/03/23 Javascript
js获取url参数值的两种方式
2013/09/10 Javascript
12种不宜使用的Javascript语法整理
2013/11/04 Javascript
鼠标经过tr时,改变tr当前背景颜色
2014/01/13 Javascript
一个CSS+jQuery实现的放大缩小动画效果
2014/02/19 Javascript
Node.js中的模块机制学习笔记
2014/11/04 Javascript
Node.js巧妙实现Web应用代码热更新
2015/10/22 Javascript
获取JS中网页各种高宽与位置的方法总结
2016/07/27 Javascript
浅谈JavaScript 函数参数传递到底是值传递还是引用传递
2016/08/23 Javascript
微信小程序 登录的简单实现
2017/04/19 Javascript
JS兼容所有浏览器的DOMContentLoaded事件
2018/01/12 Javascript
微信web端后退强制刷新功能的实现代码
2018/03/04 Javascript
页面点击小红心js实现代码
2018/05/26 Javascript
vue2.0实现音乐/视频播放进度条组件
2018/06/06 Javascript
基于mpvue小程序使用echarts画折线图的方法示例
2019/04/24 Javascript
Vue-cli3简单使用(图文步骤)
2019/04/30 Javascript
vue实现多个echarts根据屏幕大小变化而变化实例
2020/07/19 Javascript
微信小程序onShareTimeline()实现分享朋友圈
2021/01/07 Javascript
Python基于whois模块简单识别网站域名及所有者的方法
2018/04/23 Python
python代码实现逻辑回归logistic原理
2019/08/07 Python
Python使用gluon/mxnet模块实现的mnist手写数字识别功能完整示例
2019/12/18 Python
利用python控制Autocad:pyautocad方式
2020/06/01 Python
HTML5 embed 标签使用方法介绍
2013/08/13 HTML / CSS
瑞典最好的运动鞋专卖店:Sneakersnstuff
2016/08/29 全球购物
静态变量和实例变量的区别
2015/07/07 面试题
美术教师自我鉴定
2014/02/12 职场文书
销售人员求职信
2014/07/22 职场文书
社保委托书怎么写
2014/08/02 职场文书
2014年社区民政工作总结
2014/12/02 职场文书
实习推荐信格式模板
2015/03/27 职场文书
爱国电影观后感
2015/06/19 职场文书
2015秋季开学典礼演讲稿
2015/07/16 职场文书
交通安全学习心得体会
2016/01/18 职场文书