Pytorch入门之mnist分类实例


Posted in Python onApril 14, 2018

本文实例为大家分享了Pytorch入门之mnist分类的具体代码,供大家参考,具体内容如下

#!/usr/bin/env python
# -*- coding: utf-8 -*-
__author__ = 'denny'
__time__ = '2017-9-9 9:03'

import torch
import torchvision
from torch.autograd import Variable
import torch.utils.data.dataloader as Data

train_data = torchvision.datasets.MNIST(
 './mnist', train=True, transform=torchvision.transforms.ToTensor(), download=True
)
test_data = torchvision.datasets.MNIST(
 './mnist', train=False, transform=torchvision.transforms.ToTensor()
)
print("train_data:", train_data.train_data.size())
print("train_labels:", train_data.train_labels.size())
print("test_data:", test_data.test_data.size())

train_loader = Data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = Data.DataLoader(dataset=test_data, batch_size=64)


class Net(torch.nn.Module):
 def __init__(self):
 super(Net, self).__init__()
 self.conv1 = torch.nn.Sequential(
  torch.nn.Conv2d(1, 32, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2))
 self.conv2 = torch.nn.Sequential(
  torch.nn.Conv2d(32, 64, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2)
 )
 self.conv3 = torch.nn.Sequential(
  torch.nn.Conv2d(64, 64, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2)
 )
 self.dense = torch.nn.Sequential(
  torch.nn.Linear(64 * 3 * 3, 128),
  torch.nn.ReLU(),
  torch.nn.Linear(128, 10)
 )

 def forward(self, x):
 conv1_out = self.conv1(x)
 conv2_out = self.conv2(conv1_out)
 conv3_out = self.conv3(conv2_out)
 res = conv3_out.view(conv3_out.size(0), -1)
 out = self.dense(res)
 return out


model = Net()
print(model)

optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss()

for epoch in range(10):
 print('epoch {}'.format(epoch + 1))
 # training-----------------------------
 train_loss = 0.
 train_acc = 0.
 for batch_x, batch_y in train_loader:
 batch_x, batch_y = Variable(batch_x), Variable(batch_y)
 out = model(batch_x)
 loss = loss_func(out, batch_y)
 train_loss += loss.data[0]
 pred = torch.max(out, 1)[1]
 train_correct = (pred == batch_y).sum()
 train_acc += train_correct.data[0]
 optimizer.zero_grad()
 loss.backward()
 optimizer.step()
 print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
 train_data)), train_acc / (len(train_data))))

 # evaluation--------------------------------
 model.eval()
 eval_loss = 0.
 eval_acc = 0.
 for batch_x, batch_y in test_loader:
 batch_x, batch_y = Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)
 out = model(batch_x)
 loss = loss_func(out, batch_y)
 eval_loss += loss.data[0]
 pred = torch.max(out, 1)[1]
 num_correct = (pred == batch_y).sum()
 eval_acc += num_correct.data[0]
 print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
 test_data)), eval_acc / (len(test_data))))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用Python编写一个基于终端的实现翻译的脚本
Apr 24 Python
python比较两个列表大小的方法
Jul 11 Python
Python单元测试简单示例
Jul 03 Python
Python为何不能用可变对象作为默认参数的值
Jul 01 Python
python移位运算的实现
Jul 15 Python
Flask框架学习笔记之模板操作实例详解
Aug 15 Python
Python实现某论坛自动签到功能
Aug 20 Python
PageFactory设计模式基于python实现
Apr 14 Python
python实现在列表中查找某个元素的下标示例
Nov 16 Python
python二维图制作的实例代码
Dec 03 Python
Python的Tqdm模块实现进度条配置
Feb 24 Python
聊聊pytorch测试的时候为何要加上model.eval()
May 23 Python
pytorch构建网络模型的4种方法
Apr 13 #Python
Python输入二维数组方法
Apr 13 #Python
Python基于递归实现电话号码映射功能示例
Apr 13 #Python
Python的多维空数组赋值方法
Apr 13 #Python
python多维数组切片方法
Apr 13 #Python
Python实现判断并移除列表指定位置元素的方法
Apr 13 #Python
Python中的二维数组实例(list与numpy.array)
Apr 13 #Python
You might like
Zend 输出产生XML解析错误
2009/03/03 PHP
部署PHP项目应该注意的几点事项分享
2013/12/20 PHP
ThinkPHP中U方法的使用浅析
2014/06/13 PHP
PHP之预定义接口详解
2015/07/29 PHP
CSS+Table图文混排中实现文本自适应图片宽度(超简单+跨所有浏览器)
2009/02/14 Javascript
javascript 支持链式调用的异步调用框架Async.Operation
2009/08/04 Javascript
js 创建书签小工具之理论
2011/02/25 Javascript
jQuery $命名冲突解决方案汇总
2014/11/13 Javascript
js查找节点的方法小结
2015/01/13 Javascript
jQuery使用$.ajax进行异步刷新的方法(附demo下载)
2015/12/04 Javascript
prototype框架中美元符号$用法分析
2016/01/22 Javascript
理解javascript正则表达式
2016/03/08 Javascript
详解React native全局变量的使用(跨组件的通信)
2017/09/07 Javascript
js中json对象和字符串的理解及相互转化操作实现方法
2017/09/22 Javascript
node结合swig渲染摸板的方法
2018/04/11 Javascript
jQuery简单实现的HTML页面文本框模糊匹配查询功能完整示例
2018/05/09 jQuery
详解微信JS-SDK选择图片遇到的坑
2018/08/15 Javascript
解决layui的form里的元素进行动态生成,验证失效的问题
2019/09/14 Javascript
jQuery设置下拉框显示与隐藏效果的方法分析
2019/09/15 jQuery
详解如何在Javascript和Sass之间共享变量
2019/11/13 Javascript
Python学习小技巧之利用字典的默认行为
2017/05/20 Python
使用Python实现博客上进行自动翻页
2017/08/23 Python
Python 字符串转换为整形和浮点类型的方法
2018/07/17 Python
python Gunicorn服务器使用方法详解
2019/07/22 Python
python爬取音频下载的示例代码
2020/10/19 Python
怎样创建、运行java程序
2014/08/01 面试题
2014升学宴答谢词
2014/01/26 职场文书
贯彻落实“八项规定”思想汇报
2014/09/13 职场文书
2014年政协委员工作总结
2014/12/01 职场文书
个人党性分析材料
2014/12/19 职场文书
硕士论文致谢范文
2015/05/14 职场文书
办公室规章制度范本
2015/08/04 职场文书
大学校园餐饮创业计划书
2019/08/07 职场文书
python中的被动信息搜集
2021/04/29 Python
这样写python注释让代码更加的优雅
2021/06/02 Python
Vertica集成Apache Hudi重磅使用指南
2022/03/31 Servers