使用 pytorch 创建神经网络拟合sin函数的实现


Posted in Python onFebruary 24, 2020

我们知道深度神经网络的本质是输入端数据和输出端数据的一种高维非线性拟合,如何更好的理解它,下面尝试拟合一个正弦函数,本文可以通过简单设置节点数,实现任意隐藏层数的拟合。

基于pytorch的深度神经网络实战,无论任务多么复杂,都可以将其拆分成必要的几个模块来进行理解。

1)构建数据集,包括输入,对应的标签y
2) 构建神经网络模型,一般基于nn.Module继承一个net类,必须的是__init__函数和forward函数。__init__构造函数包括创建该类是必须的参数,比如输入节点数,隐藏层节点数,输出节点数。forward函数则定义了整个网络的前向传播过程,类似于一个Sequential。
3)实例化上步创建的类。
4)定义损失函数(判别准则),比如均方误差,交叉熵等
5)定义优化器(optim:SGD,adam,adadelta等),设置学习率
6)开始训练。开始训练是一个从0到设定的epoch的循环,循环期间,根据loss,不断迭代和更新网络权重参数。

无论多么复杂的网络,基于pytorch的深度神经网络都包括6个模块,训练阶段包括5个步骤,本文只通过拟合一个正弦函数来说明加深理解。

废话少说,直接上代码:

from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import torch.nn as nn
import numpy as np
import torch

# 准备数据
x=np.linspace(-2*np.pi,2*np.pi,400)
y=np.sin(x)
# 将数据做成数据集的模样
X=np.expand_dims(x,axis=1)
Y=y.reshape(400,-1)
# 使用批训练方式
dataset=TensorDataset(torch.tensor(X,dtype=torch.float),torch.tensor(Y,dtype=torch.float))
dataloader=DataLoader(dataset,batch_size=100,shuffle=True)

# 神经网络主要结构,这里就是一个简单的线性结构

class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.net=nn.Sequential(
      nn.Linear(in_features=1,out_features=10),nn.ReLU(),
      nn.Linear(10,100),nn.ReLU(),
      nn.Linear(100,10),nn.ReLU(),
      nn.Linear(10,1)
    )

  def forward(self, input:torch.FloatTensor):
    return self.net(input)

net=Net()

# 定义优化器和损失函数
optim=torch.optim.Adam(Net.parameters(net),lr=0.001)
Loss=nn.MSELoss()

# 下面开始训练:
# 一共训练 1000次
for epoch in range(1000):
  loss=None
  for batch_x,batch_y in dataloader:
    y_predict=net(batch_x)
    loss=Loss(y_predict,batch_y)
    optim.zero_grad()
    loss.backward()
    optim.step()
  # 每100次 的时候打印一次日志
  if (epoch+1)%100==0:
    print("step: {0} , loss: {1}".format(epoch+1,loss.item()))

# 使用训练好的模型进行预测
predict=net(torch.tensor(X,dtype=torch.float))

# 绘图展示预测的和真实数据之间的差异
import matplotlib.pyplot as plt
plt.plot(x,y,label="fact")
plt.plot(x,predict.detach().numpy(),label="predict")
plt.title("sin function")
plt.xlabel("x")
plt.ylabel("sin(x)")
plt.legend()
plt.savefig(fname="result.png",figsize=[10,10])
plt.show()

输出结果:

step: 100 , loss: 0.06755948066711426
step: 200 , loss: 0.003788222325965762
step: 300 , loss: 0.0004728269996121526
step: 400 , loss: 0.0001810075482353568
step: 500 , loss: 0.0001108720971387811
step: 600 , loss: 6.29749265499413e-05
step: 700 , loss: 3.707894938997924e-05
step: 800 , loss: 0.0001250380591955036
step: 900 , loss: 3.0654005968244746e-05
step: 1000 , loss: 4.349641676526517e-05

输出图像:

使用 pytorch 创建神经网络拟合sin函数的实现

到此这篇关于使用 pytorch 创建神经网络拟合sin函数的实现的文章就介绍到这了,更多相关pytorch 创建拟合sin函数内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python判断操作系统类型代码分享
Nov 22 Python
Python 3中print函数的使用方法总结
Aug 08 Python
浅谈numpy库的常用基本操作方法
Jan 09 Python
分析Python中解析构建数据知识
Jan 20 Python
python 处理数字,把大于上限的数字置零实现方法
Jan 28 Python
Python后台管理员管理前台会员信息的讲解
Jan 28 Python
OpenCV中VideoCapture类的使用详解
Feb 14 Python
python判断两个序列的成员是否一样的实例代码
Mar 01 Python
Python调用C语言程序方法解析
Jul 07 Python
Python如何将模块打包并发布
Aug 30 Python
Python导入父文件夹中模块并读取当前文件夹内的资源
Nov 19 Python
Python+Xlwings 删除Excel的行和列
Dec 19 Python
sklearn+python:线性回归案例
Feb 24 #Python
深入理解Tensorflow中的masking和padding
Feb 24 #Python
K最近邻算法(KNN)---sklearn+python实现方式
Feb 24 #Python
Python3.6 + TensorFlow 安装配置图文教程(Windows 64 bit)
Feb 24 #Python
Python enumerate内置库用法解析
Feb 24 #Python
Python模块/包/库安装的六种方法及区别
Feb 24 #Python
python之MSE、MAE、RMSE的使用
Feb 24 #Python
You might like
Session服务器配置指南与使用经验的深入解析
2013/06/17 PHP
解析:通过php socket并借助telnet实现简单的聊天程序
2013/06/18 PHP
php创建、获取cookie及基础要点分析
2015/01/26 PHP
PHP的Socket网络编程入门指引
2015/08/11 PHP
PHP GD库相关图像生成和处理函数小结
2016/09/30 PHP
PHP中类的自动加载的方法
2017/03/17 PHP
jQuery 各种浏览器下获得日期区别
2008/12/22 Javascript
jquery checkbox全选、取消全选实现代码
2010/03/05 Javascript
jQuery的attr与prop使用介绍
2013/10/10 Javascript
如何判断鼠标是否在DIV的区域内
2013/11/13 Javascript
javascript实现任务栏消息提示的简单实例
2016/05/31 Javascript
bootstrap组件之导航组件使用方法
2017/01/19 Javascript
Vue响应式添加、修改数组和对象的值
2017/03/20 Javascript
详解Node.js中exports和module.exports的区别
2017/04/19 Javascript
vue用Object.defineProperty手写一个简单的双向绑定的示例
2018/07/09 Javascript
图片文字识别(OCR)插件Ocrad.js教程
2018/11/26 Javascript
Python continue语句用法实例
2014/03/11 Python
Python不规范的日期字符串处理类
2014/06/10 Python
Python科学计算之NumPy入门教程
2017/01/15 Python
PyQt4实现下拉菜单可供选择并打印出来
2018/04/20 Python
和孩子一起学习python之变量命名规则
2018/05/27 Python
flask中的wtforms使用方法
2018/07/21 Python
使用k8s部署Django项目的方法步骤
2019/01/14 Python
pygame编写音乐播放器的实现代码示例
2019/11/19 Python
python dataframe NaN处理方式
2019/12/26 Python
M1芯片安装python3.9.1的实现
2021/02/02 Python
IE滤镜与CSS3效果(详细整理分享)
2013/01/25 HTML / CSS
荷兰街头时尚之家:Funkie House
2019/03/18 全球购物
主持人演讲稿
2014/05/13 职场文书
校长竞聘演讲稿
2014/05/16 职场文书
青奥会口号
2014/06/12 职场文书
2014年司机工作总结
2014/11/21 职场文书
云冈石窟导游词
2015/02/04 职场文书
婚庆答谢词大全
2015/09/29 职场文书
遇事可以测出您的见识与格局
2019/09/16 职场文书
python批量创建变量并赋值操作
2021/06/03 Python