pytorch实现用Resnet提取特征并保存为txt文件的方法


Posted in Python onAugust 20, 2019

接触pytorch一天,发现pytorch上手的确比TensorFlow更快。可以更方便地实现用预训练的网络提特征。

以下是提取一张jpg图像的特征的程序:

# -*- coding: utf-8 -*-
 
import os.path
 
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.autograd import Variable 
 
import numpy as np
from PIL import Image 
 
features_dir = './features'
 
img_path = "hymenoptera_data/train/ants/0013035.jpg"
file_name = img_path.split('/')[-1]
feature_path = os.path.join(features_dir, file_name + '.txt')
 
 
transform1 = transforms.Compose([
    transforms.Scale(256),
    transforms.CenterCrop(224),
    transforms.ToTensor()  ]
)
 
img = Image.open(img_path)
img1 = transform1(img)
 
#resnet18 = models.resnet18(pretrained = True)
resnet50_feature_extractor = models.resnet50(pretrained = True)
resnet50_feature_extractor.fc = nn.Linear(2048, 2048)
torch.nn.init.eye(resnet50_feature_extractor.fc.weight)
 
for param in resnet50_feature_extractor.parameters():
  param.requires_grad = False
#resnet152 = models.resnet152(pretrained = True)
#densenet201 = models.densenet201(pretrained = True) 
x = Variable(torch.unsqueeze(img1, dim=0).float(), requires_grad=False)
#y1 = resnet18(x)
y = resnet50_feature_extractor(x)
y = y.data.numpy()
np.savetxt(feature_path, y, delimiter=',')
#y3 = resnet152(x)
#y4 = densenet201(x)
 
y_ = np.loadtxt(feature_path, delimiter=',').reshape(1, 2048)

以下是提取一个文件夹下所有jpg、jpeg图像的程序:

# -*- coding: utf-8 -*-
import os, torch, glob
import numpy as np
from torch.autograd import Variable
from PIL import Image 
from torchvision import models, transforms
import torch.nn as nn
import shutil
data_dir = './hymenoptera_data'
features_dir = './features'
shutil.copytree(data_dir, os.path.join(features_dir, data_dir[2:]))
 
 
def extractor(img_path, saved_path, net, use_gpu):
  transform = transforms.Compose([
      transforms.Scale(256),
      transforms.CenterCrop(224),
      transforms.ToTensor()  ]
  )
  
  img = Image.open(img_path)
  img = transform(img)
  
 
 
  x = Variable(torch.unsqueeze(img, dim=0).float(), requires_grad=False)
  if use_gpu:
    x = x.cuda()
    net = net.cuda()
  y = net(x).cpu()
  y = y.data.numpy()
  np.savetxt(saved_path, y, delimiter=',')
  
if __name__ == '__main__':
  extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
    
  files_list = []
  sub_dirs = [x[0] for x in os.walk(data_dir) ]
  sub_dirs = sub_dirs[1:]
  for sub_dir in sub_dirs:
    for extention in extensions:
      file_glob = os.path.join(sub_dir, '*.' + extention)
      files_list.extend(glob.glob(file_glob))
    
  resnet50_feature_extractor = models.resnet50(pretrained = True)
  resnet50_feature_extractor.fc = nn.Linear(2048, 2048)
  torch.nn.init.eye(resnet50_feature_extractor.fc.weight)
  for param in resnet50_feature_extractor.parameters():
    param.requires_grad = False  
    
  use_gpu = torch.cuda.is_available()
 
  for x_path in files_list:
    print(x_path)
    fx_path = os.path.join(features_dir, x_path[2:] + '.txt')
    extractor(x_path, fx_path, resnet50_feature_extractor, use_gpu)

另外最近发现一个很简单的提取不含FC层的网络的方法:

resnet = models.resnet152(pretrained=True)
    modules = list(resnet.children())[:-1]   # delete the last fc layer.
    convnet = nn.Sequential(*modules)

另一种更简单的方法:

resnet = models.resnet152(pretrained=True)
del resnet.fc

以上这篇pytorch实现用Resnet提取特征并保存为txt文件的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python多线程编程(七):使用Condition实现复杂同步
Apr 05 Python
【Python】Python的urllib模块、urllib2模块批量进行网页下载文件
Nov 19 Python
python入门前的第一课 python怎样入门
Mar 06 Python
python实现图书管理系统
Mar 12 Python
python针对excel的操作技巧
Mar 13 Python
Python3.4学习笔记之列表、数组操作示例
Mar 01 Python
python实现对输入的密文加密
Mar 20 Python
python set内置函数的具体使用
Jul 02 Python
TensorFlow实现自定义Op方式
Feb 04 Python
keras 自定义loss层+接受输入实例
Jun 28 Python
python可视化 matplotlib画图使用colorbar工具自定义颜色
Dec 07 Python
如何基于python实现单目三维重建详解
Jun 25 Python
python web框架 django wsgi原理解析
Aug 20 #Python
opencv转换颜色空间更改图片背景
Aug 20 #Python
pytorch 预训练层的使用方法
Aug 20 #Python
python爬虫 urllib模块反爬虫机制UA详解
Aug 20 #Python
Pytorch 抽取vgg各层并进行定制化处理的方法
Aug 20 #Python
python实现抠图给证件照换背景源码
Aug 20 #Python
python爬虫 基于requests模块发起ajax的get请求实现解析
Aug 20 #Python
You might like
PHP编码规范之注释和文件结构说明
2010/07/09 PHP
php设计模式之备忘模式分析【星际争霸游戏案例】
2020/03/24 PHP
JavaScript去掉空格的方法集合
2010/12/28 Javascript
多个datatable共存造成多个表格的checkbox都被选中
2013/07/11 Javascript
优化Jquery,提升网页加载速度
2013/11/14 Javascript
javascript将浮点数转换成整数的三个方法
2014/06/23 Javascript
JavaScript中合并数组的N种方法
2014/09/16 Javascript
JavaScript中的ArrayBuffer详细介绍
2014/12/08 Javascript
使用mouse事件实现简单的鼠标经过特效
2015/01/30 Javascript
Javascript实现div层渐隐效果的方法
2015/05/30 Javascript
Javascript基础_嵌入图像的简单实现
2016/06/14 Javascript
深入理解js generator数据类型
2016/08/16 Javascript
微信小程序开发的四十个技术窍门总结(推荐)
2017/01/23 Javascript
详解VueJs异步动态加载块
2017/03/09 Javascript
js获取ip和地区
2017/03/10 Javascript
使用bootstrap插件实现模态框效果
2017/05/10 Javascript
NodeJS实现微信公众号关注后自动回复功能
2017/05/31 NodeJs
echarts饼图扇区添加点击事件的实例
2017/10/16 Javascript
[00:36]TI7不朽珍藏III——斯温不朽展示
2017/07/15 DOTA
[43:03]LGD vs Newbee 2019国际邀请赛小组赛 BO2 第一场 8.16
2019/08/19 DOTA
Python实现批量检测HTTP服务的状态
2016/10/27 Python
Flask解决跨域的问题示例代码
2018/02/12 Python
pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率
2020/01/02 Python
Python 解决相对路径问题:"No such file or directory"
2020/06/05 Python
css3中检验表单的required,focus,valid和invalid样式
2014/02/21 HTML / CSS
Java多态性的定义以及类型
2014/09/16 面试题
工作会议欢迎词
2014/01/16 职场文书
企业总经理岗位职责
2014/02/13 职场文书
十佳护士先进事迹
2014/05/08 职场文书
无财产离婚协议书范本
2014/10/28 职场文书
教师党的群众路线教育实践活动学习笔记
2014/11/05 职场文书
初中生300字旷课检讨书
2014/11/19 职场文书
教师求职信怎么写
2015/03/20 职场文书
运动会开幕式致辞
2015/07/29 职场文书
Python实现数据的序列化操作详解
2022/07/07 Python
vue实现input输入模糊查询的三种方式
2022/08/14 Vue.js