pytorch实现用Resnet提取特征并保存为txt文件的方法


Posted in Python onAugust 20, 2019

接触pytorch一天,发现pytorch上手的确比TensorFlow更快。可以更方便地实现用预训练的网络提特征。

以下是提取一张jpg图像的特征的程序:

# -*- coding: utf-8 -*-
 
import os.path
 
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.autograd import Variable 
 
import numpy as np
from PIL import Image 
 
features_dir = './features'
 
img_path = "hymenoptera_data/train/ants/0013035.jpg"
file_name = img_path.split('/')[-1]
feature_path = os.path.join(features_dir, file_name + '.txt')
 
 
transform1 = transforms.Compose([
    transforms.Scale(256),
    transforms.CenterCrop(224),
    transforms.ToTensor()  ]
)
 
img = Image.open(img_path)
img1 = transform1(img)
 
#resnet18 = models.resnet18(pretrained = True)
resnet50_feature_extractor = models.resnet50(pretrained = True)
resnet50_feature_extractor.fc = nn.Linear(2048, 2048)
torch.nn.init.eye(resnet50_feature_extractor.fc.weight)
 
for param in resnet50_feature_extractor.parameters():
  param.requires_grad = False
#resnet152 = models.resnet152(pretrained = True)
#densenet201 = models.densenet201(pretrained = True) 
x = Variable(torch.unsqueeze(img1, dim=0).float(), requires_grad=False)
#y1 = resnet18(x)
y = resnet50_feature_extractor(x)
y = y.data.numpy()
np.savetxt(feature_path, y, delimiter=',')
#y3 = resnet152(x)
#y4 = densenet201(x)
 
y_ = np.loadtxt(feature_path, delimiter=',').reshape(1, 2048)

以下是提取一个文件夹下所有jpg、jpeg图像的程序:

# -*- coding: utf-8 -*-
import os, torch, glob
import numpy as np
from torch.autograd import Variable
from PIL import Image 
from torchvision import models, transforms
import torch.nn as nn
import shutil
data_dir = './hymenoptera_data'
features_dir = './features'
shutil.copytree(data_dir, os.path.join(features_dir, data_dir[2:]))
 
 
def extractor(img_path, saved_path, net, use_gpu):
  transform = transforms.Compose([
      transforms.Scale(256),
      transforms.CenterCrop(224),
      transforms.ToTensor()  ]
  )
  
  img = Image.open(img_path)
  img = transform(img)
  
 
 
  x = Variable(torch.unsqueeze(img, dim=0).float(), requires_grad=False)
  if use_gpu:
    x = x.cuda()
    net = net.cuda()
  y = net(x).cpu()
  y = y.data.numpy()
  np.savetxt(saved_path, y, delimiter=',')
  
if __name__ == '__main__':
  extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
    
  files_list = []
  sub_dirs = [x[0] for x in os.walk(data_dir) ]
  sub_dirs = sub_dirs[1:]
  for sub_dir in sub_dirs:
    for extention in extensions:
      file_glob = os.path.join(sub_dir, '*.' + extention)
      files_list.extend(glob.glob(file_glob))
    
  resnet50_feature_extractor = models.resnet50(pretrained = True)
  resnet50_feature_extractor.fc = nn.Linear(2048, 2048)
  torch.nn.init.eye(resnet50_feature_extractor.fc.weight)
  for param in resnet50_feature_extractor.parameters():
    param.requires_grad = False  
    
  use_gpu = torch.cuda.is_available()
 
  for x_path in files_list:
    print(x_path)
    fx_path = os.path.join(features_dir, x_path[2:] + '.txt')
    extractor(x_path, fx_path, resnet50_feature_extractor, use_gpu)

另外最近发现一个很简单的提取不含FC层的网络的方法:

resnet = models.resnet152(pretrained=True)
    modules = list(resnet.children())[:-1]   # delete the last fc layer.
    convnet = nn.Sequential(*modules)

另一种更简单的方法:

resnet = models.resnet152(pretrained=True)
del resnet.fc

以上这篇pytorch实现用Resnet提取特征并保存为txt文件的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现获取序列中最小的几个元素
Sep 25 Python
Python数据结构与算法之二叉树结构定义与遍历方法详解
Dec 12 Python
如何利用python制作时间戳转换工具详解
Sep 12 Python
Pycharm+Scrapy安装并且初始化项目的方法
Jan 15 Python
python实现全盘扫描搜索功能的方法
Feb 14 Python
如何使用django的MTV开发模式返回一个网页
Jul 22 Python
python 接口实现 供第三方调用的例子
Aug 13 Python
python爬虫 基于requests模块的get请求实现详解
Aug 20 Python
Python获取时间戳代码实例
Sep 24 Python
Python如何基于Tesseract实现识别文字功能
Jun 05 Python
使用Python通过oBIX协议访问Niagara数据的示例
Dec 04 Python
用Python监控你的朋友都在浏览哪些网站?
May 27 Python
python web框架 django wsgi原理解析
Aug 20 #Python
opencv转换颜色空间更改图片背景
Aug 20 #Python
pytorch 预训练层的使用方法
Aug 20 #Python
python爬虫 urllib模块反爬虫机制UA详解
Aug 20 #Python
Pytorch 抽取vgg各层并进行定制化处理的方法
Aug 20 #Python
python实现抠图给证件照换背景源码
Aug 20 #Python
python爬虫 基于requests模块发起ajax的get请求实现解析
Aug 20 #Python
You might like
linux系统上支持php的 iconv()函数的方法
2011/10/01 PHP
php将一维数组转换为每3个连续值组成的二维数组
2016/05/06 PHP
php+mysql实现的二级联动菜单效果详解
2016/05/10 PHP
PHP session 会话处理函数
2016/06/06 PHP
php类的自动加载操作实例详解
2016/09/28 PHP
php+ajax无刷新上传图片的实现方法
2016/12/06 PHP
PHP pthreads v3下worker和pool的使用方法示例
2020/02/21 PHP
不懂JavaScript应该怎样学
2008/04/16 Javascript
javascript实现的距离现在多长时间后的一个格式化的日期
2009/10/29 Javascript
详解AngularJS Filter(过滤器)用法
2015/12/28 Javascript
浏览器环境下JavaScript脚本加载与执行探析之defer与async特性
2016/01/14 Javascript
Js实现简单的小球运动特效
2016/02/18 Javascript
JS和jQuery使用submit方法无法提交表单的原因分析及解决办法
2016/05/17 Javascript
改变checkbox默认选中状态及取值的实现代码
2016/05/26 Javascript
JS实现的适合做faq或menu滑动效果示例
2016/11/17 Javascript
js 递归和定时器的实例解析
2017/02/03 Javascript
JS去除字符串中空格的方法
2017/02/14 Javascript
vue2.0 better-scroll 实现移动端滑动的示例代码
2018/01/25 Javascript
javascript与PHP动态往类中添加方法对比
2018/03/21 Javascript
详解js访问对象的属性和方法
2018/10/25 Javascript
Scrapy框架CrawlSpiders的介绍以及使用详解
2017/11/29 Python
利用 python 对目录下的文件进行过滤删除
2017/12/27 Python
python使用opencv按一定间隔截取视频帧
2018/03/06 Python
Python实现读取SQLServer数据并插入到MongoDB数据库的方法示例
2018/06/09 Python
深入浅析python3中的unicode和bytes问题
2019/07/03 Python
Python实现FTP文件传输的实例
2019/07/07 Python
利用Python实现Json序列化库的方法步骤
2020/09/09 Python
Anaconda使用IDLE的实现示例
2020/09/23 Python
matplotlib部件之矩形选区(RectangleSelector)的实现
2021/02/01 Python
Pycharm制作搞怪弹窗的实现代码
2021/02/19 Python
NYX Professional Makeup英国官网:美国平价专业彩妆品牌
2019/11/13 全球购物
简爱电影观后感
2015/06/10 职场文书
小学运动会入场口号
2015/12/24 职场文书
外出学习心得体会范文
2016/01/18 职场文书
导游词之吉林吉塔
2019/11/11 职场文书
索尼ICF-5900W收音机测评
2022/04/24 无线电