用pytorch的nn.Module构造简单全链接层实例


Posted in Python onJanuary 14, 2020

python版本3.7,用的是虚拟环境安装的pytorch,这样随便折腾,不怕影响其他的python框架

1、先定义一个类Linear,继承nn.Module

import torch as t
from torch import nn
from torch.autograd import Variable as V
 
class Linear(nn.Module):

  '''因为Variable自动求导,所以不需要实现backward()'''
  def __init__(self, in_features, out_features):
    super().__init__()
    self.w = nn.Parameter( t.randn( in_features, out_features ) ) #权重w 注意Parameter是一个特殊的Variable
    self.b = nn.Parameter( t.randn( out_features ) )   #偏值b
  
  def forward( self, x ): #参数 x 是一个Variable对象
    x = x.mm( self.w )
    return x + self.b.expand_as( x ) #让b的形状符合 输出的x的形状

2、验证一下

layer = Linear( 4,3 )
input = V ( t.randn( 2 ,4 ) )#包装一个Variable作为输入
out = layer( input )
out

#成功运行,结果如下:

tensor([[-2.1934, 2.5590, 4.0233], [ 1.1098, -3.8182, 0.1848]], grad_fn=<AddBackward0>)

下面利用Linear构造一个多层网络

class Perceptron( nn.Module ):
  def __init__( self,in_features, hidden_features, out_features ):
    super().__init__()
    self.layer1 = Linear( in_features , hidden_features )
    self.layer2 = Linear( hidden_features, out_features )
  def forward ( self ,x ):
    x = self.layer1( x )
    x = t.sigmoid( x ) #用sigmoid()激活函数
    return self.layer2( x )

测试一下

perceptron = Perceptron ( 5,3 ,1 )
 
for name,param in perceptron.named_parameters(): 
  print( name, param.size() )

输出如预期:

layer1.w torch.Size([5, 3])
layer1.b torch.Size([3])
layer2.w torch.Size([3, 1])
layer2.b torch.Size([1])

以上这篇用pytorch的nn.Module构造简单全链接层实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python web框架学习笔记
May 03 Python
python搭建虚拟环境的步骤详解
Sep 27 Python
Python实现简单遗传算法(SGA)
Jan 29 Python
Python中将变量按行写入txt文本中的方法
Apr 03 Python
对numpy 数组和矩阵的乘法的进一步理解
Apr 04 Python
Python中关键字global和nonlocal的区别详解
Sep 03 Python
Python使用百度翻译开发平台实现英文翻译为中文功能示例
Aug 08 Python
在Python中使用turtle绘制多个同心圆示例
Nov 23 Python
pytorch 限制GPU使用效率详解(计算效率)
Jun 27 Python
python3环境搭建过程(利用Anaconda+pycharm)完整版
Aug 19 Python
解决python便携版无法直接运行py文件的问题
Sep 01 Python
python 进制转换 int、bin、oct、hex的原理
Jan 13 Python
pytorch三层全连接层实现手写字母识别方式
Jan 14 #Python
Python实现bilibili时间长度查询的示例代码
Jan 14 #Python
基于python监控程序是否关闭
Jan 14 #Python
关于pytorch中全连接神经网络搭建两种模式详解
Jan 14 #Python
使用Pytorch来拟合函数方式
Jan 14 #Python
pytorch 模拟关系拟合——回归实例
Jan 14 #Python
PyTorch实现AlexNet示例
Jan 14 #Python
You might like
php GeoIP的使用教程
2011/03/09 PHP
PHP根据IP判断地区名信息的示例代码
2014/03/03 PHP
lnmp安装多版本PHP共存的方法详解
2018/08/02 PHP
jQuery AJAX回调函数this指向问题
2010/02/08 Javascript
SWFObject 2.1以上版本语法介绍
2010/07/10 Javascript
jquery中ajax学习笔记一
2011/10/16 Javascript
jquery改变disabled的boolean状态的三种方法
2013/12/13 Javascript
js读取json的两种常用方法示例介绍
2014/10/19 Javascript
移动端触屏幻灯片图片切换插件idangerous swiper.js
2017/04/10 Javascript
BootStrap Table 后台数据绑定、特殊列处理、排序功能
2017/05/27 Javascript
js 奇葩技巧之隐藏代码
2017/08/11 Javascript
浅谈Node.js爬虫之网页请求模块
2018/01/11 Javascript
深入浅析Vue全局组件与局部组件的区别
2018/06/15 Javascript
详解从Vue-router到html5的pushState
2018/07/21 Javascript
JavaScript switch语句使用方法简介
2019/12/30 Javascript
[04:04]DOTA2亚洲邀请赛比赛场馆&酒店全攻略
2017/03/23 DOTA
[10:14]2018DOTA2国际邀请赛寻真——paiN Gaming不仅为自己而战
2018/08/14 DOTA
python通过字典dict判断指定键值是否存在的方法
2015/03/21 Python
在Python的web框架中编写创建日志的程序的教程
2015/04/30 Python
Python编程之多态用法实例详解
2015/05/19 Python
python学生管理系统学习笔记
2019/03/19 Python
使用Python3内置文档高效学习以及官方中文文档
2019/05/19 Python
python3实现斐波那契数列(4种方法)
2019/07/15 Python
关于Numpy中的行向量和列向量详解
2019/11/30 Python
Python无损压缩图片的示例代码
2020/08/06 Python
Otticanet意大利:最顶尖的世界名牌眼镜, 能得到打折季的价格
2019/03/10 全球购物
Set里的元素是不能重复的,那么用什么方法来区分重复与否呢?
2016/08/18 面试题
普师专业个人自荐信范文
2013/11/26 职场文书
参观监狱心得体会
2014/01/02 职场文书
办理护照介绍信
2014/01/16 职场文书
2014年校长工作总结
2014/12/11 职场文书
浪漫的婚礼主持词
2015/06/30 职场文书
小学远程教育工作总结
2015/08/13 职场文书
详解python字符串驻留技术
2021/05/21 Python
pytorch分类模型绘制混淆矩阵以及可视化详解
2022/04/07 Python
DQL数据查询语句使用示例
2022/12/24 MySQL