Pytorch入门之mnist分类实例


Posted in Python onApril 14, 2018

本文实例为大家分享了Pytorch入门之mnist分类的具体代码,供大家参考,具体内容如下

#!/usr/bin/env python
# -*- coding: utf-8 -*-
__author__ = 'denny'
__time__ = '2017-9-9 9:03'

import torch
import torchvision
from torch.autograd import Variable
import torch.utils.data.dataloader as Data

train_data = torchvision.datasets.MNIST(
 './mnist', train=True, transform=torchvision.transforms.ToTensor(), download=True
)
test_data = torchvision.datasets.MNIST(
 './mnist', train=False, transform=torchvision.transforms.ToTensor()
)
print("train_data:", train_data.train_data.size())
print("train_labels:", train_data.train_labels.size())
print("test_data:", test_data.test_data.size())

train_loader = Data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = Data.DataLoader(dataset=test_data, batch_size=64)


class Net(torch.nn.Module):
 def __init__(self):
 super(Net, self).__init__()
 self.conv1 = torch.nn.Sequential(
  torch.nn.Conv2d(1, 32, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2))
 self.conv2 = torch.nn.Sequential(
  torch.nn.Conv2d(32, 64, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2)
 )
 self.conv3 = torch.nn.Sequential(
  torch.nn.Conv2d(64, 64, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2)
 )
 self.dense = torch.nn.Sequential(
  torch.nn.Linear(64 * 3 * 3, 128),
  torch.nn.ReLU(),
  torch.nn.Linear(128, 10)
 )

 def forward(self, x):
 conv1_out = self.conv1(x)
 conv2_out = self.conv2(conv1_out)
 conv3_out = self.conv3(conv2_out)
 res = conv3_out.view(conv3_out.size(0), -1)
 out = self.dense(res)
 return out


model = Net()
print(model)

optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss()

for epoch in range(10):
 print('epoch {}'.format(epoch + 1))
 # training-----------------------------
 train_loss = 0.
 train_acc = 0.
 for batch_x, batch_y in train_loader:
 batch_x, batch_y = Variable(batch_x), Variable(batch_y)
 out = model(batch_x)
 loss = loss_func(out, batch_y)
 train_loss += loss.data[0]
 pred = torch.max(out, 1)[1]
 train_correct = (pred == batch_y).sum()
 train_acc += train_correct.data[0]
 optimizer.zero_grad()
 loss.backward()
 optimizer.step()
 print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
 train_data)), train_acc / (len(train_data))))

 # evaluation--------------------------------
 model.eval()
 eval_loss = 0.
 eval_acc = 0.
 for batch_x, batch_y in test_loader:
 batch_x, batch_y = Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)
 out = model(batch_x)
 loss = loss_func(out, batch_y)
 eval_loss += loss.data[0]
 pred = torch.max(out, 1)[1]
 num_correct = (pred == batch_y).sum()
 eval_acc += num_correct.data[0]
 print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
 test_data)), eval_acc / (len(test_data))))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中利用sqrt()方法进行平方根计算的教程
May 15 Python
浅析Python 中整型对象存储的位置
May 16 Python
Python3如何解决字符编码问题详解
Apr 23 Python
python web.py开发httpserver解决跨域问题实例解析
Feb 12 Python
Python3中的列表生成式、生成器与迭代器实例详解
Jun 11 Python
Python wxPython库使用wx.ListBox创建列表框示例
Sep 03 Python
python数据批量写入ScrolledText的优化方法
Oct 11 Python
Django使用unittest模块进行单元测试过程解析
Aug 02 Python
Python3 A*寻路算法实现方式
Dec 24 Python
python通过nmap扫描在线设备并尝试AAA登录(实例代码)
Dec 30 Python
如何使用 Flask 做一个评论系统
Nov 27 Python
python推导式的使用方法实例
Feb 28 Python
pytorch构建网络模型的4种方法
Apr 13 #Python
Python输入二维数组方法
Apr 13 #Python
Python基于递归实现电话号码映射功能示例
Apr 13 #Python
Python的多维空数组赋值方法
Apr 13 #Python
python多维数组切片方法
Apr 13 #Python
Python实现判断并移除列表指定位置元素的方法
Apr 13 #Python
Python中的二维数组实例(list与numpy.array)
Apr 13 #Python
You might like
phpmyadmin的#1251问题
2006/11/25 PHP
学习php设计模式 php实现合成模式(composite)
2015/12/08 PHP
学习thinkphp5.0验证类使用方法
2017/11/16 PHP
IE8 chrome中table隔行换色解决办法
2010/07/09 Javascript
兼容IE和Firefox的javascript获取iframe文档内容的函数
2011/08/15 Javascript
IE下JS读取xml文件示例代码
2013/08/05 Javascript
js实现收缩菜单效果实例代码
2013/10/30 Javascript
JavaScript中的object转换函数toString()与valueOf()介绍
2014/12/31 Javascript
js简单实现点击左右运动的方法
2015/04/10 Javascript
JavaScript多线程详解
2015/08/12 Javascript
js replace(a,b)之替换字符串中所有指定字符的方法
2016/08/17 Javascript
Bootstrap源码解读按钮(5)
2016/12/23 Javascript
利用js定义一个导航条菜单
2017/03/14 Javascript
vue.js路由跳转详解
2017/08/28 Javascript
VSCode配置react开发环境的步骤
2017/12/27 Javascript
详解从react转职到vue开发的项目准备
2019/01/14 Javascript
JS实现可切换图片的幻灯切换效果示例
2019/05/24 Javascript
jQuery实现图片下载代码
2019/07/18 jQuery
layui自定义插件citySelect实现省市区三级联动选择
2019/07/26 Javascript
js实现图片区域可点击大小随意改变(适用移动端)代码实例
2019/09/11 Javascript
Vue v-model组件封装(类似弹窗组件)
2020/01/08 Javascript
vue实现购物车列表
2020/06/30 Javascript
[48:44]2014 DOTA2国际邀请赛中国区预选赛5.21 TongFu VS HGT
2014/05/22 DOTA
Python中enumerate函数代码解析
2017/10/31 Python
Python实现Mysql数据统计及numpy统计函数
2019/07/15 Python
python删除文件夹下相同文件和无法打开的图片
2019/07/16 Python
python判断单向链表是否包括环,若包含则计算环入口的节点实例分析
2019/10/23 Python
Python OpenCV读取显示视频的方法示例
2020/02/20 Python
TensorFlow的环境配置与安装方法
2021/02/20 Python
Nordgreen台湾官网:极简北欧设计手表
2019/08/21 全球购物
什么是makefile? 如何编写makefile?
2012/08/08 面试题
给下属加薪申请报告
2015/05/15 职场文书
2015年民兵整组工作总结
2015/07/24 职场文书
vue实现同时设置多个倒计时
2021/05/20 Vue.js
pandas进行数据输入和输出的方法详解
2022/03/23 Python
分享一个vue实现的记事本功能案例
2022/04/11 Vue.js