pytorch实现用Resnet提取特征并保存为txt文件的方法


Posted in Python onAugust 20, 2019

接触pytorch一天,发现pytorch上手的确比TensorFlow更快。可以更方便地实现用预训练的网络提特征。

以下是提取一张jpg图像的特征的程序:

# -*- coding: utf-8 -*-
 
import os.path
 
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.autograd import Variable 
 
import numpy as np
from PIL import Image 
 
features_dir = './features'
 
img_path = "hymenoptera_data/train/ants/0013035.jpg"
file_name = img_path.split('/')[-1]
feature_path = os.path.join(features_dir, file_name + '.txt')
 
 
transform1 = transforms.Compose([
    transforms.Scale(256),
    transforms.CenterCrop(224),
    transforms.ToTensor()  ]
)
 
img = Image.open(img_path)
img1 = transform1(img)
 
#resnet18 = models.resnet18(pretrained = True)
resnet50_feature_extractor = models.resnet50(pretrained = True)
resnet50_feature_extractor.fc = nn.Linear(2048, 2048)
torch.nn.init.eye(resnet50_feature_extractor.fc.weight)
 
for param in resnet50_feature_extractor.parameters():
  param.requires_grad = False
#resnet152 = models.resnet152(pretrained = True)
#densenet201 = models.densenet201(pretrained = True) 
x = Variable(torch.unsqueeze(img1, dim=0).float(), requires_grad=False)
#y1 = resnet18(x)
y = resnet50_feature_extractor(x)
y = y.data.numpy()
np.savetxt(feature_path, y, delimiter=',')
#y3 = resnet152(x)
#y4 = densenet201(x)
 
y_ = np.loadtxt(feature_path, delimiter=',').reshape(1, 2048)

以下是提取一个文件夹下所有jpg、jpeg图像的程序:

# -*- coding: utf-8 -*-
import os, torch, glob
import numpy as np
from torch.autograd import Variable
from PIL import Image 
from torchvision import models, transforms
import torch.nn as nn
import shutil
data_dir = './hymenoptera_data'
features_dir = './features'
shutil.copytree(data_dir, os.path.join(features_dir, data_dir[2:]))
 
 
def extractor(img_path, saved_path, net, use_gpu):
  transform = transforms.Compose([
      transforms.Scale(256),
      transforms.CenterCrop(224),
      transforms.ToTensor()  ]
  )
  
  img = Image.open(img_path)
  img = transform(img)
  
 
 
  x = Variable(torch.unsqueeze(img, dim=0).float(), requires_grad=False)
  if use_gpu:
    x = x.cuda()
    net = net.cuda()
  y = net(x).cpu()
  y = y.data.numpy()
  np.savetxt(saved_path, y, delimiter=',')
  
if __name__ == '__main__':
  extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
    
  files_list = []
  sub_dirs = [x[0] for x in os.walk(data_dir) ]
  sub_dirs = sub_dirs[1:]
  for sub_dir in sub_dirs:
    for extention in extensions:
      file_glob = os.path.join(sub_dir, '*.' + extention)
      files_list.extend(glob.glob(file_glob))
    
  resnet50_feature_extractor = models.resnet50(pretrained = True)
  resnet50_feature_extractor.fc = nn.Linear(2048, 2048)
  torch.nn.init.eye(resnet50_feature_extractor.fc.weight)
  for param in resnet50_feature_extractor.parameters():
    param.requires_grad = False  
    
  use_gpu = torch.cuda.is_available()
 
  for x_path in files_list:
    print(x_path)
    fx_path = os.path.join(features_dir, x_path[2:] + '.txt')
    extractor(x_path, fx_path, resnet50_feature_extractor, use_gpu)

另外最近发现一个很简单的提取不含FC层的网络的方法:

resnet = models.resnet152(pretrained=True)
    modules = list(resnet.children())[:-1]   # delete the last fc layer.
    convnet = nn.Sequential(*modules)

另一种更简单的方法:

resnet = models.resnet152(pretrained=True)
del resnet.fc

以上这篇pytorch实现用Resnet提取特征并保存为txt文件的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现抓取网页并且解析的实例
Sep 20 Python
python实现数独算法实例
Jun 09 Python
基于Python实现一个简单的银行转账操作
Mar 06 Python
python Django框架实现自定义表单提交
Mar 25 Python
python处理html转义字符的方法详解
Jul 01 Python
Python实现的井字棋(Tic Tac Toe)游戏示例
Jan 31 Python
Python实现定时执行任务的三种方式简单示例
Mar 30 Python
利用Python如何实时检测自身内存占用
May 09 Python
使用keras内置的模型进行图片预测实例
Jun 17 Python
python和js交互调用的方法
Jun 23 Python
Python中使用aiohttp模拟服务器出现错误问题及解决方法
Oct 31 Python
Python中的np.argmin()和np.argmax()函数用法
Jun 02 Python
python web框架 django wsgi原理解析
Aug 20 #Python
opencv转换颜色空间更改图片背景
Aug 20 #Python
pytorch 预训练层的使用方法
Aug 20 #Python
python爬虫 urllib模块反爬虫机制UA详解
Aug 20 #Python
Pytorch 抽取vgg各层并进行定制化处理的方法
Aug 20 #Python
python实现抠图给证件照换背景源码
Aug 20 #Python
python爬虫 基于requests模块发起ajax的get请求实现解析
Aug 20 #Python
You might like
Sony CFR 320 修复改造
2020/03/14 无线电
php+iframe实现隐藏无刷新上传文件
2012/02/10 PHP
PHP与MongoDB简介|安全|M+PHP应用实例详解
2013/06/17 PHP
PHP实现把文本中的URL转换为链接的auolink()函数分享
2014/07/29 PHP
PHP连接SQL Server的方法分析【基于thinkPHP5.1框架】
2019/05/06 PHP
提高代码性能技巧谈—以创建千行表格为例
2006/07/01 Javascript
浅谈javascript中自定义模版
2015/01/29 Javascript
JavaScript+html5 canvas实现图片破碎重组动画特效
2016/02/22 Javascript
javascript实现不同颜色Tab标签切换效果
2016/04/27 Javascript
微信开发 js实现tabs选项卡效果
2016/10/28 Javascript
jQuery 移动端拖拽(模块化开发,触摸事件,webpack)
2016/10/28 Javascript
使用Node.js实现简易MVC框架的方法
2017/08/07 Javascript
基于cropper.js封装vue实现在线图片裁剪组件功能
2018/03/01 Javascript
JS中原始值和引用值的储存方式示例详解
2018/03/23 Javascript
详解Vue.js中.native修饰符
2018/04/24 Javascript
koa2实现登录注册功能的示例代码
2018/12/03 Javascript
解决vue v-for src 图片路径问题 404
2019/11/12 Javascript
基于原生js实现九宫格算法代码实例
2020/07/03 Javascript
DataFrame中的object转换成float的方法
2018/04/10 Python
python 常见字符串与函数的用法详解
2018/11/23 Python
python语言元素知识点详解
2019/05/15 Python
解决django后台样式丢失,css资源加载失败的问题
2019/06/11 Python
python实现超市商品销售管理系统
2019/11/22 Python
Tensorflow 卷积的梯度反向传播过程
2020/02/10 Python
Python Django view 两种return的实现方式
2020/03/16 Python
python 代码运行时间获取方式详解
2020/09/18 Python
移动端Web页面的CSS3 flex布局快速上手指南
2016/05/31 HTML / CSS
html5各种页面切换效果和模态对话框用法总结
2014/12/15 HTML / CSS
室内设计实习自我鉴定
2013/09/25 职场文书
大学生职业生涯规划范文
2014/01/08 职场文书
红色故事演讲稿
2014/05/22 职场文书
同志主要表现材料
2014/08/21 职场文书
《实心球》教学反思
2016/02/23 职场文书
python如何读取.mtx文件
2021/04/22 Python
详解MySQL主从复制及读写分离
2021/05/07 MySQL
聊聊Lombok中的@Builder注解使用教程
2021/11/17 Java/Android