pytorch实现用Resnet提取特征并保存为txt文件的方法


Posted in Python onAugust 20, 2019

接触pytorch一天,发现pytorch上手的确比TensorFlow更快。可以更方便地实现用预训练的网络提特征。

以下是提取一张jpg图像的特征的程序:

# -*- coding: utf-8 -*-
 
import os.path
 
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.autograd import Variable 
 
import numpy as np
from PIL import Image 
 
features_dir = './features'
 
img_path = "hymenoptera_data/train/ants/0013035.jpg"
file_name = img_path.split('/')[-1]
feature_path = os.path.join(features_dir, file_name + '.txt')
 
 
transform1 = transforms.Compose([
    transforms.Scale(256),
    transforms.CenterCrop(224),
    transforms.ToTensor()  ]
)
 
img = Image.open(img_path)
img1 = transform1(img)
 
#resnet18 = models.resnet18(pretrained = True)
resnet50_feature_extractor = models.resnet50(pretrained = True)
resnet50_feature_extractor.fc = nn.Linear(2048, 2048)
torch.nn.init.eye(resnet50_feature_extractor.fc.weight)
 
for param in resnet50_feature_extractor.parameters():
  param.requires_grad = False
#resnet152 = models.resnet152(pretrained = True)
#densenet201 = models.densenet201(pretrained = True) 
x = Variable(torch.unsqueeze(img1, dim=0).float(), requires_grad=False)
#y1 = resnet18(x)
y = resnet50_feature_extractor(x)
y = y.data.numpy()
np.savetxt(feature_path, y, delimiter=',')
#y3 = resnet152(x)
#y4 = densenet201(x)
 
y_ = np.loadtxt(feature_path, delimiter=',').reshape(1, 2048)

以下是提取一个文件夹下所有jpg、jpeg图像的程序:

# -*- coding: utf-8 -*-
import os, torch, glob
import numpy as np
from torch.autograd import Variable
from PIL import Image 
from torchvision import models, transforms
import torch.nn as nn
import shutil
data_dir = './hymenoptera_data'
features_dir = './features'
shutil.copytree(data_dir, os.path.join(features_dir, data_dir[2:]))
 
 
def extractor(img_path, saved_path, net, use_gpu):
  transform = transforms.Compose([
      transforms.Scale(256),
      transforms.CenterCrop(224),
      transforms.ToTensor()  ]
  )
  
  img = Image.open(img_path)
  img = transform(img)
  
 
 
  x = Variable(torch.unsqueeze(img, dim=0).float(), requires_grad=False)
  if use_gpu:
    x = x.cuda()
    net = net.cuda()
  y = net(x).cpu()
  y = y.data.numpy()
  np.savetxt(saved_path, y, delimiter=',')
  
if __name__ == '__main__':
  extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
    
  files_list = []
  sub_dirs = [x[0] for x in os.walk(data_dir) ]
  sub_dirs = sub_dirs[1:]
  for sub_dir in sub_dirs:
    for extention in extensions:
      file_glob = os.path.join(sub_dir, '*.' + extention)
      files_list.extend(glob.glob(file_glob))
    
  resnet50_feature_extractor = models.resnet50(pretrained = True)
  resnet50_feature_extractor.fc = nn.Linear(2048, 2048)
  torch.nn.init.eye(resnet50_feature_extractor.fc.weight)
  for param in resnet50_feature_extractor.parameters():
    param.requires_grad = False  
    
  use_gpu = torch.cuda.is_available()
 
  for x_path in files_list:
    print(x_path)
    fx_path = os.path.join(features_dir, x_path[2:] + '.txt')
    extractor(x_path, fx_path, resnet50_feature_extractor, use_gpu)

另外最近发现一个很简单的提取不含FC层的网络的方法:

resnet = models.resnet152(pretrained=True)
    modules = list(resnet.children())[:-1]   # delete the last fc layer.
    convnet = nn.Sequential(*modules)

另一种更简单的方法:

resnet = models.resnet152(pretrained=True)
del resnet.fc

以上这篇pytorch实现用Resnet提取特征并保存为txt文件的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python常见文件操作的函数示例代码
Nov 15 Python
django admin添加数据自动记录user到表中的实现方法
Jan 05 Python
对Python中创建进程的两种方式以及进程池详解
Jan 14 Python
详解Pandas之容易让人混淆的行选择和列选择
Jul 10 Python
python匿名函数用法实例分析
Aug 03 Python
使用pickle存储数据dump 和 load实例讲解
Dec 30 Python
python日期与时间戳的各种转换示例
Feb 12 Python
Python Handler处理器和自定义Opener原理详解
Mar 05 Python
Python文件时间操作步骤代码详解
Apr 13 Python
anaconda3安装及jupyter环境配置全教程
Aug 24 Python
python 制作简单的音乐播放器
Nov 25 Python
python的列表生成式,生成器和generator对象你了解吗
Mar 16 Python
python web框架 django wsgi原理解析
Aug 20 #Python
opencv转换颜色空间更改图片背景
Aug 20 #Python
pytorch 预训练层的使用方法
Aug 20 #Python
python爬虫 urllib模块反爬虫机制UA详解
Aug 20 #Python
Pytorch 抽取vgg各层并进行定制化处理的方法
Aug 20 #Python
python实现抠图给证件照换背景源码
Aug 20 #Python
python爬虫 基于requests模块发起ajax的get请求实现解析
Aug 20 #Python
You might like
用PHP将数据导入到Foxmail
2006/10/09 PHP
PHP网上调查系统
2006/10/09 PHP
用PHP和ACCESS写聊天室(十)
2006/10/09 PHP
php 缓存函数代码
2008/08/27 PHP
PHP教程 变量定义
2009/10/23 PHP
PHP设计模式之外观模式(Facade)入门与应用详解
2019/12/13 PHP
用JavaScript编写COM组件的步骤
2009/03/17 Javascript
从父页面读取和操作iframe中内容方法
2009/07/25 Javascript
通过继承IHttpHandle实现JS插件的组织与管理
2010/07/13 Javascript
深入理解Javascript闭包 新手版
2010/12/28 Javascript
jquery数组之存放checkbox全选值示例代码
2013/12/20 Javascript
关闭页面window.location事件未执行的原因及解决方法
2014/09/01 Javascript
js遮罩效果制作弹出注册界面效果
2017/01/25 Javascript
ES6入门教程之Iterator与for...of循环详解
2017/05/17 Javascript
JS实现页面内跳转的简单代码
2017/09/03 Javascript
在AngularJs中设置请求头信息(headers)的方法及不同方法的比较
2018/09/04 Javascript
vue、react等单页面项目部署到服务器的方法及vue和react的区别
2018/09/29 Javascript
微信网页登录逻辑与实现方法
2019/04/29 Javascript
浅谈Vuex注入Vue生命周期的过程
2019/05/20 Javascript
详谈Vue.js框架下main.js,App.vue,page/index.vue之间的区别
2020/08/12 Javascript
[01:11]回顾历届DOTA2国际邀请赛中国区预选赛
2017/06/26 DOTA
python的dict,set,list,tuple应用详解
2014/07/24 Python
详解Python3操作Mongodb简明易懂教程
2017/05/25 Python
Python3中列表list合并的四种方法
2019/04/19 Python
Python 3 判断2个字典相同
2019/08/06 Python
TensorFlow tf.nn.conv2d实现卷积的方式
2020/01/03 Python
Python实现一个优先级队列的方法
2020/07/31 Python
在html页面中取得session中的值的方法
2020/08/11 HTML / CSS
Abe’s of Maine:自1979以来销售相机和电子产品
2016/11/21 全球购物
彪马俄罗斯官网:PUMA俄罗斯
2019/07/13 全球购物
如何写好升职自荐信
2014/01/06 职场文书
化学教学随笔感言
2014/02/19 职场文书
劲霸男装广告词
2014/03/21 职场文书
党员作风建设整改方案
2014/10/27 职场文书
2015年宣传思想工作总结
2015/05/22 职场文书
靠谱的活动总结
2019/04/16 职场文书