使用 pytorch 创建神经网络拟合sin函数的实现


Posted in Python onFebruary 24, 2020

我们知道深度神经网络的本质是输入端数据和输出端数据的一种高维非线性拟合,如何更好的理解它,下面尝试拟合一个正弦函数,本文可以通过简单设置节点数,实现任意隐藏层数的拟合。

基于pytorch的深度神经网络实战,无论任务多么复杂,都可以将其拆分成必要的几个模块来进行理解。

1)构建数据集,包括输入,对应的标签y
2) 构建神经网络模型,一般基于nn.Module继承一个net类,必须的是__init__函数和forward函数。__init__构造函数包括创建该类是必须的参数,比如输入节点数,隐藏层节点数,输出节点数。forward函数则定义了整个网络的前向传播过程,类似于一个Sequential。
3)实例化上步创建的类。
4)定义损失函数(判别准则),比如均方误差,交叉熵等
5)定义优化器(optim:SGD,adam,adadelta等),设置学习率
6)开始训练。开始训练是一个从0到设定的epoch的循环,循环期间,根据loss,不断迭代和更新网络权重参数。

无论多么复杂的网络,基于pytorch的深度神经网络都包括6个模块,训练阶段包括5个步骤,本文只通过拟合一个正弦函数来说明加深理解。

废话少说,直接上代码:

from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import torch.nn as nn
import numpy as np
import torch

# 准备数据
x=np.linspace(-2*np.pi,2*np.pi,400)
y=np.sin(x)
# 将数据做成数据集的模样
X=np.expand_dims(x,axis=1)
Y=y.reshape(400,-1)
# 使用批训练方式
dataset=TensorDataset(torch.tensor(X,dtype=torch.float),torch.tensor(Y,dtype=torch.float))
dataloader=DataLoader(dataset,batch_size=100,shuffle=True)

# 神经网络主要结构,这里就是一个简单的线性结构

class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.net=nn.Sequential(
      nn.Linear(in_features=1,out_features=10),nn.ReLU(),
      nn.Linear(10,100),nn.ReLU(),
      nn.Linear(100,10),nn.ReLU(),
      nn.Linear(10,1)
    )

  def forward(self, input:torch.FloatTensor):
    return self.net(input)

net=Net()

# 定义优化器和损失函数
optim=torch.optim.Adam(Net.parameters(net),lr=0.001)
Loss=nn.MSELoss()

# 下面开始训练:
# 一共训练 1000次
for epoch in range(1000):
  loss=None
  for batch_x,batch_y in dataloader:
    y_predict=net(batch_x)
    loss=Loss(y_predict,batch_y)
    optim.zero_grad()
    loss.backward()
    optim.step()
  # 每100次 的时候打印一次日志
  if (epoch+1)%100==0:
    print("step: {0} , loss: {1}".format(epoch+1,loss.item()))

# 使用训练好的模型进行预测
predict=net(torch.tensor(X,dtype=torch.float))

# 绘图展示预测的和真实数据之间的差异
import matplotlib.pyplot as plt
plt.plot(x,y,label="fact")
plt.plot(x,predict.detach().numpy(),label="predict")
plt.title("sin function")
plt.xlabel("x")
plt.ylabel("sin(x)")
plt.legend()
plt.savefig(fname="result.png",figsize=[10,10])
plt.show()

输出结果:

step: 100 , loss: 0.06755948066711426
step: 200 , loss: 0.003788222325965762
step: 300 , loss: 0.0004728269996121526
step: 400 , loss: 0.0001810075482353568
step: 500 , loss: 0.0001108720971387811
step: 600 , loss: 6.29749265499413e-05
step: 700 , loss: 3.707894938997924e-05
step: 800 , loss: 0.0001250380591955036
step: 900 , loss: 3.0654005968244746e-05
step: 1000 , loss: 4.349641676526517e-05

输出图像:

使用 pytorch 创建神经网络拟合sin函数的实现

到此这篇关于使用 pytorch 创建神经网络拟合sin函数的实现的文章就介绍到这了,更多相关pytorch 创建拟合sin函数内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python BeautifulSoup中文乱码问题的2种解决方法
Apr 22 Python
Python操作SQLite简明教程
Jul 10 Python
Python的__builtin__模块中的一些要点知识
May 02 Python
理解python正则表达式
Jan 15 Python
python 队列详解及实例代码
Oct 18 Python
Python编程使用tkinter模块实现计算器软件完整代码示例
Nov 29 Python
Python实现读取Properties配置文件的方法
Mar 29 Python
python 编写简单网页服务器的实例
Jun 01 Python
python使用百度文字识别功能方法详解
Jul 23 Python
Win10下配置tensorflow-gpu的详细教程(无VS2015/2017)
Jul 14 Python
python 密码学示例——凯撒密码的实现
Sep 21 Python
Python实现异步IO的示例
Nov 05 Python
sklearn+python:线性回归案例
Feb 24 #Python
深入理解Tensorflow中的masking和padding
Feb 24 #Python
K最近邻算法(KNN)---sklearn+python实现方式
Feb 24 #Python
Python3.6 + TensorFlow 安装配置图文教程(Windows 64 bit)
Feb 24 #Python
Python enumerate内置库用法解析
Feb 24 #Python
Python模块/包/库安装的六种方法及区别
Feb 24 #Python
python之MSE、MAE、RMSE的使用
Feb 24 #Python
You might like
使用 PHPMAILER 发送邮件实例应用
2012/11/07 PHP
深入解读php中关于抽象(abstract)类和抽象方法的问题分析
2014/01/03 PHP
PHP四舍五入精确小数位及取整
2014/01/14 PHP
PHP采用超长(超大)数字运算防止数字以科学计数法显示的方法
2016/04/01 PHP
线路分流自动跳转代码;希望对大家有用!
2006/12/02 Javascript
javascript网页关键字高亮代码
2008/07/30 Javascript
javascript 写类方式之十
2009/07/05 Javascript
jquery改变disabled的boolean状态的三种方法
2013/12/13 Javascript
情人节单身的我是如何在敲完代码之后收到12束玫瑰的(javascript)
2015/08/21 Javascript
使用jQuery.form.js/springmvc框架实现文件上传功能
2016/05/12 Javascript
原生Javascript插件开发实践
2017/01/09 Javascript
使用vue-route 的 beforeEach 实现导航守卫(路由跳转前验证登录)功能
2018/03/22 Javascript
vue实现滑动到底部加载更多效果
2020/10/27 Javascript
vue自定义标签和单页面多路由的实现代码
2020/05/03 Javascript
[46:12]完美世界DOTA2联赛循环赛 DM vs Matador BO2第一场 11.04
2020/11/04 DOTA
linux平台使用Python制作BT种子并获取BT种子信息的方法
2017/01/20 Python
Python简单的制作图片验证码实例
2017/05/31 Python
python matplotlib坐标轴设置的方法
2017/12/05 Python
python Pandas 读取txt表格的实例
2018/04/29 Python
Python 字符串与数字输出方法
2018/07/16 Python
python实现汽车管理系统
2018/11/30 Python
浅谈pycharm的xmx和xms设置方法
2018/12/03 Python
Python3.7黑帽编程之病毒篇(基础篇)
2020/02/04 Python
Python 如何定义匿名或内联函数
2020/08/01 Python
PyCharm2019.3永久激活破解详细图文教程,亲测可用(不定期更新)
2020/10/29 Python
纯CSS3大转盘抽奖示例代码(响应式、可配置)
2017/01/13 HTML / CSS
意大利奢侈品网站:Italist
2016/08/23 全球购物
database面试题
2013/03/28 面试题
大众服装店创业计划书范文
2014/01/01 职场文书
社区安全检查制度
2014/02/03 职场文书
学校教研活动总结
2014/07/02 职场文书
开服装店计划书
2014/08/15 职场文书
公务员检讨书
2014/11/01 职场文书
关于运动会的宣传稿
2015/07/23 职场文书
Python制作动态字符画的源码
2021/08/04 Python
Nginx隐藏式跳转(浏览器URL跳转后保持不变)
2022/04/07 Servers