用Pytorch训练CNN(数据集MNIST,使用GPU的方法)


Posted in Python onAugust 19, 2019

听说pytorch使用比TensorFlow简单,加之pytorch现已支持windows,所以今天装了pytorch玩玩,第一件事还是写了个简单的CNN在MNIST上实验,初步体验的确比TensorFlow方便。

参考代码(在莫烦python的教程代码基础上修改)如下:

import torch 
import torch.nn as nn 
from torch.autograd import Variable 
import torch.utils.data as Data 
import torchvision 
import time
#import matplotlib.pyplot as plt 
 
torch.manual_seed(1) 
 
EPOCH = 1 
BATCH_SIZE = 50 
LR = 0.001 
DOWNLOAD_MNIST = False 
if_use_gpu = 1
 
# 获取训练集dataset 
training_data = torchvision.datasets.MNIST( 
       root='./mnist/', # dataset存储路径 
       train=True, # True表示是train训练集,False表示test测试集 
       transform=torchvision.transforms.ToTensor(), # 将原数据规范化到(0,1)区间 
       download=DOWNLOAD_MNIST, 
       ) 
 
# 打印MNIST数据集的训练集及测试集的尺寸 
print(training_data.train_data.size()) 
print(training_data.train_labels.size()) 
# torch.Size([60000, 28, 28]) 
# torch.Size([60000]) 
 
#plt.imshow(training_data.train_data[0].numpy(), cmap='gray') 
#plt.title('%i' % training_data.train_labels[0]) 
#plt.show() 
 
# 通过torchvision.datasets获取的dataset格式可直接可置于DataLoader 
train_loader = Data.DataLoader(dataset=training_data, batch_size=BATCH_SIZE, 
                shuffle=True) 
 
# 获取测试集dataset 

test_data = torchvision.datasets.MNIST( 
       root='./mnist/', # dataset存储路径 
       train=False, # True表示是train训练集,False表示test测试集 
       transform=torchvision.transforms.ToTensor(), # 将原数据规范化到(0,1)区间 
       download=DOWNLOAD_MNIST, 
       ) 
# 取前全部10000个测试集样本 
test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1).float(), requires_grad=False)
#test_x = test_x.cuda()
## (~, 28, 28) to (~, 1, 28, 28), in range(0,1) 
test_y = test_data.test_labels
#test_y = test_y.cuda() 
class CNN(nn.Module): 
  def __init__(self): 
    super(CNN, self).__init__() 
    self.conv1 = nn.Sequential( # (1,28,28) 
           nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, 
                stride=1, padding=2), # (16,28,28) 
    # 想要con2d卷积出来的图片尺寸没有变化, padding=(kernel_size-1)/2 
           nn.ReLU(), 
           nn.MaxPool2d(kernel_size=2) # (16,14,14) 
           ) 
    self.conv2 = nn.Sequential( # (16,14,14) 
           nn.Conv2d(16, 32, 5, 1, 2), # (32,14,14) 
           nn.ReLU(), 
           nn.MaxPool2d(2) # (32,7,7) 
           ) 
    self.out = nn.Linear(32*7*7, 10) 
 
  def forward(self, x): 
    x = self.conv1(x) 
    x = self.conv2(x) 
    x = x.view(x.size(0), -1) # 将(batch,32,7,7)展平为(batch,32*7*7) 
    output = self.out(x) 
    return output 
 
cnn = CNN() 
if if_use_gpu:
  cnn = cnn.cuda()

optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) 
loss_function = nn.CrossEntropyLoss() 
 


for epoch in range(EPOCH): 
  start = time.time() 
  for step, (x, y) in enumerate(train_loader): 
    b_x = Variable(x, requires_grad=False) 
    b_y = Variable(y, requires_grad=False) 
    if if_use_gpu:
      b_x = b_x.cuda()
      b_y = b_y.cuda()
 
    output = cnn(b_x) 
    loss = loss_function(output, b_y) 
    optimizer.zero_grad() 
    loss.backward() 
    optimizer.step() 
 
    if step % 100 == 0: 
      print('Epoch:', epoch, '|Step:', step, 
         '|train loss:%.4f'%loss.data[0]) 
  duration = time.time() - start 
  print('Training duation: %.4f'%duration)
  
cnn = cnn.cpu()
test_output = cnn(test_x) 
pred_y = torch.max(test_output, 1)[1].data.squeeze()
accuracy = sum(pred_y == test_y) / test_y.size(0) 
print('Test Acc: %.4f'%accuracy)

以上这篇用Pytorch训练CNN(数据集MNIST,使用GPU的方法)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 表达式和语句及for、while循环练习实例
Jul 07 Python
Python实现模拟分割大文件及多线程处理的方法
Oct 10 Python
对Python 网络设备巡检脚本的实例讲解
Apr 22 Python
对Python中DataFrame选择某列值为XX的行实例详解
Jan 29 Python
python接口自动化(十六)--参数关联接口后传(详解)
Apr 16 Python
解决Pyinstaller 打包exe文件 取消dos窗口(黑框框)的问题
Jun 21 Python
Python3 执行Linux Bash命令的方法
Jul 12 Python
Python:二维列表下标互换方式(矩阵转置)
Dec 02 Python
win10安装tensorflow-gpu1.8.0详细完整步骤
Jan 20 Python
python程序文件扩展名知识点详解
Feb 27 Python
使用Python三角函数公式计算三角形的夹角案例
Apr 15 Python
什么是python的列表推导式
May 26 Python
python opencv实现证件照换底功能
Aug 19 #Python
解决pytorch GPU 计算过程中出现内存耗尽的问题
Aug 19 #Python
将Pytorch模型从CPU转换成GPU的实现方法
Aug 19 #Python
pytorch 使用单个GPU与多个GPU进行训练与测试的方法
Aug 19 #Python
在pytorch中为Module和Tensor指定GPU的例子
Aug 19 #Python
pytorch使用指定GPU训练的实例
Aug 19 #Python
关于pytorch多GPU训练实例与性能对比分析
Aug 19 #Python
You might like
PHP查询数据库中满足条件的记录条数(两种实现方法)
2013/01/29 PHP
学习PHP Cookie处理函数
2016/08/09 PHP
Yii2中添加全局函数的方法分析
2017/05/04 PHP
PHP基于rabbitmq操作类的生产者和消费者功能示例
2018/06/16 PHP
响应鼠标变换表格背景或者颜色的代码
2009/03/30 Javascript
JavaScript字符串String和Array操作的有趣方法
2012/12/18 Javascript
javascript实现yield的方法
2013/11/06 Javascript
详解WordPress开发中get_current_screen()函数的使用
2016/01/11 Javascript
基于JS判断iframe是否加载成功的方法(多种浏览器)
2016/05/13 Javascript
NodeJS基础API搭建服务器详细过程记录
2017/04/01 NodeJs
vue路由嵌套的SPA实现步骤
2017/11/06 Javascript
原生js调用json方法总结
2018/02/22 Javascript
JavaScript捕捉事件和阻止冒泡事件实例分析
2018/08/03 Javascript
详解JavaScript修改注册表的方法
2020/01/05 Javascript
vue页面更新patch的实现示例
2020/03/25 Javascript
[49:29]LGD vs Winstrike 2018国际邀请赛小组赛BO2 第一场 8.17
2018/08/18 DOTA
python发送邮件接收邮件示例分享
2014/01/21 Python
巧用python和libnmapd,提取Nmap扫描结果
2016/08/23 Python
Python基于Matplotlib库简单绘制折线图的方法示例
2017/08/14 Python
python、java等哪一门编程语言适合人工智能?
2017/11/13 Python
python 文件转成16进制数组的实例
2018/07/09 Python
如何使用Python进行OCR识别图片中的文字
2019/04/01 Python
python实现图像检索的三种(直方图/OpenCV/哈希法)
2019/08/08 Python
django 做 migrate 时 表已存在的处理方法
2019/08/31 Python
python读取word 中指定位置的表格及表格数据
2019/10/23 Python
Python @property装饰器原理解析
2020/01/22 Python
关于Python错误重试方法总结
2021/01/03 Python
介绍CSS3使用技巧5个
2009/04/02 HTML / CSS
浅析HTML5 meta viewport参数
2020/10/28 HTML / CSS
美国本地交易和折扣网站:LocalFlavor.com
2017/10/26 全球购物
校园活动宣传方案
2014/03/28 职场文书
白血病捐款倡议书
2014/05/14 职场文书
竞聘演讲稿开场白
2014/08/25 职场文书
群众路线表态发言材料
2014/10/17 职场文书
幼儿教师辞职信
2015/02/27 职场文书
初中数学教学随笔
2015/08/15 职场文书