Pytorch入门之mnist分类实例


Posted in Python onApril 14, 2018

本文实例为大家分享了Pytorch入门之mnist分类的具体代码,供大家参考,具体内容如下

#!/usr/bin/env python
# -*- coding: utf-8 -*-
__author__ = 'denny'
__time__ = '2017-9-9 9:03'

import torch
import torchvision
from torch.autograd import Variable
import torch.utils.data.dataloader as Data

train_data = torchvision.datasets.MNIST(
 './mnist', train=True, transform=torchvision.transforms.ToTensor(), download=True
)
test_data = torchvision.datasets.MNIST(
 './mnist', train=False, transform=torchvision.transforms.ToTensor()
)
print("train_data:", train_data.train_data.size())
print("train_labels:", train_data.train_labels.size())
print("test_data:", test_data.test_data.size())

train_loader = Data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = Data.DataLoader(dataset=test_data, batch_size=64)


class Net(torch.nn.Module):
 def __init__(self):
 super(Net, self).__init__()
 self.conv1 = torch.nn.Sequential(
  torch.nn.Conv2d(1, 32, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2))
 self.conv2 = torch.nn.Sequential(
  torch.nn.Conv2d(32, 64, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2)
 )
 self.conv3 = torch.nn.Sequential(
  torch.nn.Conv2d(64, 64, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2)
 )
 self.dense = torch.nn.Sequential(
  torch.nn.Linear(64 * 3 * 3, 128),
  torch.nn.ReLU(),
  torch.nn.Linear(128, 10)
 )

 def forward(self, x):
 conv1_out = self.conv1(x)
 conv2_out = self.conv2(conv1_out)
 conv3_out = self.conv3(conv2_out)
 res = conv3_out.view(conv3_out.size(0), -1)
 out = self.dense(res)
 return out


model = Net()
print(model)

optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss()

for epoch in range(10):
 print('epoch {}'.format(epoch + 1))
 # training-----------------------------
 train_loss = 0.
 train_acc = 0.
 for batch_x, batch_y in train_loader:
 batch_x, batch_y = Variable(batch_x), Variable(batch_y)
 out = model(batch_x)
 loss = loss_func(out, batch_y)
 train_loss += loss.data[0]
 pred = torch.max(out, 1)[1]
 train_correct = (pred == batch_y).sum()
 train_acc += train_correct.data[0]
 optimizer.zero_grad()
 loss.backward()
 optimizer.step()
 print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
 train_data)), train_acc / (len(train_data))))

 # evaluation--------------------------------
 model.eval()
 eval_loss = 0.
 eval_acc = 0.
 for batch_x, batch_y in test_loader:
 batch_x, batch_y = Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)
 out = model(batch_x)
 loss = loss_func(out, batch_y)
 eval_loss += loss.data[0]
 pred = torch.max(out, 1)[1]
 num_correct = (pred == batch_y).sum()
 eval_acc += num_correct.data[0]
 print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
 test_data)), eval_acc / (len(test_data))))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python日期操作学习笔记
Oct 07 Python
python处理文本文件并生成指定格式的文件
Jul 31 Python
Python开发之快速搭建自动回复微信公众号功能
Apr 22 Python
举例讲解Python中字典的合并值相加与异或对比
Jun 04 Python
视觉直观感受若干常用排序算法
Apr 13 Python
在centos7中分布式部署pyspider
May 03 Python
Python使用PIL模块生成随机验证码
Nov 21 Python
django框架之cookie/session的使用示例(小结)
Oct 15 Python
python程序封装为win32服务的方法
Mar 07 Python
Django渲染Markdown文章目录的方法示例
Jan 02 Python
python对execl 处理操作代码
Jun 22 Python
TensorFlow Autodiff自动微分详解
Jul 06 Python
pytorch构建网络模型的4种方法
Apr 13 #Python
Python输入二维数组方法
Apr 13 #Python
Python基于递归实现电话号码映射功能示例
Apr 13 #Python
Python的多维空数组赋值方法
Apr 13 #Python
python多维数组切片方法
Apr 13 #Python
Python实现判断并移除列表指定位置元素的方法
Apr 13 #Python
Python中的二维数组实例(list与numpy.array)
Apr 13 #Python
You might like
php数组编码转换示例详解
2014/03/11 PHP
php+xml编程之SimpleXML的应用实例
2015/01/24 PHP
PHP框架Laravel中实现supervisor执行异步进程的方法
2017/06/07 PHP
Prototype Object对象 学习
2009/07/12 Javascript
动态创建样式表在各浏览器中的差异测试代码
2011/09/13 Javascript
js获取电脑分辨率的思路及操作
2013/11/22 Javascript
js中this的用法实例分析
2015/01/10 Javascript
Angular.js跨controller实现参数传递的两种方法
2017/02/20 Javascript
webpack2.0搭建前端项目的教程详解
2017/04/05 Javascript
详解在Angularjs中ui-sref和$state.go如何传递参数
2017/04/24 Javascript
JS组件系列之MVVM组件构建自己的Vue组件
2017/04/28 Javascript
JS设置手机验证码60s等待实现代码
2017/06/14 Javascript
Angular移动端页面input无法输入的解决方法
2017/11/14 Javascript
vue中使用element-ui进行表单验证的实例代码
2018/06/22 Javascript
vue项目使用.env文件配置全局环境变量的方法
2019/10/24 Javascript
Windows下安装 node 的版本控制工具 nvm
2020/02/06 Javascript
详解Vue中的watch和computed
2020/11/09 Javascript
[49:30]DOTA2-DPC中国联赛正赛 Dragon vs Dynasty BO3 第二场 3月4日
2021/03/11 DOTA
Python 出现错误TypeError: ‘NoneType’ object is not iterable解决办法
2017/01/12 Python
python的构建工具setup.py的方法使用示例
2017/10/23 Python
Python 网络编程之UDP发送接收数据功能示例【基于socket套接字】
2019/10/11 Python
python连接打印机实现打印文档、图片、pdf文件等功能
2020/02/07 Python
python 实现的IP 存活扫描脚本
2020/12/10 Python
BISSELL官网:北美吸尘器第一品牌
2019/03/14 全球购物
医生实习工作总结的自我评价
2013/09/27 职场文书
大学专科生推荐信范文
2013/11/23 职场文书
计算机专业学生求职信分享
2013/12/15 职场文书
黄继光的英雄事迹材料
2014/02/13 职场文书
农民工工资支付承诺函
2014/03/31 职场文书
《花木兰》教学反思
2014/04/09 职场文书
诉讼授权委托书
2014/10/15 职场文书
大学生旷课检讨书1000字
2015/02/19 职场文书
会计手工模拟做账心得体会
2016/01/22 职场文书
HTML中table表格拆分合并(colspan、rowspan)
2021/04/07 HTML / CSS
vue+spring boot实现校验码功能
2021/05/27 Vue.js
Linux中一对多配置日志服务器的详细步骤
2022/07/23 Servers