Python卷积神经网络图片分类框架详解分析


Posted in Python onNovember 07, 2021

【人工智能项目】卷积神经网络图片分类框架

Python卷积神经网络图片分类框架详解分析


本次硬核分享当时做图片分类的工作,主要是整理了一个图片分类的框架,如果想换模型,引入新模型,在config中修改即可。那么走起来瓷!!!

Python卷积神经网络图片分类框架详解分析

整体结构

Python卷积神经网络图片分类框架详解分析

config

在config文件夹下的config.py中主要定义数据集的位置,训练轮数,batch_size以及本次选用的模型。

# 定义训练集和测试集的路径
train_data_path = "./data/train/"
train_anno_path = "./data/train.csv"
test_data_path = "./data/test/"
# 定义多线程
num_workers = 8
# 定义batch_size大小
batch_size = 8

# 定义训练轮数
epochs = 20
# 定义k折交叉验证
k = 5
# 定义模型选择
# inception_v3_google inceptionv4
# vgg16
# resnet50 resnet101 resnet152 resnext50_32x4d resnext101_32x8d wide_resnet50_2  wide_resnet101_2
# senet154 se_resnet50 se_resnet101  se_resnet152  se_resnext50_32x4d  se_resnext101_32x4d
# nasnetalarge  pnasnet5large
# densenet121 densenet161 densenet169 densenet201
# efficientnet-b0 efficientnet-b1 efficientnet-b2 efficientnet-b3 efficientnet-b4 efficientnet-b5 efficientnet-b6 efficientnet-b7
# xception
# squeezenet1_0 squeezenet1_1
# mobilenet_v2
# mnasnet0_5 mnasnet0_75 mnasnet1_0 mnasnet1_3
# shufflenet_v2_x0_5 shufflenet_v2_x1_0
model_name = "vgg16"

# 定义分类类别
num_classes = 102
# 定义图片尺寸
img_width = 320
img_height = 320

data

data文件夹存放了train和test图片信息。

Python卷积神经网络图片分类框架详解分析


在train.csv中的存放图片名称以及对应的标签

Python卷积神经网络图片分类框架详解分析

dataloader

dataloader里面主要有data.py和data_augmentation.py文件。其中一个用于读取数据,另外一个用于数据增强操作。

import torch
from PIL import Image
from torch.utils.data.dataset import Dataset
import numpy as np
import PIL
from torchvision import transforms
from config import config
import  os
import cv2
# 定义DataSet和Transform


# 将df转换成标准的numpy array形式
def get_anno(path, images_path):
    data = []
    with open(path) as f:
        for line in f:
            idx, label = line.strip().split(',')
            data.append((os.path.join(images_path, idx), int(label)))
    return np.array(data)

# 定义读取trainData,读取df文件
# 通过df的idx,来获取image_path和label
class trainDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

    def __getitem__(self, idx):
        img_path, label = self.data[idx]
        img = Image.open(img_path).convert('RGB')
        #img = cv2.imread(img_path)
        #img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        if self.transform is not None:
            img = self.transform(img)
        return img, int(label)

    def __len__(self):
        return len(self.data)



# 通过文件路径来读取测试图片
class testDataset(Dataset):
    def __init__(self, img_path, transform=None):
        self.img_path = img_path
        if transform is not None:
            self.transform = transform
        else:
            self.transform = None

    def __getitem__(self, index):
        img = Image.open(self.img_path[index]).convert('RGB')
        # img = cv2.imread(self.img_path[index])
        # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        if self.transform is not None:
            img = self.transform(img)
        return img

    def __len__(self):
        return len(self.img_path)


# train_transform = transforms.Compose([
#     transforms.Resize([config.img_width, config.img_height]),
#     transforms.RandomRotation(10),
#     transforms.ColorJitter(brightness=0.3, contrast=0.2),
#     transforms.RandomHorizontalFlip(),
#     transforms.ToTensor(),  # range [0, 255] -> [0.0,1.0]
#     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# ])

train_transform = transforms.Compose([
    transforms.Pad(4, padding_mode='reflect'),
    transforms.RandomRotation(10),
    transforms.RandomResizedCrop([config.img_width, config.img_height]),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.RandomResizedCrop([config.img_width, config.img_height]),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.RandomResizedCrop([config.img_width, config.img_height]),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
import random

from __future__ import division
import cv2
import numpy as np
from numpy import random
import math
from sklearn.utils import shuffle

# 固定角度随机旋转
class FixedRotation(object):
    def __init__(self, angles):
        self.angles = angles

    def __call__(self, img):
        return fixed_rotate(img, self.angles)


def fixed_rotate(img, angles):
    angles = list(angles)
    angles_num = len(angles)
    index = random.randint(0, angles_num - 1)
    return img.rotate(angles[index])



__all__ = ['Compose','RandomHflip', 'RandomUpperCrop', 'Resize', 'UpperCrop', 'RandomBottomCrop',"RandomErasing",
           'BottomCrop', 'Normalize', 'RandomSwapChannels', 'RandomRotate', 'RandomHShift',"CenterCrop","RandomVflip",
           'ExpandBorder', 'RandomResizedCrop','RandomDownCrop', 'DownCrop', 'ResizedCrop',"FixRandomRotate"]

def rotate_nobound(image, angle, center=None, scale=1.):
    (h, w) = image.shape[:2]


    # if the center is None, initialize it as the center of
    # the image
    if center is None:
        center = (w // 2, h // 2)

    # perform the rotation
    M = cv2.getRotationMatrix2D(center, angle, scale)
    rotated = cv2.warpAffine(image, M, (w, h))

    return rotated

def scale_down(src_size, size):
    w, h = size
    sw, sh = src_size
    if sh < h:
        w, h = float(w * sh) / h, sh
    if sw < w:
        w, h = sw, float(h * sw) / w
    return int(w), int(h)


def fixed_crop(src, x0, y0, w, h, size=None):
    out = src[y0:y0 + h, x0:x0 + w]
    if size is not None and (w, h) != size:
        out = cv2.resize(out, (size[0], size[1]), interpolation=cv2.INTER_CUBIC)
    return out

class FixRandomRotate(object):
    def __init__(self, angles=[0,90,180,270], bound=False):
        self.angles = angles
        self.bound = bound

    def __call__(self,img):
        do_rotate = random.randint(0, 4)
        angle=self.angles[do_rotate]
        if self.bound:
            img = rotate_bound(img, angle)
        else:
            img = rotate_nobound(img, angle)
        return img

def center_crop(src, size):
    h, w = src.shape[0:2]
    new_w, new_h = scale_down((w, h), size)

    x0 = int((w - new_w) / 2)
    y0 = int((h - new_h) / 2)

    out = fixed_crop(src, x0, y0, new_w, new_h, size)
    return out


def bottom_crop(src, size):
    h, w = src.shape[0:2]
    new_w, new_h = scale_down((w, h), size)

    x0 = int((w - new_w) / 2)
    y0 = int((h - new_h) * 0.75)

    out = fixed_crop(src, x0, y0, new_w, new_h, size)
    return out

def rotate_bound(image, angle):
    # grab the dimensions of the image and then determine the
    # center
    h, w = image.shape[:2]

    (cX, cY) = (w // 2, h // 2)

    M = cv2.getRotationMatrix2D((cX, cY), angle, 1.0)
    cos = np.abs(M[0, 0])
    sin = np.abs(M[0, 1])

    # compute the new bounding dimensions of the image
    nW = int((h * sin) + (w * cos))
    nH = int((h * cos) + (w * sin))

    # adjust the rotation matrix to take into account translation
    M[0, 2] += (nW / 2) - cX
    M[1, 2] += (nH / 2) - cY

    rotated = cv2.warpAffine(image, M, (nW, nH))

    return rotated


class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms
    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img
class RandomRotate(object):
    def __init__(self, angles, bound=False):
        self.angles = angles
        self.bound = bound

    def __call__(self,img):
        do_rotate = random.randint(0, 2)
        if do_rotate:
            angle = np.random.uniform(self.angles[0], self.angles[1])
            if self.bound:
                img = rotate_bound(img, angle)
            else:
                img = rotate_nobound(img, angle)
        return img
class RandomBrightness(object):
    def __init__(self, delta=10):
        assert delta >= 0
        assert delta <= 255
        self.delta = delta

    def __call__(self, image):
        if random.randint(2):
            delta = random.uniform(-self.delta, self.delta)
            image = (image + delta).clip(0.0, 255.0)
            # print('RandomBrightness,delta ',delta)
        return image


class RandomContrast(object):
    def __init__(self, lower=0.9, upper=1.05):
        self.lower = lower
        self.upper = upper
        assert self.upper >= self.lower, "contrast upper must be >= lower."
        assert self.lower >= 0, "contrast lower must be non-negative."

    # expects float image
    def __call__(self, image):
        if random.randint(2):
            alpha = random.uniform(self.lower, self.upper)
            # print('contrast:', alpha)
            image = (image * alpha).clip(0.0,255.0)
        return image


class RandomSaturation(object):
    def __init__(self, lower=0.8, upper=1.2):
        self.lower = lower
        self.upper = upper
        assert self.upper >= self.lower, "contrast upper must be >= lower."
        assert self.lower >= 0, "contrast lower must be non-negative."

    def __call__(self, image):
        if random.randint(2):
            alpha = random.uniform(self.lower, self.upper)
            image[:, :, 1] *= alpha
            # print('RandomSaturation,alpha',alpha)
        return image


class RandomHue(object):
    def __init__(self, delta=18.0):
        assert delta >= 0.0 and delta <= 360.0
        self.delta = delta

    def __call__(self, image):
        if random.randint(2):
            alpha = random.uniform(-self.delta, self.delta)
            image[:, :, 0] += alpha
            image[:, :, 0][image[:, :, 0] > 360.0] -= 360.0
            image[:, :, 0][image[:, :, 0] < 0.0] += 360.0
            # print('RandomHue,alpha:', alpha)
        return image


class ConvertColor(object):
    def __init__(self, current='BGR', transform='HSV'):
        self.transform = transform
        self.current = current

    def __call__(self, image):
        if self.current == 'BGR' and self.transform == 'HSV':
            image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
        elif self.current == 'HSV' and self.transform == 'BGR':
            image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
        else:
            raise NotImplementedError
        return image

class RandomSwapChannels(object):
    def __call__(self, img):
        if np.random.randint(2):
            order = np.random.permutation(3)
            return img[:,:,order]
        return img

class RandomCrop(object):
    def __init__(self, size):
        self.size = size
    def __call__(self, image):
        h, w, _ = image.shape
        new_w, new_h = scale_down((w, h), self.size)

        if w == new_w:
            x0 = 0
        else:
            x0 = random.randint(0, w - new_w)

        if h == new_h:
            y0 = 0
        else:
            y0 = random.randint(0, h - new_h)

        out = fixed_crop(image, x0, y0, new_w, new_h, self.size)
        return out



class RandomResizedCrop(object):
    def __init__(self, size,scale=(0.49, 1.0), ratio=(1., 1.)):
        self.size = size
        self.scale = scale
        self.ratio = ratio

    def __call__(self,img):
        if random.random() < 0.2:
            return cv2.resize(img,self.size)
        h, w, _ = img.shape
        area = h * w
        d=1
        for attempt in range(10):
            target_area = random.uniform(self.scale[0], self.scale[1]) * area
            aspect_ratio = random.uniform(self.ratio[0], self.ratio[1])


            new_w = int(round(math.sqrt(target_area * aspect_ratio)))
            new_h = int(round(math.sqrt(target_area / aspect_ratio)))

            if random.random() < 0.5:
                new_h, new_w = new_w, new_h

            if new_w < w and new_h < h:
                x0 = random.randint(0, w - new_w)
                y0 = (random.randint(0, h - new_h))//d
                out = fixed_crop(img, x0, y0, new_w, new_h, self.size)

                return out

        # Fallback
        return center_crop(img, self.size)


class DownCrop():
    def __init__(self, size,  select, scale=(0.36,0.81)):
        self.size = size
        self.scale = scale
        self.select = select

    def __call__(self,img, attr_idx):
        if attr_idx not in self.select:
            return img, attr_idx
        if attr_idx == 0:
            self.scale=(0.64,1.0)
        h, w, _ = img.shape
        area = h * w

        s = (self.scale[0]+self.scale[1])/2.0

        target_area = s * area

        new_w = int(round(math.sqrt(target_area)))
        new_h = int(round(math.sqrt(target_area)))

        if new_w < w and new_h < h:
            dw = w-new_w
            x0 = int(0.5*dw)
            y0 = h-new_h
            out = fixed_crop(img, x0, y0, new_w, new_h, self.size)
            return out, attr_idx

        # Fallback
        return center_crop(img, self.size), attr_idx


class ResizedCrop(object):
    def __init__(self, size, select,scale=(0.64, 1.0), ratio=(3. / 4., 4. / 3.)):
        self.size = size
        self.scale = scale
        self.ratio = ratio
        self.select = select

    def __call__(self,img, attr_idx):
        if attr_idx not in self.select:
            return img, attr_idx
        h, w, _ = img.shape
        area = h * w
        d=1
        if attr_idx == 2:
            self.scale=(0.36,0.81)
            d=2
        if attr_idx == 0:
            self.scale=(0.81,1.0)

        target_area = (self.scale[0]+self.scale[1])/2.0 * area
        # aspect_ratio = random.uniform(self.ratio[0], self.ratio[1])


        new_w = int(round(math.sqrt(target_area)))
        new_h = int(round(math.sqrt(target_area)))

        # if random.random() < 0.5:
        #     new_h, new_w = new_w, new_h

        if new_w < w and new_h < h:
            x0 =  (w - new_w)//2
            y0 = (h - new_h)//d//2
            out = fixed_crop(img, x0, y0, new_w, new_h, self.size)
            # cv2.imshow('{}_img'.format(idx2attr_map[attr_idx]), img)
            # cv2.imshow('{}_crop'.format(idx2attr_map[attr_idx]), out)
            #
            # cv2.waitKey(0)
            return out, attr_idx

        # Fallback
        return center_crop(img, self.size), attr_idx

class RandomHflip(object):
    def __call__(self, image):
        if random.randint(2):
            return cv2.flip(image, 1)
        else:
            return image
class RandomVflip(object):
    def __call__(self, image):
        if random.randint(2):
            return cv2.flip(image, 0)
        else:
            return image


class Hflip(object):
    def __init__(self,doHflip):
        self.doHflip = doHflip

    def __call__(self, image):
        if self.doHflip:
            return cv2.flip(image, 1)
        else:
            return image


class CenterCrop(object):
    def __init__(self, size):
        self.size = size

    def __call__(self, image):
        return center_crop(image, self.size)

class UpperCrop():
    def __init__(self, size, scale=(0.09, 0.64)):
        self.size = size
        self.scale = scale

    def __call__(self,img):
        h, w, _ = img.shape
        area = h * w

        s = (self.scale[0]+self.scale[1])/2.0

        target_area = s * area

        new_w = int(round(math.sqrt(target_area)))
        new_h = int(round(math.sqrt(target_area)))

        if new_w < w and new_h < h:
            dw = w-new_w
            x0 = int(0.5*dw)
            y0 = 0
            out = fixed_crop(img, x0, y0, new_w, new_h, self.size)
            return out

        # Fallback
        return center_crop(img, self.size)



class RandomUpperCrop(object):
    def __init__(self, size, select, scale=(0.09, 0.64), ratio=(3. / 4., 4. / 3.)):
        self.size = size
        self.scale = scale
        self.ratio = ratio
        self.select = select

    def __call__(self,img, attr_idx):
        if random.random() < 0.2:
            return img, attr_idx
        if attr_idx not in self.select:
            return img, attr_idx

        h, w, _ = img.shape
        area = h * w
        for attempt in range(10):
            s = random.uniform(self.scale[0], self.scale[1])
            d = 0.1 + (0.3 - 0.1) / (self.scale[1] - self.scale[0]) * (s - self.scale[0])
            target_area = s * area
            aspect_ratio = random.uniform(self.ratio[0], self.ratio[1])
            new_w = int(round(math.sqrt(target_area * aspect_ratio)))
            new_h = int(round(math.sqrt(target_area / aspect_ratio)))


            # new_w = int(round(math.sqrt(target_area)))
            # new_h = int(round(math.sqrt(target_area)))

            if new_w < w and new_h < h:
                dw = w-new_w
                x0 = random.randint(int((0.5-d)*dw), int((0.5+d)*dw)+1)
                y0 = (random.randint(0, h - new_h))//10
                out = fixed_crop(img, x0, y0, new_w, new_h, self.size)
                return out, attr_idx

        # Fallback
        return center_crop(img, self.size), attr_idx
class RandomDownCrop(object):
    def __init__(self, size, select, scale=(0.36, 0.81), ratio=(3. / 4., 4. / 3.)):
        self.size = size
        self.scale = scale
        self.ratio = ratio
        self.select = select

    def __call__(self,img, attr_idx):
        if random.random() < 0.2:
            return img, attr_idx
        if attr_idx not in self.select:
            return img, attr_idx
        if attr_idx == 0:
            self.scale=(0.64,1.0)

        h, w, _ = img.shape
        area = h * w
        for attempt in range(10):
            s = random.uniform(self.scale[0], self.scale[1])
            d = 0.1 + (0.3 - 0.1) / (self.scale[1] - self.scale[0]) * (s - self.scale[0])
            target_area = s * area
            aspect_ratio = random.uniform(self.ratio[0], self.ratio[1])
            new_w = int(round(math.sqrt(target_area * aspect_ratio)))
            new_h = int(round(math.sqrt(target_area / aspect_ratio)))
            #
            # new_w = int(round(math.sqrt(target_area)))
            # new_h = int(round(math.sqrt(target_area)))

            if new_w < w and new_h < h:
                dw = w-new_w
                x0 = random.randint(int((0.5-d)*dw), int((0.5+d)*dw)+1)
                y0 = (random.randint((h - new_h)*9//10, h - new_h))
                out = fixed_crop(img, x0, y0, new_w, new_h, self.size)

                # cv2.imshow('{}_img'.format(idx2attr_map[attr_idx]), img)
                # cv2.imshow('{}_crop'.format(idx2attr_map[attr_idx]), out)
                #
                # cv2.waitKey(0)

                return out, attr_idx

        # Fallback
        return center_crop(img, self.size), attr_idx

class RandomHShift(object):
    def __init__(self, select, scale=(0.0, 0.2)):
        self.scale = scale
        self.select = select

    def __call__(self,img, attr_idx):
        if attr_idx not in self.select:
            return img, attr_idx
        do_shift_crop = random.randint(0, 2)
        if do_shift_crop:
            h, w, _ = img.shape
            min_shift = int(w*self.scale[0])
            max_shift = int(w*self.scale[1])
            shift_idx = random.randint(min_shift, max_shift)
            direction = random.randint(0,2)
            if direction:
                right_part = img[:, -shift_idx:, :]
                left_part = img[:, :-shift_idx, :]
            else:
                left_part = img[:, :shift_idx, :]
                right_part = img[:, shift_idx:, :]
            img = np.concatenate((right_part, left_part), axis=1)

        # Fallback
        return img, attr_idx


class RandomBottomCrop(object):
    def __init__(self, size, select, scale=(0.4, 0.8)):
        self.size = size
        self.scale = scale
        self.select = select

    def __call__(self,img, attr_idx):
        if attr_idx not in self.select:
            return img, attr_idx

        h, w, _ = img.shape
        area = h * w
        for attempt in range(10):
            s = random.uniform(self.scale[0], self.scale[1])
            d = 0.25 + (0.45 - 0.25) / (self.scale[1] - self.scale[0]) * (s - self.scale[0])
            target_area = s * area

            new_w = int(round(math.sqrt(target_area)))
            new_h = int(round(math.sqrt(target_area)))

            if new_w < w and new_h < h:
                dw = w-new_w
                dh = h - new_h
                x0 = random.randint(int((0.5-d)*dw), min(int((0.5+d)*dw)+1,dw))
                y0 = (random.randint(max(0,int(0.8*dh)-1), dh))
                out = fixed_crop(img, x0, y0, new_w, new_h, self.size)
                return out, attr_idx

        # Fallback
        return bottom_crop(img, self.size), attr_idx


class BottomCrop():
    def __init__(self, size,  select, scale=(0.4, 0.8)):
        self.size = size
        self.scale = scale
        self.select = select

    def __call__(self,img, attr_idx):
        if attr_idx not in self.select:
            return img, attr_idx

        h, w, _ = img.shape
        area = h * w

        s = (self.scale[0]+self.scale[1])/3.*2.

        target_area = s * area

        new_w = int(round(math.sqrt(target_area)))
        new_h = int(round(math.sqrt(target_area)))

        if new_w < w and new_h < h:
            dw = w-new_w
            dh = h-new_h
            x0 = int(0.5*dw)
            y0 = int(0.9*dh)
            out = fixed_crop(img, x0, y0, new_w, new_h, self.size)
            return out, attr_idx

        # Fallback
        return bottom_crop(img, self.size), attr_idx



class Resize(object):
    def __init__(self, size, inter=cv2.INTER_CUBIC):
        self.size = size
        self.inter = inter

    def __call__(self, image):
        return cv2.resize(image, (self.size[0], self.size[0]), interpolation=self.inter)

class ExpandBorder(object):
    def __init__(self,  mode='constant', value=255, size=(336,336), resize=False):
        self.mode = mode
        self.value = value
        self.resize = resize
        self.size = size

    def __call__(self, image):
        h, w, _ = image.shape
        if h > w:
            pad1 = (h-w)//2
            pad2 = h - w - pad1
            if self.mode == 'constant':
                image = np.pad(image, ((0, 0), (pad1, pad2), (0, 0)),
                               self.mode, constant_values=self.value)
            else:
                image = np.pad(image,((0,0), (pad1, pad2),(0,0)), self.mode)
        elif h < w:
            pad1 = (w-h)//2
            pad2 = w-h - pad1
            if self.mode == 'constant':
                image = np.pad(image, ((pad1, pad2),(0, 0), (0, 0)),
                               self.mode,constant_values=self.value)
            else:
                image = np.pad(image, ((pad1, pad2), (0, 0), (0, 0)),self.mode)
        if self.resize:
            image = cv2.resize(image, (self.size[0], self.size[0]),interpolation=cv2.INTER_LINEAR)
        return image
class AstypeToInt():
    def __call__(self, image, attr_idx):
        return image.clip(0,255.0).astype(np.uint8), attr_idx

class AstypeToFloat():
    def __call__(self, image, attr_idx):
        return image.astype(np.float32), attr_idx

import matplotlib.pyplot as plt
class Normalize(object):
    def __init__(self,mean, std):
        '''
        :param mean: RGB order
        :param std:  RGB order
        '''
        self.mean = np.array(mean).reshape(3,1,1)
        self.std = np.array(std).reshape(3,1,1)
    def __call__(self, image):
        '''
        :param image:  (H,W,3)  RGB
        :return:
        '''
        # plt.figure(1)
        # plt.imshow(image)
        # plt.show()
        return (image.transpose((2, 0, 1)) / 255. - self.mean) / self.std

class RandomErasing(object):
    def __init__(self, select,EPSILON=0.5,sl=0.02, sh=0.09, r1=0.3, mean=[0.485, 0.456, 0.406]):
        self.EPSILON = EPSILON
        self.mean = mean
        self.sl = sl
        self.sh = sh
        self.r1 = r1
        self.select = select

    def __call__(self, img,attr_idx):
        if attr_idx not in self.select:
            return img,attr_idx

        if random.uniform(0, 1) > self.EPSILON:
            return img,attr_idx

        for attempt in range(100):
            area = img.shape[1] * img.shape[2]

            target_area = random.uniform(self.sl, self.sh) * area
            aspect_ratio = random.uniform(self.r1, 1 / self.r1)

            h = int(round(math.sqrt(target_area * aspect_ratio)))
            w = int(round(math.sqrt(target_area / aspect_ratio)))

            if w <= img.shape[2] and h <= img.shape[1]:
                x1 = random.randint(0, img.shape[1] - h)
                y1 = random.randint(0, img.shape[2] - w)
                if img.shape[0] == 3:
                    # img[0, x1:x1+h, y1:y1+w] = random.uniform(0, 1)
                    # img[1, x1:x1+h, y1:y1+w] = random.uniform(0, 1)
                    # img[2, x1:x1+h, y1:y1+w] = random.uniform(0, 1)
                    img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
                    img[1, x1:x1 + h, y1:y1 + w] = self.mean[1]
                    img[2, x1:x1 + h, y1:y1 + w] = self.mean[2]
                    # img[:, x1:x1+h, y1:y1+w] = torch.from_numpy(np.random.rand(3, h, w))
                else:
                    img[0, x1:x1 + h, y1:y1 + w] = self.mean[1]
                    # img[0, x1:x1+h, y1:y1+w] = torch.from_numpy(np.random.rand(1, h, w))
                return img,attr_idx

        return img,attr_idx

# if __name__ == '__main__':
#     import matplotlib.pyplot as plt
#
#
#     class FSAug(object):
#         def __init__(self):
#             self.augment = Compose([
#                 AstypeToFloat(),
#                 # RandomHShift(scale=(0.,0.2),select=range(8)),
#                 # RandomRotate(angles=(-20., 20.), bound=True),
#                 ExpandBorder(select=range(8), mode='symmetric'),# symmetric
#                 # Resize(size=(336, 336), select=[ 2, 7]),
#                 AstypeToInt()
#             ])
#
#         def __call__(self, spct,attr_idx):
#             return self.augment(spct,attr_idx)
#
#
#     trans = FSAug()
#
#     img_path = '/media/gserver/data/FashionAI/round2/train/Images/coat_length_labels/0b6b4a2146fc8616a19fcf2026d61d50.jpg'
#     img = cv2.cvtColor(cv2.imread(img_path),cv2.COLOR_BGR2RGB)
#     img_trans,_ = trans(img,5)
#     # img_trans2,_ = trans(img,6)
#     print img_trans.max(), img_trans.min()
#     print img_trans.dtype
#
#     plt.figure()
#     plt.subplot(221)
#     plt.imshow(img)
#
#     plt.subplot(222)
#     plt.imshow(img_trans)
#
#     # plt.subplot(223)
#     # plt.imshow(img_trans2)
#     # plt.imshow(img_trans2)
#     plt.show()

factory

factory里面主要定义了一些学习率,损失函数,优化器等之类的。

Python卷积神经网络图片分类框架详解分析

models

models中主要定义了常见的分类模型。

Python卷积神经网络图片分类框架详解分析

train.py

import os
from sklearn.model_selection import KFold
from torchvision import transforms
import torch.utils.data
from dataloader.data import trainDataset,train_transform,val_transform,get_anno
from factory.loss import *
from models.model import Model
from config import config
import numpy as np
from utils import utils
from factory.LabelSmoothing import LSR


def train(model_type, prefix):
    # df -> numpy.array()形式
    data = get_anno(config.train_anno_path, config.train_data_path)
    # 5折交叉验证
    skf = KFold(n_splits=config.k, random_state=233, shuffle=True)

    for flod_idx, (train_indices, val_indices) in enumerate(skf.split(data)):
        train_loader = torch.utils.data.DataLoader(
            trainDataset(data[train_indices],
                         train_transform),
            batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True
        )

        val_loader = torch.utils.data.DataLoader(
            trainDataset(data[val_indices],
                         val_transform),
            batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers, pin_memory=True
        )

        #criterion = FocalLoss(0.5)
        criterion = LSR()
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        model = Model(model_type, config.num_classes, criterion, device=device, prefix=prefix, suffix=str(flod_idx))

        for epoch in range(config.epochs):
            print('Epoch: ', epoch)

            model.fit(train_loader)
            model.validate(val_loader)


if __name__ == '__main__':
    model_type_list = [config.model_name]
    for model_type in model_type_list:
        train(model_type, "resize")

小结

本次主要给出一个图片分类的框架,方便快速的切换模型。
那下回见!!!欢迎大家多多点赞评论呀!!!

Python卷积神经网络图片分类框架详解分析

到此这篇关于Python卷积神经网络图片分类框架详解分析的文章就介绍到这了,更多相关Python 卷积神经网络内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
跟老齐学Python之关于循环的小伎俩
Oct 02 Python
在RedHat系Linux上部署Python的Celery框架的教程
Apr 07 Python
python实现类的静态变量用法实例
May 08 Python
python类和函数中使用静态变量的方法
May 09 Python
Python使用matplotlib绘制动画的方法
May 20 Python
Python3下错误AttributeError: ‘dict’ object has no attribute’iteritems‘的分析与解决
Jul 06 Python
谈谈python中GUI的选择
Mar 01 Python
解决Pycharm无法import自己安装的第三方module问题
May 18 Python
对python字典元素的添加与修改方法详解
Jul 06 Python
python的pip有什么用
Jun 17 Python
Python3爬虫关于代理池的维护详解
Jul 30 Python
python unittest单元测试的步骤分析
Aug 02 Python
Python人工智能之混合高斯模型运动目标检测详解分析
7个关于Python的经典基础案例
Nov 07 #Python
python机器学习创建基于规则聊天机器人过程示例详解
Python中Numpy和Matplotlib的基本使用指南
python模块与C和C++动态库相互调用实现过程示例
Nov 02 #Python
Qt自定义Plot实现曲线绘制的详细过程
Nov 02 #Python
Python 正则模块详情
Nov 02 #Python
You might like
PHPShop存在多个安全漏洞
2006/10/09 PHP
PHP文字转图片功能原理与实现方法分析
2017/08/31 PHP
PHP基于redis计数器类定义与用法示例
2018/02/08 PHP
详细对比php中类继承和接口继承
2018/10/11 PHP
用javascript实现无刷新更新数据的详细步骤 asp
2006/12/26 Javascript
悄悄用脚本检查你访问过哪些网站的代码
2010/12/04 Javascript
js确定对象类型方法
2012/03/30 Javascript
浅谈JS闭包中的循环绑定处理程序
2014/11/09 Javascript
node.js中使用socket.io制作命名空间
2014/12/15 Javascript
Ext JS动态加载JavaScript创建窗体的方法
2016/06/23 Javascript
jquery 动态增加删除行的简单实例(推荐)
2016/10/12 Javascript
简单理解vue中Props属性
2016/10/27 Javascript
Javascript DOM事件操作小结(监听鼠标点击、释放,悬停、离开等)
2017/01/20 Javascript
jQuery实现单击按钮遮罩弹出对话框效果(1)
2017/02/20 Javascript
Bootstrap显示与隐藏简单实现代码
2017/03/06 Javascript
基于jQuery的$.getScript方法去加载javaScript文档解析
2017/11/08 jQuery
Node使用Sequlize连接Mysql报错:Access denied for user ‘xxx’@‘localhost’
2018/01/03 Javascript
React注册倒计时功能的实现
2018/09/06 Javascript
npm配置国内镜像资源+淘宝镜像的方法
2018/09/07 Javascript
js中async函数结合promise的小案例浅析
2019/04/14 Javascript
解决antd datepicker 获取时间默认少8个小时的问题
2020/10/29 Javascript
前端如何实现动画过渡效果
2021/02/05 Javascript
python制作一个桌面便签软件
2015/08/09 Python
Python 学习教程之networkx
2019/04/15 Python
python正则表达式匹配不包含某几个字符的字符串方法
2019/07/23 Python
pytorch实现onehot编码转为普通label标签
2020/01/02 Python
Matplotlib中%matplotlib inline如何使用
2020/07/28 Python
Python中Qslider控件实操详解
2021/02/20 Python
利用HTML5中的Canvas绘制一张笑脸的教程
2015/05/07 HTML / CSS
html5简介及新增功能介绍
2020/05/18 HTML / CSS
英国第一蛋白粉品牌:Myprotein
2016/09/14 全球购物
医学生自我评价
2014/01/27 职场文书
婚假请假条格式及范文
2014/04/10 职场文书
2014年前台接待工作总结
2014/12/05 职场文书
2014年个人总结范文
2015/03/09 职场文书
vue修饰符.capture和.self的区别
2022/04/22 Vue.js