使用 pytorch 创建神经网络拟合sin函数的实现


Posted in Python onFebruary 24, 2020

我们知道深度神经网络的本质是输入端数据和输出端数据的一种高维非线性拟合,如何更好的理解它,下面尝试拟合一个正弦函数,本文可以通过简单设置节点数,实现任意隐藏层数的拟合。

基于pytorch的深度神经网络实战,无论任务多么复杂,都可以将其拆分成必要的几个模块来进行理解。

1)构建数据集,包括输入,对应的标签y
2) 构建神经网络模型,一般基于nn.Module继承一个net类,必须的是__init__函数和forward函数。__init__构造函数包括创建该类是必须的参数,比如输入节点数,隐藏层节点数,输出节点数。forward函数则定义了整个网络的前向传播过程,类似于一个Sequential。
3)实例化上步创建的类。
4)定义损失函数(判别准则),比如均方误差,交叉熵等
5)定义优化器(optim:SGD,adam,adadelta等),设置学习率
6)开始训练。开始训练是一个从0到设定的epoch的循环,循环期间,根据loss,不断迭代和更新网络权重参数。

无论多么复杂的网络,基于pytorch的深度神经网络都包括6个模块,训练阶段包括5个步骤,本文只通过拟合一个正弦函数来说明加深理解。

废话少说,直接上代码:

from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import torch.nn as nn
import numpy as np
import torch

# 准备数据
x=np.linspace(-2*np.pi,2*np.pi,400)
y=np.sin(x)
# 将数据做成数据集的模样
X=np.expand_dims(x,axis=1)
Y=y.reshape(400,-1)
# 使用批训练方式
dataset=TensorDataset(torch.tensor(X,dtype=torch.float),torch.tensor(Y,dtype=torch.float))
dataloader=DataLoader(dataset,batch_size=100,shuffle=True)

# 神经网络主要结构,这里就是一个简单的线性结构

class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.net=nn.Sequential(
      nn.Linear(in_features=1,out_features=10),nn.ReLU(),
      nn.Linear(10,100),nn.ReLU(),
      nn.Linear(100,10),nn.ReLU(),
      nn.Linear(10,1)
    )

  def forward(self, input:torch.FloatTensor):
    return self.net(input)

net=Net()

# 定义优化器和损失函数
optim=torch.optim.Adam(Net.parameters(net),lr=0.001)
Loss=nn.MSELoss()

# 下面开始训练:
# 一共训练 1000次
for epoch in range(1000):
  loss=None
  for batch_x,batch_y in dataloader:
    y_predict=net(batch_x)
    loss=Loss(y_predict,batch_y)
    optim.zero_grad()
    loss.backward()
    optim.step()
  # 每100次 的时候打印一次日志
  if (epoch+1)%100==0:
    print("step: {0} , loss: {1}".format(epoch+1,loss.item()))

# 使用训练好的模型进行预测
predict=net(torch.tensor(X,dtype=torch.float))

# 绘图展示预测的和真实数据之间的差异
import matplotlib.pyplot as plt
plt.plot(x,y,label="fact")
plt.plot(x,predict.detach().numpy(),label="predict")
plt.title("sin function")
plt.xlabel("x")
plt.ylabel("sin(x)")
plt.legend()
plt.savefig(fname="result.png",figsize=[10,10])
plt.show()

输出结果:

step: 100 , loss: 0.06755948066711426
step: 200 , loss: 0.003788222325965762
step: 300 , loss: 0.0004728269996121526
step: 400 , loss: 0.0001810075482353568
step: 500 , loss: 0.0001108720971387811
step: 600 , loss: 6.29749265499413e-05
step: 700 , loss: 3.707894938997924e-05
step: 800 , loss: 0.0001250380591955036
step: 900 , loss: 3.0654005968244746e-05
step: 1000 , loss: 4.349641676526517e-05

输出图像:

使用 pytorch 创建神经网络拟合sin函数的实现

到此这篇关于使用 pytorch 创建神经网络拟合sin函数的实现的文章就介绍到这了,更多相关pytorch 创建拟合sin函数内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python getopt 参数处理小示例
Jun 09 Python
django实现分页的方法
May 26 Python
用Python写飞机大战游戏之pygame入门(4):获取鼠标的位置及运动
Nov 05 Python
Python编程之gui程序实现简单文件浏览器代码
Dec 08 Python
python实现requests发送/上传多个文件的示例
Jun 04 Python
python超时重新请求解决方案
Oct 21 Python
Selenium向iframe富文本框输入内容过程图解
Apr 10 Python
基于Python第三方插件实现西游记章节标注汉语拼音的方法
May 22 Python
Python 实现简单的客户端认证
Jul 29 Python
flask开启多线程的具体方法
Aug 02 Python
如何用Python 加密文件
Sep 10 Python
Python中requests做接口测试的方法
May 30 Python
sklearn+python:线性回归案例
Feb 24 #Python
深入理解Tensorflow中的masking和padding
Feb 24 #Python
K最近邻算法(KNN)---sklearn+python实现方式
Feb 24 #Python
Python3.6 + TensorFlow 安装配置图文教程(Windows 64 bit)
Feb 24 #Python
Python enumerate内置库用法解析
Feb 24 #Python
Python模块/包/库安装的六种方法及区别
Feb 24 #Python
python之MSE、MAE、RMSE的使用
Feb 24 #Python
You might like
php提示Warning:mysql_fetch_array() expects的解决方法
2014/12/16 PHP
PHP设计模式之原型模式定义与用法详解
2018/04/03 PHP
PHP7内核CGI与FastCGI详解
2019/04/14 PHP
jquery text,radio,checkbox,select操作实现代码
2009/07/09 Javascript
JavaScript接口实现代码 (Interfaces In JavaScript)
2010/06/11 Javascript
JavaScript实现鼠标滑过图片变换效果的方法
2015/04/16 Javascript
jQuery事件绑定on()、bind()与delegate() 方法详解
2015/06/03 Javascript
纯js实现无限空间大小的本地存储
2015/06/18 Javascript
JS实现点击按钮获取页面高度的方法
2015/11/02 Javascript
jQuery使用cookie与json简单实现购物车功能
2016/04/15 Javascript
jQuery+CSS3文字跑马灯特效的简单实现
2016/06/25 Javascript
分享jQuery封装好的一些常用操作
2016/07/28 Javascript
JS实现滑动门效果的方法详解
2016/12/19 Javascript
jQuery插件FusionCharts绘制2D环饼图效果示例【附demo源码】
2017/04/10 jQuery
vue实现点击当前标签高亮效果【推荐】
2018/06/22 Javascript
微信小程序导入Vant报错VM292:1 thirdScriptError的解决方法
2019/08/01 Javascript
详解Python迭代和迭代器
2016/03/28 Python
python中requests使用代理proxies方法介绍
2017/10/25 Python
python读取文本中的坐标方法
2018/10/14 Python
python使用socket 先读取长度,在读取报文内容示例
2019/09/26 Python
浅析Python 责任链设计模式
2020/09/11 Python
HTML中fieldset标签概述及使用方法
2013/02/01 HTML / CSS
HTML5 canvas绘制的玫瑰花效果
2014/05/29 HTML / CSS
Under Armour澳大利亚官网:美国知名的高端功能性运动品牌
2018/02/22 全球购物
美国高端牛仔品牌:Silver Jeans
2019/12/12 全球购物
Unix控制后台进程都有哪些进程
2016/09/22 面试题
中医药大学市场营销专业自荐信
2013/09/29 职场文书
岗位职责定义及内容
2013/11/08 职场文书
简历上的自我评价怎么写
2014/01/28 职场文书
三分钟自我介绍演讲稿
2014/08/21 职场文书
2014年药品销售工作总结
2014/12/16 职场文书
成绩单家长意见
2015/06/03 职场文书
党课主持词大全
2015/06/30 职场文书
2016年庆祝六一儿童节活动总结
2016/04/06 职场文书
承诺书应该怎么写?
2019/09/10 职场文书
DSP接收机前端设想
2022/04/05 无线电