CV算法复现(分类算法5/6):ResNet(2015年 微软亚洲研究院)
致谢:霹雳吧啦Wz:https://space.bilibili.com/18161609
目录
致谢:霹雳吧啦Wz:https://space.bilibili.com/18161609
1 本次要点
1.1 Python库语法
1.2 深度学习理论
2 网络简介
2.1 历史意义
2.2 网络亮点
2.3 网络结构
3 代码结构
3.1 model.py
3.2 train.py
3.3 predict.py
1 本次要点
1.1 深度学习理论
- BN层:使feature map(指一批图的,而不是一张图)满足均值为0,方差为1的分布。
- 注意1:如果要使用BN层,则batch size应该尽可能大,因为这样更接近全体数据集的均值和方差,而batchsize如果为1,可能还不如不用。
- 注意2:BN建议放在卷积层和激活层(如Relu)之间,且卷积不要使用偏置bias,因为有也会在BN计算时抵消掉。
- 详细可见:(霹雳吧啦Wz)Batch Normalization详解以及pytorch实验:https://blog.csdn.net/qq_37541097/article/details/104434557
- 迁移学习:如果使用别人的预训练模型,一定要知道别人的预处理方式。
- 常见的迁移学习方式:
1. 载入权重后,训练所有参数
2. 载入权重后,只训练最后几层参数
3. 载入权重后,在原网络基础上再添加一层全连接层,仅训练最后一个全连接层 - 迁移学习效果:花分类任务,重头开始训练,几十个epoch后才89%,但使用ImageNet预训练模型,一个epoch后就90.9%。
- 常见的迁移学习方式:
- 残差结构:(特征图做加法)
2 网络简介
2.1 历史意义
ResNet在2015年由微软实验室提出,斩获当年ImageNet竞赛中分类任务第一名,目标检测第一名。获得COCO数据集中目标检测第一名,图像分割第一名。
2.2 网络亮点
- 提出residual 残差模块(加法运算,而不是通道拼接,所以C、H、W维度都要一致)
- 使用Batch Normalization 加速训练( 丢弃dropout)
- 超深的网络结构( 突破1000 层)
2.3 网络结构
3 代码结构
- model.py
- train.py
- predict.py
3.1 model.py
import torch.nn as nn
import torch#18和34层残差结构(具备实线残差结构功能和虚线残差结构功能)
class BasicBlock(nn.Module):expansion = 1 #对应残差结构中卷积核个数有没有发生变化。1是1倍意思,即都一样。# downsample对应虚线的残差结构def __init__(self, in_channel, out_channel, stride=1, downsample=None):super(BasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,kernel_size=3, stride=stride, padding=1, bias=False)#注意之前的卷积层不要biasself.bn1 = nn.BatchNorm2d(out_channel)self.relu = nn.ReLU()self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channel)self.downsample = downsampledef forward(self, x):identity = xif self.downsample is not None:identity = self.downsample(x)out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out += identityout = self.relu(out)return out#50层及以上的残差结构(具备实线残差结构功能和虚线残差结构功能)
class Bottleneck(nn.Module):expansion = 4def __init__(self, in_channel, out_channel, stride=1, downsample=None):super(Bottleneck, self).__init__()self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,kernel_size=1, stride=1, bias=False) # squeeze channelsself.bn1 = nn.BatchNorm2d(out_channel)# -----------------------------------------self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,kernel_size=3, stride=stride, bias=False, padding=1)self.bn2 = nn.BatchNorm2d(out_channel)# -----------------------------------------self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel*self.expansion,kernel_size=1, stride=1, bias=False) # unsqueeze channelsself.bn3 = nn.BatchNorm2d(out_channel*self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampledef forward(self, x):identity = xif self.downsample is not None:identity = self.downsample(x)out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)out += identityout = self.relu(out)return out# block为 BasicBlock(nn.Module)或Bottleneck(nn.Module)
# blocks_num:列表参数,代表残差结构的个数,如[3,4,6,3]、[2,2,2,2]
# include_top=True方便在resnet上搭建更复杂的结构。默认就是True
class ResNet(nn.Module):def __init__(self, block, blocks_num, num_classes=1000, include_top=True):super(ResNet, self).__init__()self.include_top = include_topself.in_channel = 64self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,padding=3, bias=False)self.bn1 = nn.BatchNorm2d(self.in_channel)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = self._make_layer(block, 64, blocks_num[0])self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)if self.include_top:#采用自适应平均池化,不管输入是什么维度,输出的HW都将是1*1self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # output size = (1, 1)self.fc = nn.Linear(512 * block.expansion, num_classes)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')# 搭建残差结构的函数def _make_layer(self, block, channel, block_num, stride=1):downsample = Noneif stride != 1 or self.in_channel != channel * block.expansion:downsample = nn.Sequential(nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(channel * block.expansion))layers = []layers.append(block(self.in_channel, channel, downsample=downsample, stride=stride))self.in_channel = channel * block.expansionfor _ in range(1, block_num):layers.append(block(self.in_channel, channel))return nn.Sequential(*layers) # 将list转为非关键字参数传入def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)if self.include_top:x = self.avgpool(x)x = torch.flatten(x, 1)x = self.fc(x)return xdef resnet34(num_classes=1000, include_top=True):return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)def resnet101(num_classes=1000, include_top=True):return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)
3.2 train.py
import torch
import torch.nn as nn
from torchvision import transforms, datasets
import json
import matplotlib.pyplot as plt
import os
import torch.optim as optim
from model import resnet34, resnet101#import torchvision.models.resnet #导入pytorch框架自带的网络结构。本作者的修改于其版本。def main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("using {} device.".format(device))data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(256), #将图像最小的边缩放到256.transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # get data root pathimage_path = os.path.join(data_root, "data_set", "flower_data") # flower data set pathassert os.path.exists(image_path), "{} path does not exist.".format(image_path)train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform=data_transform["train"])train_num = len(train_dataset)# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}flower_list = train_dataset.class_to_idxcla_dict = dict((val, key) for key, val in flower_list.items())# write dict into json filejson_str = json.dumps(cla_dict, indent=4)with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 16nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=nw)validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),transform=data_transform["val"])val_num = len(validate_dataset)validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=False,num_workers=nw)print("using {} images for training, {} images fot validation.".format(train_num,val_num))net = resnet34() #实例化网络,注意:此时没有传入参数,默认是1000分类# load pretrain weights# download url: https://download.pytorch.org/models/resnet34-333f7ec4.pthmodel_weight_path = "./resnet34-pre.pth"assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)missing_keys, unexpected_keys = net.load_state_dict(torch.load(model_weight_path), strict=False)# for param in net.parameters():# param.requires_grad = False# change fc layer structurein_channel = net.fc.in_featuresnet.fc = nn.Linear(in_channel, 5) #由于花分类是5类,所以重新赋值(默认1000分类)net.to(device)loss_function = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.0001)best_acc = 0.0save_path = './resNet34.pth'for epoch in range(3):# trainnet.train()running_loss = 0.0for step, data in enumerate(train_loader, start=0):images, labels = dataoptimizer.zero_grad()logits = net(images.to(device))loss = loss_function(logits, labels.to(device))loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()# print train processrate = (step+1)/len(train_loader)a = "*" * int(rate * 50)b = "." * int((1 - rate) * 50)print("\rtrain loss: {:^3.0f}%[{}->{}]{:.4f}".format(int(rate*100), a, b, loss), end="")print()# validatenet.eval()acc = 0.0 # accumulate accurate number / epochwith torch.no_grad():for val_data in validate_loader:val_images, val_labels = val_dataoutputs = net(val_images.to(device)) # eval model only have last output layer# loss = loss_function(outputs, test_labels)predict_y = torch.max(outputs, dim=1)[1]acc += (predict_y == val_labels.to(device)).sum().item()val_accurate = acc / val_numif val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('[epoch %d] train_loss: %.3f test_accuracy: %.3f' %(epoch + 1, running_loss / step, val_accurate))print('Finished Training')if __name__ == '__main__':main()
输出:
3.3 predict.py
import torch
from model import resnet34
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import jsondevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")data_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# load image
img = Image.open("../tulip.jpg")
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)# read class_indict
try:json_file = open('./class_indices.json', 'r')class_indict = json.load(json_file)
except Exception as e:print(e)exit(-1)# create model
model = resnet34(num_classes=5)
# load model weights
model_weight_path = "./resNet34.pth"
model.load_state_dict(torch.load(model_weight_path, map_location=device))
model.eval()
with torch.no_grad():# predict classoutput = torch.squeeze(model(img))predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy()
print(class_indict[str(predict_cla)], predict[predict_cla].numpy())
plt.show()
输出:
CV算法复现(分类算法5/6):ResNet(2015年 微软亚洲研究院)相关推荐
- NLP专栏简介:数据增强、智能标注、意图识别算法|多分类算法、文本信息抽取、多模态信息抽取、可解释性分析、性能调优、模型压缩算法等
NLP专栏简介:数据增强.智能标注.意图识别算法|多分类算法.文本信息抽取.多模态信息抽取.可解释性分析.性能调优.模型压缩算法等 专栏链接:NLP领域知识+项目+码源+方案设计 订阅本专栏你能获得什 ...
- 算法杂货铺——分类算法之决策树(Decision tree)
算法杂货铺--分类算法之决策树(Decision tree) 2010-09-19 16:30 by T2噬菌体, 88978 阅读, 29 评论, 收藏, 编辑 3.1.摘要 在前面两篇文章中,分别 ...
- 算法杂货铺——分类算法之贝叶斯网络(Bayesian networks)
算法杂货铺--分类算法之贝叶斯网络(Bayesian networks) 2010-09-18 22:50 by T2噬菌体, 66011 阅读, 25 评论, 收藏, 编辑 2.1.摘要 在上一篇文 ...
- k近邻算法(KNN)-分类算法
k近邻算法(KNN)-分类算法 1 概念 定义:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别. k-近邻算法采用测量不同特征值之间的 ...
- 数据挖掘算法——常用分类算法总结
常用分类算法总结 分类算法 NBC算法 LR算法 SVM算法 ID3算法 C4.5 算法 C5.0算法 KNN 算法 ANN 算法 分类算法 分类是在一群已经知道类别标号的样本中,训练一种分类器,让其 ...
- 分类算法列一下有多少种?应用场景?分类算法介绍、常见分类算法优缺点、如何选择分类算法、分类算法评估
分类算法 分类算法介绍 概念 分类算法 常见分类算法 NBS LR SVM算法 ID3算法 C4.5 算法 C5.0算法 KNN 算法 ANN 算法 选择分类算法 分类算法性能评估 分类算法介绍 概念 ...
- 北京内推 | 微软亚洲研究院MSRA STCA招聘多模态算法实习生
合适的工作难找?最新的招聘信息也不知道? AI 求职为大家精选人工智能领域最新鲜的招聘信息,助你先人一步投递,快人一步入职! 微软亚洲研究院 微软亚洲研究院(MSRA)是微软在美国本土以外规模最大研究 ...
- 微软亚洲研究院多媒体计算组招聘计算机视觉算法实习生
来源:AI求职 微软亚洲研究院 微软亚洲研究院多媒体计算组致力于研究多媒体实时通信中的各种技术,囊括底层的视频音频处理.压缩.编解码和上层的计算机视觉和语音处理与理解.利用人工智能等技术提高多媒体实时 ...
- 岗位内推 | 微软亚洲研究院智能多媒体组招聘计算机视觉算法实习生
PaperWeekly 致力于推荐最棒的工作机会,精准地为其找到最佳求职者,做连接优质企业和优质人才的桥梁.如果你需要我们帮助你发布实习或全职岗位,请添加微信号「pwbot02」. 我们属于微软亚洲研 ...
最新文章
- CommonResult响应工具类封装
- python 接收邮件服务器地址_Python 用IMAP接收邮件
- macos通过brew安装后端开发工具
- 玩转LiteOS组件:玩转Librws
- 【Elasticsearch】Elasticsearch查询参数batched_reduce_size的解释
- html常规的布局模版,html5/css3常规布局(示例代码)
- linux 系统启动级别,LINUX系统启动级别介绍与解释
- 计算机网络telnet命令作用,全面解析telnet命令
- VARCHART XGantt系列教程:甘特图调度程序探究
- 利用python爬虫进行彼岸网图库图片的抓取(bs4)
- python requests 由于目标计算机积极拒绝,无法连接
- coreldraw x7 分布_了解CorelDRAW X7工作界面
- ps抠图技巧,抠头发
- 学习 lt MATLAB gt 心得,matlab和lingo加mathmatica学习心得
- 【基础】Linux 常用操作
- 组织结构图 js实现
- 浅谈微信营销的价值与优势
- 直流有刷电机并联小电容作用分析
- 百度小程序api怎么提交?Python推送代码免费分享给大家
- Linux常用命令大全(详细版)
热门文章
- HarmonyOS 字体在自身控件中居中(使用text_alignment)
- ExampleUnitTest的用法
- Error:Could not download guava.jar (com.google.guava:guava:19.0): No cached version available for of
- mysql的时间存储格式
- LoadRunner11录制脚本出现的问题
- Spark ListenerBus 和 MetricsSystem 体系分析
- Serializable 都这么牛逼了,Parcelable 还要你何用?
- 玩转Spring Cloud之配置中心(config server config client)
- 2022-2028年中国车载充电机行业深度调研及投资前景预测报告
- java NIO详解