pytorch实现用Resnet提取特征并保存为txt文件的方法


Posted in Python onAugust 20, 2019

接触pytorch一天,发现pytorch上手的确比TensorFlow更快。可以更方便地实现用预训练的网络提特征。

以下是提取一张jpg图像的特征的程序:

# -*- coding: utf-8 -*-
 
import os.path
 
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.autograd import Variable 
 
import numpy as np
from PIL import Image 
 
features_dir = './features'
 
img_path = "hymenoptera_data/train/ants/0013035.jpg"
file_name = img_path.split('/')[-1]
feature_path = os.path.join(features_dir, file_name + '.txt')
 
 
transform1 = transforms.Compose([
    transforms.Scale(256),
    transforms.CenterCrop(224),
    transforms.ToTensor()  ]
)
 
img = Image.open(img_path)
img1 = transform1(img)
 
#resnet18 = models.resnet18(pretrained = True)
resnet50_feature_extractor = models.resnet50(pretrained = True)
resnet50_feature_extractor.fc = nn.Linear(2048, 2048)
torch.nn.init.eye(resnet50_feature_extractor.fc.weight)
 
for param in resnet50_feature_extractor.parameters():
  param.requires_grad = False
#resnet152 = models.resnet152(pretrained = True)
#densenet201 = models.densenet201(pretrained = True) 
x = Variable(torch.unsqueeze(img1, dim=0).float(), requires_grad=False)
#y1 = resnet18(x)
y = resnet50_feature_extractor(x)
y = y.data.numpy()
np.savetxt(feature_path, y, delimiter=',')
#y3 = resnet152(x)
#y4 = densenet201(x)
 
y_ = np.loadtxt(feature_path, delimiter=',').reshape(1, 2048)

以下是提取一个文件夹下所有jpg、jpeg图像的程序:

# -*- coding: utf-8 -*-
import os, torch, glob
import numpy as np
from torch.autograd import Variable
from PIL import Image 
from torchvision import models, transforms
import torch.nn as nn
import shutil
data_dir = './hymenoptera_data'
features_dir = './features'
shutil.copytree(data_dir, os.path.join(features_dir, data_dir[2:]))
 
 
def extractor(img_path, saved_path, net, use_gpu):
  transform = transforms.Compose([
      transforms.Scale(256),
      transforms.CenterCrop(224),
      transforms.ToTensor()  ]
  )
  
  img = Image.open(img_path)
  img = transform(img)
  
 
 
  x = Variable(torch.unsqueeze(img, dim=0).float(), requires_grad=False)
  if use_gpu:
    x = x.cuda()
    net = net.cuda()
  y = net(x).cpu()
  y = y.data.numpy()
  np.savetxt(saved_path, y, delimiter=',')
  
if __name__ == '__main__':
  extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
    
  files_list = []
  sub_dirs = [x[0] for x in os.walk(data_dir) ]
  sub_dirs = sub_dirs[1:]
  for sub_dir in sub_dirs:
    for extention in extensions:
      file_glob = os.path.join(sub_dir, '*.' + extention)
      files_list.extend(glob.glob(file_glob))
    
  resnet50_feature_extractor = models.resnet50(pretrained = True)
  resnet50_feature_extractor.fc = nn.Linear(2048, 2048)
  torch.nn.init.eye(resnet50_feature_extractor.fc.weight)
  for param in resnet50_feature_extractor.parameters():
    param.requires_grad = False  
    
  use_gpu = torch.cuda.is_available()
 
  for x_path in files_list:
    print(x_path)
    fx_path = os.path.join(features_dir, x_path[2:] + '.txt')
    extractor(x_path, fx_path, resnet50_feature_extractor, use_gpu)

另外最近发现一个很简单的提取不含FC层的网络的方法:

resnet = models.resnet152(pretrained=True)
    modules = list(resnet.children())[:-1]   # delete the last fc layer.
    convnet = nn.Sequential(*modules)

另一种更简单的方法:

resnet = models.resnet152(pretrained=True)
del resnet.fc

以上这篇pytorch实现用Resnet提取特征并保存为txt文件的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python正则表达式判断字符串是否是全部小写示例
Dec 25 Python
Python实现字典依据value排序
Feb 24 Python
深入理解python中的闭包和装饰器
Jun 12 Python
python 上下文管理器使用方法小结
Oct 10 Python
Python callable()函数用法实例分析
Mar 17 Python
PyQt5每天必学之布局管理
Apr 19 Python
tensorflow实现简单的卷积网络
May 24 Python
Python3用tkinter和PIL实现看图工具
Jun 21 Python
Python学习笔记之Zip和Enumerate用法实例分析
Aug 14 Python
利用pytorch实现对CIFAR-10数据集的分类
Jan 14 Python
Python Selenium参数配置方法解析
Jan 19 Python
python如何支持并发方法详解
Jul 25 Python
python web框架 django wsgi原理解析
Aug 20 #Python
opencv转换颜色空间更改图片背景
Aug 20 #Python
pytorch 预训练层的使用方法
Aug 20 #Python
python爬虫 urllib模块反爬虫机制UA详解
Aug 20 #Python
Pytorch 抽取vgg各层并进行定制化处理的方法
Aug 20 #Python
python实现抠图给证件照换背景源码
Aug 20 #Python
python爬虫 基于requests模块发起ajax的get请求实现解析
Aug 20 #Python
You might like
php 用sock技术发送邮件的函数
2007/07/21 PHP
php验证session无效的解决方法
2014/11/04 PHP
php使用CURL伪造IP和来源实例详解
2015/01/15 PHP
在Linux系统的服务器上隐藏PHP版本号的方法
2015/06/06 PHP
Laravel框架中Blade模板的用法示例
2017/08/30 PHP
比较详细的javascript对象的property和prototype是什么一种关系
2007/08/06 Javascript
jquery插件 cluetip 关键词注释
2010/01/12 Javascript
判断控件是否已加载完成的代码
2010/02/24 Javascript
jQuery 获取对象 基本选择与层级
2010/05/31 Javascript
jquery radio 操作代码
2011/03/16 Javascript
基于jQuery捕获超链接事件进行局部刷新代码
2012/05/10 Javascript
js从Cookies里面取值的简单实现
2014/06/30 Javascript
关于session和cookie的简单理解
2016/06/08 Javascript
详解闭包解决jQuery中AJAX的外部变量问题
2017/02/22 Javascript
js排序与重组的实例讲解
2017/08/28 Javascript
vue.js获得当前元素的文字信息方法
2018/03/09 Javascript
对layui中table组件工具栏的使用详解
2019/09/19 Javascript
Python比较两个图片相似度的方法
2015/03/13 Python
开始着手第一个Django项目
2015/07/15 Python
python pandas 组内排序、单组排序、标号的实例
2018/04/12 Python
浅谈Python在pycharm中的调试(debug)
2018/11/29 Python
一文秒懂python读写csv xml json文件各种骚操作
2019/07/04 Python
python flask几分钟实现web服务的例子
2019/07/26 Python
Python数据库小程序源代码
2019/09/15 Python
Django Xadmin多对多字段过滤实例
2020/04/07 Python
python解释器安装教程的方法步骤
2020/07/02 Python
Python基于unittest实现测试用例执行
2020/11/25 Python
毕业生自荐信的主要内容
2013/10/29 职场文书
护理学专业推荐信
2013/12/03 职场文书
大一自我鉴定范文
2013/12/27 职场文书
职业生涯规划书范文
2014/03/10 职场文书
党员学习新党章思想汇报
2014/10/25 职场文书
2015年世界卫生日活动总结
2015/02/09 职场文书
致短跑运动员加油稿
2015/07/21 职场文书
七年级作文之下雨天
2019/12/23 职场文书
为什么MySQL 删除表数据 磁盘空间还一直被占用
2021/10/16 MySQL