pytorch实现用Resnet提取特征并保存为txt文件的方法


Posted in Python onAugust 20, 2019

接触pytorch一天,发现pytorch上手的确比TensorFlow更快。可以更方便地实现用预训练的网络提特征。

以下是提取一张jpg图像的特征的程序:

# -*- coding: utf-8 -*-
 
import os.path
 
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.autograd import Variable 
 
import numpy as np
from PIL import Image 
 
features_dir = './features'
 
img_path = "hymenoptera_data/train/ants/0013035.jpg"
file_name = img_path.split('/')[-1]
feature_path = os.path.join(features_dir, file_name + '.txt')
 
 
transform1 = transforms.Compose([
    transforms.Scale(256),
    transforms.CenterCrop(224),
    transforms.ToTensor()  ]
)
 
img = Image.open(img_path)
img1 = transform1(img)
 
#resnet18 = models.resnet18(pretrained = True)
resnet50_feature_extractor = models.resnet50(pretrained = True)
resnet50_feature_extractor.fc = nn.Linear(2048, 2048)
torch.nn.init.eye(resnet50_feature_extractor.fc.weight)
 
for param in resnet50_feature_extractor.parameters():
  param.requires_grad = False
#resnet152 = models.resnet152(pretrained = True)
#densenet201 = models.densenet201(pretrained = True) 
x = Variable(torch.unsqueeze(img1, dim=0).float(), requires_grad=False)
#y1 = resnet18(x)
y = resnet50_feature_extractor(x)
y = y.data.numpy()
np.savetxt(feature_path, y, delimiter=',')
#y3 = resnet152(x)
#y4 = densenet201(x)
 
y_ = np.loadtxt(feature_path, delimiter=',').reshape(1, 2048)

以下是提取一个文件夹下所有jpg、jpeg图像的程序:

# -*- coding: utf-8 -*-
import os, torch, glob
import numpy as np
from torch.autograd import Variable
from PIL import Image 
from torchvision import models, transforms
import torch.nn as nn
import shutil
data_dir = './hymenoptera_data'
features_dir = './features'
shutil.copytree(data_dir, os.path.join(features_dir, data_dir[2:]))
 
 
def extractor(img_path, saved_path, net, use_gpu):
  transform = transforms.Compose([
      transforms.Scale(256),
      transforms.CenterCrop(224),
      transforms.ToTensor()  ]
  )
  
  img = Image.open(img_path)
  img = transform(img)
  
 
 
  x = Variable(torch.unsqueeze(img, dim=0).float(), requires_grad=False)
  if use_gpu:
    x = x.cuda()
    net = net.cuda()
  y = net(x).cpu()
  y = y.data.numpy()
  np.savetxt(saved_path, y, delimiter=',')
  
if __name__ == '__main__':
  extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
    
  files_list = []
  sub_dirs = [x[0] for x in os.walk(data_dir) ]
  sub_dirs = sub_dirs[1:]
  for sub_dir in sub_dirs:
    for extention in extensions:
      file_glob = os.path.join(sub_dir, '*.' + extention)
      files_list.extend(glob.glob(file_glob))
    
  resnet50_feature_extractor = models.resnet50(pretrained = True)
  resnet50_feature_extractor.fc = nn.Linear(2048, 2048)
  torch.nn.init.eye(resnet50_feature_extractor.fc.weight)
  for param in resnet50_feature_extractor.parameters():
    param.requires_grad = False  
    
  use_gpu = torch.cuda.is_available()
 
  for x_path in files_list:
    print(x_path)
    fx_path = os.path.join(features_dir, x_path[2:] + '.txt')
    extractor(x_path, fx_path, resnet50_feature_extractor, use_gpu)

另外最近发现一个很简单的提取不含FC层的网络的方法:

resnet = models.resnet152(pretrained=True)
    modules = list(resnet.children())[:-1]   # delete the last fc layer.
    convnet = nn.Sequential(*modules)

另一种更简单的方法:

resnet = models.resnet152(pretrained=True)
del resnet.fc

以上这篇pytorch实现用Resnet提取特征并保存为txt文件的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中关于字符串对象的一些基础知识
Apr 08 Python
解析Python中的异常处理
Apr 28 Python
详解C++编程中一元运算符的重载
Jan 19 Python
Python IDLE 错误:IDLE''s subprocess didn''t make connection 的解决方案
Feb 13 Python
Python切片操作深入详解
Jul 27 Python
python tkinter组件使用详解
Sep 16 Python
Pytorch 多维数组运算过程的索引处理方式
Dec 27 Python
tensorflow通过模型文件,使用tensorboard查看其模型图Graph方式
Jan 23 Python
python和JavaScript哪个容易上手
Jun 23 Python
appium+python自动化配置(adk、jdk、node.js)
Nov 17 Python
基于Django集成CAS实现流程详解
Nov 28 Python
python3代码中实现加法重载的实例
Dec 03 Python
python web框架 django wsgi原理解析
Aug 20 #Python
opencv转换颜色空间更改图片背景
Aug 20 #Python
pytorch 预训练层的使用方法
Aug 20 #Python
python爬虫 urllib模块反爬虫机制UA详解
Aug 20 #Python
Pytorch 抽取vgg各层并进行定制化处理的方法
Aug 20 #Python
python实现抠图给证件照换背景源码
Aug 20 #Python
python爬虫 基于requests模块发起ajax的get请求实现解析
Aug 20 #Python
You might like
DC动漫人物排行
2020/03/03 欧美动漫
PHP取进制余数函数代码
2012/01/19 PHP
php遍历文件夹下的所有文件和子文件夹示例
2014/03/20 PHP
10个简化PHP开发的工具
2014/12/25 PHP
PHP定时执行任务实现方法详解(Timer)
2015/07/30 PHP
适用于初学者的简易PHP文件上传类
2015/10/29 PHP
PHP实现绘制二叉树图形显示功能详解【包括二叉搜索树、平衡树及红黑树】
2017/11/16 PHP
ThinkPHP框架实现定时执行任务的两种方法分析
2018/09/04 PHP
PHP迭代器和生成器用法实例分析
2019/09/28 PHP
PHP读取文件,解决中文乱码UTF-8的方法分析
2020/01/22 PHP
Javascript技术技巧大全(五)
2007/01/22 Javascript
屏蔽F1~F12的快捷键的js函数
2010/05/06 Javascript
推荐JavaScript实现继承的最佳方式
2014/11/11 Javascript
使用jQuery Mobile框架开发移动端Web App的入门教程
2016/05/17 Javascript
JavaScript的Ext JS框架中的GridPanel组件使用指南
2016/05/21 Javascript
jQuery+css实现非常漂亮的水平导航菜单效果
2016/07/27 Javascript
深究AngularJS中$sce的使用
2017/06/12 Javascript
vue-awesome-swiper滑块插件使用方法详解
2017/11/27 Javascript
Vue-router 切换组件页面时进入进出动画方法
2018/09/01 Javascript
JS秒杀倒计时功能完整实例【使用jQuery3.1.1】
2019/09/03 jQuery
ng-alain的sf如何自定义部件的流程
2020/06/12 Javascript
详解Vue.js3.0 组件是如何渲染为DOM的
2020/11/10 Javascript
Python黑帽编程 3.4 跨越VLAN详解
2016/09/28 Python
Python实现FTP上传文件或文件夹实例(递归)
2017/01/16 Python
Python numpy 提取矩阵的某一行或某一列的实例
2018/04/03 Python
Django中create和save方法的不同
2019/08/13 Python
python requests证书问题解决
2019/09/05 Python
Django数据统计功能count()的使用
2020/11/30 Python
柯基袜:Corgi Socks
2017/01/26 全球购物
知识竞赛活动方案
2014/02/18 职场文书
党员评议表自我评价范文
2014/10/20 职场文书
在职证明格式样本
2015/06/15 职场文书
教你怎么用PyCharm为同一服务器配置多个python解释器
2021/05/31 Python
Python scrapy爬取起点中文网小说榜单
2021/06/13 Python
Django+Celery实现定时任务的示例
2021/06/23 Python
SQL Server中搜索特定的对象
2022/05/25 SQL Server