使用 pytorch 创建神经网络拟合sin函数的实现


Posted in Python onFebruary 24, 2020

我们知道深度神经网络的本质是输入端数据和输出端数据的一种高维非线性拟合,如何更好的理解它,下面尝试拟合一个正弦函数,本文可以通过简单设置节点数,实现任意隐藏层数的拟合。

基于pytorch的深度神经网络实战,无论任务多么复杂,都可以将其拆分成必要的几个模块来进行理解。

1)构建数据集,包括输入,对应的标签y
2) 构建神经网络模型,一般基于nn.Module继承一个net类,必须的是__init__函数和forward函数。__init__构造函数包括创建该类是必须的参数,比如输入节点数,隐藏层节点数,输出节点数。forward函数则定义了整个网络的前向传播过程,类似于一个Sequential。
3)实例化上步创建的类。
4)定义损失函数(判别准则),比如均方误差,交叉熵等
5)定义优化器(optim:SGD,adam,adadelta等),设置学习率
6)开始训练。开始训练是一个从0到设定的epoch的循环,循环期间,根据loss,不断迭代和更新网络权重参数。

无论多么复杂的网络,基于pytorch的深度神经网络都包括6个模块,训练阶段包括5个步骤,本文只通过拟合一个正弦函数来说明加深理解。

废话少说,直接上代码:

from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import torch.nn as nn
import numpy as np
import torch

# 准备数据
x=np.linspace(-2*np.pi,2*np.pi,400)
y=np.sin(x)
# 将数据做成数据集的模样
X=np.expand_dims(x,axis=1)
Y=y.reshape(400,-1)
# 使用批训练方式
dataset=TensorDataset(torch.tensor(X,dtype=torch.float),torch.tensor(Y,dtype=torch.float))
dataloader=DataLoader(dataset,batch_size=100,shuffle=True)

# 神经网络主要结构,这里就是一个简单的线性结构

class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.net=nn.Sequential(
      nn.Linear(in_features=1,out_features=10),nn.ReLU(),
      nn.Linear(10,100),nn.ReLU(),
      nn.Linear(100,10),nn.ReLU(),
      nn.Linear(10,1)
    )

  def forward(self, input:torch.FloatTensor):
    return self.net(input)

net=Net()

# 定义优化器和损失函数
optim=torch.optim.Adam(Net.parameters(net),lr=0.001)
Loss=nn.MSELoss()

# 下面开始训练:
# 一共训练 1000次
for epoch in range(1000):
  loss=None
  for batch_x,batch_y in dataloader:
    y_predict=net(batch_x)
    loss=Loss(y_predict,batch_y)
    optim.zero_grad()
    loss.backward()
    optim.step()
  # 每100次 的时候打印一次日志
  if (epoch+1)%100==0:
    print("step: {0} , loss: {1}".format(epoch+1,loss.item()))

# 使用训练好的模型进行预测
predict=net(torch.tensor(X,dtype=torch.float))

# 绘图展示预测的和真实数据之间的差异
import matplotlib.pyplot as plt
plt.plot(x,y,label="fact")
plt.plot(x,predict.detach().numpy(),label="predict")
plt.title("sin function")
plt.xlabel("x")
plt.ylabel("sin(x)")
plt.legend()
plt.savefig(fname="result.png",figsize=[10,10])
plt.show()

输出结果:

step: 100 , loss: 0.06755948066711426
step: 200 , loss: 0.003788222325965762
step: 300 , loss: 0.0004728269996121526
step: 400 , loss: 0.0001810075482353568
step: 500 , loss: 0.0001108720971387811
step: 600 , loss: 6.29749265499413e-05
step: 700 , loss: 3.707894938997924e-05
step: 800 , loss: 0.0001250380591955036
step: 900 , loss: 3.0654005968244746e-05
step: 1000 , loss: 4.349641676526517e-05

输出图像:

使用 pytorch 创建神经网络拟合sin函数的实现

到此这篇关于使用 pytorch 创建神经网络拟合sin函数的实现的文章就介绍到这了,更多相关pytorch 创建拟合sin函数内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python简单实现基于SSL的IRC bot实例
Jun 15 Python
Python实现快速多线程ping的方法
Jul 15 Python
利用Python实现颜色色值转换的小工具
Oct 27 Python
对python PLT中的image和skimage处理图片方法详解
Jan 10 Python
python实现中文文本分句的例子
Jul 15 Python
django admin.py 外键,反向查询的实例
Jul 26 Python
详解python破解zip文件密码的方法
Jan 13 Python
django实现将修改好的新模型写入数据库
Mar 31 Python
django 连接数据库出现1045错误的解决方式
May 14 Python
pycharm无法导入lxml的解决办法
Mar 31 Python
如何获取numpy array前N个最大值
May 14 Python
PyTorch中permute的使用方法
Apr 26 Python
sklearn+python:线性回归案例
Feb 24 #Python
深入理解Tensorflow中的masking和padding
Feb 24 #Python
K最近邻算法(KNN)---sklearn+python实现方式
Feb 24 #Python
Python3.6 + TensorFlow 安装配置图文教程(Windows 64 bit)
Feb 24 #Python
Python enumerate内置库用法解析
Feb 24 #Python
Python模块/包/库安装的六种方法及区别
Feb 24 #Python
python之MSE、MAE、RMSE的使用
Feb 24 #Python
You might like
PHP 数组入门教程小结
2009/05/20 PHP
PHP 基于文件头的文件类型验证类函数
2012/05/01 PHP
PHP中Header使用的HTTP协议及常用方法小结
2014/11/04 PHP
PHP依赖注入(DI)和控制反转(IoC)详解
2017/06/12 PHP
基于ThinkPHP5.0实现图片上传插件
2017/09/25 PHP
Laravel如何使用Redis共享Session
2018/02/23 PHP
限制文本字节数js代码
2007/03/06 Javascript
JavaScript高级程序设计 阅读笔记(二十一) JavaScript中的XML
2012/09/14 Javascript
php显示当前文件所在的文件以及文件夹所有文件以树形展开
2013/12/13 Javascript
jquery自动切换tabs选项卡的具体实现
2013/12/24 Javascript
利用jquery写的左右轮播图特效
2014/02/12 Javascript
表单提交前触发函数返回true表单才会提交
2014/03/11 Javascript
JavaScript中的slice()方法使用详解
2015/06/06 Javascript
Angular中$compile源码分析
2016/01/28 Javascript
jquery编写Tab选项卡滚动导航切换特效
2020/07/17 Javascript
jQuery回到顶部的代码
2016/07/09 Javascript
javascript实现的左右无缝滚动效果
2016/09/19 Javascript
浅述节点的创建及常见功能的实现
2016/12/15 Javascript
JavaScript 动态三角函数实例详解
2017/01/08 Javascript
AngularJS框架中的双向数据绑定机制详解【减少需要重复的开发代码量】
2017/01/19 Javascript
JS写XSS cookie stealer来窃取密码的步骤详解
2017/11/20 Javascript
用WebStorm进行Angularjs 2开发(环境篇:Windows 10,Angular-cli方式)
2018/12/05 Javascript
vue轻量级框架无法获取到vue对象解决方法
2019/05/12 Javascript
深入理解python中的浅拷贝和深拷贝
2016/05/30 Python
老生常谈Python序列化和反序列化
2017/06/28 Python
python去除字符串中的换行符
2017/10/11 Python
python3 实现的对象与json相互转换操作示例
2019/08/17 Python
python针对mysql数据库的连接、查询、更新、删除操作示例
2019/09/11 Python
Laravel框架表单验证格式化输出的方法
2019/09/25 Python
python使用Matplotlib改变坐标轴的默认位置
2019/10/18 Python
Django如何重置migration的几种情景
2021/02/24 Python
英国健身超市:Fitness Superstore
2019/06/17 全球购物
小学生节水倡议书
2015/04/29 职场文书
详解MySQL连接挂死的原因
2021/05/18 MySQL
Django rest framework如何自定义用户表
2021/06/09 Python
CSS中理解层叠性及权重如何分配
2022/12/24 HTML / CSS