pytorch实现用Resnet提取特征并保存为txt文件的方法


Posted in Python onAugust 20, 2019

接触pytorch一天,发现pytorch上手的确比TensorFlow更快。可以更方便地实现用预训练的网络提特征。

以下是提取一张jpg图像的特征的程序:

# -*- coding: utf-8 -*-
 
import os.path
 
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.autograd import Variable 
 
import numpy as np
from PIL import Image 
 
features_dir = './features'
 
img_path = "hymenoptera_data/train/ants/0013035.jpg"
file_name = img_path.split('/')[-1]
feature_path = os.path.join(features_dir, file_name + '.txt')
 
 
transform1 = transforms.Compose([
    transforms.Scale(256),
    transforms.CenterCrop(224),
    transforms.ToTensor()  ]
)
 
img = Image.open(img_path)
img1 = transform1(img)
 
#resnet18 = models.resnet18(pretrained = True)
resnet50_feature_extractor = models.resnet50(pretrained = True)
resnet50_feature_extractor.fc = nn.Linear(2048, 2048)
torch.nn.init.eye(resnet50_feature_extractor.fc.weight)
 
for param in resnet50_feature_extractor.parameters():
  param.requires_grad = False
#resnet152 = models.resnet152(pretrained = True)
#densenet201 = models.densenet201(pretrained = True) 
x = Variable(torch.unsqueeze(img1, dim=0).float(), requires_grad=False)
#y1 = resnet18(x)
y = resnet50_feature_extractor(x)
y = y.data.numpy()
np.savetxt(feature_path, y, delimiter=',')
#y3 = resnet152(x)
#y4 = densenet201(x)
 
y_ = np.loadtxt(feature_path, delimiter=',').reshape(1, 2048)

以下是提取一个文件夹下所有jpg、jpeg图像的程序:

# -*- coding: utf-8 -*-
import os, torch, glob
import numpy as np
from torch.autograd import Variable
from PIL import Image 
from torchvision import models, transforms
import torch.nn as nn
import shutil
data_dir = './hymenoptera_data'
features_dir = './features'
shutil.copytree(data_dir, os.path.join(features_dir, data_dir[2:]))
 
 
def extractor(img_path, saved_path, net, use_gpu):
  transform = transforms.Compose([
      transforms.Scale(256),
      transforms.CenterCrop(224),
      transforms.ToTensor()  ]
  )
  
  img = Image.open(img_path)
  img = transform(img)
  
 
 
  x = Variable(torch.unsqueeze(img, dim=0).float(), requires_grad=False)
  if use_gpu:
    x = x.cuda()
    net = net.cuda()
  y = net(x).cpu()
  y = y.data.numpy()
  np.savetxt(saved_path, y, delimiter=',')
  
if __name__ == '__main__':
  extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
    
  files_list = []
  sub_dirs = [x[0] for x in os.walk(data_dir) ]
  sub_dirs = sub_dirs[1:]
  for sub_dir in sub_dirs:
    for extention in extensions:
      file_glob = os.path.join(sub_dir, '*.' + extention)
      files_list.extend(glob.glob(file_glob))
    
  resnet50_feature_extractor = models.resnet50(pretrained = True)
  resnet50_feature_extractor.fc = nn.Linear(2048, 2048)
  torch.nn.init.eye(resnet50_feature_extractor.fc.weight)
  for param in resnet50_feature_extractor.parameters():
    param.requires_grad = False  
    
  use_gpu = torch.cuda.is_available()
 
  for x_path in files_list:
    print(x_path)
    fx_path = os.path.join(features_dir, x_path[2:] + '.txt')
    extractor(x_path, fx_path, resnet50_feature_extractor, use_gpu)

另外最近发现一个很简单的提取不含FC层的网络的方法:

resnet = models.resnet152(pretrained=True)
    modules = list(resnet.children())[:-1]   # delete the last fc layer.
    convnet = nn.Sequential(*modules)

另一种更简单的方法:

resnet = models.resnet152(pretrained=True)
del resnet.fc

以上这篇pytorch实现用Resnet提取特征并保存为txt文件的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python获取远程文件大小的函数代码分享
May 13 Python
Python contextlib模块使用示例
Feb 18 Python
举例介绍Python中的25个隐藏特性
Mar 30 Python
Python调用命令行进度条的方法
May 05 Python
Python中的getopt函数使用详解
Jul 28 Python
tensorflow: variable的值与variable.read_value()的值区别详解
Jul 30 Python
pandas 把数据写入txt文件每行固定写入一定数量的值方法
Dec 28 Python
python3实现猜数字游戏
Dec 07 Python
Python中栈、队列与优先级队列的实现方法
Jun 30 Python
pytorch 实现张量tensor,图片,CPU,GPU,数组等的转换
Jan 13 Python
python对Excel的读取的示例代码
Feb 14 Python
python 实现表情识别
Nov 21 Python
python web框架 django wsgi原理解析
Aug 20 #Python
opencv转换颜色空间更改图片背景
Aug 20 #Python
pytorch 预训练层的使用方法
Aug 20 #Python
python爬虫 urllib模块反爬虫机制UA详解
Aug 20 #Python
Pytorch 抽取vgg各层并进行定制化处理的方法
Aug 20 #Python
python实现抠图给证件照换背景源码
Aug 20 #Python
python爬虫 基于requests模块发起ajax的get请求实现解析
Aug 20 #Python
You might like
Smarty模板快速入门
2007/01/04 PHP
php 正则 过滤html 的超链接
2009/06/02 PHP
利用php生成验证码
2017/02/23 PHP
JS实现点击图片在当前页面放大并可关闭的漂亮效果
2013/10/18 Javascript
node.js适合游戏后台开发吗?
2014/09/03 Javascript
如何在MVC应用程序中使用Jquery
2014/11/17 Javascript
DWR3 访问WEB元素的两种方法实例详解
2017/01/03 Javascript
完美解决spring websocket自动断开连接再创建引发的问题
2017/03/02 Javascript
js实现悬浮窗效果(支持拖动)
2017/03/09 Javascript
ReactNative之键盘Keyboard的弹出与消失示例
2017/07/11 Javascript
JavaScript使用Ajax上传文件的示例代码
2017/08/10 Javascript
Vue ElementUi同时校验多个表单(巧用new promise)
2018/06/06 Javascript
bootstrap实现点击删除按钮弹出确认框的实例代码
2018/08/16 Javascript
Node.js EventEmmitter事件监听器用法实例分析
2019/01/07 Javascript
vue cli 3.0 搭建项目的图文教程
2019/05/17 Javascript
用python删除java文件头上版权信息的方法
2014/07/31 Python
python中执行shell命令的几个方法小结
2014/09/18 Python
Python基于matplotlib绘制栈式直方图的方法示例
2017/08/09 Python
python实现按长宽比缩放图片
2018/06/07 Python
Python字典对象实现原理详解
2019/07/01 Python
如何利用Python开发一个简单的猜数字游戏
2019/09/22 Python
Python实现ATM系统
2020/02/17 Python
使用Python和百度语音识别生成视频字幕的实现
2020/04/09 Python
python中requests模拟登录的三种方式(携带cookie/session进行请求网站)
2020/11/17 Python
彻底弄明白CSS3的Media Queries(跨平台设计)
2010/07/27 HTML / CSS
Stefania Mode美国:奢华设计师和时尚服装
2018/01/07 全球购物
美国家居装饰网上商店:Lulu & Georgia
2019/09/14 全球购物
《得道多助,失道寡助》教学反思
2014/04/19 职场文书
群众路线领导对照材料
2014/08/23 职场文书
防灾减灾日活动总结
2014/08/26 职场文书
党员志愿者活动方案
2014/08/28 职场文书
停电调休通知
2015/04/16 职场文书
高一语文教学反思
2016/02/16 职场文书
纯CSS如何禁止用户复制网页的内容
2021/11/01 HTML / CSS
JS数组去重详情
2021/11/07 Javascript
基于docker安装zabbix的详细教程
2022/06/05 Servers