pytorch实现用Resnet提取特征并保存为txt文件的方法


Posted in Python onAugust 20, 2019

接触pytorch一天,发现pytorch上手的确比TensorFlow更快。可以更方便地实现用预训练的网络提特征。

以下是提取一张jpg图像的特征的程序:

# -*- coding: utf-8 -*-
 
import os.path
 
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.autograd import Variable 
 
import numpy as np
from PIL import Image 
 
features_dir = './features'
 
img_path = "hymenoptera_data/train/ants/0013035.jpg"
file_name = img_path.split('/')[-1]
feature_path = os.path.join(features_dir, file_name + '.txt')
 
 
transform1 = transforms.Compose([
    transforms.Scale(256),
    transforms.CenterCrop(224),
    transforms.ToTensor()  ]
)
 
img = Image.open(img_path)
img1 = transform1(img)
 
#resnet18 = models.resnet18(pretrained = True)
resnet50_feature_extractor = models.resnet50(pretrained = True)
resnet50_feature_extractor.fc = nn.Linear(2048, 2048)
torch.nn.init.eye(resnet50_feature_extractor.fc.weight)
 
for param in resnet50_feature_extractor.parameters():
  param.requires_grad = False
#resnet152 = models.resnet152(pretrained = True)
#densenet201 = models.densenet201(pretrained = True) 
x = Variable(torch.unsqueeze(img1, dim=0).float(), requires_grad=False)
#y1 = resnet18(x)
y = resnet50_feature_extractor(x)
y = y.data.numpy()
np.savetxt(feature_path, y, delimiter=',')
#y3 = resnet152(x)
#y4 = densenet201(x)
 
y_ = np.loadtxt(feature_path, delimiter=',').reshape(1, 2048)

以下是提取一个文件夹下所有jpg、jpeg图像的程序:

# -*- coding: utf-8 -*-
import os, torch, glob
import numpy as np
from torch.autograd import Variable
from PIL import Image 
from torchvision import models, transforms
import torch.nn as nn
import shutil
data_dir = './hymenoptera_data'
features_dir = './features'
shutil.copytree(data_dir, os.path.join(features_dir, data_dir[2:]))
 
 
def extractor(img_path, saved_path, net, use_gpu):
  transform = transforms.Compose([
      transforms.Scale(256),
      transforms.CenterCrop(224),
      transforms.ToTensor()  ]
  )
  
  img = Image.open(img_path)
  img = transform(img)
  
 
 
  x = Variable(torch.unsqueeze(img, dim=0).float(), requires_grad=False)
  if use_gpu:
    x = x.cuda()
    net = net.cuda()
  y = net(x).cpu()
  y = y.data.numpy()
  np.savetxt(saved_path, y, delimiter=',')
  
if __name__ == '__main__':
  extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
    
  files_list = []
  sub_dirs = [x[0] for x in os.walk(data_dir) ]
  sub_dirs = sub_dirs[1:]
  for sub_dir in sub_dirs:
    for extention in extensions:
      file_glob = os.path.join(sub_dir, '*.' + extention)
      files_list.extend(glob.glob(file_glob))
    
  resnet50_feature_extractor = models.resnet50(pretrained = True)
  resnet50_feature_extractor.fc = nn.Linear(2048, 2048)
  torch.nn.init.eye(resnet50_feature_extractor.fc.weight)
  for param in resnet50_feature_extractor.parameters():
    param.requires_grad = False  
    
  use_gpu = torch.cuda.is_available()
 
  for x_path in files_list:
    print(x_path)
    fx_path = os.path.join(features_dir, x_path[2:] + '.txt')
    extractor(x_path, fx_path, resnet50_feature_extractor, use_gpu)

另外最近发现一个很简单的提取不含FC层的网络的方法:

resnet = models.resnet152(pretrained=True)
    modules = list(resnet.children())[:-1]   # delete the last fc layer.
    convnet = nn.Sequential(*modules)

另一种更简单的方法:

resnet = models.resnet152(pretrained=True)
del resnet.fc

以上这篇pytorch实现用Resnet提取特征并保存为txt文件的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python的ORM框架SQLAlchemy入门教程
Apr 28 Python
python非递归全排列实现方法
Apr 10 Python
使用python 爬虫抓站的一些技巧总结
Jan 10 Python
python 读取视频,处理后,实时计算帧数fps的方法
Jul 10 Python
Python玩转加密的技巧【推荐】
May 13 Python
Python连接字符串过程详解
Jan 06 Python
利用Python制作动态排名图的实现代码
Apr 09 Python
python和c语言哪个更适合初学者
Jun 22 Python
python之语音识别speech模块
Sep 09 Python
基于python爬取链家二手房信息代码示例
Oct 21 Python
matplotlib之属性组合包(cycler)的使用
Feb 24 Python
pycharm无法安装cv2模块问题
May 20 Python
python web框架 django wsgi原理解析
Aug 20 #Python
opencv转换颜色空间更改图片背景
Aug 20 #Python
pytorch 预训练层的使用方法
Aug 20 #Python
python爬虫 urllib模块反爬虫机制UA详解
Aug 20 #Python
Pytorch 抽取vgg各层并进行定制化处理的方法
Aug 20 #Python
python实现抠图给证件照换背景源码
Aug 20 #Python
python爬虫 基于requests模块发起ajax的get请求实现解析
Aug 20 #Python
You might like
PHP截取中文字符串的问题
2006/07/12 PHP
解析php中mysql_connect与mysql_pconncet的区别详解
2013/05/15 PHP
php jq jquery getJSON跨域提交数据完整版
2013/09/13 PHP
学习面向对象之面向对象的基本概念:对象和其他基本要素
2010/11/30 Javascript
关于javascript function对象那些迷惑分析
2011/10/24 Javascript
5个javascript的数字格式化函数分享
2011/12/07 Javascript
js优化针对IE6.0起作用(详细整理)
2012/12/25 Javascript
jquery动态改变onclick属性导致失效的问题解决方法
2013/12/04 Javascript
jQuery中验证表单提交方式及序列化表单内容的实现
2014/01/06 Javascript
jQuery与getJson结合的用法实例
2015/08/07 Javascript
smartcrop.js智能图片裁剪库
2015/10/14 Javascript
基于jQuery的checkbox全选问题分析
2016/11/18 Javascript
Angular 4依赖注入学习教程之简介(一)
2017/06/04 Javascript
微信小程序 Buffer缓冲区的详解
2017/07/06 Javascript
javascript 日期相减-在线教程(附代码)
2017/08/17 Javascript
vue实现样式之间的切换及vue动态样式的实现方法
2017/12/19 Javascript
bootstrap+jquery项目引入文件报错的解决方法
2018/01/22 jQuery
vue路由跳转传递参数的方式总结
2020/05/10 Javascript
Python版的文曲星猜数字游戏代码
2013/09/02 Python
在Python的Django框架中使用通用视图的方法
2015/07/21 Python
Python实现删除文件中含“指定内容”的行示例
2017/06/09 Python
python爬取拉勾网职位数据的方法
2018/01/24 Python
基于Python的ModbusTCP客户端实现详解
2019/07/13 Python
Python OpenCV实现测量图片物体宽度
2020/05/27 Python
Django实现前台上传并显示图片功能
2020/05/29 Python
详解Python设计模式之策略模式
2020/06/15 Python
天巡全球:Skyscanner Global
2017/06/20 全球购物
智能家居、吸尘器、滑板车、电动自行车网上购物:Geekmaxi
2021/01/18 全球购物
2013年大学生的自我鉴定
2013/10/24 职场文书
工程项目经理岗位职责
2013/12/15 职场文书
竞聘演讲稿
2014/04/24 职场文书
暑期政治学习心得体会
2014/09/02 职场文书
个人房屋买卖协议书(范本)
2014/10/04 职场文书
创业计划书之零食店(进口)
2019/09/24 职场文书
Python爬虫入门案例之爬取二手房源数据
2021/10/16 Python
Java9新特性对HTTP2协议支持与非阻塞HTTP API
2022/03/16 Java/Android