使用 pytorch 创建神经网络拟合sin函数的实现


Posted in Python onFebruary 24, 2020

我们知道深度神经网络的本质是输入端数据和输出端数据的一种高维非线性拟合,如何更好的理解它,下面尝试拟合一个正弦函数,本文可以通过简单设置节点数,实现任意隐藏层数的拟合。

基于pytorch的深度神经网络实战,无论任务多么复杂,都可以将其拆分成必要的几个模块来进行理解。

1)构建数据集,包括输入,对应的标签y
2) 构建神经网络模型,一般基于nn.Module继承一个net类,必须的是__init__函数和forward函数。__init__构造函数包括创建该类是必须的参数,比如输入节点数,隐藏层节点数,输出节点数。forward函数则定义了整个网络的前向传播过程,类似于一个Sequential。
3)实例化上步创建的类。
4)定义损失函数(判别准则),比如均方误差,交叉熵等
5)定义优化器(optim:SGD,adam,adadelta等),设置学习率
6)开始训练。开始训练是一个从0到设定的epoch的循环,循环期间,根据loss,不断迭代和更新网络权重参数。

无论多么复杂的网络,基于pytorch的深度神经网络都包括6个模块,训练阶段包括5个步骤,本文只通过拟合一个正弦函数来说明加深理解。

废话少说,直接上代码:

from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import torch.nn as nn
import numpy as np
import torch

# 准备数据
x=np.linspace(-2*np.pi,2*np.pi,400)
y=np.sin(x)
# 将数据做成数据集的模样
X=np.expand_dims(x,axis=1)
Y=y.reshape(400,-1)
# 使用批训练方式
dataset=TensorDataset(torch.tensor(X,dtype=torch.float),torch.tensor(Y,dtype=torch.float))
dataloader=DataLoader(dataset,batch_size=100,shuffle=True)

# 神经网络主要结构,这里就是一个简单的线性结构

class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.net=nn.Sequential(
      nn.Linear(in_features=1,out_features=10),nn.ReLU(),
      nn.Linear(10,100),nn.ReLU(),
      nn.Linear(100,10),nn.ReLU(),
      nn.Linear(10,1)
    )

  def forward(self, input:torch.FloatTensor):
    return self.net(input)

net=Net()

# 定义优化器和损失函数
optim=torch.optim.Adam(Net.parameters(net),lr=0.001)
Loss=nn.MSELoss()

# 下面开始训练:
# 一共训练 1000次
for epoch in range(1000):
  loss=None
  for batch_x,batch_y in dataloader:
    y_predict=net(batch_x)
    loss=Loss(y_predict,batch_y)
    optim.zero_grad()
    loss.backward()
    optim.step()
  # 每100次 的时候打印一次日志
  if (epoch+1)%100==0:
    print("step: {0} , loss: {1}".format(epoch+1,loss.item()))

# 使用训练好的模型进行预测
predict=net(torch.tensor(X,dtype=torch.float))

# 绘图展示预测的和真实数据之间的差异
import matplotlib.pyplot as plt
plt.plot(x,y,label="fact")
plt.plot(x,predict.detach().numpy(),label="predict")
plt.title("sin function")
plt.xlabel("x")
plt.ylabel("sin(x)")
plt.legend()
plt.savefig(fname="result.png",figsize=[10,10])
plt.show()

输出结果:

step: 100 , loss: 0.06755948066711426
step: 200 , loss: 0.003788222325965762
step: 300 , loss: 0.0004728269996121526
step: 400 , loss: 0.0001810075482353568
step: 500 , loss: 0.0001108720971387811
step: 600 , loss: 6.29749265499413e-05
step: 700 , loss: 3.707894938997924e-05
step: 800 , loss: 0.0001250380591955036
step: 900 , loss: 3.0654005968244746e-05
step: 1000 , loss: 4.349641676526517e-05

输出图像:

使用 pytorch 创建神经网络拟合sin函数的实现

到此这篇关于使用 pytorch 创建神经网络拟合sin函数的实现的文章就介绍到这了,更多相关pytorch 创建拟合sin函数内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Using Django with GAE Python 后台抓取多个网站的页面全文
Feb 17 Python
Python 加密的实例详解
Oct 09 Python
解决python3 Pycharm上连接数据库时报错的问题
Dec 03 Python
python pandas时序处理相关功能详解
Jul 03 Python
在Windows上安装和配置 Jupyter Lab 作为桌面级应用程序教程
Apr 22 Python
python使用Thread的setDaemon启动后台线程教程
Apr 25 Python
基于python SMTP实现自动发送邮件教程解析
Jun 02 Python
Django QuerySet查询集原理及代码实例
Jun 13 Python
Pycharm调试程序技巧小结
Aug 08 Python
python 实现Harris角点检测算法
Dec 11 Python
python 模拟登录B站的示例代码
Dec 15 Python
Python实现石头剪刀布游戏
Jan 20 Python
sklearn+python:线性回归案例
Feb 24 #Python
深入理解Tensorflow中的masking和padding
Feb 24 #Python
K最近邻算法(KNN)---sklearn+python实现方式
Feb 24 #Python
Python3.6 + TensorFlow 安装配置图文教程(Windows 64 bit)
Feb 24 #Python
Python enumerate内置库用法解析
Feb 24 #Python
Python模块/包/库安装的六种方法及区别
Feb 24 #Python
python之MSE、MAE、RMSE的使用
Feb 24 #Python
You might like
根德YB400的电路分析
2021/03/02 无线电
PHP读取MySQL数据代码
2008/06/05 PHP
php去除换行符的方法小结(PHP_EOL变量的使用)
2013/02/16 PHP
php获取表单中多个同名input元素的值
2014/03/20 PHP
php简单随机字符串生成方法示例
2017/04/19 PHP
漂亮的widgets,支持换肤和后期开发新皮肤(2007-4-27已更新1.7alpha)
2007/04/27 Javascript
关于使用 jBox 对话框的提交不能弹出问题解决方法
2012/11/07 Javascript
javascript从image转换为base64位编码的String
2014/07/29 Javascript
jQuery中:gt选择器用法实例
2014/12/29 Javascript
jQuery中insertBefore()方法用法实例
2015/01/08 Javascript
jQuery左侧大图右侧小图焦点图幻灯切换代码分享
2015/08/19 Javascript
JavaScript 常见安全漏洞和自动化检测技术
2015/08/21 Javascript
jQuery实现的超酷苹果风格图标滑出菜单效果代码
2015/09/16 Javascript
jQuery中判断对象是否存在的方法汇总
2016/02/24 Javascript
详述JavaScript实现继承的几种方式(推荐)
2016/03/22 Javascript
JavaScript面试题(指针、帽子和女朋友)
2016/11/23 Javascript
学好js,这些js函数概念一定要知道【推荐】
2017/01/19 Javascript
JavaScript如何使用插值实现图像渐变
2020/06/28 Javascript
手把手带你搭建一个node cli的方法示例
2020/08/07 Javascript
js canvas实现俄罗斯方块
2020/10/11 Javascript
javascript实现时钟动画
2020/12/03 Javascript
Vue中inheritAttrs的使用实例详解
2020/12/31 Vue.js
react-native 实现购物车滑动删除效果的示例代码
2021/01/15 Javascript
Python之eval()函数危险性浅析
2014/07/03 Python
解析Python中while true的使用
2015/10/13 Python
Python Dataframe 指定多列去重、求差集的方法
2018/07/10 Python
Python函数参数匹配模型通用规则keyword-only参数详解
2019/06/10 Python
如何修复使用 Python ORM 工具 SQLAlchemy 时的常见陷阱
2019/11/19 Python
基于CSS3实现的黑色个性导航菜单效果
2015/09/14 HTML / CSS
德国化妆品和天然化妆品网上商店:kosmetikfuchs.de
2017/06/09 全球购物
Lacoste澳大利亚官网:服装、鞋类及配饰
2018/11/14 全球购物
软件测试工程师结构化面试题库
2016/11/23 面试题
咖啡书吧创业计划书
2014/01/13 职场文书
缓刑人员思想汇报
2014/10/11 职场文书
社区国庆节活动总结
2015/03/23 职场文书
Python包argparse模块常用方法
2021/06/04 Python