关于pytorch中全连接神经网络搭建两种模式详解


Posted in Python onJanuary 14, 2020

pytorch搭建神经网络是很简单明了的,这里介绍两种自己常用的搭建模式:

import torch
import torch.nn as nn

first:

class NN(nn.Module):
 def __init__(self):
  super(NN,self).__init__()
  self.model=nn.Sequential(
   nn.Linear(30,40),
   nn.ReLU(),
   nn.Linear(40,60),
   nn.Tanh(),
   nn.Linear(60,10),
   nn.Softmax()
  )
  self.model[0].weight.data.uniform_(-3e-3, 3e-3)
  self.model[0].bias.data.uniform(-1,1)
 def forward(self,states):
  return self.model(states)

这一种是将整个网络写在一个Sequential中,网络参数设置可以在网络搭建好后单独设置:self.model[0].weight.data.uniform_(-3e-3,3e-3),这是设置第一个linear的权重是(-3e-3,3e-3)之间的均匀分布,bias是-1至1之间的均匀分布。

second:

class NN1(nn.Module):
 def __init__(self):
  super(NN1,self).__init__()
  self.Linear1=nn.Linear(30,40)
  self.Linear1.weight.data.fill_(-0.1)
  #self.Linear1.weight.data.uniform_(-3e-3,3e-3)
  self.Linear1.bias.data.fill_(-0.1)
  self.layer1=nn.Sequential(self.Linear1,nn.ReLU())

  self.Linear2=nn.Linear(40,60)
  self.layer2=nn.Sequential(self.Linear2,nn.Tanh())

  self.Linear3=nn.Linear(60,10)
  self.layer3=nn.Sequential(self.Linear3,nn.Softmax())


 def forward(self,states):
  return self.model(states)

网络参数的设置可以在定义完线性层之后直接设置如这里对于第一个线性层是这样设置:self.Linear1.weight.data.fill_(-0.1),self.Linear1.bias.data.fill_(-0.1)。

你可以看一下这样定义完的参数的效果:

Net=NN()
print("0:",Net.model[0])
print("weight:",type(Net.model[0].weight))
print("weight:",type(Net.model[0].weight.data))
print("bias",Net.model[0].bias.data)
print('1:',Net.model[1])
#print("weight:",Net.model[1].weight.data)
print('2:',Net.model[2])
print('3:',Net.model[3])
#print(Net.model[-1])

Net1=NN1()
print(Net1.Linear1.weight.data)

输出:

0: Linear (30 -> 40)
weight: <class 'torch.nn.parameter.Parameter'>
weight: <class 'torch.FloatTensor'>
bias 
-0.6287
-0.6573
-0.0452
 0.9594
-0.7477
 0.1363
-0.1594
-0.1586
 0.0360
 0.7375
 0.2501
-0.1371
 0.8359
-0.9684
-0.3886
 0.7200
-0.3906
 0.4911
 0.8081
-0.5449
 0.9872
 0.2004
 0.0969
-0.9712
 0.0873
 0.4562
-0.4857
-0.6013
 0.1651
 0.3315
-0.7033
-0.7440
 0.6487
 0.9802
-0.5977
 0.3245
 0.7563
 0.5596
 0.2303
-0.3836
[torch.FloatTensor of size 40]

1: ReLU ()
2: Linear (40 -> 60)
3: Tanh ()

-0.1000 -0.1000 -0.1000 ... -0.1000 -0.1000 -0.1000
-0.1000 -0.1000 -0.1000 ... -0.1000 -0.1000 -0.1000
-0.1000 -0.1000 -0.1000 ... -0.1000 -0.1000 -0.1000
   ...    ⋱    ...   
-0.1000 -0.1000 -0.1000 ... -0.1000 -0.1000 -0.1000
-0.1000 -0.1000 -0.1000 ... -0.1000 -0.1000 -0.1000
-0.1000 -0.1000 -0.1000 ... -0.1000 -0.1000 -0.1000
[torch.FloatTensor of size 40x30]


Process finished with exit code 0

这里要注意self.Linear1.weight的类型是网络的parameter。而self.Linear1.weight.data是FloatTensor。

以上这篇关于pytorch中全连接神经网络搭建两种模式详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之for循环语句
Oct 02 Python
解读Python编程中的命名空间与作用域
Oct 16 Python
Python实现简单http服务器
Apr 12 Python
详解python3中的真值测试
Aug 13 Python
python与C、C++混编的四种方式(小结)
Jul 15 Python
关于Numpy数据类型对象(dtype)使用详解
Nov 27 Python
Python自动采集微信联系人的实现示例
Feb 28 Python
Python线程协作threading.Condition实现过程解析
Mar 12 Python
利用keras使用神经网络预测销量操作
Jul 07 Python
Python利用matplotlib绘制折线图的新手教程
Nov 05 Python
详解MindSpore自定义模型损失函数
Jun 30 Python
Python可视化神器pyecharts之绘制箱形图
Jul 07 Python
使用Pytorch来拟合函数方式
Jan 14 #Python
pytorch 模拟关系拟合——回归实例
Jan 14 #Python
PyTorch实现AlexNet示例
Jan 14 #Python
Pytorch 实现focal_loss 多类别和二分类示例
Jan 14 #Python
Python实现钉钉订阅消息功能
Jan 14 #Python
Python Tensor FLow简单使用方法实例详解
Jan 14 #Python
Python利用全连接神经网络求解MNIST问题详解
Jan 14 #Python
You might like
PHP新手入门学习方法
2011/05/08 PHP
解析php多线程下载远程多个文件
2013/06/25 PHP
PHP封装返回Ajax字符串和JSON数组的方法
2017/02/17 PHP
基于win2003虚拟机中apache服务器的访问
2017/08/01 PHP
PHP应用跨时区功能的实现方法
2019/03/21 PHP
php+websocket 实现的聊天室功能详解
2020/05/27 PHP
js 替换
2008/02/19 Javascript
jQuery 使用手册(七)
2009/09/23 Javascript
Javascript学习笔记6 prototype的提出
2010/01/11 Javascript
Javascript实现DIV滚动自动滚动到底部的代码
2012/03/01 Javascript
借助JavaScript脚本判断浏览器Flash Player信息的方法
2014/07/09 Javascript
JavaScript中的ArrayBuffer详细介绍
2014/12/08 Javascript
Jquery解析json字符串及json数组的方法
2015/05/29 Javascript
浅谈javascript的call()、apply()、bind()的用法
2016/02/21 Javascript
jQuery使用$.each遍历json数组的简单实现方法
2016/04/18 Javascript
jQuery的事件预绑定
2016/12/05 Javascript
vue双向绑定及观察者模式详解
2019/03/19 Javascript
详解Angular Karma测试的持续集成实践
2019/11/15 Javascript
vue 使用外部JS与调用原生API操作示例
2019/12/02 Javascript
在Vue中实现随hash改变响应菜单高亮
2020/03/09 Javascript
JavaScript实现手机号码 3-4-4格式并控制新增和删除时光标的位置
2020/06/02 Javascript
[02:32]DOTA2亚洲邀请赛 C9战队出场宣传片
2015/02/07 DOTA
[41:52]DOTA2-DPC中国联赛 正赛 CDEC vs Dynasty BO3 第二场 2月22日
2021/03/11 DOTA
Python中生成Epoch的方法
2017/04/26 Python
Python实现使用卷积提取图片轮廓功能示例
2018/05/12 Python
python实现Dijkstra静态寻路算法
2019/01/17 Python
应用OpenCV和Python进行SIFT算法的实现详解
2019/08/21 Python
解决pandas展示数据输出时列名不能对齐的问题
2019/11/18 Python
python pptx复制指定页的ppt教程
2020/02/14 Python
详解Python中string模块除去Str还剩下什么
2020/11/30 Python
一套软件测试笔试题
2014/07/25 面试题
客服实习的个人自我鉴定
2013/10/20 职场文书
小松树教学反思
2014/02/11 职场文书
80后职场人的职业生涯规划
2014/03/08 职场文书
数学教师个人总结
2015/02/06 职场文书
win10系统计算机图标怎么调出来?win10调出计算机图标的方法
2022/08/14 数码科技