Pytorch入门之mnist分类实例


Posted in Python onApril 14, 2018

本文实例为大家分享了Pytorch入门之mnist分类的具体代码,供大家参考,具体内容如下

#!/usr/bin/env python
# -*- coding: utf-8 -*-
__author__ = 'denny'
__time__ = '2017-9-9 9:03'

import torch
import torchvision
from torch.autograd import Variable
import torch.utils.data.dataloader as Data

train_data = torchvision.datasets.MNIST(
 './mnist', train=True, transform=torchvision.transforms.ToTensor(), download=True
)
test_data = torchvision.datasets.MNIST(
 './mnist', train=False, transform=torchvision.transforms.ToTensor()
)
print("train_data:", train_data.train_data.size())
print("train_labels:", train_data.train_labels.size())
print("test_data:", test_data.test_data.size())

train_loader = Data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = Data.DataLoader(dataset=test_data, batch_size=64)


class Net(torch.nn.Module):
 def __init__(self):
 super(Net, self).__init__()
 self.conv1 = torch.nn.Sequential(
  torch.nn.Conv2d(1, 32, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2))
 self.conv2 = torch.nn.Sequential(
  torch.nn.Conv2d(32, 64, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2)
 )
 self.conv3 = torch.nn.Sequential(
  torch.nn.Conv2d(64, 64, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2)
 )
 self.dense = torch.nn.Sequential(
  torch.nn.Linear(64 * 3 * 3, 128),
  torch.nn.ReLU(),
  torch.nn.Linear(128, 10)
 )

 def forward(self, x):
 conv1_out = self.conv1(x)
 conv2_out = self.conv2(conv1_out)
 conv3_out = self.conv3(conv2_out)
 res = conv3_out.view(conv3_out.size(0), -1)
 out = self.dense(res)
 return out


model = Net()
print(model)

optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss()

for epoch in range(10):
 print('epoch {}'.format(epoch + 1))
 # training-----------------------------
 train_loss = 0.
 train_acc = 0.
 for batch_x, batch_y in train_loader:
 batch_x, batch_y = Variable(batch_x), Variable(batch_y)
 out = model(batch_x)
 loss = loss_func(out, batch_y)
 train_loss += loss.data[0]
 pred = torch.max(out, 1)[1]
 train_correct = (pred == batch_y).sum()
 train_acc += train_correct.data[0]
 optimizer.zero_grad()
 loss.backward()
 optimizer.step()
 print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
 train_data)), train_acc / (len(train_data))))

 # evaluation--------------------------------
 model.eval()
 eval_loss = 0.
 eval_acc = 0.
 for batch_x, batch_y in test_loader:
 batch_x, batch_y = Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)
 out = model(batch_x)
 loss = loss_func(out, batch_y)
 eval_loss += loss.data[0]
 pred = torch.max(out, 1)[1]
 num_correct = (pred == batch_y).sum()
 eval_acc += num_correct.data[0]
 print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
 test_data)), eval_acc / (len(test_data))))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python下的常用下载安装工具pip的安装方法
Nov 13 Python
Python中用字符串调用函数或方法示例代码
Aug 04 Python
python实现基于信息增益的决策树归纳
Dec 18 Python
Python小程序 控制鼠标循环点击代码实例
Oct 08 Python
Python字典常见操作实例小结【定义、添加、删除、遍历】
Oct 25 Python
Python 连接 MySQL 的几种方法
Sep 09 Python
python如何爬取动态网站
Sep 09 Python
用Python进行websocket接口测试
Oct 16 Python
GitHub上值得推荐的8个python 项目
Oct 30 Python
Python Unittest原理及基本使用方法
Nov 06 Python
如何利用python 读取配置文件
Jan 06 Python
matplotlib源码解析标题实现(窗口标题,标题,子图标题不同之间的差异)
Feb 22 Python
pytorch构建网络模型的4种方法
Apr 13 #Python
Python输入二维数组方法
Apr 13 #Python
Python基于递归实现电话号码映射功能示例
Apr 13 #Python
Python的多维空数组赋值方法
Apr 13 #Python
python多维数组切片方法
Apr 13 #Python
Python实现判断并移除列表指定位置元素的方法
Apr 13 #Python
Python中的二维数组实例(list与numpy.array)
Apr 13 #Python
You might like
在PWS上安装PHP4.0正式版
2006/10/09 PHP
PHP简单系统数据添加以及数据删除模块源文件下载
2008/06/07 PHP
Ubuntu中搭建Nginx、PHP环境最简单的方法
2015/03/05 PHP
php文件读取方法实例分析
2015/06/20 PHP
PHP 使用位运算实现四则运算的代码
2021/03/09 PHP
js导出格式化的excel 实例方法
2013/07/17 Javascript
javascript页面加载完执行事件代码
2014/02/11 Javascript
JS清空多文本框、文本域示例代码
2014/02/24 Javascript
iframe实用操作锦集
2014/04/22 Javascript
js简单抽奖代码
2015/01/16 Javascript
浅谈Jquery中Ajax异步请求中的async参数的作用
2016/06/06 Javascript
jQuery实现带遮罩层效果的blockUI弹出层示例【附demo源码下载】
2016/09/14 Javascript
简单实现Vue的observer和watcher
2016/12/21 Javascript
jQuery插件zTree实现单独选中根节点中第一个节点示例
2017/03/08 Javascript
angular仿支付宝密码框输入效果
2017/03/25 Javascript
浅谈js-FCC算法Friendly Date Ranges(详解)
2017/04/10 Javascript
Javascript操作dom对象之select全面解析
2017/04/24 Javascript
jQuery实现点击DIV同时点击CheckBox,并为DIV上背景色的实例
2017/12/18 jQuery
vue中element-ui表格缩略图悬浮放大功能的实例代码
2018/06/26 Javascript
使用Angular Cli如何创建Angular私有库详解
2019/01/30 Javascript
vue实现图片上传功能
2020/05/28 Javascript
angular中的post请求处理示例详解
2020/06/30 Javascript
Python批量重命名同一文件夹下文件的方法
2015/05/25 Python
使用Python将数组的元素导出到变量中(unpacking)
2016/10/27 Python
python批量将excel内容进行翻译写入功能
2019/10/10 Python
tornado+celery的简单使用详解
2019/12/21 Python
Anaconda+spyder+pycharm的pytorch配置详解(GPU)
2020/10/18 Python
PyCharm 解决找不到新打开项目的窗口问题
2021/01/15 Python
matplotlib之多边形选区(PolygonSelector)的使用
2021/02/24 Python
幼儿园的门卫岗位职责
2014/04/10 职场文书
态度决定一切演讲稿
2014/05/20 职场文书
医院护士党的群众路线教育实践活动对照检查材料思想汇报
2014/10/04 职场文书
党员干部学习心得体会
2016/01/23 职场文书
乡镇干部学习心得体会
2016/01/23 职场文书
创业计划书之寿司
2019/07/19 职场文书
JavaScript 事件捕获冒泡与捕获详情
2021/11/11 Javascript