关于pytorch中全连接神经网络搭建两种模式详解


Posted in Python onJanuary 14, 2020

pytorch搭建神经网络是很简单明了的,这里介绍两种自己常用的搭建模式:

import torch
import torch.nn as nn

first:

class NN(nn.Module):
 def __init__(self):
  super(NN,self).__init__()
  self.model=nn.Sequential(
   nn.Linear(30,40),
   nn.ReLU(),
   nn.Linear(40,60),
   nn.Tanh(),
   nn.Linear(60,10),
   nn.Softmax()
  )
  self.model[0].weight.data.uniform_(-3e-3, 3e-3)
  self.model[0].bias.data.uniform(-1,1)
 def forward(self,states):
  return self.model(states)

这一种是将整个网络写在一个Sequential中,网络参数设置可以在网络搭建好后单独设置:self.model[0].weight.data.uniform_(-3e-3,3e-3),这是设置第一个linear的权重是(-3e-3,3e-3)之间的均匀分布,bias是-1至1之间的均匀分布。

second:

class NN1(nn.Module):
 def __init__(self):
  super(NN1,self).__init__()
  self.Linear1=nn.Linear(30,40)
  self.Linear1.weight.data.fill_(-0.1)
  #self.Linear1.weight.data.uniform_(-3e-3,3e-3)
  self.Linear1.bias.data.fill_(-0.1)
  self.layer1=nn.Sequential(self.Linear1,nn.ReLU())

  self.Linear2=nn.Linear(40,60)
  self.layer2=nn.Sequential(self.Linear2,nn.Tanh())

  self.Linear3=nn.Linear(60,10)
  self.layer3=nn.Sequential(self.Linear3,nn.Softmax())


 def forward(self,states):
  return self.model(states)

网络参数的设置可以在定义完线性层之后直接设置如这里对于第一个线性层是这样设置:self.Linear1.weight.data.fill_(-0.1),self.Linear1.bias.data.fill_(-0.1)。

你可以看一下这样定义完的参数的效果:

Net=NN()
print("0:",Net.model[0])
print("weight:",type(Net.model[0].weight))
print("weight:",type(Net.model[0].weight.data))
print("bias",Net.model[0].bias.data)
print('1:',Net.model[1])
#print("weight:",Net.model[1].weight.data)
print('2:',Net.model[2])
print('3:',Net.model[3])
#print(Net.model[-1])

Net1=NN1()
print(Net1.Linear1.weight.data)

输出:

0: Linear (30 -> 40)
weight: <class 'torch.nn.parameter.Parameter'>
weight: <class 'torch.FloatTensor'>
bias 
-0.6287
-0.6573
-0.0452
 0.9594
-0.7477
 0.1363
-0.1594
-0.1586
 0.0360
 0.7375
 0.2501
-0.1371
 0.8359
-0.9684
-0.3886
 0.7200
-0.3906
 0.4911
 0.8081
-0.5449
 0.9872
 0.2004
 0.0969
-0.9712
 0.0873
 0.4562
-0.4857
-0.6013
 0.1651
 0.3315
-0.7033
-0.7440
 0.6487
 0.9802
-0.5977
 0.3245
 0.7563
 0.5596
 0.2303
-0.3836
[torch.FloatTensor of size 40]

1: ReLU ()
2: Linear (40 -> 60)
3: Tanh ()

-0.1000 -0.1000 -0.1000 ... -0.1000 -0.1000 -0.1000
-0.1000 -0.1000 -0.1000 ... -0.1000 -0.1000 -0.1000
-0.1000 -0.1000 -0.1000 ... -0.1000 -0.1000 -0.1000
   ...    ⋱    ...   
-0.1000 -0.1000 -0.1000 ... -0.1000 -0.1000 -0.1000
-0.1000 -0.1000 -0.1000 ... -0.1000 -0.1000 -0.1000
-0.1000 -0.1000 -0.1000 ... -0.1000 -0.1000 -0.1000
[torch.FloatTensor of size 40x30]


Process finished with exit code 0

这里要注意self.Linear1.weight的类型是网络的parameter。而self.Linear1.weight.data是FloatTensor。

以上这篇关于pytorch中全连接神经网络搭建两种模式详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现系统状态监测和故障转移实例方法
Nov 18 Python
基于python 爬虫爬到含空格的url的处理方法
May 11 Python
Python应用库大全总结
May 30 Python
解决Python2.7中IDLE启动没有反应的问题
Nov 30 Python
Python模块的加载讲解
Jan 15 Python
python全栈要学什么 python全栈学习路线
Jun 28 Python
python getpass模块用法及实例详解
Oct 07 Python
tensorflow使用CNN分析mnist手写体数字数据集
Jun 17 Python
浅谈优化Django ORM中的性能问题
Jul 09 Python
python修改微信和支付宝步数的示例代码
Oct 12 Python
Python&Matlab实现樱花的绘制
Apr 07 Python
使用python将HTML转换为PDF pdfkit包(wkhtmltopdf) 的使用方法
Apr 21 Python
使用Pytorch来拟合函数方式
Jan 14 #Python
pytorch 模拟关系拟合——回归实例
Jan 14 #Python
PyTorch实现AlexNet示例
Jan 14 #Python
Pytorch 实现focal_loss 多类别和二分类示例
Jan 14 #Python
Python实现钉钉订阅消息功能
Jan 14 #Python
Python Tensor FLow简单使用方法实例详解
Jan 14 #Python
Python利用全连接神经网络求解MNIST问题详解
Jan 14 #Python
You might like
SONY SRF-22W(33W)的电路分析和维修案例
2021/03/02 无线电
PHP正则表达式 /i, /is, /s, /isU等介绍
2014/10/23 PHP
必须收藏的23个php实用代码片段
2016/02/02 PHP
PHP magento后台无法登录问题解决方法
2016/11/24 PHP
ASP.NET jQuery 实例5 (显示CheckBoxList成员选中的内容)
2012/01/13 Javascript
json数据的列循环示例
2013/09/06 Javascript
jQuery实现列表的全选功能
2015/03/18 Javascript
如何用jQuery实现ASP.NET GridView折叠伸展效果
2015/09/26 Javascript
jQuery Ajax页面局部加载方法汇总
2016/06/02 Javascript
微信小程序开发(一) 微信登录流程详解
2017/01/11 Javascript
使用openSpeDiv方法实现Ecshop登录弹窗框效果
2017/03/13 Javascript
Vue中父组件向子组件通信的方法
2017/07/11 Javascript
vue使用keep-alive实现数据缓存不刷新
2017/10/21 Javascript
微信小程序图片选择区域裁剪实现方法
2017/12/02 Javascript
JavaScript数组基于交换的排序示例【冒泡排序】
2018/07/21 Javascript
使用jQuery给Table动态增加行、清空table的方法
2018/09/05 jQuery
JS实现百度搜索框
2021/02/25 Javascript
python的staticmethod与classmethod实现实例代码
2018/02/11 Python
Python实现针对给定字符串寻找最长非重复子串的方法
2018/04/21 Python
Python实现微信消息防撤回功能的实例代码
2019/04/29 Python
Python实现使用request模块下载图片demo示例
2019/05/24 Python
Python单元测试与测试用例简析
2019/11/09 Python
Pytorch 使用opnecv读入图像由HWC转为BCHW格式方式
2020/06/02 Python
Python使用Selenium模拟浏览器自动操作功能
2020/09/08 Python
python基于selenium爬取斗鱼弹幕
2021/02/20 Python
英国景点门票网站:attractiontix
2019/08/27 全球购物
PyQt 如何创建自定义QWidget
2021/03/24 Python
2013英文求职信模板范文
2013/11/15 职场文书
大学生村官事迹材料
2014/01/21 职场文书
洗手间标语
2014/06/23 职场文书
期末考试复习计划
2015/01/19 职场文书
护士自荐信范文
2015/03/25 职场文书
山楂树之恋观后感
2015/06/11 职场文书
红灯733-1型14管5波段半导体收音机
2021/04/22 无线电
Spring Data JPA使用JPQL与原生SQL进行查询的操作
2021/06/15 Java/Android
Ajax异步刷新功能及简单案例
2021/11/20 Javascript