Pytorch入门之mnist分类实例


Posted in Python onApril 14, 2018

本文实例为大家分享了Pytorch入门之mnist分类的具体代码,供大家参考,具体内容如下

#!/usr/bin/env python
# -*- coding: utf-8 -*-
__author__ = 'denny'
__time__ = '2017-9-9 9:03'

import torch
import torchvision
from torch.autograd import Variable
import torch.utils.data.dataloader as Data

train_data = torchvision.datasets.MNIST(
 './mnist', train=True, transform=torchvision.transforms.ToTensor(), download=True
)
test_data = torchvision.datasets.MNIST(
 './mnist', train=False, transform=torchvision.transforms.ToTensor()
)
print("train_data:", train_data.train_data.size())
print("train_labels:", train_data.train_labels.size())
print("test_data:", test_data.test_data.size())

train_loader = Data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = Data.DataLoader(dataset=test_data, batch_size=64)


class Net(torch.nn.Module):
 def __init__(self):
 super(Net, self).__init__()
 self.conv1 = torch.nn.Sequential(
  torch.nn.Conv2d(1, 32, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2))
 self.conv2 = torch.nn.Sequential(
  torch.nn.Conv2d(32, 64, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2)
 )
 self.conv3 = torch.nn.Sequential(
  torch.nn.Conv2d(64, 64, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2)
 )
 self.dense = torch.nn.Sequential(
  torch.nn.Linear(64 * 3 * 3, 128),
  torch.nn.ReLU(),
  torch.nn.Linear(128, 10)
 )

 def forward(self, x):
 conv1_out = self.conv1(x)
 conv2_out = self.conv2(conv1_out)
 conv3_out = self.conv3(conv2_out)
 res = conv3_out.view(conv3_out.size(0), -1)
 out = self.dense(res)
 return out


model = Net()
print(model)

optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss()

for epoch in range(10):
 print('epoch {}'.format(epoch + 1))
 # training-----------------------------
 train_loss = 0.
 train_acc = 0.
 for batch_x, batch_y in train_loader:
 batch_x, batch_y = Variable(batch_x), Variable(batch_y)
 out = model(batch_x)
 loss = loss_func(out, batch_y)
 train_loss += loss.data[0]
 pred = torch.max(out, 1)[1]
 train_correct = (pred == batch_y).sum()
 train_acc += train_correct.data[0]
 optimizer.zero_grad()
 loss.backward()
 optimizer.step()
 print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
 train_data)), train_acc / (len(train_data))))

 # evaluation--------------------------------
 model.eval()
 eval_loss = 0.
 eval_acc = 0.
 for batch_x, batch_y in test_loader:
 batch_x, batch_y = Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)
 out = model(batch_x)
 loss = loss_func(out, batch_y)
 eval_loss += loss.data[0]
 pred = torch.max(out, 1)[1]
 num_correct = (pred == batch_y).sum()
 eval_acc += num_correct.data[0]
 print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
 test_data)), eval_acc / (len(test_data))))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Django1.3添加app提示模块不存在的解决方法
Aug 26 Python
web.py在模板中输出美元符号的方法
Aug 26 Python
Python的Flask框架中Flask-Admin库的简单入门指引
Apr 07 Python
Python进程间通信之共享内存详解
Oct 30 Python
python验证码识别实例代码
Feb 03 Python
python ftp 按目录结构上传下载的实现代码
Sep 12 Python
Python对HTML转义字符进行反转义的实现方法
Apr 28 Python
python实现批量视频分帧、保存视频帧
May 31 Python
PyQt QCombobox设置行高的方法
Jun 20 Python
Pytorch GPU显存充足却显示out of memory的解决方式
Jan 13 Python
python 解决print数组/矩阵无法完整输出的问题
Feb 19 Python
Python递归求出列表(包括列表中的子列表)的最大值实例
Feb 27 Python
pytorch构建网络模型的4种方法
Apr 13 #Python
Python输入二维数组方法
Apr 13 #Python
Python基于递归实现电话号码映射功能示例
Apr 13 #Python
Python的多维空数组赋值方法
Apr 13 #Python
python多维数组切片方法
Apr 13 #Python
Python实现判断并移除列表指定位置元素的方法
Apr 13 #Python
Python中的二维数组实例(list与numpy.array)
Apr 13 #Python
You might like
PHP基于yii框架实现生成ICO图标
2015/11/13 PHP
php高清晰度无损图片压缩功能的实现代码
2018/12/09 PHP
jquery插件制作 手风琴Panel效果实现
2012/08/17 Javascript
jquery的选择器的使用技巧之如何选择input框
2013/09/22 Javascript
读取input:file的路径并显示本地图片的方法
2013/09/23 Javascript
javascript在网页中实现读取剪贴板粘贴截图功能
2014/06/07 Javascript
node.js中的fs.writeSync方法使用说明
2014/12/15 Javascript
javascript弹出拖动窗口
2015/08/11 Javascript
jQuery ui实现动感的圆角渐变网站导航菜单效果代码
2015/08/26 Javascript
Prototype框架详解
2015/11/25 Javascript
基于javascript实现的快速排序
2016/12/02 Javascript
jQuery插件echarts实现的单折线图效果示例【附demo源码下载】
2017/03/04 Javascript
vue动态路由配置及路由传参的方式
2018/05/23 Javascript
webpack分离css单独打包的方法
2018/06/12 Javascript
vue使用pdfjs显示PDF可复制的实现方法
2018/12/14 Javascript
[jQuery] 事件和动画详解
2019/03/05 jQuery
vue路由守卫,限制前端页面访问权限的例子
2019/11/11 Javascript
使用Python神器对付12306变态验证码
2016/01/05 Python
Python数据类型详解(四)字典:dict
2016/05/12 Python
python字典DICT类型合并详解
2017/08/17 Python
Python使用回溯法子集树模板解决迷宫问题示例
2017/09/01 Python
python中协程实现TCP连接的实例分析
2018/10/14 Python
详解Windows下PyCharm安装Numpy包及无法安装问题解决方案
2020/06/18 Python
CSS3教程(4):网页边框和网页文字阴影
2009/04/02 HTML / CSS
浅谈CSS3鼠标移入图片动态提示效果(transform)
2017/11/06 HTML / CSS
HTML5+CSS3模仿优酷视频截图功能示例
2017/01/05 HTML / CSS
印尼在线精品店:Berrybenka.com
2016/10/22 全球购物
瑞典香水、须后水和美容产品购物网站:Parfym-Klick.se
2019/12/29 全球购物
毕业生精彩的自我评价分享
2013/10/06 职场文书
艺术应用与设计专业个人的自我评价
2013/11/19 职场文书
母亲七十大寿答谢词
2014/01/18 职场文书
收银员的岗位职责范本
2014/02/04 职场文书
大一新生期末自我评价
2014/09/12 职场文书
2014年内勤工作总结
2014/11/24 职场文书
人口与计划生育责任书
2015/05/09 职场文书
2015年节能降耗工作总结
2015/05/22 职场文书