用pytorch的nn.Module构造简单全链接层实例


Posted in Python onJanuary 14, 2020

python版本3.7,用的是虚拟环境安装的pytorch,这样随便折腾,不怕影响其他的python框架

1、先定义一个类Linear,继承nn.Module

import torch as t
from torch import nn
from torch.autograd import Variable as V
 
class Linear(nn.Module):

  '''因为Variable自动求导,所以不需要实现backward()'''
  def __init__(self, in_features, out_features):
    super().__init__()
    self.w = nn.Parameter( t.randn( in_features, out_features ) ) #权重w 注意Parameter是一个特殊的Variable
    self.b = nn.Parameter( t.randn( out_features ) )   #偏值b
  
  def forward( self, x ): #参数 x 是一个Variable对象
    x = x.mm( self.w )
    return x + self.b.expand_as( x ) #让b的形状符合 输出的x的形状

2、验证一下

layer = Linear( 4,3 )
input = V ( t.randn( 2 ,4 ) )#包装一个Variable作为输入
out = layer( input )
out

#成功运行,结果如下:

tensor([[-2.1934, 2.5590, 4.0233], [ 1.1098, -3.8182, 0.1848]], grad_fn=<AddBackward0>)

下面利用Linear构造一个多层网络

class Perceptron( nn.Module ):
  def __init__( self,in_features, hidden_features, out_features ):
    super().__init__()
    self.layer1 = Linear( in_features , hidden_features )
    self.layer2 = Linear( hidden_features, out_features )
  def forward ( self ,x ):
    x = self.layer1( x )
    x = t.sigmoid( x ) #用sigmoid()激活函数
    return self.layer2( x )

测试一下

perceptron = Perceptron ( 5,3 ,1 )
 
for name,param in perceptron.named_parameters(): 
  print( name, param.size() )

输出如预期:

layer1.w torch.Size([5, 3])
layer1.b torch.Size([3])
layer2.w torch.Size([3, 1])
layer2.b torch.Size([1])

以上这篇用pytorch的nn.Module构造简单全链接层实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python环境下搭建属于自己的pip源的教程
May 05 Python
关于numpy中np.nonzero()函数用法的详解
Feb 07 Python
简单了解OpenCV是个什么东西
Nov 10 Python
Django之Mode的外键自关联和引用未定义的Model方法
Dec 15 Python
浅谈python3.x pool.map()方法的实质
Jan 16 Python
python中update的基本使用方法详解
Jul 17 Python
python openvc 裁剪、剪切图片 提取图片的行和列
Sep 19 Python
Python 类,property属性(简化属性的操作),@property,property()用法示例
Oct 12 Python
Pytest参数化parametrize使用代码实例
Feb 22 Python
Python爬虫入门有哪些基础知识点
Jun 02 Python
pytorch中的numel函数用法说明
May 13 Python
Python四款GUI图形界面库介绍
Jun 05 Python
pytorch三层全连接层实现手写字母识别方式
Jan 14 #Python
Python实现bilibili时间长度查询的示例代码
Jan 14 #Python
基于python监控程序是否关闭
Jan 14 #Python
关于pytorch中全连接神经网络搭建两种模式详解
Jan 14 #Python
使用Pytorch来拟合函数方式
Jan 14 #Python
pytorch 模拟关系拟合——回归实例
Jan 14 #Python
PyTorch实现AlexNet示例
Jan 14 #Python
You might like
phpmyadmin中禁止外网使用的方法
2014/11/04 PHP
PHPExcel读取EXCEL中的图片并保存到本地的方法
2015/02/14 PHP
php实现头像上传预览功能
2017/04/27 PHP
PHP基于自定义函数生成笛卡尔积的方法示例
2017/09/30 PHP
php+redis消息队列实现抢购功能
2018/02/08 PHP
phpstorm 配置xdebug的示例代码
2019/03/31 PHP
php数组和链表的区别总结
2019/09/20 PHP
编写跨浏览器的javascript代码必备[js多浏览器兼容写法]
2008/10/29 Javascript
兼容ie、firefox的图片自动缩放的css跟js代码分享
2013/08/12 Javascript
js判断IE浏览器版本过低示例代码
2013/11/22 Javascript
javascript操作excel生成报表全攻略
2014/05/04 Javascript
Javascript 读取操作Sql中的Xml字段
2014/10/09 Javascript
IE中document.createElement的iframe无法设置属性name的解决方法
2015/09/14 Javascript
JSON对象转化为字符串详解
2017/08/11 Javascript
JS获取数组中出现次数最多及第二多元素的方法
2017/10/27 Javascript
javaScript实现鼠标在文字上悬浮时弹出悬浮层效果
2020/04/12 Javascript
如何使用VuePress搭建一个类型element ui文档
2019/02/14 Javascript
nodejs使用async模块同步执行的方法
2019/03/02 NodeJs
深入了解js原型模式
2019/05/30 Javascript
JavaScript代码实现微博批量取消关注功能
2021/02/05 Javascript
Python捕捉和模拟鼠标事件的方法
2015/06/03 Python
Django中使用locals()函数的技巧
2015/07/16 Python
Python选课系统开发程序
2016/09/02 Python
python爬取NUS-WIDE数据库图片
2016/10/05 Python
Python semaphore evevt生产者消费者模型原理解析
2020/03/18 Python
css3高级选择器使用方法
2013/12/02 HTML / CSS
购买限量版收藏品、珠宝和礼品:Bradford Exchange
2016/09/23 全球购物
Boutique 1美国:阿联酋奢侈时尚零售商
2017/10/16 全球购物
韩国美国时尚服装和美容在线全球市场:KOODING
2018/11/07 全球购物
英国珠宝和手表专家:Pleasance & Harper
2020/10/21 全球购物
人事代理委托书
2014/09/27 职场文书
出生证明范本
2015/06/15 职场文书
2015年汽车销售员工作总结
2015/07/24 职场文书
golang 如何用反射reflect操作结构体
2021/04/28 Golang
html+css实现分层金字塔的实例
2021/06/02 HTML / CSS
win10电脑双屏显示一个黑屏怎么办?win10电脑双屏显示一个黑屏解决方法
2022/07/15 数码科技