Pytorch入门之mnist分类实例


Posted in Python onApril 14, 2018

本文实例为大家分享了Pytorch入门之mnist分类的具体代码,供大家参考,具体内容如下

#!/usr/bin/env python
# -*- coding: utf-8 -*-
__author__ = 'denny'
__time__ = '2017-9-9 9:03'

import torch
import torchvision
from torch.autograd import Variable
import torch.utils.data.dataloader as Data

train_data = torchvision.datasets.MNIST(
 './mnist', train=True, transform=torchvision.transforms.ToTensor(), download=True
)
test_data = torchvision.datasets.MNIST(
 './mnist', train=False, transform=torchvision.transforms.ToTensor()
)
print("train_data:", train_data.train_data.size())
print("train_labels:", train_data.train_labels.size())
print("test_data:", test_data.test_data.size())

train_loader = Data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = Data.DataLoader(dataset=test_data, batch_size=64)


class Net(torch.nn.Module):
 def __init__(self):
 super(Net, self).__init__()
 self.conv1 = torch.nn.Sequential(
  torch.nn.Conv2d(1, 32, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2))
 self.conv2 = torch.nn.Sequential(
  torch.nn.Conv2d(32, 64, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2)
 )
 self.conv3 = torch.nn.Sequential(
  torch.nn.Conv2d(64, 64, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2)
 )
 self.dense = torch.nn.Sequential(
  torch.nn.Linear(64 * 3 * 3, 128),
  torch.nn.ReLU(),
  torch.nn.Linear(128, 10)
 )

 def forward(self, x):
 conv1_out = self.conv1(x)
 conv2_out = self.conv2(conv1_out)
 conv3_out = self.conv3(conv2_out)
 res = conv3_out.view(conv3_out.size(0), -1)
 out = self.dense(res)
 return out


model = Net()
print(model)

optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss()

for epoch in range(10):
 print('epoch {}'.format(epoch + 1))
 # training-----------------------------
 train_loss = 0.
 train_acc = 0.
 for batch_x, batch_y in train_loader:
 batch_x, batch_y = Variable(batch_x), Variable(batch_y)
 out = model(batch_x)
 loss = loss_func(out, batch_y)
 train_loss += loss.data[0]
 pred = torch.max(out, 1)[1]
 train_correct = (pred == batch_y).sum()
 train_acc += train_correct.data[0]
 optimizer.zero_grad()
 loss.backward()
 optimizer.step()
 print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
 train_data)), train_acc / (len(train_data))))

 # evaluation--------------------------------
 model.eval()
 eval_loss = 0.
 eval_acc = 0.
 for batch_x, batch_y in test_loader:
 batch_x, batch_y = Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)
 out = model(batch_x)
 loss = loss_func(out, batch_y)
 eval_loss += loss.data[0]
 pred = torch.max(out, 1)[1]
 num_correct = (pred == batch_y).sum()
 eval_acc += num_correct.data[0]
 print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
 test_data)), eval_acc / (len(test_data))))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python的web.py框架实现类似Django的ORM查询的教程
May 02 Python
Python2.7基于笛卡尔积算法实现N个数组的排列组合运算示例
Nov 23 Python
Python实现PS图像明亮度调整效果示例
Jan 23 Python
利用python和百度地图API实现数据地图标注的方法
May 13 Python
详解pandas数据合并与重塑(pd.concat篇)
Jul 09 Python
Django forms表单 select下拉框的传值实例
Jul 19 Python
Python解析json时提示“string indices must be integers”问题解决方法
Jul 31 Python
Python 限定函数参数的类型及默认值方式
Dec 24 Python
Python中filter与lambda的结合使用详解
Dec 24 Python
Python图像处理库PIL的ImageFilter模块使用介绍
Feb 26 Python
用Python制作灯光秀短视频的思路详解
Apr 13 Python
浅谈Python中的正则表达式
Jun 28 Python
pytorch构建网络模型的4种方法
Apr 13 #Python
Python输入二维数组方法
Apr 13 #Python
Python基于递归实现电话号码映射功能示例
Apr 13 #Python
Python的多维空数组赋值方法
Apr 13 #Python
python多维数组切片方法
Apr 13 #Python
Python实现判断并移除列表指定位置元素的方法
Apr 13 #Python
Python中的二维数组实例(list与numpy.array)
Apr 13 #Python
You might like
PHP+ajax实现获取新闻数据简单示例
2018/05/08 PHP
PHP使用redis位图bitMap 实现签到功能
2019/10/08 PHP
Jquery Ajax学习实例5 向WebService发出请求,返回泛型集合数据的异步调用
2010/03/17 Javascript
javascript面向对象编程(一) 实例代码
2010/06/25 Javascript
兼容IE和Firefox的javascript获取iframe文档内容的函数
2011/08/15 Javascript
extjs4 treepanel动态改变行高度示例
2013/12/17 Javascript
Jquery Uploadify上传带进度条的简单实例
2014/02/12 Javascript
优化Node.js Web应用运行速度的10个技巧
2014/09/03 Javascript
jQuery实现选中弹出窗口选择框内容后赋值给文本框的方法
2015/11/23 Javascript
巧用Vue.js+Vuex制作专门收藏微信公众号的app
2016/11/03 Javascript
Angular JS 生成动态二维码的方法
2017/02/23 Javascript
vue+Java后端进行调试时解决跨域问题的方式
2017/10/19 Javascript
ejsExcel模板在Vue.js项目中的实际运用
2018/01/27 Javascript
vue-lazyload图片延迟加载插件的实例讲解
2018/02/09 Javascript
详解处理bootstrap4不支持远程静态框问题
2018/07/20 Javascript
vue 中引用gojs绘制E-R图的方法示例
2018/08/24 Javascript
vue首次赋值不触发watch的解决方法
2018/09/11 Javascript
jQuery实现图片下载代码
2019/07/18 jQuery
如何利用Node.js与JSON搭建简单的动态服务器
2020/06/16 Javascript
vue实现单一筛选、删除筛选条件
2020/10/26 Javascript
[56:20]LGD vs VP Supermajor 败者组决赛 BO3 第三场 6.10
2018/07/04 DOTA
python实现simhash算法实例
2014/04/25 Python
深入解析Python中的urllib2模块
2015/11/13 Python
浅谈python字符串方法的简单使用
2016/07/18 Python
利用python写个下载teahour音频的小脚本
2017/05/08 Python
Python实现打印螺旋矩阵功能的方法
2017/11/21 Python
Windows下的Python 3.6.1的下载与安装图文详解(适合32位和64位)
2018/02/21 Python
python 输出上个月的月末日期实例
2018/04/11 Python
Pycharm激活码激活两种快速方式(附最新激活码和插件)
2020/03/12 Python
pytorch掉坑记录:model.eval的作用说明
2020/06/23 Python
Pandas直接读取sql脚本的方法
2021/01/21 Python
开发中都用到了那些设计模式?用在什么场合?
2014/08/21 面试题
文员个人的求职信范文
2013/09/26 职场文书
电影地道战观后感
2015/06/04 职场文书
谢师宴学生致辞
2015/07/27 职场文书
JavaScript原型链中函数和对象的理解
2022/06/16 Javascript