使用 pytorch 创建神经网络拟合sin函数的实现


Posted in Python onFebruary 24, 2020

我们知道深度神经网络的本质是输入端数据和输出端数据的一种高维非线性拟合,如何更好的理解它,下面尝试拟合一个正弦函数,本文可以通过简单设置节点数,实现任意隐藏层数的拟合。

基于pytorch的深度神经网络实战,无论任务多么复杂,都可以将其拆分成必要的几个模块来进行理解。

1)构建数据集,包括输入,对应的标签y
2) 构建神经网络模型,一般基于nn.Module继承一个net类,必须的是__init__函数和forward函数。__init__构造函数包括创建该类是必须的参数,比如输入节点数,隐藏层节点数,输出节点数。forward函数则定义了整个网络的前向传播过程,类似于一个Sequential。
3)实例化上步创建的类。
4)定义损失函数(判别准则),比如均方误差,交叉熵等
5)定义优化器(optim:SGD,adam,adadelta等),设置学习率
6)开始训练。开始训练是一个从0到设定的epoch的循环,循环期间,根据loss,不断迭代和更新网络权重参数。

无论多么复杂的网络,基于pytorch的深度神经网络都包括6个模块,训练阶段包括5个步骤,本文只通过拟合一个正弦函数来说明加深理解。

废话少说,直接上代码:

from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import torch.nn as nn
import numpy as np
import torch

# 准备数据
x=np.linspace(-2*np.pi,2*np.pi,400)
y=np.sin(x)
# 将数据做成数据集的模样
X=np.expand_dims(x,axis=1)
Y=y.reshape(400,-1)
# 使用批训练方式
dataset=TensorDataset(torch.tensor(X,dtype=torch.float),torch.tensor(Y,dtype=torch.float))
dataloader=DataLoader(dataset,batch_size=100,shuffle=True)

# 神经网络主要结构,这里就是一个简单的线性结构

class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.net=nn.Sequential(
      nn.Linear(in_features=1,out_features=10),nn.ReLU(),
      nn.Linear(10,100),nn.ReLU(),
      nn.Linear(100,10),nn.ReLU(),
      nn.Linear(10,1)
    )

  def forward(self, input:torch.FloatTensor):
    return self.net(input)

net=Net()

# 定义优化器和损失函数
optim=torch.optim.Adam(Net.parameters(net),lr=0.001)
Loss=nn.MSELoss()

# 下面开始训练:
# 一共训练 1000次
for epoch in range(1000):
  loss=None
  for batch_x,batch_y in dataloader:
    y_predict=net(batch_x)
    loss=Loss(y_predict,batch_y)
    optim.zero_grad()
    loss.backward()
    optim.step()
  # 每100次 的时候打印一次日志
  if (epoch+1)%100==0:
    print("step: {0} , loss: {1}".format(epoch+1,loss.item()))

# 使用训练好的模型进行预测
predict=net(torch.tensor(X,dtype=torch.float))

# 绘图展示预测的和真实数据之间的差异
import matplotlib.pyplot as plt
plt.plot(x,y,label="fact")
plt.plot(x,predict.detach().numpy(),label="predict")
plt.title("sin function")
plt.xlabel("x")
plt.ylabel("sin(x)")
plt.legend()
plt.savefig(fname="result.png",figsize=[10,10])
plt.show()

输出结果:

step: 100 , loss: 0.06755948066711426
step: 200 , loss: 0.003788222325965762
step: 300 , loss: 0.0004728269996121526
step: 400 , loss: 0.0001810075482353568
step: 500 , loss: 0.0001108720971387811
step: 600 , loss: 6.29749265499413e-05
step: 700 , loss: 3.707894938997924e-05
step: 800 , loss: 0.0001250380591955036
step: 900 , loss: 3.0654005968244746e-05
step: 1000 , loss: 4.349641676526517e-05

输出图像:

使用 pytorch 创建神经网络拟合sin函数的实现

到此这篇关于使用 pytorch 创建神经网络拟合sin函数的实现的文章就介绍到这了,更多相关pytorch 创建拟合sin函数内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python中os操作文件及文件路径实例汇总
Jan 15 Python
python求解水仙花数的方法
May 11 Python
wxPython中listbox用法实例详解
Jun 01 Python
详解Python之数据序列化(json、pickle、shelve)
Mar 30 Python
python3实现点餐系统
Jan 24 Python
dataframe 按条件替换某一列中的值方法
Jan 29 Python
Django forms表单 select下拉框的传值实例
Jul 19 Python
python fuzzywuzzy模块模糊字符串匹配详细用法
Aug 29 Python
python快速排序的实现及运行时间比较
Nov 22 Python
python 读取.nii格式图像实例
Jul 01 Python
Idea安装python显示无SDK问题解决方案
Aug 12 Python
Python 读写 Matlab Mat 格式数据的操作
May 19 Python
sklearn+python:线性回归案例
Feb 24 #Python
深入理解Tensorflow中的masking和padding
Feb 24 #Python
K最近邻算法(KNN)---sklearn+python实现方式
Feb 24 #Python
Python3.6 + TensorFlow 安装配置图文教程(Windows 64 bit)
Feb 24 #Python
Python enumerate内置库用法解析
Feb 24 #Python
Python模块/包/库安装的六种方法及区别
Feb 24 #Python
python之MSE、MAE、RMSE的使用
Feb 24 #Python
You might like
php接口技术实例详解
2016/12/07 PHP
PHP实现链表的定义与反转功能示例
2018/06/09 PHP
laravel 验证错误信息到 blade模板的方法
2019/09/29 PHP
PHP基于openssl实现非对称加密代码实例
2020/06/19 PHP
js获取TreeView控件选中节点的Text和Value值的方法
2012/11/24 Javascript
jquery时间下拉框小例子
2013/04/15 Javascript
JS Pro-深入面向对象的程序设计之继承的详解
2013/05/07 Javascript
jQuery表格排序组件-tablesorter使用示例
2014/05/26 Javascript
JavaScript基本语法学习教程
2016/01/14 Javascript
mui框架移动开发初体验详解
2017/10/11 Javascript
vue-cli3.0配置及使用注意事项详解
2018/09/05 Javascript
vue实现div拖拽互换位置
2020/07/29 Javascript
JavaScript数据结构与算法之二叉树遍历算法详解【先序、中序、后序】
2019/02/21 Javascript
vue+高德地图写地图选址组件的方法
2019/05/18 Javascript
Node.js+ELK日志规范的实现
2019/05/23 Javascript
Vue组件间通信 Vuex的用法解析
2019/08/05 Javascript
python下paramiko模块实现ssh连接登录Linux服务器
2015/06/03 Python
Python优化技巧之利用ctypes提高执行速度
2016/09/11 Python
详解Python3 基本数据类型
2019/04/19 Python
调试Django时打印SQL语句的日志代码实例
2019/09/12 Python
简单了解Pandas缺失值处理方法
2019/11/16 Python
TFRecord格式存储数据与队列读取实例
2020/01/21 Python
python对一个数向上取整的实例方法
2020/06/18 Python
html5+css3之动画在webapp中的应用
2014/11/21 HTML / CSS
HTML5如何使用SVG的方法示例
2019/01/11 HTML / CSS
英国著名书店:Foyles
2018/12/01 全球购物
奥地利时尚、美容、玩具和家居之家:Kastner & Öhler
2020/04/26 全球购物
美国椅子和沙发制造商:La-Z-Boy
2020/10/25 全球购物
PHP如何调用MYSQL存储过程
2014/05/30 面试题
华美博弈C/VC工程师笔试试题
2012/07/16 面试题
大专生自荐信
2013/10/04 职场文书
学校先进集体事迹材料
2014/05/31 职场文书
书法兴趣小组活动总结
2014/07/07 职场文书
小学庆六一活动总结
2014/08/28 职场文书
2016年度师德标兵先进事迹材料
2016/02/26 职场文书
使用CSS连接数据库的方式
2022/02/28 HTML / CSS