Pytorch入门之mnist分类实例


Posted in Python onApril 14, 2018

本文实例为大家分享了Pytorch入门之mnist分类的具体代码,供大家参考,具体内容如下

#!/usr/bin/env python
# -*- coding: utf-8 -*-
__author__ = 'denny'
__time__ = '2017-9-9 9:03'

import torch
import torchvision
from torch.autograd import Variable
import torch.utils.data.dataloader as Data

train_data = torchvision.datasets.MNIST(
 './mnist', train=True, transform=torchvision.transforms.ToTensor(), download=True
)
test_data = torchvision.datasets.MNIST(
 './mnist', train=False, transform=torchvision.transforms.ToTensor()
)
print("train_data:", train_data.train_data.size())
print("train_labels:", train_data.train_labels.size())
print("test_data:", test_data.test_data.size())

train_loader = Data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = Data.DataLoader(dataset=test_data, batch_size=64)


class Net(torch.nn.Module):
 def __init__(self):
 super(Net, self).__init__()
 self.conv1 = torch.nn.Sequential(
  torch.nn.Conv2d(1, 32, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2))
 self.conv2 = torch.nn.Sequential(
  torch.nn.Conv2d(32, 64, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2)
 )
 self.conv3 = torch.nn.Sequential(
  torch.nn.Conv2d(64, 64, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2)
 )
 self.dense = torch.nn.Sequential(
  torch.nn.Linear(64 * 3 * 3, 128),
  torch.nn.ReLU(),
  torch.nn.Linear(128, 10)
 )

 def forward(self, x):
 conv1_out = self.conv1(x)
 conv2_out = self.conv2(conv1_out)
 conv3_out = self.conv3(conv2_out)
 res = conv3_out.view(conv3_out.size(0), -1)
 out = self.dense(res)
 return out


model = Net()
print(model)

optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss()

for epoch in range(10):
 print('epoch {}'.format(epoch + 1))
 # training-----------------------------
 train_loss = 0.
 train_acc = 0.
 for batch_x, batch_y in train_loader:
 batch_x, batch_y = Variable(batch_x), Variable(batch_y)
 out = model(batch_x)
 loss = loss_func(out, batch_y)
 train_loss += loss.data[0]
 pred = torch.max(out, 1)[1]
 train_correct = (pred == batch_y).sum()
 train_acc += train_correct.data[0]
 optimizer.zero_grad()
 loss.backward()
 optimizer.step()
 print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
 train_data)), train_acc / (len(train_data))))

 # evaluation--------------------------------
 model.eval()
 eval_loss = 0.
 eval_acc = 0.
 for batch_x, batch_y in test_loader:
 batch_x, batch_y = Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)
 out = model(batch_x)
 loss = loss_func(out, batch_y)
 eval_loss += loss.data[0]
 pred = torch.max(out, 1)[1]
 num_correct = (pred == batch_y).sum()
 eval_acc += num_correct.data[0]
 print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
 test_data)), eval_acc / (len(test_data))))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中Collection的使用小技巧
Aug 18 Python
python实现用于测试网站访问速率的方法
May 26 Python
详解Django框架中的视图级缓存
Jul 23 Python
Python中functools模块函数解析
Mar 12 Python
python中学习K-Means和图片压缩
Nov 20 Python
Python遍历pandas数据方法总结
Feb 09 Python
解决Python3 抓取微信账单信息问题
Jul 19 Python
基于python操作ES实例详解
Nov 16 Python
Python函数的默认参数设计示例详解
Dec 01 Python
Python爬虫之Selenium中frame/iframe表单嵌套页面
Dec 04 Python
pycharm 实现光标快速移动到括号外或行尾的操作
Feb 05 Python
用Python远程登陆服务器的步骤
Apr 16 Python
pytorch构建网络模型的4种方法
Apr 13 #Python
Python输入二维数组方法
Apr 13 #Python
Python基于递归实现电话号码映射功能示例
Apr 13 #Python
Python的多维空数组赋值方法
Apr 13 #Python
python多维数组切片方法
Apr 13 #Python
Python实现判断并移除列表指定位置元素的方法
Apr 13 #Python
Python中的二维数组实例(list与numpy.array)
Apr 13 #Python
You might like
有道搜索和IP138的IP的API接口(PHP应用)
2012/11/29 PHP
页面版文本框智能提示JS代码
2009/11/20 Javascript
Jquery动态改变图片IMG的src地址示例
2013/06/25 Javascript
JQuery Highcharts 动态生成图表的方法
2013/11/15 Javascript
做好七件事帮你提升jQuery的性能
2014/02/06 Javascript
js实现网页右上角滑出会自动消失大幅广告的方法
2015/02/27 Javascript
PHP+MySQL+jQuery随意拖动层并即时保存拖动位置实例讲解
2015/10/09 Javascript
JS实现iframe自适应高度的方法(兼容IE与FireFox)
2016/06/24 Javascript
AngularJS中$http服务常用的应用及参数
2016/08/22 Javascript
分享一道关于闭包、bind和this的面试题
2017/02/20 Javascript
从零学习node.js之搭建http服务器(二)
2017/02/21 Javascript
js,jq,css多方面实现简易下拉菜单功能
2017/05/13 Javascript
Vue学习笔记进阶篇之过渡状态详解
2017/07/14 Javascript
JS基础之逻辑结构与循环操作示例
2020/01/19 Javascript
详解JSON.stringify()的5个秘密特性
2020/05/26 Javascript
[03:34]2014DOTA2西雅图国际邀请赛 淘汰赛7月15日TOPPLAY
2014/07/15 DOTA
[49:08]Secret vs VP 2018国际邀请赛小组赛BO2 第一场 8.17
2018/08/20 DOTA
Python读写Excel文件方法介绍
2014/11/22 Python
Python机器学习算法之k均值聚类(k-means)
2018/02/23 Python
运行django项目指定IP和端口的方法
2018/05/14 Python
python绘制散点图并标记序号的方法
2018/12/11 Python
python交易记录整合交易类详解
2019/07/03 Python
Python的几种主动结束程序方式
2019/11/22 Python
python [:3] 实现提取数组中的数
2019/11/27 Python
CSS3教程(1):什么是CSS3
2009/04/02 HTML / CSS
CSS3过渡transition效果实例介绍
2016/05/03 HTML / CSS
英国家喻户晓的高街品牌:River Island
2017/11/28 全球购物
中国跨境在线时尚零售商:Bellelily
2018/04/06 全球购物
如何进行有效的自我评价
2013/09/27 职场文书
2014年残疾人工作总结
2014/12/06 职场文书
2015年光棍节活动总结
2015/03/24 职场文书
业务员管理制度范本
2015/08/06 职场文书
2016猴年春节慰问信
2015/11/30 职场文书
《围炉夜话》110句人生箴言,精辟有内涵,引人深思
2019/10/23 职场文书
导游词之杭州岳王庙
2019/11/13 职场文书
Python爬虫之自动爬取某车之家各车销售数据
2021/06/02 Python