关于pytorch中全连接神经网络搭建两种模式详解


Posted in Python onJanuary 14, 2020

pytorch搭建神经网络是很简单明了的,这里介绍两种自己常用的搭建模式:

import torch
import torch.nn as nn

first:

class NN(nn.Module):
 def __init__(self):
  super(NN,self).__init__()
  self.model=nn.Sequential(
   nn.Linear(30,40),
   nn.ReLU(),
   nn.Linear(40,60),
   nn.Tanh(),
   nn.Linear(60,10),
   nn.Softmax()
  )
  self.model[0].weight.data.uniform_(-3e-3, 3e-3)
  self.model[0].bias.data.uniform(-1,1)
 def forward(self,states):
  return self.model(states)

这一种是将整个网络写在一个Sequential中,网络参数设置可以在网络搭建好后单独设置:self.model[0].weight.data.uniform_(-3e-3,3e-3),这是设置第一个linear的权重是(-3e-3,3e-3)之间的均匀分布,bias是-1至1之间的均匀分布。

second:

class NN1(nn.Module):
 def __init__(self):
  super(NN1,self).__init__()
  self.Linear1=nn.Linear(30,40)
  self.Linear1.weight.data.fill_(-0.1)
  #self.Linear1.weight.data.uniform_(-3e-3,3e-3)
  self.Linear1.bias.data.fill_(-0.1)
  self.layer1=nn.Sequential(self.Linear1,nn.ReLU())

  self.Linear2=nn.Linear(40,60)
  self.layer2=nn.Sequential(self.Linear2,nn.Tanh())

  self.Linear3=nn.Linear(60,10)
  self.layer3=nn.Sequential(self.Linear3,nn.Softmax())


 def forward(self,states):
  return self.model(states)

网络参数的设置可以在定义完线性层之后直接设置如这里对于第一个线性层是这样设置:self.Linear1.weight.data.fill_(-0.1),self.Linear1.bias.data.fill_(-0.1)。

你可以看一下这样定义完的参数的效果:

Net=NN()
print("0:",Net.model[0])
print("weight:",type(Net.model[0].weight))
print("weight:",type(Net.model[0].weight.data))
print("bias",Net.model[0].bias.data)
print('1:',Net.model[1])
#print("weight:",Net.model[1].weight.data)
print('2:',Net.model[2])
print('3:',Net.model[3])
#print(Net.model[-1])

Net1=NN1()
print(Net1.Linear1.weight.data)

输出:

0: Linear (30 -> 40)
weight: <class 'torch.nn.parameter.Parameter'>
weight: <class 'torch.FloatTensor'>
bias 
-0.6287
-0.6573
-0.0452
 0.9594
-0.7477
 0.1363
-0.1594
-0.1586
 0.0360
 0.7375
 0.2501
-0.1371
 0.8359
-0.9684
-0.3886
 0.7200
-0.3906
 0.4911
 0.8081
-0.5449
 0.9872
 0.2004
 0.0969
-0.9712
 0.0873
 0.4562
-0.4857
-0.6013
 0.1651
 0.3315
-0.7033
-0.7440
 0.6487
 0.9802
-0.5977
 0.3245
 0.7563
 0.5596
 0.2303
-0.3836
[torch.FloatTensor of size 40]

1: ReLU ()
2: Linear (40 -> 60)
3: Tanh ()

-0.1000 -0.1000 -0.1000 ... -0.1000 -0.1000 -0.1000
-0.1000 -0.1000 -0.1000 ... -0.1000 -0.1000 -0.1000
-0.1000 -0.1000 -0.1000 ... -0.1000 -0.1000 -0.1000
   ...    ⋱    ...   
-0.1000 -0.1000 -0.1000 ... -0.1000 -0.1000 -0.1000
-0.1000 -0.1000 -0.1000 ... -0.1000 -0.1000 -0.1000
-0.1000 -0.1000 -0.1000 ... -0.1000 -0.1000 -0.1000
[torch.FloatTensor of size 40x30]


Process finished with exit code 0

这里要注意self.Linear1.weight的类型是网络的parameter。而self.Linear1.weight.data是FloatTensor。

以上这篇关于pytorch中全连接神经网络搭建两种模式详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
分享15个最受欢迎的Python开源框架
Jul 13 Python
python简单实现旋转图片的方法
May 30 Python
Python聚类算法之凝聚层次聚类实例分析
Nov 20 Python
Python实现复杂对象转JSON的方法示例
Jun 22 Python
Python实现的本地文件搜索功能示例【测试可用】
May 30 Python
设置python3为默认python的方法
Oct 31 Python
微信公众号token验证失败解决方案
Jul 22 Python
PyCharm如何导入python项目的方法
Feb 06 Python
使用python实现名片管理系统
Jun 18 Python
Python容器类型公共方法总结
Aug 19 Python
使用Python Tkinter实现剪刀石头布小游戏功能
Oct 23 Python
解决pycharm下载库时出现Failed to install package的问题
Sep 04 Python
使用Pytorch来拟合函数方式
Jan 14 #Python
pytorch 模拟关系拟合——回归实例
Jan 14 #Python
PyTorch实现AlexNet示例
Jan 14 #Python
Pytorch 实现focal_loss 多类别和二分类示例
Jan 14 #Python
Python实现钉钉订阅消息功能
Jan 14 #Python
Python Tensor FLow简单使用方法实例详解
Jan 14 #Python
Python利用全连接神经网络求解MNIST问题详解
Jan 14 #Python
You might like
php dirname(__FILE__) 获取当前文件的绝对路径
2011/06/28 PHP
PHP扩展开发教程(总结)
2015/11/04 PHP
PHP+Ajax 检测网络是否正常实例详解
2016/12/16 PHP
jQuery中获取Radio元素值的方法
2013/07/02 Javascript
javascript中取前n天日期的两种方法分享
2014/01/26 Javascript
解析javascript中鼠标滚轮事件
2015/05/26 Javascript
js跨域请求数据的3种常用的方法
2015/12/01 Javascript
jQuery移动web开发中的页面初始化与加载事件
2015/12/03 Javascript
浅析jQuery 3.0中的Data
2016/06/14 Javascript
WebSocket+node.js创建即时通信的Web聊天服务器
2016/08/08 Javascript
JavaScript里 ==与===区别详解
2016/08/16 Javascript
Bootstrap轮播插件使用代码
2016/10/11 Javascript
Node.js中路径处理模块path详解
2016/11/14 Javascript
微信小程序实现皮肤功能(夜间模式)
2017/06/18 Javascript
微信小程序 scroll-view实现锚点滑动的示例
2017/12/06 Javascript
vue脚手架及vue-router基本使用
2018/04/09 Javascript
vue的全局变量和全局拦截请求器的示例代码
2018/09/13 Javascript
Vue数据绑定简析小结
2019/05/07 Javascript
JavaScript Image对象实现原理实例解析
2020/08/26 Javascript
微信小程序实现弹幕墙(祝福墙)
2020/11/18 Javascript
Python 开发Activex组件方法
2009/11/08 Python
python查询sqlite数据表的方法
2015/05/08 Python
Python基于csv模块实现读取与写入csv数据的方法
2018/01/18 Python
对python中的argv和argc使用详解
2018/12/15 Python
在VS2017中用C#调用python脚本的实现
2019/07/31 Python
python elasticsearch环境搭建详解
2019/09/02 Python
Yahoo的PHP面试题
2014/05/26 面试题
机械专业应届生求职信
2013/09/21 职场文书
创业计划书如何吸引他人眼球
2014/01/10 职场文书
护理职业生涯规划书
2014/01/24 职场文书
HR求职自荐信范文
2014/06/21 职场文书
汽车4S店销售经理岗位职责
2015/04/02 职场文书
2015年乡镇人大工作总结
2015/04/22 职场文书
初二英语教学反思
2016/02/15 职场文书
springboot 多数据源配置不生效遇到的坑及解决
2021/11/17 Java/Android
java实现自定义时钟并实现走时功能
2022/06/21 Java/Android