Pytorch入门之mnist分类实例


Posted in Python onApril 14, 2018

本文实例为大家分享了Pytorch入门之mnist分类的具体代码,供大家参考,具体内容如下

#!/usr/bin/env python
# -*- coding: utf-8 -*-
__author__ = 'denny'
__time__ = '2017-9-9 9:03'

import torch
import torchvision
from torch.autograd import Variable
import torch.utils.data.dataloader as Data

train_data = torchvision.datasets.MNIST(
 './mnist', train=True, transform=torchvision.transforms.ToTensor(), download=True
)
test_data = torchvision.datasets.MNIST(
 './mnist', train=False, transform=torchvision.transforms.ToTensor()
)
print("train_data:", train_data.train_data.size())
print("train_labels:", train_data.train_labels.size())
print("test_data:", test_data.test_data.size())

train_loader = Data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = Data.DataLoader(dataset=test_data, batch_size=64)


class Net(torch.nn.Module):
 def __init__(self):
 super(Net, self).__init__()
 self.conv1 = torch.nn.Sequential(
  torch.nn.Conv2d(1, 32, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2))
 self.conv2 = torch.nn.Sequential(
  torch.nn.Conv2d(32, 64, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2)
 )
 self.conv3 = torch.nn.Sequential(
  torch.nn.Conv2d(64, 64, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2)
 )
 self.dense = torch.nn.Sequential(
  torch.nn.Linear(64 * 3 * 3, 128),
  torch.nn.ReLU(),
  torch.nn.Linear(128, 10)
 )

 def forward(self, x):
 conv1_out = self.conv1(x)
 conv2_out = self.conv2(conv1_out)
 conv3_out = self.conv3(conv2_out)
 res = conv3_out.view(conv3_out.size(0), -1)
 out = self.dense(res)
 return out


model = Net()
print(model)

optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss()

for epoch in range(10):
 print('epoch {}'.format(epoch + 1))
 # training-----------------------------
 train_loss = 0.
 train_acc = 0.
 for batch_x, batch_y in train_loader:
 batch_x, batch_y = Variable(batch_x), Variable(batch_y)
 out = model(batch_x)
 loss = loss_func(out, batch_y)
 train_loss += loss.data[0]
 pred = torch.max(out, 1)[1]
 train_correct = (pred == batch_y).sum()
 train_acc += train_correct.data[0]
 optimizer.zero_grad()
 loss.backward()
 optimizer.step()
 print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
 train_data)), train_acc / (len(train_data))))

 # evaluation--------------------------------
 model.eval()
 eval_loss = 0.
 eval_acc = 0.
 for batch_x, batch_y in test_loader:
 batch_x, batch_y = Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)
 out = model(batch_x)
 loss = loss_func(out, batch_y)
 eval_loss += loss.data[0]
 pred = torch.max(out, 1)[1]
 num_correct = (pred == batch_y).sum()
 eval_acc += num_correct.data[0]
 print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
 test_data)), eval_acc / (len(test_data))))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python程序员鲜为人知但你应该知道的17个问题
Jun 04 Python
Python实现的径向基(RBF)神经网络示例
Feb 06 Python
Python实现在某个数组中查找一个值的算法示例
Jun 27 Python
解决python3 urllib 链接中有中文的问题
Jul 16 Python
python实现多层感知器MLP(基于双月数据集)
Jan 18 Python
python中update的基本使用方法详解
Jul 17 Python
如何使用selenium和requests组合实现登录页面
Feb 03 Python
关于matplotlib-legend 位置属性 loc 使用说明
May 16 Python
jupyter notebook的安装与使用详解
May 18 Python
Python Tkinter实例——模拟掷骰子
Oct 24 Python
详解Python Celery和RabbitMQ实战教程
Jan 20 Python
详解MindSpore自定义模型损失函数
Jun 30 Python
pytorch构建网络模型的4种方法
Apr 13 #Python
Python输入二维数组方法
Apr 13 #Python
Python基于递归实现电话号码映射功能示例
Apr 13 #Python
Python的多维空数组赋值方法
Apr 13 #Python
python多维数组切片方法
Apr 13 #Python
Python实现判断并移除列表指定位置元素的方法
Apr 13 #Python
Python中的二维数组实例(list与numpy.array)
Apr 13 #Python
You might like
PHP与SQL语句写一句话木马总结
2019/10/11 PHP
php设计模式之观察者模式实例详解【星际争霸游戏案例】
2020/03/30 PHP
用js判断浏览器是否是IE的比较好的办法
2007/05/08 Javascript
关于jQuery参考实例 1.0 jQuery的哲学
2013/04/07 Javascript
JavaScript中判断页面关闭、页面刷新的实现代码
2014/08/27 Javascript
同一个网页中实现多个JavaScript特效的方法
2015/02/02 Javascript
分享一个原生的JavaScript拖动方法
2016/09/25 Javascript
angular.js 路由及页面传参示例
2017/02/24 Javascript
js中DOM三级列表(代码分享)
2017/03/20 Javascript
微信小程序实现给循环列表添加点击样式实例
2017/04/26 Javascript
基于复选框demo(分享)
2017/09/27 Javascript
AngularJS2 与 D3.js集成实现自定义可视化的方法
2017/12/01 Javascript
快速解决brew安装特定版本flow的问题
2018/05/17 Javascript
Vue.js中 v-model 指令的修饰符详解
2018/12/03 Javascript
javascript实现5秒倒计时并跳转功能
2019/06/20 Javascript
微信小程序在线客服自动回复功能(基于node)
2019/07/03 Javascript
[01:00]DOTA2 store: Collection of Artisan's Wonders
2015/08/12 DOTA
[03:12]2016完美“圣”典风云人物:单车专访
2016/12/02 DOTA
Python实例之wxpython中Frame使用方法
2014/06/09 Python
Python实现的数据结构与算法之队列详解
2015/04/22 Python
Python遍历numpy数组的实例
2018/04/04 Python
python中的print()输出
2019/04/12 Python
python实现集中式的病毒扫描功能详解
2019/07/09 Python
使用python对多个txt文件中的数据进行筛选的方法
2019/07/10 Python
python中的测试框架
2020/11/13 Python
美国电子产品折扣网站:Daily Steals
2017/05/20 全球购物
工程部主管岗位职责
2013/11/17 职场文书
仓库管理计划书
2014/05/04 职场文书
房产公证委托书范本
2014/09/20 职场文书
2015幼儿园新学期寄语
2015/02/27 职场文书
聋哑人盗窃罪辩护词
2015/05/21 职场文书
繁星春水读书笔记
2015/06/30 职场文书
运动会开幕式主持词
2015/07/01 职场文书
小学运动会加油稿
2015/07/22 职场文书
离婚协议书范本(2016最新版)
2016/03/18 职场文书
Oracle创建只读账号的详细步骤
2021/06/07 Oracle