关于pytorch中全连接神经网络搭建两种模式详解


Posted in Python onJanuary 14, 2020

pytorch搭建神经网络是很简单明了的,这里介绍两种自己常用的搭建模式:

import torch
import torch.nn as nn

first:

class NN(nn.Module):
 def __init__(self):
  super(NN,self).__init__()
  self.model=nn.Sequential(
   nn.Linear(30,40),
   nn.ReLU(),
   nn.Linear(40,60),
   nn.Tanh(),
   nn.Linear(60,10),
   nn.Softmax()
  )
  self.model[0].weight.data.uniform_(-3e-3, 3e-3)
  self.model[0].bias.data.uniform(-1,1)
 def forward(self,states):
  return self.model(states)

这一种是将整个网络写在一个Sequential中,网络参数设置可以在网络搭建好后单独设置:self.model[0].weight.data.uniform_(-3e-3,3e-3),这是设置第一个linear的权重是(-3e-3,3e-3)之间的均匀分布,bias是-1至1之间的均匀分布。

second:

class NN1(nn.Module):
 def __init__(self):
  super(NN1,self).__init__()
  self.Linear1=nn.Linear(30,40)
  self.Linear1.weight.data.fill_(-0.1)
  #self.Linear1.weight.data.uniform_(-3e-3,3e-3)
  self.Linear1.bias.data.fill_(-0.1)
  self.layer1=nn.Sequential(self.Linear1,nn.ReLU())

  self.Linear2=nn.Linear(40,60)
  self.layer2=nn.Sequential(self.Linear2,nn.Tanh())

  self.Linear3=nn.Linear(60,10)
  self.layer3=nn.Sequential(self.Linear3,nn.Softmax())


 def forward(self,states):
  return self.model(states)

网络参数的设置可以在定义完线性层之后直接设置如这里对于第一个线性层是这样设置:self.Linear1.weight.data.fill_(-0.1),self.Linear1.bias.data.fill_(-0.1)。

你可以看一下这样定义完的参数的效果:

Net=NN()
print("0:",Net.model[0])
print("weight:",type(Net.model[0].weight))
print("weight:",type(Net.model[0].weight.data))
print("bias",Net.model[0].bias.data)
print('1:',Net.model[1])
#print("weight:",Net.model[1].weight.data)
print('2:',Net.model[2])
print('3:',Net.model[3])
#print(Net.model[-1])

Net1=NN1()
print(Net1.Linear1.weight.data)

输出:

0: Linear (30 -> 40)
weight: <class 'torch.nn.parameter.Parameter'>
weight: <class 'torch.FloatTensor'>
bias 
-0.6287
-0.6573
-0.0452
 0.9594
-0.7477
 0.1363
-0.1594
-0.1586
 0.0360
 0.7375
 0.2501
-0.1371
 0.8359
-0.9684
-0.3886
 0.7200
-0.3906
 0.4911
 0.8081
-0.5449
 0.9872
 0.2004
 0.0969
-0.9712
 0.0873
 0.4562
-0.4857
-0.6013
 0.1651
 0.3315
-0.7033
-0.7440
 0.6487
 0.9802
-0.5977
 0.3245
 0.7563
 0.5596
 0.2303
-0.3836
[torch.FloatTensor of size 40]

1: ReLU ()
2: Linear (40 -> 60)
3: Tanh ()

-0.1000 -0.1000 -0.1000 ... -0.1000 -0.1000 -0.1000
-0.1000 -0.1000 -0.1000 ... -0.1000 -0.1000 -0.1000
-0.1000 -0.1000 -0.1000 ... -0.1000 -0.1000 -0.1000
   ...    ⋱    ...   
-0.1000 -0.1000 -0.1000 ... -0.1000 -0.1000 -0.1000
-0.1000 -0.1000 -0.1000 ... -0.1000 -0.1000 -0.1000
-0.1000 -0.1000 -0.1000 ... -0.1000 -0.1000 -0.1000
[torch.FloatTensor of size 40x30]


Process finished with exit code 0

这里要注意self.Linear1.weight的类型是网络的parameter。而self.Linear1.weight.data是FloatTensor。

以上这篇关于pytorch中全连接神经网络搭建两种模式详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python进阶教程之异常处理
Aug 30 Python
如何使用七牛Python SDK写一个同步脚本及使用教程
Aug 23 Python
使用Python编写简单的画图板程序的示例教程
Dec 08 Python
Python 爬虫爬取指定博客的所有文章
Feb 17 Python
Python编码爬坑指南(必看)
Jun 10 Python
python 生成器协程运算实例
Sep 04 Python
Python日期时间对象转换为字符串的实例
Jun 22 Python
Python GUI库PyQt5样式QSS子控件介绍
Feb 25 Python
Django media static外部访问Django中的图片设置教程
Apr 07 Python
python里的单引号和双引号的有什么作用
Jun 17 Python
用Python实现一个打字速度测试工具来测试你的手速
May 28 Python
python状态机transitions库详解
Jun 02 Python
使用Pytorch来拟合函数方式
Jan 14 #Python
pytorch 模拟关系拟合——回归实例
Jan 14 #Python
PyTorch实现AlexNet示例
Jan 14 #Python
Pytorch 实现focal_loss 多类别和二分类示例
Jan 14 #Python
Python实现钉钉订阅消息功能
Jan 14 #Python
Python Tensor FLow简单使用方法实例详解
Jan 14 #Python
Python利用全连接神经网络求解MNIST问题详解
Jan 14 #Python
You might like
如何使用Strace调试工具
2013/06/03 PHP
php简单获取文件扩展名的方法
2015/03/24 PHP
浅析PHP echo 和 print 语句
2020/06/30 PHP
JavaScript 笔记二 Array和Date对象方法
2010/05/22 Javascript
在jQuery 1.5中使用deferred对象的代码(翻译)
2011/03/10 Javascript
jquery键盘事件使用介绍
2011/11/01 Javascript
JavaScript中圆括号()和方括号[]的特殊用法疑问解答
2013/08/06 Javascript
JQuery触发事件例如click
2013/09/11 Javascript
JQuery插件ajaxfileupload.js异步上传文件实例
2015/05/19 Javascript
js实现圆盘记速表
2015/08/03 Javascript
jQuery实现带延迟效果的滑动菜单代码
2015/09/02 Javascript
jQuery UI库中dialog对话框功能使用全解析
2016/04/23 Javascript
Angular2  NgModule 模块详解
2016/10/19 Javascript
基于casperjs和resemble.js实现一个像素对比服务详解
2018/01/10 Javascript
angular5 子组件监听父组件传入值的变化方法
2018/09/30 Javascript
小程序测试后台服务的方法(ngrok)
2019/03/08 Javascript
vue基础之v-bind属性、class和style用法分析
2019/03/11 Javascript
JS基础之逻辑结构与循环操作示例
2020/01/19 Javascript
基于Electron实现桌面应用开发代码实例
2020/07/07 Javascript
使用element-ui +Vue 解决 table 里包含表单验证的问题
2020/07/17 Javascript
编写Python脚本抓取网络小说来制作自己的阅读器
2015/08/20 Python
Python文件与文件夹常见基本操作总结
2016/09/19 Python
django站点管理详解
2017/12/12 Python
django模板加载静态文件的方法步骤
2019/03/01 Python
超简单使用Python换脸实例
2019/03/27 Python
基于python-opencv3的图像显示和保存操作
2019/06/27 Python
python 命令行传入参数实现解析
2019/08/30 Python
Flask和pyecharts实现动态数据可视化
2020/02/26 Python
简单介绍一下pyinstaller打包以及安全性的实现
2020/06/02 Python
Python实现区域填充的示例代码
2021/02/03 Python
俄罗斯运动鞋商店:Sneakerhead
2018/05/10 全球购物
十八届三中全会报告学习材料
2014/02/17 职场文书
优秀电子工程系毕业生求职信
2014/05/24 职场文书
责任书格式
2015/01/29 职场文书
法院个人总结
2015/03/03 职场文书
2015年教师节贺卡寄语
2015/03/24 职场文书