Pytorch入门之mnist分类实例


Posted in Python onApril 14, 2018

本文实例为大家分享了Pytorch入门之mnist分类的具体代码,供大家参考,具体内容如下

#!/usr/bin/env python
# -*- coding: utf-8 -*-
__author__ = 'denny'
__time__ = '2017-9-9 9:03'

import torch
import torchvision
from torch.autograd import Variable
import torch.utils.data.dataloader as Data

train_data = torchvision.datasets.MNIST(
 './mnist', train=True, transform=torchvision.transforms.ToTensor(), download=True
)
test_data = torchvision.datasets.MNIST(
 './mnist', train=False, transform=torchvision.transforms.ToTensor()
)
print("train_data:", train_data.train_data.size())
print("train_labels:", train_data.train_labels.size())
print("test_data:", test_data.test_data.size())

train_loader = Data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = Data.DataLoader(dataset=test_data, batch_size=64)


class Net(torch.nn.Module):
 def __init__(self):
 super(Net, self).__init__()
 self.conv1 = torch.nn.Sequential(
  torch.nn.Conv2d(1, 32, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2))
 self.conv2 = torch.nn.Sequential(
  torch.nn.Conv2d(32, 64, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2)
 )
 self.conv3 = torch.nn.Sequential(
  torch.nn.Conv2d(64, 64, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2)
 )
 self.dense = torch.nn.Sequential(
  torch.nn.Linear(64 * 3 * 3, 128),
  torch.nn.ReLU(),
  torch.nn.Linear(128, 10)
 )

 def forward(self, x):
 conv1_out = self.conv1(x)
 conv2_out = self.conv2(conv1_out)
 conv3_out = self.conv3(conv2_out)
 res = conv3_out.view(conv3_out.size(0), -1)
 out = self.dense(res)
 return out


model = Net()
print(model)

optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss()

for epoch in range(10):
 print('epoch {}'.format(epoch + 1))
 # training-----------------------------
 train_loss = 0.
 train_acc = 0.
 for batch_x, batch_y in train_loader:
 batch_x, batch_y = Variable(batch_x), Variable(batch_y)
 out = model(batch_x)
 loss = loss_func(out, batch_y)
 train_loss += loss.data[0]
 pred = torch.max(out, 1)[1]
 train_correct = (pred == batch_y).sum()
 train_acc += train_correct.data[0]
 optimizer.zero_grad()
 loss.backward()
 optimizer.step()
 print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
 train_data)), train_acc / (len(train_data))))

 # evaluation--------------------------------
 model.eval()
 eval_loss = 0.
 eval_acc = 0.
 for batch_x, batch_y in test_loader:
 batch_x, batch_y = Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)
 out = model(batch_x)
 loss = loss_func(out, batch_y)
 eval_loss += loss.data[0]
 pred = torch.max(out, 1)[1]
 num_correct = (pred == batch_y).sum()
 eval_acc += num_correct.data[0]
 print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
 test_data)), eval_acc / (len(test_data))))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现CET查分的方法
Mar 10 Python
Python简单实现enum功能的方法
Apr 25 Python
python实现读取并显示图片的两种方法
Jan 13 Python
python3制作捧腹网段子页爬虫
Feb 12 Python
利用Opencv中Houghline方法实现直线检测
Feb 11 Python
python3.x 将byte转成字符串的方法
Jul 17 Python
Python openpyxl 遍历所有sheet 查找特定字符串的方法
Dec 10 Python
浅析python redis的连接及相关操作
Nov 07 Python
python错误调试及单元文档测试过程解析
Dec 19 Python
Pytorch 实现数据集自定义读取
Jan 18 Python
如何利用Python写个坦克大战
Nov 18 Python
Appium+Python实现简单的自动化登录测试的实现
Jan 26 Python
pytorch构建网络模型的4种方法
Apr 13 #Python
Python输入二维数组方法
Apr 13 #Python
Python基于递归实现电话号码映射功能示例
Apr 13 #Python
Python的多维空数组赋值方法
Apr 13 #Python
python多维数组切片方法
Apr 13 #Python
Python实现判断并移除列表指定位置元素的方法
Apr 13 #Python
Python中的二维数组实例(list与numpy.array)
Apr 13 #Python
You might like
php5编程中的异常处理详细方法介绍
2008/07/29 PHP
ThinkPHP控制器间实现相互调用的方法
2014/10/31 PHP
微信公众平台接口开发入门示例
2014/12/24 PHP
php绘制圆形的方法
2015/01/24 PHP
php定义参数数量可变的函数用法实例
2015/03/16 PHP
PHP dirname功能及原理实例解析
2020/10/28 PHP
jquery下组织javascript代码(js函数化)
2010/08/25 Javascript
jQuery图片滚动图片的效果(另类实现)
2013/06/02 Javascript
JQuery实现绚丽的横向下拉菜单
2013/12/19 Javascript
jqGrid随窗口大小变化自适应大小的示例代码
2013/12/28 Javascript
浅析node.js中close事件
2014/11/26 Javascript
JavaScript 学习笔记之数据类型
2015/01/14 Javascript
javascript实现框架高度随内容改变的方法
2015/07/23 Javascript
JavaScript+html5 canvas绘制的小人效果
2016/01/27 Javascript
js+canvas绘制矩形的方法
2016/01/28 Javascript
JavaScript获取select中text值的方法
2017/02/13 Javascript
jQuery中clone()函数实现表单中增加和减少输入项
2017/05/13 jQuery
javascript 开发之网页兼容各种浏览器
2017/09/28 Javascript
使用Vue-Router 2实现路由功能实例详解
2017/11/14 Javascript
浅谈vuex 闲置状态重置方案
2018/01/04 Javascript
vue鼠标悬停事件实例详解
2019/04/01 Javascript
React传值 组件传值 之间的关系详解
2019/08/26 Javascript
JavaScript创建、读取和删除cookie
2019/09/03 Javascript
python如何实现内容写在图片上
2018/03/23 Python
基于django micro搭建网站实现加水印功能
2020/05/22 Python
德国柯吉澳趣味家居:Koziol
2017/08/24 全球购物
日本动漫周边服饰销售网站:Atsuko
2019/12/16 全球购物
Windows和Linux动态库应用异同
2016/07/28 面试题
大学生护理专业自荐信
2013/10/03 职场文书
大学生第一学年自我鉴定
2014/09/12 职场文书
主持人大赛开场白
2015/05/29 职场文书
立案决定书范文
2015/06/24 职场文书
交通事故协议书范本
2016/03/19 职场文书
如何撰写出一份完美的商业计划书?
2019/07/12 职场文书
GO语言异常处理分析 err接口及defer延迟
2022/04/14 Golang
mysql 排序失效
2022/05/20 MySQL