用Pytorch训练CNN(数据集MNIST,使用GPU的方法)


Posted in Python onAugust 19, 2019

听说pytorch使用比TensorFlow简单,加之pytorch现已支持windows,所以今天装了pytorch玩玩,第一件事还是写了个简单的CNN在MNIST上实验,初步体验的确比TensorFlow方便。

参考代码(在莫烦python的教程代码基础上修改)如下:

import torch 
import torch.nn as nn 
from torch.autograd import Variable 
import torch.utils.data as Data 
import torchvision 
import time
#import matplotlib.pyplot as plt 
 
torch.manual_seed(1) 
 
EPOCH = 1 
BATCH_SIZE = 50 
LR = 0.001 
DOWNLOAD_MNIST = False 
if_use_gpu = 1
 
# 获取训练集dataset 
training_data = torchvision.datasets.MNIST( 
       root='./mnist/', # dataset存储路径 
       train=True, # True表示是train训练集,False表示test测试集 
       transform=torchvision.transforms.ToTensor(), # 将原数据规范化到(0,1)区间 
       download=DOWNLOAD_MNIST, 
       ) 
 
# 打印MNIST数据集的训练集及测试集的尺寸 
print(training_data.train_data.size()) 
print(training_data.train_labels.size()) 
# torch.Size([60000, 28, 28]) 
# torch.Size([60000]) 
 
#plt.imshow(training_data.train_data[0].numpy(), cmap='gray') 
#plt.title('%i' % training_data.train_labels[0]) 
#plt.show() 
 
# 通过torchvision.datasets获取的dataset格式可直接可置于DataLoader 
train_loader = Data.DataLoader(dataset=training_data, batch_size=BATCH_SIZE, 
                shuffle=True) 
 
# 获取测试集dataset 

test_data = torchvision.datasets.MNIST( 
       root='./mnist/', # dataset存储路径 
       train=False, # True表示是train训练集,False表示test测试集 
       transform=torchvision.transforms.ToTensor(), # 将原数据规范化到(0,1)区间 
       download=DOWNLOAD_MNIST, 
       ) 
# 取前全部10000个测试集样本 
test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1).float(), requires_grad=False)
#test_x = test_x.cuda()
## (~, 28, 28) to (~, 1, 28, 28), in range(0,1) 
test_y = test_data.test_labels
#test_y = test_y.cuda() 
class CNN(nn.Module): 
  def __init__(self): 
    super(CNN, self).__init__() 
    self.conv1 = nn.Sequential( # (1,28,28) 
           nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, 
                stride=1, padding=2), # (16,28,28) 
    # 想要con2d卷积出来的图片尺寸没有变化, padding=(kernel_size-1)/2 
           nn.ReLU(), 
           nn.MaxPool2d(kernel_size=2) # (16,14,14) 
           ) 
    self.conv2 = nn.Sequential( # (16,14,14) 
           nn.Conv2d(16, 32, 5, 1, 2), # (32,14,14) 
           nn.ReLU(), 
           nn.MaxPool2d(2) # (32,7,7) 
           ) 
    self.out = nn.Linear(32*7*7, 10) 
 
  def forward(self, x): 
    x = self.conv1(x) 
    x = self.conv2(x) 
    x = x.view(x.size(0), -1) # 将(batch,32,7,7)展平为(batch,32*7*7) 
    output = self.out(x) 
    return output 
 
cnn = CNN() 
if if_use_gpu:
  cnn = cnn.cuda()

optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) 
loss_function = nn.CrossEntropyLoss() 
 


for epoch in range(EPOCH): 
  start = time.time() 
  for step, (x, y) in enumerate(train_loader): 
    b_x = Variable(x, requires_grad=False) 
    b_y = Variable(y, requires_grad=False) 
    if if_use_gpu:
      b_x = b_x.cuda()
      b_y = b_y.cuda()
 
    output = cnn(b_x) 
    loss = loss_function(output, b_y) 
    optimizer.zero_grad() 
    loss.backward() 
    optimizer.step() 
 
    if step % 100 == 0: 
      print('Epoch:', epoch, '|Step:', step, 
         '|train loss:%.4f'%loss.data[0]) 
  duration = time.time() - start 
  print('Training duation: %.4f'%duration)
  
cnn = cnn.cpu()
test_output = cnn(test_x) 
pred_y = torch.max(test_output, 1)[1].data.squeeze()
accuracy = sum(pred_y == test_y) / test_y.size(0) 
print('Test Acc: %.4f'%accuracy)

以上这篇用Pytorch训练CNN(数据集MNIST,使用GPU的方法)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之玩转字符串(2)
Sep 14 Python
Python中的列表知识点汇总
Apr 14 Python
Python时间模块datetime、time、calendar的使用方法
Jan 13 Python
Python编程实现及时获取新邮件的方法示例
Aug 10 Python
django上传图片并生成缩略图方法示例
Dec 11 Python
python 3.6.2 安装配置方法图文教程
Sep 18 Python
详解Django中CBV(Class Base Views)模型源码分析
Feb 25 Python
pycharm新建一个python工程步骤
Jul 16 Python
python破解bilibili滑动验证码登录功能
Sep 11 Python
PYTHON实现SIGN签名的过程解析
Oct 28 Python
python路径的写法及目录的获取方式
Dec 26 Python
python闭包、深浅拷贝、垃圾回收、with语句知识点汇总
Mar 11 Python
python opencv实现证件照换底功能
Aug 19 #Python
解决pytorch GPU 计算过程中出现内存耗尽的问题
Aug 19 #Python
将Pytorch模型从CPU转换成GPU的实现方法
Aug 19 #Python
pytorch 使用单个GPU与多个GPU进行训练与测试的方法
Aug 19 #Python
在pytorch中为Module和Tensor指定GPU的例子
Aug 19 #Python
pytorch使用指定GPU训练的实例
Aug 19 #Python
关于pytorch多GPU训练实例与性能对比分析
Aug 19 #Python
You might like
全国FM电台频率大全 - 15 山东省
2020/03/11 无线电
ThinkPHP 3.2 数据分页代码分享
2014/10/14 PHP
smarty中英文多编码字符截取乱码问题解决方法
2014/10/28 PHP
php定时执行任务设置详解
2015/02/06 PHP
php中final关键字用法分析
2016/12/07 PHP
php设计模式之备忘模式分析【星际争霸游戏案例】
2020/03/24 PHP
ExtJS自定义主题(theme)样式详解
2013/11/18 Javascript
判断及设置浏览器全屏模式
2014/04/20 Javascript
一个html5播放视频的video控件只支持android的默认格式mp4和3gp
2014/05/08 Javascript
js不能获取隐藏的div的宽度只能先显示后获取
2014/09/04 Javascript
分享一则javascript 调试技巧
2015/01/02 Javascript
JS+CSS实现带关闭按钮DIV弹出窗口的方法
2015/02/27 Javascript
javascript模拟评分控件实现方法
2015/05/13 Javascript
JavaScript实现强制重定向至HTTPS页面
2015/06/10 Javascript
教你如何终止JQUERY的$.AJAX请求
2016/02/23 Javascript
javascript 数组的正态分布排序的问题
2016/07/31 Javascript
js中利用cookie实现记住密码功能
2020/08/20 Javascript
详解NodeJS框架express的路径映射(路由)功能及控制
2017/03/24 NodeJs
使用jquery-easyui的布局layout写后台管理页面的代码详解
2019/06/19 jQuery
Element中Slider滑块的具体使用
2020/07/29 Javascript
Vue 实现拨打电话操作
2020/11/16 Javascript
在java中如何定义一个抽象属性示例详解
2017/08/18 Python
python中requests库session对象的妙用详解
2017/10/30 Python
tensorflow 恢复指定层与不同层指定不同学习率的方法
2018/07/26 Python
Python日志模块logging基本用法分析
2018/08/23 Python
Pycharm配置autopep8实现流程解析
2020/11/28 Python
Selenium关闭INFO:CONSOLE提示的解决
2020/12/07 Python
matplotlib实现数据实时刷新的示例代码
2021/01/05 Python
简单几步用纯CSS3实现3D翻转效果
2019/01/17 HTML / CSS
HTML5 File接口在web页面上使用文件下载
2017/02/27 HTML / CSS
X/HTML5 和 XHTML2
2008/10/17 HTML / CSS
秋天的怀念教学反思
2014/04/28 职场文书
学校安全工作汇报材料
2014/08/16 职场文书
2014年科研工作总结
2014/12/03 职场文书
航班延误投诉信
2015/07/02 职场文书
JavaScript模拟实现网易云轮播效果
2022/04/04 Javascript