使用 pytorch 创建神经网络拟合sin函数的实现


Posted in Python onFebruary 24, 2020

我们知道深度神经网络的本质是输入端数据和输出端数据的一种高维非线性拟合,如何更好的理解它,下面尝试拟合一个正弦函数,本文可以通过简单设置节点数,实现任意隐藏层数的拟合。

基于pytorch的深度神经网络实战,无论任务多么复杂,都可以将其拆分成必要的几个模块来进行理解。

1)构建数据集,包括输入,对应的标签y
2) 构建神经网络模型,一般基于nn.Module继承一个net类,必须的是__init__函数和forward函数。__init__构造函数包括创建该类是必须的参数,比如输入节点数,隐藏层节点数,输出节点数。forward函数则定义了整个网络的前向传播过程,类似于一个Sequential。
3)实例化上步创建的类。
4)定义损失函数(判别准则),比如均方误差,交叉熵等
5)定义优化器(optim:SGD,adam,adadelta等),设置学习率
6)开始训练。开始训练是一个从0到设定的epoch的循环,循环期间,根据loss,不断迭代和更新网络权重参数。

无论多么复杂的网络,基于pytorch的深度神经网络都包括6个模块,训练阶段包括5个步骤,本文只通过拟合一个正弦函数来说明加深理解。

废话少说,直接上代码:

from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import torch.nn as nn
import numpy as np
import torch

# 准备数据
x=np.linspace(-2*np.pi,2*np.pi,400)
y=np.sin(x)
# 将数据做成数据集的模样
X=np.expand_dims(x,axis=1)
Y=y.reshape(400,-1)
# 使用批训练方式
dataset=TensorDataset(torch.tensor(X,dtype=torch.float),torch.tensor(Y,dtype=torch.float))
dataloader=DataLoader(dataset,batch_size=100,shuffle=True)

# 神经网络主要结构,这里就是一个简单的线性结构

class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.net=nn.Sequential(
      nn.Linear(in_features=1,out_features=10),nn.ReLU(),
      nn.Linear(10,100),nn.ReLU(),
      nn.Linear(100,10),nn.ReLU(),
      nn.Linear(10,1)
    )

  def forward(self, input:torch.FloatTensor):
    return self.net(input)

net=Net()

# 定义优化器和损失函数
optim=torch.optim.Adam(Net.parameters(net),lr=0.001)
Loss=nn.MSELoss()

# 下面开始训练:
# 一共训练 1000次
for epoch in range(1000):
  loss=None
  for batch_x,batch_y in dataloader:
    y_predict=net(batch_x)
    loss=Loss(y_predict,batch_y)
    optim.zero_grad()
    loss.backward()
    optim.step()
  # 每100次 的时候打印一次日志
  if (epoch+1)%100==0:
    print("step: {0} , loss: {1}".format(epoch+1,loss.item()))

# 使用训练好的模型进行预测
predict=net(torch.tensor(X,dtype=torch.float))

# 绘图展示预测的和真实数据之间的差异
import matplotlib.pyplot as plt
plt.plot(x,y,label="fact")
plt.plot(x,predict.detach().numpy(),label="predict")
plt.title("sin function")
plt.xlabel("x")
plt.ylabel("sin(x)")
plt.legend()
plt.savefig(fname="result.png",figsize=[10,10])
plt.show()

输出结果:

step: 100 , loss: 0.06755948066711426
step: 200 , loss: 0.003788222325965762
step: 300 , loss: 0.0004728269996121526
step: 400 , loss: 0.0001810075482353568
step: 500 , loss: 0.0001108720971387811
step: 600 , loss: 6.29749265499413e-05
step: 700 , loss: 3.707894938997924e-05
step: 800 , loss: 0.0001250380591955036
step: 900 , loss: 3.0654005968244746e-05
step: 1000 , loss: 4.349641676526517e-05

输出图像:

使用 pytorch 创建神经网络拟合sin函数的实现

到此这篇关于使用 pytorch 创建神经网络拟合sin函数的实现的文章就介绍到这了,更多相关pytorch 创建拟合sin函数内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python进阶教程之动态类型详解
Aug 30 Python
Python实现对PPT文件进行截图操作的方法
Apr 28 Python
python版opencv摄像头人脸实时检测方法
Aug 03 Python
Python3按一定数据位数格式处理bin文件的方法
Jan 24 Python
pandas去重复行并分类汇总的实现方法
Jan 29 Python
如何通过python画loss曲线的方法
Jun 26 Python
Python input函数使用实例解析
Nov 22 Python
Python如何读写二进制数组数据
Aug 01 Python
python学习笔记之多进程
Aug 06 Python
解决阿里云邮件发送不能使用25端口问题
Aug 07 Python
OpenCV实现机器人对物体进行移动跟随的方法实例
Nov 09 Python
python 调用js的四种方式
Apr 11 Python
sklearn+python:线性回归案例
Feb 24 #Python
深入理解Tensorflow中的masking和padding
Feb 24 #Python
K最近邻算法(KNN)---sklearn+python实现方式
Feb 24 #Python
Python3.6 + TensorFlow 安装配置图文教程(Windows 64 bit)
Feb 24 #Python
Python enumerate内置库用法解析
Feb 24 #Python
Python模块/包/库安装的六种方法及区别
Feb 24 #Python
python之MSE、MAE、RMSE的使用
Feb 24 #Python
You might like
PHP 采集获取指定网址的内容
2010/01/05 PHP
关于url地址传参数时字符串有回车造成页面脚本赋值失败的解决方法
2013/06/28 PHP
使用php语句将数据库*.sql文件导入数据库
2014/05/05 PHP
PHP获取youku视频真实flv文件地址的方法
2014/12/23 PHP
php邮箱地址正则表达式验证
2015/11/13 PHP
PHP获取二维数组中某一列的值集合
2015/12/25 PHP
PHP+MySql+jQuery实现的"顶"和"踩"投票功能
2016/05/21 PHP
使用tp框架和SQL语句查询数据表中的某字段包含某值
2019/10/18 PHP
js全屏显示显示代码的三种方法
2013/11/11 Javascript
jQuery实现信息提示框(带有圆角框与动画)效果
2015/08/07 Javascript
js判断是否为空和typeof的用法(详解)
2016/10/07 Javascript
Angular.js中用ng-repeat-start实现自定义显示
2016/10/18 Javascript
react开发中如何使用require.ensure加载es6风格的组件
2017/05/09 Javascript
JavaScript门面模式详解
2017/10/19 Javascript
实战node静态文件服务器的示例代码
2018/03/08 Javascript
angular2 ng2-file-upload上传示例代码
2018/08/23 Javascript
vue项目刷新当前页面的三种方法
2018/12/04 Javascript
vue模仿网易云音乐的单页面应用
2019/04/24 Javascript
vue深度监听(监听对象和数组的改变)与立即执行监听实例
2020/09/04 Javascript
Vue2.0 ES6语法降级ES5的操作
2020/10/30 Javascript
vant-ui AddressEdit地址编辑和van-area的用法说明
2020/11/03 Javascript
Python模块学习 filecmp 文件比较
2012/08/27 Python
Python 使用os.remove删除文件夹时报错的解决方法
2017/01/13 Python
详解pyqt5 动画在QThread线程中无法运行问题
2018/05/05 Python
python3实现字符串的全排列的方法(无重复字符)
2018/07/07 Python
代码详解django中数据库设置
2019/01/28 Python
python自定义函数实现最大值的输出方法
2019/07/09 Python
python3实现mysql导出excel的方法
2019/07/31 Python
详解通过变换矩阵实现canvas的缩放功能
2019/01/14 HTML / CSS
iframe在移动端的缩放的示例代码
2018/10/12 HTML / CSS
日本索尼音乐商店:Sony Music Shop
2018/07/17 全球购物
客服端调用EJB对象的几个基本步骤
2012/01/15 面试题
网络优化专员求职信
2014/05/04 职场文书
小学优秀教师事迹材料
2014/12/16 职场文书
协议书格式模板
2016/03/24 职场文书
Python实现列表拼接和去重的三种方式
2021/07/02 Python