使用 pytorch 创建神经网络拟合sin函数的实现


Posted in Python onFebruary 24, 2020

我们知道深度神经网络的本质是输入端数据和输出端数据的一种高维非线性拟合,如何更好的理解它,下面尝试拟合一个正弦函数,本文可以通过简单设置节点数,实现任意隐藏层数的拟合。

基于pytorch的深度神经网络实战,无论任务多么复杂,都可以将其拆分成必要的几个模块来进行理解。

1)构建数据集,包括输入,对应的标签y
2) 构建神经网络模型,一般基于nn.Module继承一个net类,必须的是__init__函数和forward函数。__init__构造函数包括创建该类是必须的参数,比如输入节点数,隐藏层节点数,输出节点数。forward函数则定义了整个网络的前向传播过程,类似于一个Sequential。
3)实例化上步创建的类。
4)定义损失函数(判别准则),比如均方误差,交叉熵等
5)定义优化器(optim:SGD,adam,adadelta等),设置学习率
6)开始训练。开始训练是一个从0到设定的epoch的循环,循环期间,根据loss,不断迭代和更新网络权重参数。

无论多么复杂的网络,基于pytorch的深度神经网络都包括6个模块,训练阶段包括5个步骤,本文只通过拟合一个正弦函数来说明加深理解。

废话少说,直接上代码:

from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import torch.nn as nn
import numpy as np
import torch

# 准备数据
x=np.linspace(-2*np.pi,2*np.pi,400)
y=np.sin(x)
# 将数据做成数据集的模样
X=np.expand_dims(x,axis=1)
Y=y.reshape(400,-1)
# 使用批训练方式
dataset=TensorDataset(torch.tensor(X,dtype=torch.float),torch.tensor(Y,dtype=torch.float))
dataloader=DataLoader(dataset,batch_size=100,shuffle=True)

# 神经网络主要结构,这里就是一个简单的线性结构

class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.net=nn.Sequential(
      nn.Linear(in_features=1,out_features=10),nn.ReLU(),
      nn.Linear(10,100),nn.ReLU(),
      nn.Linear(100,10),nn.ReLU(),
      nn.Linear(10,1)
    )

  def forward(self, input:torch.FloatTensor):
    return self.net(input)

net=Net()

# 定义优化器和损失函数
optim=torch.optim.Adam(Net.parameters(net),lr=0.001)
Loss=nn.MSELoss()

# 下面开始训练:
# 一共训练 1000次
for epoch in range(1000):
  loss=None
  for batch_x,batch_y in dataloader:
    y_predict=net(batch_x)
    loss=Loss(y_predict,batch_y)
    optim.zero_grad()
    loss.backward()
    optim.step()
  # 每100次 的时候打印一次日志
  if (epoch+1)%100==0:
    print("step: {0} , loss: {1}".format(epoch+1,loss.item()))

# 使用训练好的模型进行预测
predict=net(torch.tensor(X,dtype=torch.float))

# 绘图展示预测的和真实数据之间的差异
import matplotlib.pyplot as plt
plt.plot(x,y,label="fact")
plt.plot(x,predict.detach().numpy(),label="predict")
plt.title("sin function")
plt.xlabel("x")
plt.ylabel("sin(x)")
plt.legend()
plt.savefig(fname="result.png",figsize=[10,10])
plt.show()

输出结果:

step: 100 , loss: 0.06755948066711426
step: 200 , loss: 0.003788222325965762
step: 300 , loss: 0.0004728269996121526
step: 400 , loss: 0.0001810075482353568
step: 500 , loss: 0.0001108720971387811
step: 600 , loss: 6.29749265499413e-05
step: 700 , loss: 3.707894938997924e-05
step: 800 , loss: 0.0001250380591955036
step: 900 , loss: 3.0654005968244746e-05
step: 1000 , loss: 4.349641676526517e-05

输出图像:

使用 pytorch 创建神经网络拟合sin函数的实现

到此这篇关于使用 pytorch 创建神经网络拟合sin函数的实现的文章就介绍到这了,更多相关pytorch 创建拟合sin函数内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python实现哈希表
Feb 07 Python
使用IronPython把Python脚本集成到.NET程序中的教程
Mar 31 Python
python 上下文管理器使用方法小结
Oct 10 Python
用Python写王者荣耀刷金币脚本
Dec 21 Python
Python使用装饰器进行django开发实例代码
Feb 06 Python
Python字符串通过'+'和join函数拼接新字符串的性能测试比较
Mar 05 Python
Python实现最大子序和的方法示例
Jul 05 Python
Python3从零开始搭建一个语音对话机器人的实现
Aug 23 Python
Python能做什么
Jun 02 Python
python3.7调试的实例方法
Jul 21 Python
基于 Python 实践感知器分类算法
Jan 07 Python
教你用Python爬取英雄联盟皮肤原画
Jun 13 Python
sklearn+python:线性回归案例
Feb 24 #Python
深入理解Tensorflow中的masking和padding
Feb 24 #Python
K最近邻算法(KNN)---sklearn+python实现方式
Feb 24 #Python
Python3.6 + TensorFlow 安装配置图文教程(Windows 64 bit)
Feb 24 #Python
Python enumerate内置库用法解析
Feb 24 #Python
Python模块/包/库安装的六种方法及区别
Feb 24 #Python
python之MSE、MAE、RMSE的使用
Feb 24 #Python
You might like
定义php常量的详解
2013/06/09 PHP
dedecms集成财付通支付接口
2014/12/28 PHP
js实现的动画导航菜单效果代码
2015/09/10 Javascript
JQuery ztree 异步加载实例讲解
2016/02/25 Javascript
Node.js文件操作方法汇总
2016/03/22 Javascript
jQuery实现移动端手机商城购物车功能
2016/09/24 Javascript
JavaScript自定义浏览器滚动条兼容IE、 火狐和chrome
2017/01/05 Javascript
微信小程序 列表的上拉加载和下拉刷新的实现
2017/04/01 Javascript
JavaScript简单计算人的年龄示例
2017/04/15 Javascript
详解如何在Vue里建立长按指令
2018/08/20 Javascript
Vue CLI 3.x 自动部署项目至服务器的方法
2019/04/02 Javascript
微信小程序实现注册登录功能(表单校验、错误提示)
2019/12/10 Javascript
JS+HTML实现自定义上传图片按钮并显示图片功能的方法分析
2020/02/12 Javascript
Vue Object.defineProperty及ProxyVue实现双向数据绑定
2020/09/02 Javascript
vue的$http的get请求要加上params操作
2020/11/12 Javascript
Python中用于转换字母为小写的lower()方法使用简介
2015/05/19 Python
Python3读取文件常用方法实例分析
2015/05/22 Python
Python函数中的函数(闭包)用法实例
2016/03/15 Python
Python对数据库操作
2016/03/28 Python
python获取多线程及子线程的返回值
2017/11/15 Python
python实现二叉树的遍历
2017/12/11 Python
python+os根据文件名自动生成文本
2019/03/21 Python
python增加图像对比度的方法
2019/07/12 Python
win10系统Anaconda和Pycharm的Tensorflow2.0之CPU和GPU版本安装教程
2019/12/03 Python
python实现从尾到头打印单链表操作示例
2020/02/22 Python
2021年的Python 时间轴和即将推出的功能详解
2020/07/27 Python
HTML5实现一个能够移动的小坦克示例代码
2013/09/02 HTML / CSS
Lancome兰蔻官方旗舰店:来自法国的世界知名美妆品牌
2018/06/14 全球购物
化学专业毕业生求职信
2014/07/28 职场文书
2015年医院创卫工作总结
2015/04/22 职场文书
2016年大学校运会广播稿件
2015/12/21 职场文书
CSS的class与id常用的命名规则
2021/05/18 HTML / CSS
PyTorch 如何自动计算梯度
2021/05/23 Python
mysql自增长id用完了该怎么办
2022/02/12 MySQL
JVM的类加载器和双亲委派模式你了解吗
2022/03/13 Java/Android
Spring Boot项目如何优雅实现Excel导入与导出功能
2022/06/10 Java/Android