Pytorch入门之mnist分类实例


Posted in Python onApril 14, 2018

本文实例为大家分享了Pytorch入门之mnist分类的具体代码,供大家参考,具体内容如下

#!/usr/bin/env python
# -*- coding: utf-8 -*-
__author__ = 'denny'
__time__ = '2017-9-9 9:03'

import torch
import torchvision
from torch.autograd import Variable
import torch.utils.data.dataloader as Data

train_data = torchvision.datasets.MNIST(
 './mnist', train=True, transform=torchvision.transforms.ToTensor(), download=True
)
test_data = torchvision.datasets.MNIST(
 './mnist', train=False, transform=torchvision.transforms.ToTensor()
)
print("train_data:", train_data.train_data.size())
print("train_labels:", train_data.train_labels.size())
print("test_data:", test_data.test_data.size())

train_loader = Data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = Data.DataLoader(dataset=test_data, batch_size=64)


class Net(torch.nn.Module):
 def __init__(self):
 super(Net, self).__init__()
 self.conv1 = torch.nn.Sequential(
  torch.nn.Conv2d(1, 32, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2))
 self.conv2 = torch.nn.Sequential(
  torch.nn.Conv2d(32, 64, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2)
 )
 self.conv3 = torch.nn.Sequential(
  torch.nn.Conv2d(64, 64, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2)
 )
 self.dense = torch.nn.Sequential(
  torch.nn.Linear(64 * 3 * 3, 128),
  torch.nn.ReLU(),
  torch.nn.Linear(128, 10)
 )

 def forward(self, x):
 conv1_out = self.conv1(x)
 conv2_out = self.conv2(conv1_out)
 conv3_out = self.conv3(conv2_out)
 res = conv3_out.view(conv3_out.size(0), -1)
 out = self.dense(res)
 return out


model = Net()
print(model)

optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss()

for epoch in range(10):
 print('epoch {}'.format(epoch + 1))
 # training-----------------------------
 train_loss = 0.
 train_acc = 0.
 for batch_x, batch_y in train_loader:
 batch_x, batch_y = Variable(batch_x), Variable(batch_y)
 out = model(batch_x)
 loss = loss_func(out, batch_y)
 train_loss += loss.data[0]
 pred = torch.max(out, 1)[1]
 train_correct = (pred == batch_y).sum()
 train_acc += train_correct.data[0]
 optimizer.zero_grad()
 loss.backward()
 optimizer.step()
 print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
 train_data)), train_acc / (len(train_data))))

 # evaluation--------------------------------
 model.eval()
 eval_loss = 0.
 eval_acc = 0.
 for batch_x, batch_y in test_loader:
 batch_x, batch_y = Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)
 out = model(batch_x)
 loss = loss_func(out, batch_y)
 eval_loss += loss.data[0]
 pred = torch.max(out, 1)[1]
 num_correct = (pred == batch_y).sum()
 eval_acc += num_correct.data[0]
 print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
 test_data)), eval_acc / (len(test_data))))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python计算对角线有理函数插值的方法
May 07 Python
对变量赋值的理解--Pyton中让两个值互换的实现方法
Nov 29 Python
基于python批量处理dat文件及科学计算方法详解
May 08 Python
uwsgi+nginx部署Django项目操作示例
Dec 04 Python
python字典一键多值实例代码分享
Jun 14 Python
Python pandas.DataFrame调整列顺序及修改index名的方法
Jun 21 Python
大家都说好用的Python命令行库click的使用
Nov 07 Python
pytorch 实现将自己的图片数据处理成可以训练的图片类型
Jan 08 Python
基于tf.shape(tensor)和tensor.shape()的区别说明
Jun 30 Python
Python基于爬虫实现全网搜索并下载音乐
Feb 14 Python
python Autopep8实现按PEP8风格自动排版Python代码
Mar 02 Python
python利用pandas分析学生期末成绩实例代码
Jul 09 Python
pytorch构建网络模型的4种方法
Apr 13 #Python
Python输入二维数组方法
Apr 13 #Python
Python基于递归实现电话号码映射功能示例
Apr 13 #Python
Python的多维空数组赋值方法
Apr 13 #Python
python多维数组切片方法
Apr 13 #Python
Python实现判断并移除列表指定位置元素的方法
Apr 13 #Python
Python中的二维数组实例(list与numpy.array)
Apr 13 #Python
You might like
咖啡语言
2021/03/03 咖啡文化
如何使用PHP中的字符串函数
2006/11/24 PHP
PHP aes (ecb)解密后乱码问题
2015/06/22 PHP
Mootools 1.2教程 设置和获取样式表属性
2009/09/15 Javascript
jquery+json 通用三级联动下拉列表
2010/04/19 Javascript
JavaScript 计算图片加载数量的代码
2011/01/01 Javascript
让innerText在firefox火狐和IE浏览器都能用的写法
2011/05/14 Javascript
javascript中的数字与字符串相加实例分析
2011/08/14 Javascript
新浪微博字数统计 textarea字数统计实现代码
2011/08/28 Javascript
js 利用image对象实现图片的预加载提高访问速度
2013/03/29 Javascript
当滚动条滚动到页面底部自动加载增加内容的js代码
2014/05/13 Javascript
Jquery性能优化详解
2014/05/15 Javascript
javascript实现禁止复制网页内容
2014/12/16 Javascript
JS传递对象数组为参数给后端,后端获取的实例代码
2016/06/28 Javascript
jQuery实现搜索页面关键字的功能
2017/02/16 Javascript
微信小程序分页加载的实例代码
2017/07/11 Javascript
vue实现多个元素或多个组件之间动画效果
2018/09/25 Javascript
JavaScript(js)处理的HTML事件、键盘事件、鼠标事件简单示例
2019/11/19 Javascript
[02:27]刀塔重生降临
2015/10/14 DOTA
在Python的Django框架中调用方法和处理无效变量
2015/07/15 Python
深入理解NumPy简明教程---数组1
2016/12/17 Python
详解Python中的分组函数groupby和itertools)
2018/07/11 Python
python读取word文档,插入mysql数据库的示例代码
2018/11/07 Python
Django REST framework 分页的实现代码
2019/06/19 Python
Flask框架学习笔记之消息提示与异常处理操作详解
2019/08/15 Python
python opencv实现信用卡的数字识别
2020/01/12 Python
在Ubuntu中安装并配置Pycharm教程的实现方法
2021/01/06 Python
Python实现粒子群算法的示例
2021/02/14 Python
什么叫应用程序域?什么是托管代码?什么是强类型系统?什么是装箱和拆箱?什么是重载?CTS、CLS和CLR分别作何解释?
2012/05/23 面试题
int和Integer有什么区别
2013/05/25 面试题
外贸公司实习自我鉴定
2013/09/24 职场文书
体育课外活动总结
2014/07/08 职场文书
学校领导干部民主生活会整改方案
2014/09/29 职场文书
2014年营销工作总结
2014/11/22 职场文书
保险公司2016开门红口号集锦
2015/12/24 职场文书
Pytorch中的数据集划分&正则化方法
2021/05/27 Python