因内容重要,故做此笔记,也仅做笔记。

imports、创建日志、创建全局参数

读取数据、构建Dataset、Dataloader

创建网络模型结构

设置优化器、loss函数、LR及下降变化方式

mixup、训练、打印日志记录训练信息、保存模型权重

验证、统计混淆矩阵、做一些结果可视化

读取模型权重、测试

train.py

import torch
import adabound
from tensorboardX import SummaryWriter
from torchvision.datasets import ImageFolder
from torch import nn, optim
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import DataLoader
from torchvision import transforms
from datetime import datetime
import torch
import numpy as np
import os
import json
import math
import time
from torch.cuda import ampfrom mobilenetv2 import MobileNetV2
from Bnnet import Net, NetBN1Nopool
from resnet import resnet50, resnet101, resnet152
from ResNeSt.resnest.torch import resnest50, resnest101, resnest200, resnest269
from mixup import *
from mydatasets import *
from utils import validate, show_confMattorch.cuda.set_device(0)  # gpu
result_dir = 'Result'  # log 保存日志
now_time = datetime.now()
time_str = datetime.strftime(now_time, '%m-%d_%H-%M-%S')  # 时间
log_dir = os.path.join(result_dir, time_str)
if not os.path.exists(log_dir):os.makedirs(log_dir)
classes_name = ['1', '2', '3', '4']
writer = SummaryWriter(log_dir=log_dir)"""
train_data_path=''
valid_data_path=''
train_set=ImageFolder(root=train_data_path,transform=transforms.Compose([transforms.Resize(64),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=(0.034528155, 0.033598177, 0.009853649), std=(0.15804708, 0.16410254, 0.0643605))]))
valid_set=ImageFolder(root=valid_data_path,transform=transforms.Compose([transforms.Resize(64),transforms.ToTensor(),
transforms.Normalize(mean=(0.034528155, 0.033598177, 0.009853649),std=(0.15804708, 0.16410254, 0.0643605))]))
train_loader=DataLoader(dataset=train_set,batch_size=1024,shuffle=True)
valid_loader=DataLoader(dataset=valid_set,batch_size=1024)
"""np.random.seed(1)
# Loading in the dataset
img_dir = '/home/train/'
valid_img_dir = '/home/test/'
valid_size = 0.0
test_size = 0.0# Dataset
# 提前分好训练集和验证集
classes, class_to_idx = find_classes(img_dir)  # dcit
print(class_to_idx)
samples = make_dataset(img_dir, class_to_idx,extensions=IMG_EXTENSIONS, is_valid_file=None)
valid_samples = make_dataset(valid_img_dir, class_to_idx, extensions=IMG_EXTENSIONS, is_valid_file=None)num_data = len(samples)
indices = list(range(num_data))
np.random.shuffle(indices)valid_split = int(np.floor((valid_size) * num_data))
test_split = int(np.floor((valid_size+test_size) * num_data))
print(valid_split, test_split)
valid_idx, test_idx, train_idx = indices[:valid_split], indices[valid_split:test_split], indices[test_split:]
print(len(valid_idx), len(test_idx), len(train_idx))# end# data = ImageFolder(img_dir,transform=strong_aug(p=1.0))
'''
data = ImageFolder(img_dir,transform=transforms.Compose([transforms.Resize((320,320)), #Bnnet 改成64,mobilenet,resnet改成224transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=(0.034528155, 0.033598177, 0.009853649), std=(0.15804708, 0.16410254, 0.0643605))#transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))]))valid_data = ImageFolder(valid_img_dir,transform=transforms.Compose([transforms.Resize((224,224)), #Bnnet 改成64,mobilenet,resnet改成224transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=(0.034528155, 0.033598177, 0.009853649), std=(0.15804708, 0.16410254, 0.0643605))]))
'''# np.array(samples)[train_idx].tolist()   transform=strong_aug(p=1.0)
train_data = loadImagesandLabels(np.array(samples)[train_idx].tolist(), transform=Resizeimg(p=1.0))
valid_data = loadImagesandLabels(valid_samples, transform=Resizeimg(p=1.0))
''''
# number of subprocesses to use for data loading
# percentage of training set to use as validation
# obtain training indices that will be used for validation
#
num_train = len(data)
indices = list(range(num_train))
np.random.shuffle(indices)
valid_split = int(np.floor((valid_size) * num_train))
test_split = int(np.floor((valid_size+test_size) * num_train))
print(valid_split,test_split)
valid_idx, test_idx, train_idx = indices[:valid_split], indices[valid_split:test_split], indices[test_split:]
print(len(valid_idx), len(test_idx), len(train_idx))
#
# define samplers for obtaining training and validation batches
#
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
test_sampler = SubsetRandomSampler(test_idx)
'''
# prepare data loaders (combine dataset and sampler)
# Bnet batch_size=1024   mobilenet batch_size=32   resnet batch_size=8
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
# Bnet batch_size=1024   mobilenet batch_size=32   resnet batch_size=8
valid_loader = DataLoader(valid_data, batch_size=32, shuffle=True)
# test_loader = DataLoader(data, batch_size=16,sampler=test_sampler)if torch.cuda.is_available() == True:# model=NetBN1Nopool(num_classes=3).cuda()# model=MobileNetV2(n_class=3).to('cuda')# model=resnet152(num_class=3,pretrained=False).to('cuda')model = resnest101(num_classes=4, pretrained=True).cuda()# model = resnest101(num_classes=4,pretrained=True).to('cuda')# print(model)print("cuda:0")
else:# model=NetBN1Nopool(num_classes=3).cuda()# model = MobileNetV2(n_class=9).to("cuda")# model=resnet50(num_class=9).to("cuda")print("cpu")
epochs = 15# you know
fclayer = []
pg0, pg1, pg2 = [], [], []  # optimizer parameter groups
for k, v in model.named_parameters():v.requires_grad = Trueif '.bias' in k:pg2.append(v)  # biaseselif '.weight' in k and '.bn' not in k:pg1.append(v)  # apply weight decay# elif '.fc' in k:#     fclayer.append(v)else:pg0.append(v)  # all elseoptimizer = optim.SGD(pg0, lr=0.01, momentum=0.9, nesterov=True)
# add pg1 with weight_decay , 'weight_decay': 0.1
optimizer.add_param_group({'params': pg1})
optimizer.add_param_group({'params': pg2})  # add pg2 (biases)# optimizer=adabound.AdaBound(model.parameters(),lr=1e-3,final_lr=0.1)
del pg0, pg1, pg2# ,weight_decay=0.03
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, nesterov=True)
# scheduler=optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=10)
criterion = nn.CrossEntropyLoss()
lr = 0.3def lf(x):return ((1 + math.cos(x * math.pi / epochs)) / 2) * (lr - 0.01*lr) + 0.01*lrwarm_up_epochs = 3def warm_up_with_cosine_lr(epoch):return ((epoch+1)*lr) / (warm_up_epochs) if epoch < warm_up_epochs \else 0.5 * (math.cos((epoch - warm_up_epochs) / (epochs - warm_up_epochs) * math.pi) + 1) \* (lr - 0.01*lr) + 0.01*lrscheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)starttime = time.time()  # training
one_batch_mixup = False
for epoch in range(epochs):if epoch > 9:one_batch_mixup = Falsemodel.train()loss_sigma = 0.0correct = 0.0total = 0.0scheduler.step()for i, (data, label, Path) in enumerate(train_loader):if one_batch_mixup:data, labela, labelb, lam = mixup_data(data, label, alpha=0.5)data = data.cuda()label = label.cuda()labela = labela.cuda()labelb = labelb.cuda()lam = torch.FloatTensor([lam])[0].cuda()optimizer.zero_grad()with amp.autocast(enabled=True):outputs = model(data)# loss=criterion(outputs,label)loss = lam * criterion(outputs, labela) + \(1 - lam) * criterion(outputs, labelb)else:  # no mixupdata = data.permute(0, 3, 1, 2).float()data = data.cuda()label = label.cuda()optimizer.zero_grad()with amp.autocast(enabled=True):outputs = model(data)# loss=criterion(outputs,label)loss = criterion(outputs, label)loss.backward()optimizer.step()_, predicted = torch.max(outputs.data, dim=1)total += label.size(0)correct += (predicted == label).squeeze().sum().cpu().numpy()loss_sigma += loss.item()# 每10个iteration 打印一次训练信息,loss为10个iteration的平均if i % 100 == 0:loss_avg = loss_sigma / 10loss_sigma = 0.0print("Training: Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(epoch + 1, epochs, i + 1, len(train_loader), loss_avg, correct / total))# 记录训练losswriter.add_scalars('Loss_group', {'train_loss': loss_avg}, epoch)# 记录learning ratewriter.add_scalar('learning rate', scheduler.get_lr()[0], epoch)# 记录Accuracywriter.add_scalars('Accuracy_group', {'train_acc': correct / total}, epoch)# 每个epoch,记录梯度,权值for name, layer in model.named_parameters():writer.add_histogram(name + '_grad', layer.grad.cpu().data.numpy(), epoch)writer.add_histogram(name + '_data', layer.cpu().data.numpy(), epoch)# -------------------- 观察模型在验证集上的表现 --------------------if epoch % 1 == 0:oneepoch = time.time()-starttimeprint('one epoch time is:   ', oneepoch/(epoch+1))torch.save(model.state_dict(), 'myresnest101a_train.pkl' + str(epoch))loss_sigma = 0.0cls_num = len(classes_name)conf_mat = np.zeros([cls_num, cls_num])  # 混淆矩阵model.eval()with torch.no_grad():for i, (data, label, Path) in enumerate(valid_loader):# forwarddata = data.to("cuda")label = label.to("cuda")outputs = model(data)outputs.detach_()# 计算lossloss = criterion(outputs, label)loss_sigma += loss.item()# 统计_, predicted = torch.max(outputs.data, 1)# labels = labels.data    # Variable --> tensor# 统计混淆矩阵for j in range(len(label)):cate_i = label[j].cpu().numpy()pre_i = predicted[j].cpu().numpy()conf_mat[cate_i, pre_i] += 1.0  # 统计print('{} {} set Accuracy:{:.2%}'.format(epoch, 'Valid', conf_mat.trace() / conf_mat.sum()))# 记录Loss, accuracywriter.add_scalars('Loss_group', {'valid_loss': loss_sigma / len(valid_loader)}, epoch)writer.add_scalars('Accuracy_group', {'valid_acc': conf_mat.trace() / conf_mat.sum()}, epoch)
print('Finished Training')# torch.save(model, 'net.pkl' + str(epochs))  # 保存整个神经网络的结构和模型参数
# torch.save(model.state_dict(), 'net_params_BN1_Nopool.pkl' + str(epoch))  # 只保存神经网络的模型参数conf_mat_train, train_acc = validate(model.cpu(), train_loader, 'train', classes_name)
conf_mat_valid, valid_acc = validate(model.cpu(), test_loader, 'test', classes_name)
show_confMat(conf_mat_train, classes_name, 'train', log_dir)
show_confMat(conf_mat_valid, classes_name, 'test', log_dir)

test.py

import torch
import cv2
import numpy as np
import os
import shutilfrom ResNeSt.resnest.torch import resnest50, resnest101, resnest200, resnest269
from mobilenetv2 import MobileNetV2
from Bnnet import Net, NetBN1Nopool
from mydatasets import *if __name__ == "__main__":model = resnest50(num_classes=20).cuda()# model=MobileNetV2(n_class=20)model.load_state_dict(torch.load('/home/data/weight/log/best.pkl'))model.eval()# print(torch.rand(1).long())# im=torch.rand(2,3,128,128)path = '/home/runs/detect/exp1/'# imlist=os.listdir('/home/images/')imlist = os.listdir(path)for i in imlist:im = cv2.imread(path + i)# 'alansar/alansar000390.jpg'# '/alhind/alhind000550.jpg'im = cv2.resize(im, (128, 128))im = im[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416im = np.ascontiguousarray(im, dtype=np.float32)  # uint8 to float32im /= 255.0  # 0 - 255 to 0.0 - 1.0#label = torch.full((18,), 1, dtype=torch.long)output = model(torch.tensor([im]).cuda())score = torch.sigmoid(output.data)# print(score)a, b = torch.max(score.data, dim=1)# shutil.copy(path+i,path.replace('exp4','ceshi')+str(str(b.item())+'_'+str((a.item()-0.99)*10)[:15]+'.jpg'))if a.item() > 0.0:print(i)print(a, b)# lastput=[]# for j in range(12):#     im=cv2.imread(path+i)#     #'alansar/alansar000390.jpg'#     #'/alhind/alhind000550.jpg'#     #im = cv2.resize(im, (128, 128))#     im = strong_aug(p=1.0,width=128)(image=im)['image']#     im = im[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416#     im = np.ascontiguousarray(im, dtype=np.float32)  # uint8 to float32#     im /= 255.0  # 0 - 255 to 0.0 - 1.0#     #label = torch.full((18,), 1, dtype=torch.long)#     output=model(torch.tensor([im]))#     lastput+=[torch.softmax(output,1).detach().cpu().numpy()]#     # score=torch.sigmoid(output.data)#     # #print(score)#     # a,b=torch.max(score.data,dim=1)#     # shutil.copy(path+i,path.replace('exp4','ceshi')+str(str(b.item())+'_'+str((a.item()-0.99)*10)[:15]+'.jpg'))#     # print(a,b)# image_preds_all=np.concatenate(lastput,axis=0)# valid_preds=np.mean(image_preds_all/12,axis=0)# print(np.argmax(valid_preds,axis=0))# tta# for i in imlist:#     print(i)#     lastput=[]#     for j in range(12):#         im=cv2.imread(path+i)#         #'alansar/alansar000390.jpg'#         #'/alhind/alhind000550.jpg'#         #im = cv2.resize(im, (128, 128))#         im = strong_aug(p=1.0,width=128)(image=im)['image']#         im = im[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416#         im = np.ascontiguousarray(im, dtype=np.float32)  # uint8 to float32#         im /= 255.0  # 0 - 255 to 0.0 - 1.0#         #label = torch.full((18,), 1, dtype=torch.long)#         output=model(torch.tensor([im]))#         lastput+=[torch.softmax(output,1).detach().cpu().numpy()]#         # score=torch.sigmoid(output.data)#         # #print(score)#         # a,b=torch.max(score.data,dim=1)#         # shutil.copy(path+i,path.replace('exp4','ceshi')+str(str(b.item())+'_'+str((a.item()-0.99)*10)[:15]+'.jpg'))#         # print(a,b)#     image_preds_all=np.concatenate(lastput,axis=0)#     valid_preds=np.mean(image_preds_all/12,axis=0)#     print(np.argmax(valid_preds,axis=0))

utils.py

# coding: utf-8
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
import torch
from torch.autograd import Variable
import os
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool1 = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.pool2 = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(16*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool1(F.relu(self.conv1(x)))x = self.pool2(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# 定义权值初始化def initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):torch.nn.init.xavier_normal_(m.weight.data)if m.bias is not None:m.bias.data.zero_()elif isinstance(m, nn.BatchNorm2d):m.weight.data.fill_(1)m.bias.data.zero_()elif isinstance(m, nn.Linear):torch.nn.init.normal_(m.weight.data, 0, 0.01)m.bias.data.zero_()class MyDataset(Dataset):def __init__(self, txt_path, transform=None, target_transform=None):fh = open(txt_path, 'r')imgs = []for line in fh:line = line.rstrip()words = line.split()imgs.append((words[0], int(words[1])))self.imgs = imgs        # 最主要就是要生成这个list, 然后DataLoader中给index,通过getitem读取图片数据self.transform = transformself.target_transform = target_transformdef __getitem__(self, index):fn, label = self.imgs[index]# 像素值 0~255,在transfrom.totensor会除以255,使像素值变成 0~1img = Image.open(fn).convert('RGB')if self.transform is not None:img = self.transform(img)   # 在这里做transform,转为tensor等等return img, labeldef __len__(self):return len(self.imgs)def validate(net, data_loader, set_name, classes_name):"""对一批数据进行预测,返回混淆矩阵以及Accuracy:param net::param data_loader::param set_name:  eg: 'valid' 'train' 'tesst:param classes_name::return:"""net.eval()cls_num = len(classes_name)conf_mat = np.zeros([cls_num, cls_num])  # 混淆矩阵for data in data_loader:images, labels, paths = dataimages = images.to('cuda')labels = labels.to('cuda')images = Variable(images)labels = Variable(labels)outputs = net(images)outputs.detach_()_, predicted = torch.max(outputs.data, 1)# 统计混淆矩阵for i in range(len(labels)):cate_i = labels[i].cpu().numpy()pre_i = predicted[i].cpu().numpy()conf_mat[cate_i, pre_i] += 1.0for i in range(cls_num):print('class:{:<10}, total num:{:<6}, correct num:{:<5}  Recall: {:.2%} Precision: {:.2%}'.format(classes_name[i], np.sum(conf_mat[i, :]), conf_mat[i, i], conf_mat[i, i] / (1 + np.sum(conf_mat[i, :])),conf_mat[i, i] / (1 + np.sum(conf_mat[:, i]))))print('{} set Accuracy:{:.2%}'.format(set_name, np.trace(conf_mat) / np.sum(conf_mat)))return conf_mat, '{:.2}'.format(np.trace(conf_mat) / np.sum(conf_mat))def show_confMat(confusion_mat, classes, set_name, out_dir):# 归一化confusion_mat_N = confusion_mat.copy()for i in range(len(classes)):confusion_mat_N[i, :] = confusion_mat[i, :] / confusion_mat[i, :].sum()# 获取颜色# 更多颜色: http://matplotlib.org/examples/color/colormaps_reference.htmlcmap = plt.cm.get_cmap('Greys')plt.imshow(confusion_mat_N, cmap=cmap)plt.colorbar()# 设置文字xlocations = np.array(range(len(classes)))plt.xticks(xlocations, list(classes), rotation=60)plt.yticks(xlocations, list(classes))plt.xlabel('Predict label')plt.ylabel('True label')plt.title('Confusion_Matrix_' + set_name)# 打印数字for i in range(confusion_mat_N.shape[0]):for j in range(confusion_mat_N.shape[1]):plt.text(x=j, y=i, s=int(confusion_mat[i, j]), va='center', ha='center', color='red', fontsize=10)# 保存plt.savefig(os.path.join(out_dir, 'Confusion_Matrix' + set_name + '.png'))plt.close()def normalize_invert(tensor, mean, std):for t, m, s in zip(tensor, mean, std):t.mul_(s).add_(m)return tensor

mixup.py

import numpy as np
import torchdef mixup_data(x, y, alpha=1.0):'''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda'''if alpha > 0.:lam = np.random.beta(alpha, alpha)else:lam = 1.if lam > 1:lam = 1.batch_size = x.size()[0]index = torch.randperm(batch_size)mixed_x = lam * x + (1 - lam) * x[index, :]  # 自己和打乱的自己进行叠加y_a, y_b = y, y[index]return mixed_x, y_a, y_b, lamdef mixup_criterion(y_a, y_b, lam):return lambda criterion, pred: lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

mydatasets.py

from albumentations.pytorch import ToTensorV2
from albumentations import (HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, RandomResizedCrop,IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose, Normalize, Cutout, CoarseDropout, ShiftScaleRotate, CenterCrop, Resize
)
import matplotlib.pyplot as plt
import glob
import math
import os
import random
import shutil
import time
from pathlib import Path
from threading import Thread
from torchvision import transforms
import cv2
import numpy as np
import torch
from PIL import Image, ExifTags
from torch.utils.data import Dataset
from tqdm import tqdm
from torch.utils.data import DataLoader
from albumentations import (Blur, Flip, ShiftScaleRotate, GridDistortion, ElasticTransform, HorizontalFlip, CenterCrop,HueSaturationValue, Transpose, RandomBrightnessContrast, CLAHE, RandomCrop, Cutout, CoarseDropout,CoarseDropout, Normalize, ToFloat, OneOf, Compose, Resize, RandomRain, RandomFog, Lambda, ChannelDropout, ISONoise, VerticalFlip, RandomGamma, RandomRotate90, OpticalDistortion,MotionBlur, MedianBlur, GaussianBlur, GaussNoise, RGBShift, RandomBrightness, RandomContrast, RandomBrightnessContrast,CLAHE, InvertImg, ChannelShuffle, ToGray, RandomGridShuffle)
import albumentations as A
img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.tiff', '.dng']
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp','.pgm', '.tif', '.tiff', '.webp')def pil_loader(path):# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)with open(path, 'rb') as f:img = Image.open(f)return np.array(img.convert('RGB'))def accimage_loader(path):import accimagetry:return accimage.Image(path)except IOError:# Potentially a decoding problem, fall back to PIL.Imagereturn pil_loader(path)def default_loader(path):from torchvision import get_image_backendif get_image_backend() == 'accimage':return accimage_loader(path)else:return pil_loader(path)def has_file_allowed_extension(filename, extensions):"""Checks if a file is an allowed extension.Args:filename (string): path to a fileextensions (tuple of strings): extensions to consider (lowercase)Returns:bool: True if the filename ends with one of given extensions"""return filename.lower().endswith(extensions)def make_dataset(directory, class_to_idx, extensions=None, is_valid_file=None):instances = []directory = os.path.expanduser(directory)both_none = extensions is None and is_valid_file is Noneboth_something = extensions is not None and is_valid_file is not Noneif both_none or both_something:raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")if extensions is not None:def is_valid_file(x):return has_file_allowed_extension(x, extensions)for target_class in sorted(class_to_idx.keys()):class_index = class_to_idx[target_class]target_dir = os.path.join(directory, target_class)if not os.path.isdir(target_dir):continuefor root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):for fname in sorted(fnames):path = os.path.join(root, fname)if is_valid_file(path):item = path, class_indexinstances.append(item)return instancesdef find_classes(dir):"""Finds the class folders in a dataset.Args:dir (string): Root directory path.Returns:tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.Ensures:No class is a subdirectory of another."""classes = [d.name for d in os.scandir(dir) if d.is_dir()]# change this to match real class namesclasses.sort(key=lambda x: int(x))class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}return classes, class_to_idxclass loadImagesandLabels(Dataset):def __init__(self, path, batch_size=16, loader=default_loader,transform=None, mixup=None, valid=False):# path =   ../../../../train/if isinstance(path, torch._six.string_classes):self.root = pathself.classes, self.class_to_idx = self._find_classes(self.root)samples = make_dataset(self.root, class_to_idx, extensions=IMG_EXTENSIONS, is_valid_file=None)self.samples = sampleselif isinstance(path, list):self.samples = path# print(self.samples)if len(self.samples) == 0:msg = "Found 0 files in subfolders of: {}\n".format(self.root)if extensions is not None:msg += "Supported extensions are: {}".format(",".join(extensions))raise RuntimeError(msg)self.transform = transformself.loader = loaderself.mixup = mixupdef get_image_labels(self):self.image_labels = []for i in self.samples:self.image_labels.append(int(i[-1]))return self.image_labelsdef get_class_count(self):self.classes = len(set(self.image_labels))self.class_count = dict(zip(range(self.classes), (self.image_labels.count(str(i)) for i in range(self.classes))))self.class_count = [self.image_labels.count(i) for i in range(self.classes)]return self.class_countdef _find_classes(self, dir):"""Finds the class folders in a dataset.Args:dir (string): Root directory path.Returns:tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.Ensures:No class is a subdirectory of another."""classes = [d.name for d in os.scandir(dir) if d.is_dir()]classes.sort()class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}return classes, class_to_idxdef __len__(self):return len(self.samples)# def __getitem__(self, index):#     path, target = self.samples[index]#     sample = self.loader(path)#     #print(type(sample))#     if self.transform is not None:#         sample = self.transform(image=sample)['image']#         #print(type(sample),len(sample),type(sample[-1]))#         #print(sample,type(sample))#         sample = np.ascontiguousarray(sample, dtype=np.float32)#         sample = sample/255.0#         sample = transforms.ToTensor()(sample)#         # sample = transforms.Normalize(mean=(0.034528155, 0.033598177, 0.009853649), std=(0.15804708, 0.16410254, 0.0643605))(sample)#     return sample,int(target),pathdef __getitem__(self, index):path, target = self.samples[index]sample = self.loader(path)# print(type(sample))if self.transform is not None:sample = self.transform(image=sample)['image']# print(type(sample),len(sample),type(sample[-1]))# print(sample,type(sample))#sample =transforms.Normalize(mean=(0.034528155, 0.033598177, 0.009853649), std=(0.15804708, 0.16410254, 0.0643605))(sample)return sample, int(target), pathclass loadImagesandLabels2(Dataset):def __init__(self, path, batch_size=16, loader=default_loader,transform=None, mixup=None, valid=False):# path =   ../../../../train/if isinstance(path, torch._six.string_classes):self.root = pathself.classes, self.class_to_idx = self._find_classes(self.root)samples = make_dataset(self.root, class_to_idx, extensions=IMG_EXTENSIONS, is_valid_file=None)self.samples = sampleselif isinstance(path, list):self.samples = path# print(self.samples)if len(self.samples) == 0:msg = "Found 0 files in subfolders of: {}\n".format(self.root)if extensions is not None:msg += "Supported extensions are: {}".format(",".join(extensions))raise RuntimeError(msg)self.transform = transformself.loader = loaderself.mixup = mixupdef get_image_labels(self):self.image_labels = []for i in self.samples:self.image_labels.append(int(i[-1]))return self.image_labelsdef get_class_count(self):self.classes = len(set(self.image_labels))self.class_count = dict(zip(range(self.classes), (self.image_labels.count(str(i)) for i in range(self.classes))))self.class_count = [self.image_labels.count(i) for i in range(self.classes)]return self.class_countdef _find_classes(self, dir):"""Finds the class folders in a dataset.Args:dir (string): Root directory path.Returns:tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.Ensures:No class is a subdirectory of another."""classes = [d.name for d in os.scandir(dir) if d.is_dir()]classes.sort()class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}return classes, class_to_idxdef __len__(self):return len(self.samples)def __getitem__(self, index):path, target = self.samples[index]sample = self.loader(path)# print(type(sample))if self.transform is not None:sample = self.transform(image=sample)['image']# print(type(sample),len(sample),type(sample[-1]))# print(sample,type(sample))sample = np.ascontiguousarray(sample, dtype=np.float32)sample = sample/255.0sample = transforms.ToTensor()(sample)#sample =transforms.Normalize(mean=(0.034528155, 0.033598177, 0.009853649), std=(0.15804708, 0.16410254, 0.0643605))(sample)return sample, int(target), path# def __getitem__(self, index):#     path, target = self.samples[index]#     sample = self.loader(path)#     #print(type(sample))#     if self.transform is not None:#         sample = self.transform(image=sample)['image']#         #print(type(sample),len(sample),type(sample[-1]))#         #print(sample,type(sample))#         #sample =transforms.Normalize(mean=(0.034528155, 0.033598177, 0.009853649), std=(0.15804708, 0.16410254, 0.0643605))(sample)#     return sample,int(target),path'''#@staticmethoddef collate_fn(self,batch):if not self.mixup:return batchimg, target, path = zip(*batch)  # transposedprint(type(img[0]))alpha=beta=0.5lam = np.random.beta(alpha, beta)batch_size = len(img)index = torch.randperm(batch_size)print(batch_size,index)mixed_img = lam * img + (1 - lam) * img[index,:] # 自己和打乱的自己进行叠加target, target_random = target, target[index]return mixed_x, target, target_random, lam,path'''def strong_aug(p=1.0, width=320):return Compose([# VerticalFlip(p=0.5),#RandomResizedCrop(width, width,scale=(0.95, 1.0)),Resize(height=width, width=width, p=1),# HorizontalFlip(p=0.5),OneOf([Blur(blur_limit=7, p=1.0,),RandomGamma(gamma_limit=(80, 120), p=1.0),  # 随机gamaOpticalDistortion(p=1.0),  # 光学畸变MotionBlur(p=1.0),  # 运动模糊MedianBlur(p=1.0),  # 中心模糊GaussianBlur(p=1.0),  # 高斯模糊GaussNoise(p=1.0),  # 高斯噪声ISONoise(p=1.0),  # 施加摄像头传感器噪音], p=0.2),OneOf([GridDistortion(p=0.5),  # 网格失真ElasticTransform(p=0.5),  # 弹性变换], p=0.2),OneOf([HueSaturationValue(p=1.0),  # 色调饱和度RGBShift(p=1.0),  # RGBRandomBrightness(p=1.0),  # 随机亮度RandomContrast(p=1.0),  # 随机对比度RandomBrightnessContrast(p=1.0),  # 随机更改输入图像的亮度和对比度CLAHE(p=1.0),  # 将对比度受限的自适应直方图均衡化应用于输入图像InvertImg(p=1.0),  # 通过用255减去像素值来反转输入图像。ChannelShuffle(p=1.0),  # 随机改变RGB三个通道的顺序], p=0.2),ToGray(p=0.01),RandomGridShuffle(p=0.01),  # 随机网格洗牌# PadIfNeeded(p = 1.0), #填充OneOf([Cutout(max_h_size=30, max_w_size=30, p=0.5),  # 在图像中生成正方形区域。CoarseDropout(max_height=20, max_width=40, p=0.5),  # 在图像上生成矩形区域。], p=0.2),# Solarize(p = 1.0), #反转高于阈值的所有像素值# RandomRotate90(),# Flip(),# Transpose(),], p=p)def Resizeimg(p=1.0, width=320):return Compose([Resize(height=width, width=width, p=1)], p=p)def Resizeimg128(p=1.0):return Compose([Resize(height=128, width=128, p=1)], p=p)def imread(image):image = cv2.imread(image)image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)image = image.astype(np.uint8)return np.array(image)def show(image):plt.imshow(image)plt.axis('off')plt.show()mushusize = 512def get_train_transforms(mushusize=mushusize):return Compose([RandomResizedCrop(mushusize, mushusize),Transpose(p=0.5),HorizontalFlip(p=0.5),VerticalFlip(p=0.5),ShiftScaleRotate(p=0.5),HueSaturationValue(hue_shift_limit=0.2,sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),RandomBrightnessContrast(brightness_limit=(-0.1, 0.1), contrast_limit=(-0.1, 0.1), p=0.5),Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),CoarseDropout(p=0.5),Cutout(p=0.5),ToTensorV2(p=1.0),], p=1.)def get_valid_transforms(mushusize=mushusize):return Compose([Resize(mushusize, mushusize),CenterCrop(mushusize, mushusize, p=1.),Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),ToTensorV2(p=1.0),], p=1.)def get_inference_transforms(mushusize=mushusize):return Compose([RandomResizedCrop(mushusize, mushusize),Transpose(p=0.5),HorizontalFlip(p=0.5),VerticalFlip(p=0.5),HueSaturationValue(hue_shift_limit=0.2,sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),RandomBrightnessContrast(brightness_limit=(-0.1, 0.1), contrast_limit=(-0.1, 0.1), p=0.5),Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),ToTensorV2(p=1.0),], p=1.)if __name__ == "__main__":'''imgdir='/home/all_train/train'dataset=loadImagesandLabels(imgdir,transform=transforms.Compose([transforms.Resize((224,224)), #Bnnet 改成64,mobilenet,resnet改成224transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=(0.034528155, 0.033598177, 0.009853649), std=(0.15804708, 0.16410254, 0.0643605))]))train_loader = DataLoader(dataset,batch_size=8,num_workers=1,collate_fn=dataset.collate_fn)#collate_fn=dataset.collate_fn#,label_random,lam,for i,(img,label,label_random,lam,path) in enumerate(train_loader):print(i,img,label,path)break'''b = '/home/4713766.jpg'a = imread(b)print(type(a))image2 = Resize(320, 320, p=1)(image=a)['image']print(type(image2))#image2 =  strong_aug(p=1.0)(image=a)['image']# show(a)# show(image2)

mobilenetv2.py

Bnnet.py

resnet.py

resnest.py

从零构建一个图像分类项目 -- 代码相关推荐

  1. 从零构建一个简单的 Python 框架

    为什么你想要自己构建一个 web 框架呢?我想,原因有以下几点: 你有一个新奇的想法,觉得将会取代其他的框架 你想要获得一些名气 你遇到的问题很独特,以至于现有的框架不太合适 你对 web 框架是如何 ...

  2. python框架实例,从零构建一个简单的 Python 框架

    为什么你想要自己构建一个 web 框架呢?我想,原因有以下几点: 你有一个新奇的想法,觉得将会取代其他的框架 你想要获得一些名气 你遇到的问题很独特,以至于现有的框架不太合适 你对 web 框架是如何 ...

  3. 小团队如何从零构建一个自动化运维体系

    自动化运维是指将日常运维中大量的重复性工作,小到简单的日常检查.配置变更和软件安装,大到整个变更流程的组织调度,由过去的手工执行转为自动化操作,从而减少乃至消除运维中的延迟,实现"零延时&q ...

  4. 从零构建一个微信小程序

    微信小程序 一.微信小程序介绍 1.1 为什么是微信小程序? 1.2 还有其他的小程序 不容忽视 1.3 优秀的第三方小程序 二.环境准备 2.1 注册账号 2.2 获取 APPID 2.3 开发工具 ...

  5. 如何构建一个理想UI代码表达的自动化工具?

    作者:闲鱼技术-吉丰 基于设计师产出的 Sketch,甚至是一张 PNG,就能自动生成高可维护可扩展的 UI 代码,质量堪比一位资深前端工程师, 一定是一件让整个大前端领域都为之尖叫的事情. 出于这样 ...

  6. 从零构建一个riscv64 ubuntu-20发行版系统

    文章目录 1.搭建开发环境 1.1 开发涉及的环境/工具: 1.2 安装qemu虚拟化工具 1.2.1 apt安装 1.2.2 手动交叉编译安装 2. 获取riscv架构ubuntu根文件系统 3.V ...

  7. 图神经网络07-从零构建一个电影推荐系统

    欢迎大家"Fork",点击右上角的 " Fork ",可直接运行并查看代码效果 关注我的专栏

  8. 使用 Fastai 构建食物图像分类器

    背景 社交媒体平台是分享有趣的图像的常用方式.食物图像,尤其是与不同的美食和文化相关的图像,是一个似乎经常流行的话题.Instagram 等社交媒体平台拥有大量属于不同类别的图像.我们都可能使用谷歌图 ...

  9. 如何用 Slack 和 Kubernetes 构建一个聊天机器人?| 附代码

    作者 | Alexander Kainz 译者 | 天道酬勤,责编 | Carol 出品 | AI科技大本营(ID:rgznai100) ChatOps可以让你使用基于聊天的接口来管理DevOps任务 ...

  10. easyui treegrid 获取新添加行inserted_18行JavaScript代码构建一个倒数计时器

    有时候,你会需要构建一个JavaScript倒计时时钟.你可能会有一个活动.一个销售.一个促销或一个游戏.你可以用原生的JavaScript构建一个时钟,而不是去找一个插件.尽管有很多很棒的时钟插件, ...

最新文章

  1. error C2061: 语法错误: 标识符“std”
  2. 这个火热的社区都升级到2.0了,你还不知道它?
  3. Hibernate4 : 持久化你的第一个类
  4. 获取DataGridView上选中的一行并转换为一个DataRow类型
  5. 此页的状态信息无效,可能已损坏。”的解决办法
  6. java中的triple_无法在使用Java加密的.NET中使用TripleDES进行解密
  7. ic启动器怎么导入模组_icmod模组管理器最新版下载-ICMOD管理器(Inner Core)下载v1.5-阵 安卓版-西西软件下载...
  8. 2014小米校园招聘笔试(10.13北京)
  9. 【P2P网络】BitTorrent协议中文版4
  10. python实用脚本(三)—— 通过有道智云API实现翻译
  11. 华为“最青睐”的10所大学,有你的大学么?
  12. 如何查看自己阿里云服务器的ip地址
  13. 美国软件公司Salesforce获得区块链反垃圾邮件解决方案专利
  14. P5055 【模板】可持久化文艺平衡树 可持久化fhqtreap
  15. 股豆网:欧洲5G套餐轮番公布 流量无限用
  16. WMS、WFS、WMTS、TMS
  17. [Delaunay Triangle] [图形学] Delaunay Triangles最易懂的实现方案
  18. 牛客面试系列之Redis
  19. 引用android-support-v7-appcompat库文件出错的问题
  20. 【渝粤题库】陕西师范大学201291 商法学 作业(高起专)

热门文章

  1. 高净值人群依靠什么可以挣那么多钱?
  2. JavaSE--RMI初识
  3. IndexOf() LastIndexOf() Contains() StartsWith() EndsWith()方法比较
  4. UIApplication, UIApplicationDelegate,UIApplicationMain的分析
  5. 在 UIWebView 中如何准确获得页面加载完成的事件
  6. windows7 配置iis技巧
  7. 89C52定时/计数器
  8. 【Retinex】【Frankle-McCann Retinex】matlab代码注释
  9. C++内置数组和array的比较
  10. linux根据端口号查询项目路径