Python使用gluon/mxnet模块实现的mnist手写数字识别功能完整示例


Posted in Python onDecember 18, 2019

本文实例讲述了Python使用gluon/mxnet模块实现的mnist手写数字识别功能。分享给大家供大家参考,具体如下:

import gluonbook as gb
from mxnet import autograd,nd,init,gluon
from mxnet.gluon import loss as gloss,data as gdata,nn,utils as gutils
import mxnet as mx
net = nn.Sequential()
with net.name_scope():
  net.add(
    nn.Conv2D(channels=32, kernel_size=5, activation='relu'),
    nn.MaxPool2D(pool_size=2, strides=2),
    nn.Flatten(),
    nn.Dense(128, activation='sigmoid'),
    nn.Dense(10, activation='sigmoid')
  )
lr = 0.5
batch_size=256
ctx = mx.gpu()
net.initialize(init=init.Xavier(), ctx=ctx)
train_data, test_data = gb.load_data_fashion_mnist(batch_size)
trainer = gluon.Trainer(net.collect_params(),'sgd',{'learning_rate' : lr})
loss = gloss.SoftmaxCrossEntropyLoss()
num_epochs = 30
def train(train_data, test_data, net, loss, trainer,num_epochs):
  for epoch in range(num_epochs):
    total_loss = 0
    for x,y in train_data:
      with autograd.record():
        x = x.as_in_context(ctx)
        y = y.as_in_context(ctx)
        y_hat=net(x)
        l = loss(y_hat,y)
      l.backward()
      total_loss += l
      trainer.step(batch_size)
    mx.nd.waitall()
    print("Epoch [{}]: Loss {}".format(epoch, total_loss.sum().asnumpy()[0]/(batch_size*len(train_data))))
if __name__ == '__main__':
  try:
    ctx = mx.gpu()
    _ = nd.zeros((1,), ctx=ctx)
  except:
    ctx = mx.cpu()
  ctx
  gb.train(train_data,test_data,net,loss,trainer,ctx,num_epochs)

更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
python自动格式化json文件的方法
Mar 11 Python
Python自定义scrapy中间模块避免重复采集的方法
Apr 07 Python
在Python中使用PIL模块处理图像的教程
Apr 29 Python
在Django的视图中使用form对象的方法
Jul 18 Python
python中import reload __import__的区别详解
Oct 16 Python
python的Crypto模块实现AES加密实例代码
Jan 22 Python
python使用knn实现特征向量分类
Dec 26 Python
Python模块汇总(常用第三方库)
Oct 07 Python
python3.x 生成3维随机数组实例
Nov 28 Python
python基于opencv检测程序运行效率
Dec 28 Python
tensorflow之获取tensor的shape作为max_pool的ksize实例
Jan 04 Python
全网非常详细的pytest配置文件
Jul 15 Python
简单了解Python读取大文件代码实例
Dec 18 #Python
python 比较2张图片的相似度的方法示例
Dec 18 #Python
使用Python的Turtle库绘制森林的实例
Dec 18 #Python
python3 requests库实现多图片爬取教程
Dec 18 #Python
在notepad++中实现直接运行python代码
Dec 18 #Python
简单了解python装饰器原理及使用方法
Dec 18 #Python
修改Pandas的行或列的名字(重命名)
Dec 18 #Python
You might like
PHP 七大优势分析
2009/06/23 PHP
PHP基础陷阱题(变量赋值)
2012/09/12 PHP
浅析php面向对象public private protected 访问修饰符
2013/06/30 PHP
php通过两层过滤获取留言内容的方法
2016/07/11 PHP
JavaScript定义类或函数的几种方式小结
2011/01/09 Javascript
js 编程笔记 无名函数
2011/06/28 Javascript
Javascript 面向对象(二)封装代码
2012/05/23 Javascript
JS 页面计时器示例代码
2013/10/28 Javascript
js创建元素(节点)示例
2014/01/02 Javascript
jquery实现的table排序功能示例
2017/03/10 Javascript
对Angular中单向数据流的深入理解
2018/03/31 Javascript
详解js删除数组中的指定元素
2018/10/31 Javascript
Bootstrap告警框(alert)实现弹出效果和短暂显示后上浮消失的示例代码
2020/08/27 Javascript
[04:39]显微镜下的DOTA2第十三期—Pis卡尔个人秀
2014/04/04 DOTA
python向已存在的excel中新增表,不覆盖原数据的实例
2018/05/02 Python
Django将默认的SQLite更换为MySQL的实现
2019/11/18 Python
python字符串的拼接方法总结
2019/11/18 Python
python opencv根据颜色进行目标检测的方法示例
2020/01/15 Python
tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用
2020/01/20 Python
python计算波峰波谷值的方法(极值点)
2020/02/18 Python
Python常用编译器原理及特点解析
2020/03/23 Python
python 实现全球IP归属地查询工具
2020/12/18 Python
荷兰多品牌网上鞋店:Stoute Schoenen
2017/08/24 全球购物
Araks官网:纽约内衣品牌
2020/10/15 全球购物
一些PHP的面试题
2015/05/06 面试题
如何将字串String转换成整数int
2015/02/21 面试题
软件部经理岗位职责范本
2014/02/25 职场文书
园艺师求职信
2014/03/10 职场文书
幼儿园毕业典礼主持词
2014/03/21 职场文书
纪念一二九运动演讲稿
2014/09/16 职场文书
2014年财政所工作总结
2014/11/22 职场文书
公司前台接待岗位职责
2015/04/03 职场文书
离职信范文
2015/06/23 职场文书
行政处罚决定书
2015/06/24 职场文书
优秀员工演讲稿
2019/06/21 职场文书
Python网络编程之ZeroMQ知识总结
2021/04/25 Python