基于ResNet迁移学习的LFW人脸识别分类

LFW数据集(Labeled Faces in the Wild)是马萨诸塞大学阿姆斯特分校计算机视觉研究所整理制作的一个非限制环境下人脸数据集,包含5749人合计13233张图片,图片大小都是250x250

本代码背景是一份CNN的人脸分类报告,仅需要完成简单的人脸分类即可,不需要完成人脸识别,因此就当作是人脸识别的简单入门,之后的话可能会根据自己的兴趣做一个人脸识别检测的demo程序用在树莓派上面

PS. 基于pytorch-gpu 1.5.1实现,但是为了通用性所以改成了cpu版本,需要使用gpu的同学请自行添加相应代码

数据集准备

下载数据集

可以到LFW官网上下载数据集,下载之后会有好几个压缩包,我们只需要其中的lfw.tgz文件,解压之后就得到了包含所有图片的文件夹

也可以直接拿我下好的数据集,下面是度娘链接

链接:https://pan.baidu.com/s/152iVUmPoMDQN_B94hJWETA
提取码:7a6h
复制这段内容后打开百度网盘手机App,操作更方便哦–来自百度网盘超级会员V5的分享(炫耀下我的v5的(~ ̄▽ ̄)~)

制作DataSet

考虑到LFW原始数据集中有很多人只有一张照片,也有部分名人,像布什这种一个人就有上百张照片,一方面为了保持每个人对应的人脸照片量合适,另一方面尽量减少需要分类的人的个数以减小网络大小方便训练,因此需要从LFW数据集中挑选一部分照片用于本次实验。这里最终挑选的是拥有30-100张照片的这部分人,共有29人,也就是说最终的CNN需要分类的个数为29类,对于小实验而言可以接受了

制作过程分为以下几步:

  1. 读取文件夹,获取图片及人名
  2. 挑选其中符合要求的人脸图片并将人名转换为整数标签
  3. 对人脸图片进行变换后和人名标签一起存入DataSet
  4. 定义DataLoader用于后续训练

PS. 在图像处理的时候,因为ResNet的图片输入大小是224x224,因此做了一个中心裁剪

class MyDataSet(Dataset):'''定义数据集,用于将读取到的图片数据转换并处理成CNN神经网络需要的格式'''def __init__(self, DataArray, LabelArray):super(MyDataSet, self).__init__()self.data = DataArrayself.label = LabelArraydef __getitem__(self, index):# 对图片的预处理步骤# 1. 中心缩放至224(ResNet的输入大小)# 2. 随机旋转0-30°# 3. 对图片进行归一化,参数来源为pytorch官方文档im_trans = transforms.Compose([transforms.ToPILImage(),transforms.CenterCrop(size=224),transforms.RandomRotation((0, 30)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])return im_trans(self.data[index]), t.tensor(self.label[index], dtype=t.long)def __len__(self):return self.label.shape[0]# 读取LFW数据集,将图片数据读入数组并将名字转换为标签
path = r'face+\lfw'
pathlist = map(lambda x: '\\'.join([path, x]), os.listdir(path))
namedict = {}
data, label = [], []
idx = 0
for item in pathlist:dirlist = os.listdir(item)# 选取拥有30-100张照片的人作为数据来源# 太少网络不容易学习到其人脸特征,太多的话则容易过拟合if not (30<= len(dirlist) <= 100):continue# data:     存储人像照片的三通道数据# label:    存储人像的对应标签(整数)# namedict: 记录label中整数与人名的对应关系for picpath in dirlist:data.append(image.imread(item + '\\' + picpath))label.append(idx)namedict[str(idx)] = item.split('\\')[-1]idx += 1# 随机打乱数据,重新排序并按照8:2的比例分割训练集和测试集
data, label = np.stack(data), np.array(label)
idx = np.random.permutation(data.shape[0])
data, label = data[idx], label[idx]
train_X, test_X, train_Y, test_Y = train_test_split(data, label, test_size=0.2)# 将分割好的训练集和测试集处理为pytorch所需的格式
TrainSet = MyDataSet(train_X, train_Y)
TestSet = MyDataSet(test_X, test_Y)
TrainLoader = DataLoader(TrainSet, batch_size=32, shuffle=True, drop_last=True)
TestLoader = DataLoader(TestSet, batch_size=32, shuffle=True, drop_last=True)

调用ResNet18

pytorch官方提供了很多CNN网络的现成版本可以直接调用,就不用自己费力去写了。而且官方提供的网络都有预训练版本,可以直接拿在ImageNet训练过的CNN网络在我们的简易LFW数据集上稍微训练微调,从而实现迁移学习,效果一般都会比较好。

考虑到我们简易LFW数据集的规模,用ResNet18就可以了,把pretrained属性设置为True使用预训练版本,初始使用的话会自动下载网络参数,需要等一会。ResNet18模型没办法直接运用在我们的数据集上,需要做如下三点变换

  1. 将输入图片的大小转为N x C x 224 x 244
  2. ResNet18网络中的requires_grad置为False,使其后续不参与训练更新(可设置也可以不设置,看哪个效果好而定,不过不更新ResNet网络参数的话训练更新会更快,但是通常效果会差一些)
  3. ResNet18网络的fc分类头改为适合我们数据集的大小
# 调用预训练的resnet18进行迁移学习
# resnet50参数量过多,训练效果不太好
resnet = models.resnet18(pretrained=True)
for param in resnet.parameters():param.requires_grad = False# 将resnet的输出fc(全连接层)替换为本任务所需的格式
# 1000-->256-->relu-->dropout-->29-->softmax
fc_inputs = resnet.fc.in_features
resnet.fc = nn.Sequential(nn.Linear(fc_inputs, 256),nn.ReLU(),nn.Dropout(),nn.Linear(256, 29)
)

进行迁移学习

之后的步骤就跟通常的CNN训练没有区别了,设置好参数按照模板进行训练即可,由于迁移学习的效果比较好,因此这里也不需要特别设置网络训练的参数,保持默认即可

# 定义交叉熵损失函数和Adam优化器(学习率,权重衰减使用默认值)
loss = nn.CrossEntropyLoss()
optimizer = t.optim.Adam(resnet.parameters())def train(net, dataloader, testdataloader, optimizer, criterion, epocs=20):# 以下四个参数分别用于存储训练和测试的损失函数值以及分类准确率train_loss_arr, train_acc_arr, test_loss_arr, test_acc_arr = [], [], [], []for epoc in range(epocs):net.train()TrainLoss, TrainAcc = 0, 0for BatchIdx, (InputData, Labels) in enumerate(dataloader):Outputs = net(InputData)optimizer.zero_grad()loss = criterion(Outputs.squeeze(), Labels)loss.backward()optimizer.step()TrainLoss += loss.item()_, pred = t.max(Outputs.data, 1)TrainAcc += t.mean(pred.eq(Labels.data.view_as(pred)).type(t.FloatTensor)).item() * len(InputData)if BatchIdx % 10 == 0 and BatchIdx > 0:print('Bathch: {}/{}\tLoss: {}\tAcc: {}%'.format(BatchIdx, len(dataloader), round(TrainLoss, 2), round(100*TrainAcc/((BatchIdx+1) * InputData.shape[0]), 2)))train_acc_arr.append(100*TrainAcc/(len(dataloader)*32))train_loss_arr.append(TrainLoss)TestLoss, TestAcc = 0, 0with t.no_grad():net.eval()for BatchIdx, (InputData, Labels) in enumerate(testdataloader):Outputs = net(InputData)loss = criterion(Outputs.squeeze(), Labels)TestLoss += loss.item()_, pred = t.max(Outputs.data, 1)TestAcc += t.mean(pred.eq(Labels.data.view_as(pred)).type(t.FloatTensor)).item() * len(InputData)print('Loss: {}\tAcc: {}%'.format(round(TrainLoss, 2),round(100*TestAcc/(len(testdataloader) * 32), 2)))print('-'*60)  test_acc_arr.append(100*TestAcc/(len(testdataloader)*32))test_loss_arr.append(TestLoss)return train_loss_arr, train_acc_arr, test_loss_arr, test_acc_arr# 进行训练并绘制训练曲线
train_loss_arr, train_acc_arr, test_loss_arr, test_acc_arr = train(resnet, TrainLoader, TestLoader, optimizer, loss)
fig = plt.figure()
ax1 = fig.add_subplot(121)
ax1.plot(train_loss_arr, label='train loss')
ax1.plot(test_loss_arr, label='test loss')
ax1.legend()
ax1.set_title('Loss Curve')
ax1.set_xlabel('epocs')
ax1.set_ylabel('loss')
ax2 = fig.add_subplot(122)
ax2.plot(train_acc_arr, label='train acc')
ax2.plot(test_acc_arr, label='test acc')
ax2.legend()
ax2.set_title('Accuracy Curve')
ax2.set_xlabel('epocs')
ax2.set_ylabel('loss')
plt.show()# 打印测试集的真实/预测结果
for InputData, Labels in enumerate(TestSet):Outputs = resnet(Labels[0].unsqueeze(0))_, pred = t.max(Outputs.data, 1)pred_name = namedict[str(pred.item())]real_name = namedict[str(Labels[1].item())]print('real name: {}\t\t\t\tpredict name: {}'.format(real_name, pred_name))
t.save(resnet, r'face+\resnet.pth')

模型分类结果

训练完成后模型的分类准确率训练集上差不多99%,测试集上最高可以到90%,还是比较符合预期了,毕竟整个网络其实没有进行太多的调整

lfw_test中的8张人脸照片进行测试,其中6张正确,2张错误,看了下分类错误的两张之一

左边是Jean Chretien(加拿大前总理),右边是大名鼎鼎的贝克汉姆,网络把总理的人脸照片错误识别成了贝克汉姆。讲道理,有一说一,我觉得没啥毛病,总理也挺帅的

基于ResNet迁移学习的LFW人脸识别分类相关推荐

  1. 蚂蚁金服张洁:基于深度学习的支付宝人脸识别技术解秘-1

    蚂蚁金服张洁:基于深度学习的支付宝人脸识别技术解秘(1) 2015-08-13 10:22 于雪 51CTO 字号:T | T 用户身份认证是互联网金融发展的基石.今年三月,在德国汉诺威举办的IT展览 ...

  2. 基于 CNN 和迁移学习的农作物病害识别方法研究

    基于 CNN 和迁移学习的农作物病害识别方法研究 1.研究思路 采用互联网公开的 ImageNet 图像大数据集和PlantVillage 植物病害公共数据集, 以实验室的黄瓜和水稻病害数据集 AES ...

  3. 学习笔记TF058:人脸识别

    人脸识别,基于人脸部特征信息识别身份的生物识别技术.摄像机.摄像头采集人脸图像或视频流,自动检测.跟踪图像中人脸,做脸部相关技术处理,人脸检测.人脸关键点检测.人脸验证等.<麻省理工科技评论&g ...

  4. 基于opencv tensorflow2.0开发的人脸识别锁定与解锁win10屏幕实战

    基于opencv tensorflow2.0开发的人脸识别锁定与解锁win10屏幕实战 基于opencv tensorflow2.0开发的人脸识别锁定与解锁win10屏幕 前言 运行python环境 ...

  5. 【深度学习】DIY 人脸识别技术的探索(二)

    [深度学习]DIY 人脸识别技术的探索(二) 文章目录 训练模型 工具 结果展示 问题二的模型建立与求解 基于 KNN 的人脸识别模型 训练模型 MTCNN 可以并行训练(3 个网络同时训练,前提是内 ...

  6. 【项目实战课】基于Pytorch的MTCNN与Centerloss人脸识别实战

    欢迎大家来到我们的项目实战课,本期内容是<基于Pytorch的MTCNN与Centerloss人脸识别实战>. 所谓项目实战课,就是以简单的原理回顾+详细的项目实战的模式,针对具体的某一个 ...

  7. python人脸识别库_基于Python的face_recognition库实现人脸识别

    Python Python开发 Python语言 基于Python的face_recognition库实现人脸识别 一.face_recognition库简介 face_recognition是Pyt ...

  8. 深度学习下的人脸识别技术:从“后真相”到“无隐私”

    2019-06-17 14:27:08 图片来源@视觉中国 文|五矩研究社,作者|劫镖 2018年7月,<大西洋月刊>曾发表过一篇人脸识别的文章,名字叫做<开启假视频时代>,文 ...

  9. 【毕业设计_课程设计】基于opencv、dilb的员工人脸识别考勤系统

    文章目录 0 项目说明 1 需求分析 2 总体设计 3 详细设计 4 效果展示 5 实验心得 6 项目源码 7 最后 0 项目说明 基于opencv.dilb的员工人脸识别考勤系统 提示:适合用于课程 ...

  10. 深度学习之视频人脸识别系列一:介绍

    作者 | 东田应子 [导读]本文是深度学习之视频人脸识别系列的第一篇文章,介绍了人脸识别领域的一些基本概念,分析了深度学习在人脸识别的基本流程,并总结了近年来科研领域的研究进展,最后分析了静态数据与视 ...

最新文章

  1. 客户资料查询传递数据格式
  2. Groovy正则表达式复杂逻辑判断实例
  3. 如何用Word 2007写Blog
  4. C++细节系列(零):零散记录
  5. Linux文件属性1——文件类型
  6. ioc spring技术手册学习
  7. python如何输出整数逆序_python字符串类型及操作
  8. FFmpeg之获取音视频信息(二十八)
  9. python连接SQLServer数据库创建数据表同时为每个字段加上对应的中文注释信息
  10. Spring整合JsonRpc
  11. python攻击校园网_python爬虫 模拟登陆校园网-初级
  12. 如何帮助空降经理人成功?
  13. requests库入门-14-HTTP基本认证
  14. 算法自学笔记:Convex Hull问题
  15. C语言 char 和 signed char的区别
  16. redis中的incr和incrBy
  17. emWin 2天速成实例教程012_基于STM32单片机的全键盘中文汉字拼音输入法
  18. js面向对象prototype
  19. <<计算机操作系统(慕课版)>>第五章参考答案
  20. 关于模拟题的一些弱鸡总结

热门文章

  1. GB35114视频流处理
  2. protues7 使用笔记
  3. Python实战RBF神经网络
  4. PreferenceScreen1
  5. DolbyAudio访问杜比音效驱动程序时发生问题,请重新启动计算机或......
  6. 软件项目管理 2.2.项目招投标流程
  7. 简单理解t检验与秩和检验
  8. WEB专用服务器的安全设置
  9. 局域网不同网段远程桌面_自动化已非原来的自动化:看虚拟局域网技术应用到罗克韦尔的DCS...
  10. * Redis —— Scan、SScan、HScan、ZScan