Pytorch入门之mnist分类实例


Posted in Python onApril 14, 2018

本文实例为大家分享了Pytorch入门之mnist分类的具体代码,供大家参考,具体内容如下

#!/usr/bin/env python
# -*- coding: utf-8 -*-
__author__ = 'denny'
__time__ = '2017-9-9 9:03'

import torch
import torchvision
from torch.autograd import Variable
import torch.utils.data.dataloader as Data

train_data = torchvision.datasets.MNIST(
 './mnist', train=True, transform=torchvision.transforms.ToTensor(), download=True
)
test_data = torchvision.datasets.MNIST(
 './mnist', train=False, transform=torchvision.transforms.ToTensor()
)
print("train_data:", train_data.train_data.size())
print("train_labels:", train_data.train_labels.size())
print("test_data:", test_data.test_data.size())

train_loader = Data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = Data.DataLoader(dataset=test_data, batch_size=64)


class Net(torch.nn.Module):
 def __init__(self):
 super(Net, self).__init__()
 self.conv1 = torch.nn.Sequential(
  torch.nn.Conv2d(1, 32, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2))
 self.conv2 = torch.nn.Sequential(
  torch.nn.Conv2d(32, 64, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2)
 )
 self.conv3 = torch.nn.Sequential(
  torch.nn.Conv2d(64, 64, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2)
 )
 self.dense = torch.nn.Sequential(
  torch.nn.Linear(64 * 3 * 3, 128),
  torch.nn.ReLU(),
  torch.nn.Linear(128, 10)
 )

 def forward(self, x):
 conv1_out = self.conv1(x)
 conv2_out = self.conv2(conv1_out)
 conv3_out = self.conv3(conv2_out)
 res = conv3_out.view(conv3_out.size(0), -1)
 out = self.dense(res)
 return out


model = Net()
print(model)

optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss()

for epoch in range(10):
 print('epoch {}'.format(epoch + 1))
 # training-----------------------------
 train_loss = 0.
 train_acc = 0.
 for batch_x, batch_y in train_loader:
 batch_x, batch_y = Variable(batch_x), Variable(batch_y)
 out = model(batch_x)
 loss = loss_func(out, batch_y)
 train_loss += loss.data[0]
 pred = torch.max(out, 1)[1]
 train_correct = (pred == batch_y).sum()
 train_acc += train_correct.data[0]
 optimizer.zero_grad()
 loss.backward()
 optimizer.step()
 print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
 train_data)), train_acc / (len(train_data))))

 # evaluation--------------------------------
 model.eval()
 eval_loss = 0.
 eval_acc = 0.
 for batch_x, batch_y in test_loader:
 batch_x, batch_y = Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)
 out = model(batch_x)
 loss = loss_func(out, batch_y)
 eval_loss += loss.data[0]
 pred = torch.max(out, 1)[1]
 num_correct = (pred == batch_y).sum()
 eval_acc += num_correct.data[0]
 print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
 test_data)), eval_acc / (len(test_data))))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现Mysql数据库连接池实例详解
Apr 11 Python
Python实现字典(dict)的迭代操作示例
Jun 05 Python
python使用matplotlib画饼状图
Sep 25 Python
一百多行python代码实现抢票助手
Sep 25 Python
python实现推箱子游戏
Mar 25 Python
对python:threading.Thread类的使用方法详解
Jan 31 Python
Python调用百度根据经纬度查询地址的示例代码
Jul 07 Python
python3实现微型的web服务器
Sep 03 Python
Python日志syslog使用原理详解
Feb 18 Python
基于python检查SSL证书到期情况代码实例
Apr 04 Python
Python3爬虫RedisDump的安装步骤
Feb 20 Python
使用python生成大量数据写入es数据库并查询操作(2)
Sep 23 Python
pytorch构建网络模型的4种方法
Apr 13 #Python
Python输入二维数组方法
Apr 13 #Python
Python基于递归实现电话号码映射功能示例
Apr 13 #Python
Python的多维空数组赋值方法
Apr 13 #Python
python多维数组切片方法
Apr 13 #Python
Python实现判断并移除列表指定位置元素的方法
Apr 13 #Python
Python中的二维数组实例(list与numpy.array)
Apr 13 #Python
You might like
整合了前面的PHP数据库连接类~~做成一个分页类!
2006/11/25 PHP
laravel实现登录时监听事件,添加登录用户的记录方法
2019/09/30 PHP
PHP+MySQL实现在线测试答题实例
2020/01/02 PHP
获取Javscript执行函数名称的方法
2006/12/22 Javascript
收藏一些不常用,但是有用的代码
2007/03/12 Javascript
用Javascript 获取页面元素的位置的代码
2009/09/25 Javascript
Javascript 刷新全集常用代码
2009/11/22 Javascript
jquery 仿QQ校友的DIV模拟窗口效果源码
2010/03/24 Javascript
5个最佳的Javascript日期处理类库分享
2012/04/15 Javascript
jquery简单瀑布流实现原理及ie8下测试代码
2013/01/23 Javascript
JS链式调用的实现方法
2013/03/07 Javascript
JS常用表单验证方法总结
2014/05/22 Javascript
JavaScript学习笔记(三):JavaScript也有入口Main函数
2015/09/12 Javascript
jQuery实现下拉框左右移动(全部移动,已选移动)
2016/04/15 Javascript
JavaScript中的各种操作符使用总结
2016/05/26 Javascript
js智能获取浏览器版本UA信息的方法
2016/08/08 Javascript
微信小程序 tabs选项卡效果的实现
2017/01/05 Javascript
微信小程序scroll-view实现横向滚动和上拉加载示例
2017/03/06 Javascript
Bootstrap模态框插入视频的实现代码
2017/06/25 Javascript
React中常见的动画实现的几种方式
2018/01/10 Javascript
使用vue如何构建一个自动建站项目
2018/02/05 Javascript
Vue2.0用户权限控制解决方案的示例
2018/02/10 Javascript
Vue2.0点击切换类名改变样式的方法
2018/08/22 Javascript
[31:00]2014 DOTA2华西杯精英邀请赛5 24 NewBee VS iG
2014/05/25 DOTA
[01:06]DOTA2小知识课堂 Ep.02 吹风竟可解梦境缠绕
2019/12/05 DOTA
Python函数参数操作详解
2018/08/03 Python
Python 加密与解密小结
2018/12/06 Python
20行python代码的入门级小游戏的详解
2019/05/05 Python
python如何将两个txt文件内容合并
2019/10/18 Python
wxPython之wx.DC绘制形状
2019/11/19 Python
基于Python 中函数的 收集参数 机制
2019/12/21 Python
CSS3 Notes: -webkit-box-reflect实现倒影的实例
2016/12/08 HTML / CSS
考察现实表现材料
2014/05/19 职场文书
食堂厨师岗位职责
2014/08/25 职场文书
英文导游词
2015/02/13 职场文书
大学学生会辞职信
2015/05/13 职场文书