pytorch实现用Resnet提取特征并保存为txt文件的方法


Posted in Python onAugust 20, 2019

接触pytorch一天,发现pytorch上手的确比TensorFlow更快。可以更方便地实现用预训练的网络提特征。

以下是提取一张jpg图像的特征的程序:

# -*- coding: utf-8 -*-
 
import os.path
 
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.autograd import Variable 
 
import numpy as np
from PIL import Image 
 
features_dir = './features'
 
img_path = "hymenoptera_data/train/ants/0013035.jpg"
file_name = img_path.split('/')[-1]
feature_path = os.path.join(features_dir, file_name + '.txt')
 
 
transform1 = transforms.Compose([
    transforms.Scale(256),
    transforms.CenterCrop(224),
    transforms.ToTensor()  ]
)
 
img = Image.open(img_path)
img1 = transform1(img)
 
#resnet18 = models.resnet18(pretrained = True)
resnet50_feature_extractor = models.resnet50(pretrained = True)
resnet50_feature_extractor.fc = nn.Linear(2048, 2048)
torch.nn.init.eye(resnet50_feature_extractor.fc.weight)
 
for param in resnet50_feature_extractor.parameters():
  param.requires_grad = False
#resnet152 = models.resnet152(pretrained = True)
#densenet201 = models.densenet201(pretrained = True) 
x = Variable(torch.unsqueeze(img1, dim=0).float(), requires_grad=False)
#y1 = resnet18(x)
y = resnet50_feature_extractor(x)
y = y.data.numpy()
np.savetxt(feature_path, y, delimiter=',')
#y3 = resnet152(x)
#y4 = densenet201(x)
 
y_ = np.loadtxt(feature_path, delimiter=',').reshape(1, 2048)

以下是提取一个文件夹下所有jpg、jpeg图像的程序:

# -*- coding: utf-8 -*-
import os, torch, glob
import numpy as np
from torch.autograd import Variable
from PIL import Image 
from torchvision import models, transforms
import torch.nn as nn
import shutil
data_dir = './hymenoptera_data'
features_dir = './features'
shutil.copytree(data_dir, os.path.join(features_dir, data_dir[2:]))
 
 
def extractor(img_path, saved_path, net, use_gpu):
  transform = transforms.Compose([
      transforms.Scale(256),
      transforms.CenterCrop(224),
      transforms.ToTensor()  ]
  )
  
  img = Image.open(img_path)
  img = transform(img)
  
 
 
  x = Variable(torch.unsqueeze(img, dim=0).float(), requires_grad=False)
  if use_gpu:
    x = x.cuda()
    net = net.cuda()
  y = net(x).cpu()
  y = y.data.numpy()
  np.savetxt(saved_path, y, delimiter=',')
  
if __name__ == '__main__':
  extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
    
  files_list = []
  sub_dirs = [x[0] for x in os.walk(data_dir) ]
  sub_dirs = sub_dirs[1:]
  for sub_dir in sub_dirs:
    for extention in extensions:
      file_glob = os.path.join(sub_dir, '*.' + extention)
      files_list.extend(glob.glob(file_glob))
    
  resnet50_feature_extractor = models.resnet50(pretrained = True)
  resnet50_feature_extractor.fc = nn.Linear(2048, 2048)
  torch.nn.init.eye(resnet50_feature_extractor.fc.weight)
  for param in resnet50_feature_extractor.parameters():
    param.requires_grad = False  
    
  use_gpu = torch.cuda.is_available()
 
  for x_path in files_list:
    print(x_path)
    fx_path = os.path.join(features_dir, x_path[2:] + '.txt')
    extractor(x_path, fx_path, resnet50_feature_extractor, use_gpu)

另外最近发现一个很简单的提取不含FC层的网络的方法:

resnet = models.resnet152(pretrained=True)
    modules = list(resnet.children())[:-1]   # delete the last fc layer.
    convnet = nn.Sequential(*modules)

另一种更简单的方法:

resnet = models.resnet152(pretrained=True)
del resnet.fc

以上这篇pytorch实现用Resnet提取特征并保存为txt文件的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python切换hosts文件代码示例
Dec 31 Python
python实现类似ftp传输文件的网络程序示例
Apr 08 Python
Python中对列表排序实例
Jan 04 Python
python Django模板的使用方法
Jan 14 Python
利用python写个下载teahour音频的小脚本
May 08 Python
深入解析神经网络从原理到实现
Jul 26 Python
pycharm 批量修改变量名称的方法
Aug 01 Python
python opencv 简单阈值算法的实现
Aug 04 Python
opencv转换颜色空间更改图片背景
Aug 20 Python
解决Numpy中sum函数求和结果维度的问题
Dec 06 Python
详解Ubuntu环境下部署Django+uwsgi+nginx总结
Apr 02 Python
Python flask路由间传递变量实例详解
Jun 03 Python
python web框架 django wsgi原理解析
Aug 20 #Python
opencv转换颜色空间更改图片背景
Aug 20 #Python
pytorch 预训练层的使用方法
Aug 20 #Python
python爬虫 urllib模块反爬虫机制UA详解
Aug 20 #Python
Pytorch 抽取vgg各层并进行定制化处理的方法
Aug 20 #Python
python实现抠图给证件照换背景源码
Aug 20 #Python
python爬虫 基于requests模块发起ajax的get请求实现解析
Aug 20 #Python
You might like
弄了个检测传输的参数是否为数字的Function
2006/12/06 PHP
php下使用iconv需要注意的问题
2010/11/20 PHP
服务器web工具 php环境下
2010/12/29 PHP
PHP错误抑制符(@)导致引用传参失败Bug的分析
2011/05/02 PHP
PHP中鲜为人知的10个函数
2014/02/28 PHP
PHP微信开发之二维码生成类
2015/06/26 PHP
Yii控制器中filter过滤器用法分析
2016/07/15 PHP
Laravel框架定时任务2种实现方式示例
2018/12/08 PHP
PHP模版引擎原理、定义与用法实例
2019/03/29 PHP
JS实现简易图片轮播效果的方法
2015/03/25 Javascript
JS实现可展开折叠层的鼠标拖曳效果
2015/10/09 Javascript
JS控制弹出悬浮窗口(一览画面)的实例代码
2016/05/30 Javascript
Vue.js原理分析之observer模块详解
2017/02/17 Javascript
nodejs入门教程五:连接数据库的方法分析
2017/04/24 NodeJs
gulp解决跨域的配置文件问题
2017/06/08 Javascript
JavaScript中的一些隐式转换和总结(推荐)
2017/12/22 Javascript
Vue 表情包输入组件的实现代码
2019/01/21 Javascript
在Layui 的表格模板中,实现layer父页面和子页面传值交互的方法
2019/09/10 Javascript
详解JavaScript数据类型和判断方法
2020/09/04 Javascript
线程和进程的区别及Python代码实例
2015/02/04 Python
python网络编程之文件下载实例分析
2015/05/20 Python
Python基于回溯法子集树模板实现图的遍历功能示例
2017/09/05 Python
python实现基于信息增益的决策树归纳
2018/12/18 Python
我用Python抓取了7000 多本电子书案例详解
2019/03/25 Python
python 使用opencv 把视频分割成图片示例
2019/12/12 Python
django 外键创建注意事项说明
2020/05/20 Python
详解Selenium-webdriver绕开反爬虫机制的4种方法
2020/10/28 Python
上海天奕面试题笔试题
2015/04/19 面试题
MySQL面试题目集锦
2016/04/14 面试题
公司同意接收函
2014/01/13 职场文书
爱心捐助倡议书
2014/05/19 职场文书
个人师德师风自我剖析材料
2014/09/29 职场文书
预备党员思想汇报1000字
2014/10/07 职场文书
2014年护理部工作总结
2014/11/14 职场文书
Python-OpenCV教程之图像的位运算详解
2021/06/21 Python
Python实现简单得递归下降Parser
2022/05/02 Python