用pytorch的nn.Module构造简单全链接层实例


Posted in Python onJanuary 14, 2020

python版本3.7,用的是虚拟环境安装的pytorch,这样随便折腾,不怕影响其他的python框架

1、先定义一个类Linear,继承nn.Module

import torch as t
from torch import nn
from torch.autograd import Variable as V
 
class Linear(nn.Module):

  '''因为Variable自动求导,所以不需要实现backward()'''
  def __init__(self, in_features, out_features):
    super().__init__()
    self.w = nn.Parameter( t.randn( in_features, out_features ) ) #权重w 注意Parameter是一个特殊的Variable
    self.b = nn.Parameter( t.randn( out_features ) )   #偏值b
  
  def forward( self, x ): #参数 x 是一个Variable对象
    x = x.mm( self.w )
    return x + self.b.expand_as( x ) #让b的形状符合 输出的x的形状

2、验证一下

layer = Linear( 4,3 )
input = V ( t.randn( 2 ,4 ) )#包装一个Variable作为输入
out = layer( input )
out

#成功运行,结果如下:

tensor([[-2.1934, 2.5590, 4.0233], [ 1.1098, -3.8182, 0.1848]], grad_fn=<AddBackward0>)

下面利用Linear构造一个多层网络

class Perceptron( nn.Module ):
  def __init__( self,in_features, hidden_features, out_features ):
    super().__init__()
    self.layer1 = Linear( in_features , hidden_features )
    self.layer2 = Linear( hidden_features, out_features )
  def forward ( self ,x ):
    x = self.layer1( x )
    x = t.sigmoid( x ) #用sigmoid()激活函数
    return self.layer2( x )

测试一下

perceptron = Perceptron ( 5,3 ,1 )
 
for name,param in perceptron.named_parameters(): 
  print( name, param.size() )

输出如预期:

layer1.w torch.Size([5, 3])
layer1.b torch.Size([3])
layer2.w torch.Size([3, 1])
layer2.b torch.Size([1])

以上这篇用pytorch的nn.Module构造简单全链接层实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用requests库制作Python爬虫
Mar 25 Python
Python Selenium Cookie 绕过验证码实现登录示例代码
Apr 10 Python
对python中矩阵相加函数sum()的使用详解
Jan 28 Python
教你如何编写、保存与运行Python程序的方法
Jul 12 Python
django Admin文档生成器使用详解
Jul 22 Python
如何在Django配置文件里配置session链接
Aug 06 Python
python并发编程多进程 互斥锁原理解析
Aug 20 Python
python matplotlib 画dataframe的时间序列图实例
Nov 20 Python
Python基本类型的连接组合和互相转换方式(13种)
Dec 16 Python
python实现音乐播放和下载小程序功能
Apr 26 Python
Python检测端口IP字符串是否合法
Jun 05 Python
Python必须了解的35个关键词
Jul 16 Python
pytorch三层全连接层实现手写字母识别方式
Jan 14 #Python
Python实现bilibili时间长度查询的示例代码
Jan 14 #Python
基于python监控程序是否关闭
Jan 14 #Python
关于pytorch中全连接神经网络搭建两种模式详解
Jan 14 #Python
使用Pytorch来拟合函数方式
Jan 14 #Python
pytorch 模拟关系拟合——回归实例
Jan 14 #Python
PyTorch实现AlexNet示例
Jan 14 #Python
You might like
在线短消息收发的程序,不用数据库
2006/10/09 PHP
PHP性能优化准备篇图解PEAR安装
2011/12/05 PHP
PHP易混淆函数的区别及用法汇总
2014/11/22 PHP
Symfony页面的基本创建实例详解
2015/01/26 PHP
PHP 接入微信扫码支付总结(总结篇)
2016/11/03 PHP
laravel 中如何使用ajax和vue总结
2017/08/16 PHP
PHP面向对象之里氏替换原则简单示例
2018/04/08 PHP
php设计模式之原型模式分析【星际争霸游戏案例】
2020/03/23 PHP
一个用javascript写的select支持上下键、首字母筛选以及回车取值的功能
2009/09/09 Javascript
jQuery实现鼠标移到元素上动态提示消息框效果
2013/10/20 Javascript
JS 数字转换研究总结
2013/12/26 Javascript
简单实现AngularJS轮播图效果
2020/04/10 Javascript
用nodejs实现json和jsonp服务的方法
2017/08/25 NodeJs
浅谈KOA2 Restful方式路由初探
2019/03/14 Javascript
云服务器部署Node.js项目的方法步骤(小白系列)
2020/03/23 Javascript
[02:44]DOTA2英雄基础教程 钢背兽
2013/12/19 DOTA
python实现的jpg格式图片修复代码
2015/04/21 Python
Python实现的三层BP神经网络算法示例
2018/02/07 Python
让Django支持Sql Server作后端数据库的方法
2018/05/29 Python
PyCharm鼠标右键不显示Run unittest的解决方法
2018/11/30 Python
python操作日志的封装方法(两种方法)
2019/05/23 Python
解决python多行注释引发缩进错误的问题
2019/08/23 Python
python3 dict ndarray 存成json,并保留原数据精度的实例
2019/12/06 Python
Python第三方包之DingDingBot钉钉机器人
2020/04/09 Python
Python偏函数实现原理及应用
2020/11/20 Python
一文读懂python Scrapy爬虫框架
2021/02/24 Python
萌新HTML5 入门指南(二)
2020/11/09 HTML / CSS
中国最大的潮流商品购物网站:YOHO!BUY有货
2017/01/07 全球购物
Nike荷兰官方网站:Nike.com (NL)
2018/04/19 全球购物
波兰购物网站:MALL.PL
2019/05/01 全球购物
一套SQL笔试题
2016/08/14 面试题
酒店司机岗位职责
2013/12/14 职场文书
《花的勇气》教后反思
2014/02/12 职场文书
室内设计专业自荐信
2014/05/31 职场文书
2015年工程师工作总结
2015/04/30 职场文书
幼儿园端午节活动总结
2015/05/05 职场文书