使用 pytorch 创建神经网络拟合sin函数的实现


Posted in Python onFebruary 24, 2020

我们知道深度神经网络的本质是输入端数据和输出端数据的一种高维非线性拟合,如何更好的理解它,下面尝试拟合一个正弦函数,本文可以通过简单设置节点数,实现任意隐藏层数的拟合。

基于pytorch的深度神经网络实战,无论任务多么复杂,都可以将其拆分成必要的几个模块来进行理解。

1)构建数据集,包括输入,对应的标签y
2) 构建神经网络模型,一般基于nn.Module继承一个net类,必须的是__init__函数和forward函数。__init__构造函数包括创建该类是必须的参数,比如输入节点数,隐藏层节点数,输出节点数。forward函数则定义了整个网络的前向传播过程,类似于一个Sequential。
3)实例化上步创建的类。
4)定义损失函数(判别准则),比如均方误差,交叉熵等
5)定义优化器(optim:SGD,adam,adadelta等),设置学习率
6)开始训练。开始训练是一个从0到设定的epoch的循环,循环期间,根据loss,不断迭代和更新网络权重参数。

无论多么复杂的网络,基于pytorch的深度神经网络都包括6个模块,训练阶段包括5个步骤,本文只通过拟合一个正弦函数来说明加深理解。

废话少说,直接上代码:

from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import torch.nn as nn
import numpy as np
import torch

# 准备数据
x=np.linspace(-2*np.pi,2*np.pi,400)
y=np.sin(x)
# 将数据做成数据集的模样
X=np.expand_dims(x,axis=1)
Y=y.reshape(400,-1)
# 使用批训练方式
dataset=TensorDataset(torch.tensor(X,dtype=torch.float),torch.tensor(Y,dtype=torch.float))
dataloader=DataLoader(dataset,batch_size=100,shuffle=True)

# 神经网络主要结构,这里就是一个简单的线性结构

class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.net=nn.Sequential(
      nn.Linear(in_features=1,out_features=10),nn.ReLU(),
      nn.Linear(10,100),nn.ReLU(),
      nn.Linear(100,10),nn.ReLU(),
      nn.Linear(10,1)
    )

  def forward(self, input:torch.FloatTensor):
    return self.net(input)

net=Net()

# 定义优化器和损失函数
optim=torch.optim.Adam(Net.parameters(net),lr=0.001)
Loss=nn.MSELoss()

# 下面开始训练:
# 一共训练 1000次
for epoch in range(1000):
  loss=None
  for batch_x,batch_y in dataloader:
    y_predict=net(batch_x)
    loss=Loss(y_predict,batch_y)
    optim.zero_grad()
    loss.backward()
    optim.step()
  # 每100次 的时候打印一次日志
  if (epoch+1)%100==0:
    print("step: {0} , loss: {1}".format(epoch+1,loss.item()))

# 使用训练好的模型进行预测
predict=net(torch.tensor(X,dtype=torch.float))

# 绘图展示预测的和真实数据之间的差异
import matplotlib.pyplot as plt
plt.plot(x,y,label="fact")
plt.plot(x,predict.detach().numpy(),label="predict")
plt.title("sin function")
plt.xlabel("x")
plt.ylabel("sin(x)")
plt.legend()
plt.savefig(fname="result.png",figsize=[10,10])
plt.show()

输出结果:

step: 100 , loss: 0.06755948066711426
step: 200 , loss: 0.003788222325965762
step: 300 , loss: 0.0004728269996121526
step: 400 , loss: 0.0001810075482353568
step: 500 , loss: 0.0001108720971387811
step: 600 , loss: 6.29749265499413e-05
step: 700 , loss: 3.707894938997924e-05
step: 800 , loss: 0.0001250380591955036
step: 900 , loss: 3.0654005968244746e-05
step: 1000 , loss: 4.349641676526517e-05

输出图像:

使用 pytorch 创建神经网络拟合sin函数的实现

到此这篇关于使用 pytorch 创建神经网络拟合sin函数的实现的文章就介绍到这了,更多相关pytorch 创建拟合sin函数内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python中global与nonlocal比较
Nov 21 Python
Python cx_freeze打包工具处理问题思路及解决办法
Feb 13 Python
python实现基于SVM手写数字识别功能
May 27 Python
Python获取CPU、内存使用率以及网络使用状态代码
Feb 08 Python
tensorflow 恢复指定层与不同层指定不同学习率的方法
Jul 26 Python
对Python闭包与延迟绑定的方法详解
Jan 07 Python
python处理excel绘制雷达图
Oct 18 Python
Python socket模块ftp传输文件过程解析
Nov 05 Python
Pytoch之torchvision.transforms图像变换实例
Dec 30 Python
Windows下Anaconda和PyCharm的安装与使用详解
Apr 23 Python
Pytorch可视化的几种实现方法
Jun 10 Python
Python加密技术之RSA加密解密的实现
Apr 08 Python
sklearn+python:线性回归案例
Feb 24 #Python
深入理解Tensorflow中的masking和padding
Feb 24 #Python
K最近邻算法(KNN)---sklearn+python实现方式
Feb 24 #Python
Python3.6 + TensorFlow 安装配置图文教程(Windows 64 bit)
Feb 24 #Python
Python enumerate内置库用法解析
Feb 24 #Python
Python模块/包/库安装的六种方法及区别
Feb 24 #Python
python之MSE、MAE、RMSE的使用
Feb 24 #Python
You might like
我的论坛源代码(八)
2006/10/09 PHP
php array_pop()数组函数将数组最后一个单元弹出(出栈)
2011/07/12 PHP
解析linux下安装memcacheq(mcq)全过程笔记
2013/06/27 PHP
浅谈discuz密码加密的方式
2014/05/22 PHP
PHP中的Iterator迭代对象属性详解
2019/04/12 PHP
比较搞笑的js陷阱题
2010/02/07 Javascript
初识javascript 文档碎片
2010/07/13 Javascript
javascript面向对象之二 命名空间
2011/02/08 Javascript
jquery时间下拉框小例子
2013/04/15 Javascript
JavaScript:new 一个函数和直接调用函数的区别分析
2013/07/10 Javascript
Javascript WebSocket使用实例介绍(简明入门教程)
2014/04/16 Javascript
关于List.ToArray()方法的效率测试
2016/09/30 Javascript
基于jquery实现的银行卡号每隔4位自动插入空格的实现代码
2016/11/22 Javascript
javascript  数组排序与对象排序的实例
2017/07/17 Javascript
jquery获取链接地址和跳转详解(推荐)
2017/08/15 jQuery
微信小程序实现城市列表选择
2018/06/05 Javascript
vue中v-for循环给标签属性赋值的方法
2018/10/18 Javascript
基于Vue 实现一个中规中矩loading组件
2019/04/03 Javascript
ionic+html5+API实现双击返回键退出应用
2019/09/17 Javascript
微信小程序背景音乐开发详解
2019/12/12 Javascript
[51:26]DOTA2上海特级锦标赛主赛事日 - 2 胜者组第一轮#3Secret VS OG第二局
2016/03/03 DOTA
Python实现从url中提取域名的几种方法
2014/09/26 Python
举例详解Python中循环语句的嵌套使用
2015/05/14 Python
python计算圆周率pi的方法
2015/07/11 Python
举例讲解Python的lambda语句声明匿名函数的用法
2016/07/01 Python
深入理解Python中的*重复运算符
2017/10/28 Python
python版本的仿windows计划任务工具
2018/04/30 Python
python IP地址转整数
2020/11/20 Python
HTML中fieldset标签概述及使用方法
2013/02/01 HTML / CSS
经销商会议欢迎词
2014/01/11 职场文书
银行办理业务介绍信
2014/01/18 职场文书
简单租房协议书
2014/04/09 职场文书
公务员群众路线心得体会
2014/11/03 职场文书
2015年办公室人员工作总结
2015/05/15 职场文书
2015年中秋晚会主持词
2015/07/01 职场文书
车辆安全隐患排查制度
2015/08/05 职场文书