使用 pytorch 创建神经网络拟合sin函数的实现


Posted in Python onFebruary 24, 2020

我们知道深度神经网络的本质是输入端数据和输出端数据的一种高维非线性拟合,如何更好的理解它,下面尝试拟合一个正弦函数,本文可以通过简单设置节点数,实现任意隐藏层数的拟合。

基于pytorch的深度神经网络实战,无论任务多么复杂,都可以将其拆分成必要的几个模块来进行理解。

1)构建数据集,包括输入,对应的标签y
2) 构建神经网络模型,一般基于nn.Module继承一个net类,必须的是__init__函数和forward函数。__init__构造函数包括创建该类是必须的参数,比如输入节点数,隐藏层节点数,输出节点数。forward函数则定义了整个网络的前向传播过程,类似于一个Sequential。
3)实例化上步创建的类。
4)定义损失函数(判别准则),比如均方误差,交叉熵等
5)定义优化器(optim:SGD,adam,adadelta等),设置学习率
6)开始训练。开始训练是一个从0到设定的epoch的循环,循环期间,根据loss,不断迭代和更新网络权重参数。

无论多么复杂的网络,基于pytorch的深度神经网络都包括6个模块,训练阶段包括5个步骤,本文只通过拟合一个正弦函数来说明加深理解。

废话少说,直接上代码:

from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import torch.nn as nn
import numpy as np
import torch

# 准备数据
x=np.linspace(-2*np.pi,2*np.pi,400)
y=np.sin(x)
# 将数据做成数据集的模样
X=np.expand_dims(x,axis=1)
Y=y.reshape(400,-1)
# 使用批训练方式
dataset=TensorDataset(torch.tensor(X,dtype=torch.float),torch.tensor(Y,dtype=torch.float))
dataloader=DataLoader(dataset,batch_size=100,shuffle=True)

# 神经网络主要结构,这里就是一个简单的线性结构

class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.net=nn.Sequential(
      nn.Linear(in_features=1,out_features=10),nn.ReLU(),
      nn.Linear(10,100),nn.ReLU(),
      nn.Linear(100,10),nn.ReLU(),
      nn.Linear(10,1)
    )

  def forward(self, input:torch.FloatTensor):
    return self.net(input)

net=Net()

# 定义优化器和损失函数
optim=torch.optim.Adam(Net.parameters(net),lr=0.001)
Loss=nn.MSELoss()

# 下面开始训练:
# 一共训练 1000次
for epoch in range(1000):
  loss=None
  for batch_x,batch_y in dataloader:
    y_predict=net(batch_x)
    loss=Loss(y_predict,batch_y)
    optim.zero_grad()
    loss.backward()
    optim.step()
  # 每100次 的时候打印一次日志
  if (epoch+1)%100==0:
    print("step: {0} , loss: {1}".format(epoch+1,loss.item()))

# 使用训练好的模型进行预测
predict=net(torch.tensor(X,dtype=torch.float))

# 绘图展示预测的和真实数据之间的差异
import matplotlib.pyplot as plt
plt.plot(x,y,label="fact")
plt.plot(x,predict.detach().numpy(),label="predict")
plt.title("sin function")
plt.xlabel("x")
plt.ylabel("sin(x)")
plt.legend()
plt.savefig(fname="result.png",figsize=[10,10])
plt.show()

输出结果:

step: 100 , loss: 0.06755948066711426
step: 200 , loss: 0.003788222325965762
step: 300 , loss: 0.0004728269996121526
step: 400 , loss: 0.0001810075482353568
step: 500 , loss: 0.0001108720971387811
step: 600 , loss: 6.29749265499413e-05
step: 700 , loss: 3.707894938997924e-05
step: 800 , loss: 0.0001250380591955036
step: 900 , loss: 3.0654005968244746e-05
step: 1000 , loss: 4.349641676526517e-05

输出图像:

使用 pytorch 创建神经网络拟合sin函数的实现

到此这篇关于使用 pytorch 创建神经网络拟合sin函数的实现的文章就介绍到这了,更多相关pytorch 创建拟合sin函数内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
仅用50行代码实现一个Python编写的计算器的教程
Apr 17 Python
Python过滤列表用法实例分析
Apr 29 Python
windows下ipython的安装与使用详解
Oct 20 Python
python下如何查询CS反恐精英的服务器信息
Jan 17 Python
python实现简单点对点(p2p)聊天
Sep 13 Python
Python回文字符串及回文数字判定功能示例
Mar 20 Python
python爬虫之自动登录与验证码识别
Jun 15 Python
python读取图片任意范围区域
Jan 23 Python
Python产生一个数值范围内的不重复的随机数的实现方法
Aug 21 Python
python数据预处理 :数据抽样解析
Feb 24 Python
详解python tcp编程
Aug 24 Python
基于Python正确读取资源文件
Sep 14 Python
sklearn+python:线性回归案例
Feb 24 #Python
深入理解Tensorflow中的masking和padding
Feb 24 #Python
K最近邻算法(KNN)---sklearn+python实现方式
Feb 24 #Python
Python3.6 + TensorFlow 安装配置图文教程(Windows 64 bit)
Feb 24 #Python
Python enumerate内置库用法解析
Feb 24 #Python
Python模块/包/库安装的六种方法及区别
Feb 24 #Python
python之MSE、MAE、RMSE的使用
Feb 24 #Python
You might like
php自定文件保存session的方法
2014/12/10 PHP
php经典算法集锦
2015/11/14 PHP
PHP实现的限制IP投票程序IP来源分析
2016/05/04 PHP
PHP ajax+jQuery 实现批量删除功能实例代码小结
2018/12/06 PHP
javascript算法学习(直接插入排序)
2011/04/12 Javascript
javascript倒计时功能实现代码
2012/06/07 Javascript
formvalidator验证插件中有关ajax验证问题
2013/01/04 Javascript
javascript 系统文件夹文件操作及参数介绍
2013/01/08 Javascript
ie下$.getJSON出现问题的解决方法
2014/02/12 Javascript
jQuery实现的Div窗口震动特效
2014/06/09 Javascript
Jquery通过JSON字符串创建JSON对象
2014/08/24 Javascript
微信小程序 switch组件详解及简单实例
2017/01/10 Javascript
JavaScript实现的原生态Tab标签页功能【兼容IE6】
2017/09/18 Javascript
用p5.js制作烟花特效的示例代码
2018/03/21 Javascript
基于vue cli 通过命令行传参实现多环境配置
2018/07/12 Javascript
Jquery遍历筛选数组的几种方法和遍历解析json对象,Map()方法详解以及数组中查询某值是否存在
2019/01/18 jQuery
bootstrap-treeview实现多级树形菜单 后台JSON格式如何组织?
2019/07/26 Javascript
javascript canvas时钟模拟器
2020/07/13 Javascript
[05:11]TI9战队采访——VIRTUSPRO
2019/08/22 DOTA
[46:44]DOTA2-DPC中国联赛 正赛 Ehome vs PSG.LGD BO3 第二场 3月7日
2021/03/11 DOTA
跟老齐学Python之赋值,简单也不简单
2014/09/24 Python
python实现简单socket程序在两台电脑之间传输消息的方法
2015/03/13 Python
Python中文字符串截取问题
2015/06/15 Python
基于随机梯度下降的矩阵分解推荐算法(python)
2018/08/31 Python
PyQT5 QTableView显示绑定数据的实例详解
2019/06/25 Python
CSS3属性box-shadow使用详细教程
2012/01/21 HTML / CSS
什么是smarty? Smarty的优点是什么?
2013/08/11 面试题
法学个人求职信范文
2014/01/27 职场文书
2013年军训通讯稿
2014/02/05 职场文书
应届生求职信范文
2014/05/26 职场文书
会计专业自荐书
2014/07/08 职场文书
干部作风整顿自我剖析材料和整改措施
2014/09/18 职场文书
开国大典观后感
2015/06/04 职场文书
工作收入证明范本
2015/06/12 职场文书
2015秋季运动会通讯稿
2015/07/18 职场文书
2019大学生预备党员转正思想汇报
2019/06/21 职场文书