关于pytorch中全连接神经网络搭建两种模式详解


Posted in Python onJanuary 14, 2020

pytorch搭建神经网络是很简单明了的,这里介绍两种自己常用的搭建模式:

import torch
import torch.nn as nn

first:

class NN(nn.Module):
 def __init__(self):
  super(NN,self).__init__()
  self.model=nn.Sequential(
   nn.Linear(30,40),
   nn.ReLU(),
   nn.Linear(40,60),
   nn.Tanh(),
   nn.Linear(60,10),
   nn.Softmax()
  )
  self.model[0].weight.data.uniform_(-3e-3, 3e-3)
  self.model[0].bias.data.uniform(-1,1)
 def forward(self,states):
  return self.model(states)

这一种是将整个网络写在一个Sequential中,网络参数设置可以在网络搭建好后单独设置:self.model[0].weight.data.uniform_(-3e-3,3e-3),这是设置第一个linear的权重是(-3e-3,3e-3)之间的均匀分布,bias是-1至1之间的均匀分布。

second:

class NN1(nn.Module):
 def __init__(self):
  super(NN1,self).__init__()
  self.Linear1=nn.Linear(30,40)
  self.Linear1.weight.data.fill_(-0.1)
  #self.Linear1.weight.data.uniform_(-3e-3,3e-3)
  self.Linear1.bias.data.fill_(-0.1)
  self.layer1=nn.Sequential(self.Linear1,nn.ReLU())

  self.Linear2=nn.Linear(40,60)
  self.layer2=nn.Sequential(self.Linear2,nn.Tanh())

  self.Linear3=nn.Linear(60,10)
  self.layer3=nn.Sequential(self.Linear3,nn.Softmax())


 def forward(self,states):
  return self.model(states)

网络参数的设置可以在定义完线性层之后直接设置如这里对于第一个线性层是这样设置:self.Linear1.weight.data.fill_(-0.1),self.Linear1.bias.data.fill_(-0.1)。

你可以看一下这样定义完的参数的效果:

Net=NN()
print("0:",Net.model[0])
print("weight:",type(Net.model[0].weight))
print("weight:",type(Net.model[0].weight.data))
print("bias",Net.model[0].bias.data)
print('1:',Net.model[1])
#print("weight:",Net.model[1].weight.data)
print('2:',Net.model[2])
print('3:',Net.model[3])
#print(Net.model[-1])

Net1=NN1()
print(Net1.Linear1.weight.data)

输出:

0: Linear (30 -> 40)
weight: <class 'torch.nn.parameter.Parameter'>
weight: <class 'torch.FloatTensor'>
bias 
-0.6287
-0.6573
-0.0452
 0.9594
-0.7477
 0.1363
-0.1594
-0.1586
 0.0360
 0.7375
 0.2501
-0.1371
 0.8359
-0.9684
-0.3886
 0.7200
-0.3906
 0.4911
 0.8081
-0.5449
 0.9872
 0.2004
 0.0969
-0.9712
 0.0873
 0.4562
-0.4857
-0.6013
 0.1651
 0.3315
-0.7033
-0.7440
 0.6487
 0.9802
-0.5977
 0.3245
 0.7563
 0.5596
 0.2303
-0.3836
[torch.FloatTensor of size 40]

1: ReLU ()
2: Linear (40 -> 60)
3: Tanh ()

-0.1000 -0.1000 -0.1000 ... -0.1000 -0.1000 -0.1000
-0.1000 -0.1000 -0.1000 ... -0.1000 -0.1000 -0.1000
-0.1000 -0.1000 -0.1000 ... -0.1000 -0.1000 -0.1000
   ...    ⋱    ...   
-0.1000 -0.1000 -0.1000 ... -0.1000 -0.1000 -0.1000
-0.1000 -0.1000 -0.1000 ... -0.1000 -0.1000 -0.1000
-0.1000 -0.1000 -0.1000 ... -0.1000 -0.1000 -0.1000
[torch.FloatTensor of size 40x30]


Process finished with exit code 0

这里要注意self.Linear1.weight的类型是网络的parameter。而self.Linear1.weight.data是FloatTensor。

以上这篇关于pytorch中全连接神经网络搭建两种模式详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python构造函数及解构函数介绍
Feb 26 Python
详解Python编程中基本的数学计算使用
Feb 04 Python
浅谈Python爬取网页的编码处理
Nov 04 Python
Python通过Django实现用户注册和邮箱验证功能代码
Dec 11 Python
python实现外卖信息管理系统
Jan 11 Python
pandas数据框,统计某列数据对应的个数方法
Apr 11 Python
PyQt5每天必学之滑块控件QSlider
Apr 20 Python
python2.7实现邮件发送功能
Dec 12 Python
Python使用pandas对数据进行差分运算的方法
Dec 22 Python
Python 多个图同时在不同窗口显示的实现方法
Jul 07 Python
Python爬虫实现自动登录、签到功能的代码
Aug 20 Python
Python数据分析之绘图和可视化详解
Jun 02 Python
使用Pytorch来拟合函数方式
Jan 14 #Python
pytorch 模拟关系拟合——回归实例
Jan 14 #Python
PyTorch实现AlexNet示例
Jan 14 #Python
Pytorch 实现focal_loss 多类别和二分类示例
Jan 14 #Python
Python实现钉钉订阅消息功能
Jan 14 #Python
Python Tensor FLow简单使用方法实例详解
Jan 14 #Python
Python利用全连接神经网络求解MNIST问题详解
Jan 14 #Python
You might like
PHP strip_tags保留多个HTML标签的方法
2016/05/22 PHP
PHP实现针对日期,月数,天数,周数,小时,分,秒等的加减运算示例【基于strtotime】
2017/04/19 PHP
Yii框架的布局文件实例分析
2019/09/04 PHP
jQuery1.6 正式版发布并提供下载
2011/05/05 Javascript
JS调试必备的5个debug技巧
2014/03/07 Javascript
JS实现获取键盘按下的按键并显示在页面上的方法
2015/11/04 Javascript
js省市联动效果完整实例代码
2015/12/09 Javascript
Angular学习笔记之angular的$filter服务浅析
2016/11/12 Javascript
Vue.js手风琴菜单组件开发实例
2017/05/16 Javascript
详谈commonjs模块与es6模块的区别
2017/10/18 Javascript
js实现复制功能(多种方法集合)
2018/01/06 Javascript
setTimeout与setInterval的区别浅析
2019/03/23 Javascript
详解服务端预渲染之Nuxt(介绍篇)
2019/04/07 Javascript
vue点击页面空白处实现保存功能
2019/11/06 Javascript
[54:15]DOTA2-DPC中国联赛 正赛 DLG vs Dragon BO3 第二场2月1日
2021/03/11 DOTA
python读写csv文件并增加行列的实例代码
2019/08/01 Python
Python中six模块基础用法
2019/12/08 Python
Python常用模块sys,os,time,random功能与用法实例分析
2020/01/07 Python
tensorflow 实现数据类型转换
2020/02/17 Python
Python开发之身份证验证库id_validator验证身份证号合法性及根据身份证号返回住址年龄等信息
2020/03/20 Python
在pycharm中使用pipenv创建虚拟环境和安装django的详细教程
2020/11/30 Python
Clearly新西兰:购买眼镜、太阳镜和隐形眼镜
2018/04/26 全球购物
浙大毕业生自荐信
2014/01/26 职场文书
宝宝满月酒主持词和仪式流程
2014/03/27 职场文书
《大禹治水》教学反思
2014/04/27 职场文书
高一新生军训方案
2014/05/12 职场文书
房屋买卖委托书格式范本格式
2014/10/13 职场文书
中学生检讨书范文
2014/11/03 职场文书
2014年个人工作总结报告
2014/11/27 职场文书
2014保险公司内勤工作总结
2014/12/16 职场文书
民主评议党员个人总结
2015/02/13 职场文书
复试通知单模板
2015/04/24 职场文书
毕业生爱心捐书倡议书
2015/04/27 职场文书
王亚平太空授课观后感
2015/06/12 职场文书
2015年安全生产月工作总结
2015/07/27 职场文书
SQL实战演练之网上商城数据库商品类别数据操作
2021/10/24 MySQL