Pytorch入门之mnist分类实例


Posted in Python onApril 14, 2018

本文实例为大家分享了Pytorch入门之mnist分类的具体代码,供大家参考,具体内容如下

#!/usr/bin/env python
# -*- coding: utf-8 -*-
__author__ = 'denny'
__time__ = '2017-9-9 9:03'

import torch
import torchvision
from torch.autograd import Variable
import torch.utils.data.dataloader as Data

train_data = torchvision.datasets.MNIST(
 './mnist', train=True, transform=torchvision.transforms.ToTensor(), download=True
)
test_data = torchvision.datasets.MNIST(
 './mnist', train=False, transform=torchvision.transforms.ToTensor()
)
print("train_data:", train_data.train_data.size())
print("train_labels:", train_data.train_labels.size())
print("test_data:", test_data.test_data.size())

train_loader = Data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = Data.DataLoader(dataset=test_data, batch_size=64)


class Net(torch.nn.Module):
 def __init__(self):
 super(Net, self).__init__()
 self.conv1 = torch.nn.Sequential(
  torch.nn.Conv2d(1, 32, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2))
 self.conv2 = torch.nn.Sequential(
  torch.nn.Conv2d(32, 64, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2)
 )
 self.conv3 = torch.nn.Sequential(
  torch.nn.Conv2d(64, 64, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2)
 )
 self.dense = torch.nn.Sequential(
  torch.nn.Linear(64 * 3 * 3, 128),
  torch.nn.ReLU(),
  torch.nn.Linear(128, 10)
 )

 def forward(self, x):
 conv1_out = self.conv1(x)
 conv2_out = self.conv2(conv1_out)
 conv3_out = self.conv3(conv2_out)
 res = conv3_out.view(conv3_out.size(0), -1)
 out = self.dense(res)
 return out


model = Net()
print(model)

optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss()

for epoch in range(10):
 print('epoch {}'.format(epoch + 1))
 # training-----------------------------
 train_loss = 0.
 train_acc = 0.
 for batch_x, batch_y in train_loader:
 batch_x, batch_y = Variable(batch_x), Variable(batch_y)
 out = model(batch_x)
 loss = loss_func(out, batch_y)
 train_loss += loss.data[0]
 pred = torch.max(out, 1)[1]
 train_correct = (pred == batch_y).sum()
 train_acc += train_correct.data[0]
 optimizer.zero_grad()
 loss.backward()
 optimizer.step()
 print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
 train_data)), train_acc / (len(train_data))))

 # evaluation--------------------------------
 model.eval()
 eval_loss = 0.
 eval_acc = 0.
 for batch_x, batch_y in test_loader:
 batch_x, batch_y = Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)
 out = model(batch_x)
 loss = loss_func(out, batch_y)
 eval_loss += loss.data[0]
 pred = torch.max(out, 1)[1]
 num_correct = (pred == batch_y).sum()
 eval_acc += num_correct.data[0]
 print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
 test_data)), eval_acc / (len(test_data))))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python比较2个时间大小的实现方法
Apr 10 Python
Python selenium抓取微博内容的示例代码
May 17 Python
python删除文本中行数标签的方法
May 31 Python
python 实现UTC时间加减的方法
Dec 31 Python
Python3 itchat实现微信定时发送群消息的实例代码
Jul 12 Python
flask框架配置mysql数据库操作详解
Nov 29 Python
python 导入数据及作图的实现
Dec 03 Python
python实现将视频按帧读取到自定义目录
Dec 10 Python
Python属性和内建属性实例解析
Jan 14 Python
python实现双人五子棋(终端版)
Dec 30 Python
python链表类中获取元素实例方法
Feb 23 Python
深入探讨opencv图像矫正算法实战
May 21 Python
pytorch构建网络模型的4种方法
Apr 13 #Python
Python输入二维数组方法
Apr 13 #Python
Python基于递归实现电话号码映射功能示例
Apr 13 #Python
Python的多维空数组赋值方法
Apr 13 #Python
python多维数组切片方法
Apr 13 #Python
Python实现判断并移除列表指定位置元素的方法
Apr 13 #Python
Python中的二维数组实例(list与numpy.array)
Apr 13 #Python
You might like
PHP三元运算的2种写法代码实例
2014/05/12 PHP
Yii中使用PHPExcel导出Excel的方法
2014/12/26 PHP
深入学习微信网址链接解封的防封原理visit_type
2019/08/15 PHP
jquery中的sortable排序之后的保存状态的解决方法
2010/01/28 Javascript
javascript中的new使用
2010/03/20 Javascript
不用锚点也可以平滑滚动到页面的指定位置实现代码
2013/05/08 Javascript
jquery中页面Ajax方法$.load的功能使用介绍
2014/10/20 Javascript
jQuery扁平化风格下拉框美化插件FancySelect使用指南
2015/02/10 Javascript
Javascript实现字数统计
2015/07/03 Javascript
js实现超酷的照片墙展示效果图附源码下载
2015/10/08 Javascript
jQuery ready()和onload的加载耗时分析
2016/09/08 Javascript
jQuery中fadein与fadeout方法用法示例
2016/09/16 Javascript
Javascript for in的缺陷总结
2017/02/03 Javascript
javascript实现table单元格点击展开隐藏效果(实例代码)
2017/04/10 Javascript
JS正则表达式验证中文字符
2017/05/08 Javascript
JavaScrip数组删除特定元素的几种方法总结
2017/09/06 Javascript
微信小程序实现倒计时调用相机自动拍照功能
2018/06/10 Javascript
微信小程序实现自定义加载图标功能
2018/07/19 Javascript
Vue.js 中的 v-model 指令及绑定表单元素的方法
2018/12/03 Javascript
js实现双人五子棋小游戏
2020/05/28 Javascript
你真的了解Python的random模块吗?
2017/12/12 Python
详解用TensorFlow实现逻辑回归算法
2018/05/02 Python
python2.7实现爬虫网页数据
2018/05/25 Python
不到40行代码用Python实现一个简单的推荐系统
2019/05/10 Python
Django 实现前端图片压缩功能的方法
2019/08/07 Python
python super函数使用方法详解
2020/02/14 Python
详解python中groupby函数通俗易懂
2020/05/14 Python
css3实现文字首尾衔接跑马灯的示例代码
2020/10/16 HTML / CSS
商场消防管理制度
2014/01/12 职场文书
财务人员担保书
2014/05/13 职场文书
幼师中班个人总结
2015/02/12 职场文书
大学毕业生自我评价
2015/03/02 职场文书
2015年团队工作总结范文
2015/05/04 职场文书
骆驼祥子读书笔记
2015/06/26 职场文书
一文搞懂Redis中String数据类型
2022/04/03 Redis
CSS实现背景图片全屏铺满自适应的3种方式
2022/07/07 HTML / CSS