导师的课题需要用到图片分类;入门萌新啥也不会,只需要实现这个功能,给出初步效果,不需要花太多时间了解内部逻辑。经过一周的摸索,建好环境、pytorch,终于找到整套的代码和数据集,实现了一个小小的分类。记录一下使用方法,避免后续使用时遗忘。感谢各位大佬的开源代码和注释!

找到一个大佬的视频讲解和代码开源:

github:

https://github.com/WZMIAOMIAO/deep-learning-for-image-processing/tree/master/data_set

bilbil:

6.2 使用pytorch搭建ResNet并基于迁移学习训练_哔哩哔哩_bilibili

参考的拆分讲解:

pytorch图像分类篇:6. ResNet网络结构详解与迁移学习简介_fun1024-CSDN博客_resnet网络结构

一、数据处理

项目文件夹为Project2,使用的是五种花朵的数据集,首先有spilt_data的代码将已经分好文件夹的数据集分类成测试集和训练集。Project2下建data_set文件夹,data_set下文件目录为:

1.------spilt_data.py #用于分类训练集和测试集的代码

2-------flower_data #已经分好文件夹的数据

2.1--------------flower_photos

2.1.1-------------------daisy

2.1.2-------------------dandelion

2.1.3-------------------roses

2.1.4-------------------sunflowers

2.1.5-------------------tulips

2.1.6-------------------LICENSE.txt

Spilt_data的代码如下:

import os
from shutil import copy, rmtree
import random
#flower_data文件夹必须和spilt_data程序在同一级中def mk_file(file_path: str):if os.path.exists(file_path):# 如果文件夹存在,则先删除原文件夹在重新创建rmtree(file_path)os.makedirs(file_path)def main():# 保证随机可复现random.seed(0)# 将数据集中10%的数据划分到验证集中split_rate = 0.1# 指向你解压后的flower_photos文件夹cwd = os.getcwd()data_root = os.path.join(cwd, "position_data")origin_flower_path = os.path.join(data_root, "position_photos")assert os.path.exists(origin_flower_path), "path '{}' does not exist.".format(origin_flower_path)flower_class = [cla for cla in os.listdir(origin_flower_path)if os.path.isdir(os.path.join(origin_flower_path, cla))]# 建立保存训练集的文件夹train_root = os.path.join(data_root, "train")mk_file(train_root)for cla in flower_class:# 建立每个类别对应的文件夹mk_file(os.path.join(train_root, cla))# 建立保存验证集的文件夹val_root = os.path.join(data_root, "val")mk_file(val_root)for cla in flower_class:# 建立每个类别对应的文件夹mk_file(os.path.join(val_root, cla))for cla in flower_class:cla_path = os.path.join(origin_flower_path, cla)images = os.listdir(cla_path)num = len(images)# 随机采样验证集的索引eval_index = random.sample(images, k=int(num*split_rate))for index, image in enumerate(images):if image in eval_index:# 将分配至验证集中的文件复制到相应目录image_path = os.path.join(cla_path, image)new_path = os.path.join(val_root, cla)copy(image_path, new_path)else:# 将分配至训练集中的文件复制到相应目录image_path = os.path.join(cla_path, image)new_path = os.path.join(train_root, cla)copy(image_path, new_path)print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="")  # processing barprint()print("processing done!")if __name__ == '__main__':main()

运行完成后文件夹目录为:

Project2

---data_set  #已经分好文件夹的数据集

1.------spilt_data.py #用于分类训练集和测试集的代码

2.-------flower_data

2.1--------------flower_photos

2.1.1-------------------daisy

2.1.2-------------------dandelion

2.1.3-------------------roses

2.1.4-------------------sunflowers

2.1.5-------------------tulips

2.1.6-------------------LICENSE.txt

2.2--------------train

2.2.1-------------------daisy

2.2.2-------------------dandelion

2.2.3-------------------roses

2.2.4-------------------sunflowers

2.2.5-------------------tulips

2.3--------------val

2.3.1-------------------daisy

2.3.2-------------------dandelion

2.3.3-------------------roses

2.3.4-------------------sunflowers

2.3.5-------------------tulips

至此,完成数据集分类。

二、Model

由于只是使用,没有对其中过多了解,总之感谢开源的各位大佬!

import torch.nn as nn
import torch# ResNet18/34的残差结构,用的是2个3x3的卷积
class BasicBlock(nn.Module):expansion = 1  # 残差结构中,主分支的卷积核个数是否发生变化,不变则为1def __init__(self, in_channel, out_channel, stride=1, downsample=None):  # downsample对应虚线残差结构super(BasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channel)self.relu = nn.ReLU()self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channel)self.downsample = downsampledef forward(self, x):identity = xif self.downsample is not None:  # 虚线残差结构,需要下采样identity = self.downsample(x)  # 捷径分支 short cutout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out += identityout = self.relu(out)return out# ResNet50/101/152的残差结构,用的是1x1+3x3+1x1的卷积
class Bottleneck(nn.Module):expansion = 4  # 残差结构中第三层卷积核个数是第一/二层卷积核个数的4倍def __init__(self, in_channel, out_channel, stride=1, downsample=None):super(Bottleneck, self).__init__()self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,kernel_size=1, stride=1, bias=False)  # squeeze channelsself.bn1 = nn.BatchNorm2d(out_channel)# -----------------------------------------self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,kernel_size=3, stride=stride, bias=False, padding=1)self.bn2 = nn.BatchNorm2d(out_channel)# -----------------------------------------self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel * self.expansion,kernel_size=1, stride=1, bias=False)  # unsqueeze channelsself.bn3 = nn.BatchNorm2d(out_channel * self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampledef forward(self, x):identity = xif self.downsample is not None:identity = self.downsample(x)  # 捷径分支 short cutout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)out += identityout = self.relu(out)return outclass ResNet(nn.Module):# block = BasicBlock or Bottleneck# block_num为残差结构中conv2_x~conv5_x中残差块个数,是一个列表def __init__(self, block, blocks_num, num_classes=1000, include_top=True):super(ResNet, self).__init__()self.include_top = include_topself.in_channel = 64self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,padding=3, bias=False)self.bn1 = nn.BatchNorm2d(self.in_channel)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = self._make_layer(block, 64, blocks_num[0])             # conv2_xself.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)  # conv3_xself.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)  # conv4_xself.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)  # conv5_xif self.include_top:self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)self.fc = nn.Linear(512 * block.expansion, num_classes)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')# channel为残差结构中第一层卷积核个数def _make_layer(self, block, channel, block_num, stride=1):downsample = None# ResNet50/101/152的残差结构,block.expansion=4if stride != 1 or self.in_channel != channel * block.expansion:downsample = nn.Sequential(nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(channel * block.expansion))layers = []layers.append(block(self.in_channel, channel, downsample=downsample, stride=stride))self.in_channel = channel * block.expansionfor _ in range(1, block_num):layers.append(block(self.in_channel, channel))return nn.Sequential(*layers)def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)if self.include_top:x = self.avgpool(x)x = torch.flatten(x, 1)x = self.fc(x)return xdef resnet34(num_classes=1000, include_top=True):return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)def resnet101(num_classes=1000, include_top=True):return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)

三、Train

import torch
import torch.nn as nn
from torchvision import transforms, datasets
import json
import matplotlib.pyplot as plt
import os
import torch.optim as optim
from model import resnet34, resnet101device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
image_path = data_root + "/pycharmProject/Project2/data_set/flower_data/"  # data set path 这里需要改,在这个路径里找traintrain_dataset = datasets.ImageFolder(root=image_path+"train",transform=data_transform["train"])
train_num = len(train_dataset)# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 16
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=0)validate_dataset = datasets.ImageFolder(root=image_path + "val",transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=False,num_workers=0)net = resnet34()
# load pretrain weights
model_weight_path = "./resnet34-pre.pth"
missing_keys, unexpected_keys = net.load_state_dict(torch.load(model_weight_path), strict=False)
# for param in net.parameters():
#     param.requires_grad = False
# change fc layer structure
inchannel = net.fc.in_features
net.fc = nn.Linear(inchannel, 5)
net.to(device)loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0001)best_acc = 0.0
save_path = './resNet34.pth'
for epoch in range(3):# trainnet.train()running_loss = 0.0for step, data in enumerate(train_loader, start=0):images, labels = dataoptimizer.zero_grad()logits = net(images.to(device))loss = loss_function(logits, labels.to(device))loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()# print train processrate = (step+1)/len(train_loader)a = "*" * int(rate * 50)b = "." * int((1 - rate) * 50)print("\rtrain loss: {:^3.0f}%[{}->{}]{:.4f}".format(int(rate*100), a, b, loss), end="")print()# validatenet.eval()acc = 0.0  # accumulate accurate number / epochwith torch.no_grad():for val_data in validate_loader:val_images, val_labels = val_dataoutputs = net(val_images.to(device))  # eval model only have last output layer# loss = loss_function(outputs, test_labels)predict_y = torch.max(outputs, dim=1)[1]acc += (predict_y == val_labels.to(device)).sum().item()val_accurate = acc / val_numif val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %(epoch + 1, running_loss / step, val_accurate))print('Finished Training')

四、predict and batch_predict

单张图片Predict:

import torch
from model import resnet34
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import jsondata_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# load image
img = Image.open("G:\\pycharmProject/Project2/tulip.jpg")
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)# read class_indict
try:json_file = open('./class_indices.json', 'r')class_indict = json.load(json_file)
except Exception as e:print(e)exit(-1)# create model
model = resnet34(num_classes=5)
# load model weights
model_weight_path = "./resNet34.pth"
model.load_state_dict(torch.load(model_weight_path))
model.eval()
with torch.no_grad():# predict classoutput = torch.squeeze(model(img))predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy()
print(class_indict[str(predict_cla)], predict[predict_cla].numpy())
plt.show()

batch predict:

import os
import jsonimport torch
from PIL import Image
from torchvision import transformsfrom model import resnet34def main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")data_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# load image# 指向需要遍历预测的图像文件夹imgs_root = "G:\pycharmProject\Project2\predict_batch\imgs"assert os.path.exists(imgs_root), f"file: '{imgs_root}' dose not exist."# 读取指定文件夹下所有jpg图像路径img_path_list = [os.path.join(imgs_root, i) for i in os.listdir(imgs_root) if i.endswith(".jpg")]# read class_indictjson_path = './class_indices.json'assert os.path.exists(json_path), f"file: '{json_path}' dose not exist."json_file = open(json_path, "r")class_indict = json.load(json_file)# create modelmodel = resnet34(num_classes=5).to(device)# load model weightsweights_path = "./resNet34.pth"assert os.path.exists(weights_path), f"file: '{weights_path}' dose not exist."model.load_state_dict(torch.load(weights_path, map_location=device))# predictionmodel.eval()batch_size = 8  # 每次预测时将多少张图片打包成一个batchwith torch.no_grad():for ids in range(0, len(img_path_list) // batch_size):img_list = []for img_path in img_path_list[ids * batch_size: (ids + 1) * batch_size]:assert os.path.exists(img_path), f"file: '{img_path}' dose not exist."img = Image.open(img_path)img = data_transform(img)img_list.append(img)# batch img# 将img_list列表中的所有图像打包成一个batchbatch_img = torch.stack(img_list, dim=0)# predict classoutput = model(batch_img.to(device)).cpu()predict = torch.softmax(output, dim=1)probs, classes = torch.max(predict, dim=1)for idx, (pro, cla) in enumerate(zip(probs, classes)):print("image: {}  class: {}  prob: {:.3}".format(img_path_list[ids * batch_size + idx],class_indict[str(cla.numpy())],pro.numpy()))if __name__ == '__main__':main()

五、下载预训练的模型参数

下载地址如下所示,浏览器直接跳转下载就行。自己用的是resnet34就足够了。

model_urls = {'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth','resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth','resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth','resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth','resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth','resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth','resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth','wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth','wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}

下载好的文件更名为resnet34-pre.pth

六、检查所有需要的程序、数据集和模型参数

打开pycharm,打开各py文件,点击train,右键运行,训练完毕之后,切换到predict.py程序。预测之前选取测试集中的一张图片,命名为tulip.jpg;如果是批预测,选取测试集中的各类别的图片组成文件夹predict_batch/imgs/ ..... .jpg.

预测程序中写的路径为:

G:\\pycharmProject/Project2/tulip.jpg

批预测程序的路径为:

G:\\pycharmProject\Project2\predict_batch\imgs

注意这里需要修改。

七、运行和结果

这样应该算是预测成功了。

[笔记]Pytorch框架下的入门应用:resnet34实现分类相关推荐

  1. 我的实践:pytorch框架下基于BERT实现文本情感分类

    当前,在BERT等预训练模型的基础上进行微调已经成了NLP任务的一个定式了.为了了解BERT怎么用,在这次实践中,我实现了一个最简单的NLP任务,即文本情感分类. 文章目录 1.基于BERT进行情感分 ...

  2. 1.Pytorch框架下使用yolov3-tiny网络模型 训练自己的数据集

    在Pytorch框架下使用yolov3-tiny网络模型 ,训练自己的数据集 1.本文参考链接如下: https://blog.csdn.net/gbz3300255/article/details/ ...

  3. 踩坑记录: Pytorch框架下--- 从零使用卷积神经网络实现人脸面部表情识别 (基于连续维度)

    之前一直在自学深度神经网络的知识,在跟着书本一步一步走的时候,感觉每一个思路,每一句代码都特别容易,实现思路清晰明了,实验代码简单易懂.但当我真正课题需要用到的时候,想跳出书本的框架,自行实现并通透其 ...

  4. pytorch框架下faster rcnn使用softnms

    pytorch faster rcnn softnms frcnn使用softnms方法一:pytorch复现版本的cpu版softnms(本方法可以跑通) 0. 首先overview一波:infer ...

  5. NLP学习笔记-Pytorch框架(补充)

    PDF Pytorch初步应用 使用Pytorch构建一个神经网络 学习目标 掌握用Pytorch构建神经网络的基本流程. 掌握用Pytorch构建神经网络的实现过程. 关于torch.nn: 使用P ...

  6. l1、l2正则化在pytorch框架下的实现方式

    转载PyTorch训练模型添加L1/L2正则化的两种实现方式_hlld__的博客-CSDN博客_pytorch添加正则化 在使用PyTorch训练模型时,可使用三种方式添加L1/L2正则化:一种是添加 ...

  7. 【个人网站搭建】GitHub pages+hexo框架下为next主题添加菜单分类页面

    0x00 前言 文章中的文字可能存在语法错误以及标点错误,请谅解: 如果在文章中发现代码错误或其它问题请告知,感谢! Hexo博客框架版本(hexo vesion):5.3.0 Next主题版本:v5 ...

  8. Pytorch框架中SGD&Adam优化器以及BP反向传播入门思想及实现

    因为这章内容比较多,分开来叙述,前面先讲理论后面是讲代码.最重要的是代码部分,结合代码去理解思想. SGD优化器 思想: 根据梯度,控制调整权重的幅度 公式: 权重(新) = 权重(旧) - 学习率 ...

  9. paddle.paramattr转换为torch框架下算法

    paddle.paramattr是PaddlePaddle框架中用于表示网络层参数的属性类.如果想要将其转换为PyTorch框架下的算法,需要使用torch.nn.Parameter类. 具体而言,需 ...

最新文章

  1. Android 2018最新验证手机号正则表达式
  2. 基于xml进行bean装配
  3. mysql $gt_mysql变量(用户+系统)
  4. python语言是编译性语音_最强编程语言 Java 和最受欢迎之 Python 的巅峰对决
  5. 小型云台用的是什么电机_直流电机的工作原理是什么?未来的电动车都会用直流电机吗?...
  6. Asp.Net 构架(HttpModule 介绍) - Part.3
  7. mysql8基本操作
  8. mysql 过滤单引号_python实现mysql的单引号字符串过滤方法
  9. 红帽干掉 CentOS 8,CentOS Stream 上位
  10. 软件平台与中间技术复习
  11. ArcMap进行地图标注与注记
  12. Android源码分析(十三)----SystemUI下拉状态栏如何添加快捷开关
  13. Z世代成为消费新主力,我国潮牌营销洞察报告​
  14. 操作系统正则符号知识点总结
  15. SQL查询语句(从单表到多表、从简单到复杂)
  16. 解决tar 时间戳XXX是未来的XXX秒之后问题
  17. 马尔可夫决策过程(Markov Decision Process)学习笔记
  18. 《电磁场与电磁波》---恒定电场思维导图
  19. 关于比尔盖茨的几点思考
  20. Codeforces 432C (哥德巴赫猜想的巧妙应用)

热门文章

  1. OSG_64位动态链接库+静态链接库的使用
  2. html文字抖动效果,CSS实现TikTok文字抖动效果示例
  3. python中文字符串比较模块_python比较字符串相似度,原创度检测工具
  4. 解析下载blob视频
  5. 最新国产电源厂家及具体型号pin-to-pin替代手册发布
  6. VIM 参 考 手 册[转]
  7. uniapp在onLoad事件中不显示showToast的问题
  8. grub4dos 启动ubuntu 12.04
  9. 哪个计算机无法做到双屏显示,如何实现笔记本电脑的双屏显示
  10. Python程序员必备——手把手教你配置最漂亮的PyCharm界面