在完成数据准备之后,便可以使用 PyTorch 深度学习框架,实现卷积神经网络的定义、训练和预测。

一、模型搭建与训练

得到了数据之后,接下来咱们使用 PyTorch 这个框架来进行模型的训练。整个训练流程包括数据接口准备、模型定义、结果保存与分析。

1.1 数据接口准备

PyTorch 图像分类直接利用文件夹作为输入,只需要把不同类的数据放到不同的文件夹中。数据读取的完整代码如下:

data_transforms = {'train': transforms.Compose([transforms.Scale(64),transforms.RandomSizedCrop(48),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])]),'val': transforms.Compose([transforms.Scale(64),transforms.CenterCrop(48),transforms.ToTensor(),transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])]),
}data_dir = './train_val_data/'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),data_transforms[x]) for x in ['train', 'val']}
dataloders = {x: torch.utils.data.DataLoader(image_datasets[x],batch_size=16,shuffle=True,num_workers=4) for x in ['train', 'val']}

上面脚本中的函数,输入一个文件夹,输出图片路径以及标签,在开始训练之前需要将数据集进行拆分,拆分成训练集(train)和验证集(val),训练集和测试集的比例为9:1train_val_data文件结构如下所示,其中 0 代表 none、 1 代表pouting、2 代表 smile、3 代表 openmouth:

到此,数据接口就定义完毕了,接下来在训练代码中看如何使用迭代器进行数据读取就可以了。

1.2 模型定义

创建数据接⼝后,我们开始定义⼀个⽹络 simpleconv3

import torch.nn as nn
import torch.nn.functional as Fclass simpleconv3(nn.Module):def __init__(self):super(simpleconv3,self).__init__()self.conv1 = nn.Conv2d(3, 12, 3, 2)self.bn1 = nn.BatchNorm2d(12)self.conv2 = nn.Conv2d(12, 24, 3, 2)self.bn2 = nn.BatchNorm2d(24)self.conv3 = nn.Conv2d(24, 48, 3, 2)self.bn3 = nn.BatchNorm2d(48)self.fc1 = nn.Linear(48 * 5 * 5 , 1200)self.fc2 = nn.Linear(1200 , 128)self.fc3 = nn.Linear(128 , 4)def forward(self , x):x = F.relu(self.bn1(self.conv1(x)))#print "bn1 shape",x.shapex = F.relu(self.bn2(self.conv2(x)))x = F.relu(self.bn3(self.conv3(x)))x = x.view(-1 , 48 * 5 * 5) x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x

上面就是我们定义的网络,是一个简单的 3 层卷积。在 torch.nn 下,有各种网络层,这里就用到了 nn.Conv2d,nn.BatchNorm2d 和 nn.Linear,分别是卷积层,BN 层和全连接层。我们以一个卷积层为例:

conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=3, stride=2)
bn1 = nn.BatchNorm2d(num_features=12)
  • in_channels:输入通道数
  • out_channels:输出通道数
  • kernel_size:卷积核的大小
  • stride:卷积核的移动步长

更全面的参数,请自查 API:PyTorch

1.3  模型训练

(深度学习一般使用 GPU 进行训练)

#coding:utf8
from __future__ import print_function, divisionimport torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
import torchvision
from torchvision import datasets, models, transforms
import time
import os
from tensorboardX import SummaryWriter
import torch.nn.functional as F
import numpy as npimport warningswarnings.filterwarnings('ignore')writer = SummaryWriter()def train_model(model, criterion, optimizer, scheduler, num_epochs=25):for epoch in range(num_epochs):print('Epoch {}/{}'.format(epoch, num_epochs - 1))for phase in ['train', 'val']:if phase == 'train':scheduler.step()model.train(True)  # Set model to training modeelse:model.train(False)  # Set model to evaluate moderunning_loss = 0.0running_corrects = 0.0for data in dataloders[phase]:inputs, labels = dataif use_gpu:inputs = Variable(inputs.cuda())labels = Variable(labels.cuda())else:inputs, labels = Variable(inputs), Variable(labels)optimizer.zero_grad()outputs = model(inputs)_, preds = torch.max(outputs.data, 1)loss = criterion(outputs, labels)if phase == 'train':loss.backward()optimizer.step()running_loss += loss.data.item()running_corrects += torch.sum(preds == labels).item()epoch_loss = running_loss / dataset_sizes[phase]epoch_acc = running_corrects / dataset_sizes[phase]if phase == 'train':writer.add_scalar('data/trainloss', epoch_loss, epoch)writer.add_scalar('data/trainacc', epoch_acc, epoch)else:writer.add_scalar('data/valloss', epoch_loss, epoch)writer.add_scalar('data/valacc', epoch_acc, epoch)print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))writer.export_scalars_to_json("./all_scalars.json")writer.close()return modelif __name__ == '__main__':data_transforms = {'train': transforms.Compose([transforms.Scale(64),transforms.RandomSizedCrop(48),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])]),'val': transforms.Compose([transforms.Scale(64),transforms.CenterCrop(48),transforms.ToTensor(),transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])]),}data_dir = './Emotion_Recognition_File/train_val_data/' # 数据集所在的位置image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),data_transforms[x]) for x in ['train', 'val']}dataloders = {x: torch.utils.data.DataLoader(image_datasets[x],batch_size=64,shuffle=True if x=="train" else False,num_workers=8) for x in ['train', 'val']}dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}use_gpu = torch.cuda.is_available()print("是否使用 GPU", use_gpu)modelclc = simpleconv3()print(modelclc)if use_gpu:modelclc = modelclc.cuda()criterion = nn.CrossEntropyLoss()optimizer_ft = optim.SGD(modelclc.parameters(), lr=0.1, momentum=0.9)exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=100, gamma=0.1)modelclc = train_model(model=modelclc,criterion=criterion,optimizer=optimizer_ft,scheduler=exp_lr_scheduler,num_epochs=10)  # 这里可以调节训练的轮次if not os.path.exists("models"):os.mkdir('models')torch.save(modelclc.state_dict(),'models/model.ckpt')

训练的过程需要注意几个参数,第一个是数据加载器(dataloders)中的 batch_size,这个代表的含义是每次送入模型训练的图片数量,这个需要根据GPU的显存来设置,显存越大,可以设置越大,这个数一般设置为 2 的整数次幂(如 4、8、16、32 等)

dataloders = {x: torch.utils.data.DataLoader(image_datasets[x],batch_size=64,shuffle=True if x=="train" else False,num_workers=8) for x in ['train', 'val']}

第二个需要注意的参数是训练函数的 num_epochs,这个参数代表的意义是,模型训练的轮次。

modelclc = train_model(model=modelclc,criterion=criterion,optimizer=optimizer_ft,scheduler=exp_lr_scheduler,num_epochs=10)  # 这里可以调节训练的轮次

模型测试

上⾯已经训练好了模型,我们接下来的⽬标,就是要⽤它来做推理,真正把模型⽤起来,下⾯我们载⼊⼀个图⽚,⽤模型进⾏测试。 结果在 results 文件夹中

import sys
import numpy as np
import cv2
import os
import dlibimport torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
import torchvision
from torchvision import datasets, models, transforms
import time
from PIL import Image
import torch.nn.functional as Fimport matplotlib.pyplot as plt
import warningswarnings.filterwarnings('ignore')PREDICTOR_PATH = "./Emotion_Recognition_File/face_detect_model/shape_predictor_68_face_landmarks.dat"
predictor = dlib.shape_predictor(PREDICTOR_PATH)
cascade_path = './Emotion_Recognition_File/face_detect_model/haarcascade_frontalface_default.xml'
cascade = cv2.CascadeClassifier(cascade_path)if not os.path.exists("results"):os.mkdir("results")def standardization(data):mu = np.mean(data, axis=0)sigma = np.std(data, axis=0)return (data - mu) / sigmadef get_landmarks(im):rects = cascade.detectMultiScale(im, 1.3, 5)x, y, w, h = rects[0]rect = dlib.rectangle(int(x), int(y), int(x + w), int(y + h))return np.matrix([[p.x, p.y] for p in predictor(im, rect).parts()])def annotate_landmarks(im, landmarks):im = im.copy()for idx, point in enumerate(landmarks):pos = (point[0, 0], point[0, 1])cv2.putText(im,str(idx),pos,fontFace=cv2.FONT_HERSHEY_SCRIPT_SIMPLEX,fontScale=0.4,color=(0, 0, 255))cv2.circle(im, pos, 3, color=(0, 255, 255))return imtestsize = 48  # 测试图大小data_transforms = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
net = simpleconv3()
net.eval()
modelpath = "./models/model.ckpt"  # 模型路径
net.load_state_dict(torch.load(modelpath, map_location=lambda storage, loc: storage))# 一次测试一个文件
img_path = "./Emotion_Recognition_File/find_face_img/"
imagepaths = os.listdir(img_path)  # 图像文件夹
for imagepath in imagepaths:im = cv2.imread(os.path.join(img_path, imagepath), 1)try:rects = cascade.detectMultiScale(im, 1.3, 5)x, y, w, h = rects[0]rect = dlib.rectangle(int(x), int(y), int(x + w), int(y + h))landmarks = np.matrix([[p.x, p.y]for p in predictor(im, rect).parts()])except:
#         print("没有检测到人脸")continue  # 没有检测到人脸xmin = 10000xmax = 0ymin = 10000ymax = 0for i in range(48, 67):x = landmarks[i, 0]y = landmarks[i, 1]if x < xmin:xmin = xif x > xmax:xmax = xif y < ymin:ymin = yif y > ymax:ymax = yroiwidth = xmax - xminroiheight = ymax - yminroi = im[ymin:ymax, xmin:xmax, 0:3]if roiwidth > roiheight:dstlen = 1.5 * roiwidthelse:dstlen = 1.5 * roiheightdiff_xlen = dstlen - roiwidthdiff_ylen = dstlen - roiheightnewx = xminnewy = yminimagerows, imagecols, channel = im.shapeif newx >= diff_xlen / 2 and newx + roiwidth + diff_xlen / 2 < imagecols:newx = newx - diff_xlen / 2elif newx < diff_xlen / 2:newx = 0else:newx = imagecols - dstlenif newy >= diff_ylen / 2 and newy + roiheight + diff_ylen / 2 < imagerows:newy = newy - diff_ylen / 2elif newy < diff_ylen / 2:newy = 0else:newy = imagerows - dstlenroi = im[int(newy):int(newy + dstlen), int(newx):int(newx + dstlen), 0:3]roi = cv2.cvtColor(roi, cv2.COLOR_BGR2RGB)roiresized = cv2.resize(roi,(testsize, testsize)).astype(np.float32) / 255.0imgblob = data_transforms(roiresized).unsqueeze(0)imgblob.requires_grad = Falseimgblob = Variable(imgblob)torch.no_grad()predict = F.softmax(net(imgblob))print(predict)index = np.argmax(predict.detach().numpy())im_show = cv2.imread(os.path.join(img_path, imagepath), 1)im_h, im_w, im_c = im_show.shapepos_x = int(newx + dstlen)pos_y = int(newy + dstlen)font = cv2.FONT_HERSHEY_SIMPLEXcv2.rectangle(im_show, (int(newx), int(newy)),(int(newx + dstlen), int(newy + dstlen)), (0, 255, 255), 2)if index == 0:cv2.putText(im_show, 'none', (pos_x, pos_y), font, 1.5, (0, 0, 255), 2)if index == 1:cv2.putText(im_show, 'pout', (pos_x, pos_y), font, 1.5, (0, 0, 255), 2)if index == 2:cv2.putText(im_show, 'smile', (pos_x, pos_y), font, 1.5, (0, 0, 255), 2)if index == 3:cv2.putText(im_show, 'open', (pos_x, pos_y), font, 1.5, (0, 0, 255), 2)
#     cv2.namedWindow('result', 0)
#     cv2.imshow('result', im_show)cv2.imwrite(os.path.join('results', imagepath), im_show)
#     print(os.path.join('results', imagepath))plt.imshow(im_show[:, :, ::-1])  # 这里需要交换通道,因为 matplotlib 保存图片的通道顺序是 RGB,而在 OpenCV 中是 BGRplt.show()
#     cv2.waitKey(0)
# cv2.destroyAllWindows()

再次说明:0 代表 none、 1 代表pouting、2 代表 smile、3 代表 openmouth

上面展示的图片上方会有一个输出,如:tensor([[8.1330e-03, 6.7033e-04, 9.8497e-01, 6.2311e-03]])

这个代表的含义是,该图片在这个模型预测下,是该类别的可能性,比如上面这个例子 9.8497e-01 是四个值最大的,它的索引是 2(从 0 开始算),所以预测该图片为 smile

天池训练营——基于人脸的常见表情识别(3)——模型搭建、训练与测试相关推荐

  1. 基于人脸的常见表情识别(3)——模型搭建、训练与测试

    基于人脸的常见表情识别(3)--模型搭建.训练与测试 模型搭建与训练 1. 数据接口准备 2. 模型定义 3. 模型训练 模型测试 本 Task 是『基于人脸的常见表情识别』训练营的第 3 课,如果你 ...

  2. 基于人脸的常见表情识别(1)——深度学习基础知识

    基于人脸的常见表情识别(1)--深度学习基础知识 神经网络 1. 感知机 2. 多层感知机与反向传播 卷积神经网络 1. 全连接神经网络的2大缺陷 2. 卷积神经网络的崛起 卷积神经网络的基本网络层 ...

  3. 基于人脸的常见表情识别(2)——数据获取与整理

    基于人脸的常见表情识别(2)--数据获取与整理 项目背景 数据获取 2.1 数据爬取 数据整理 3.1 图片格式统一 3.2 数据清洗 3.3 提取嘴唇区域 该 Task 就是本训练营的实战部分了,这 ...

  4. 基于人脸的常见表情识别 Task03笔记

    基于人脸的常见表情识别--模型搭建.训练与测试 模型搭建与训练 得到了数据之后,接下来咱们使用 PyTorch 这个框架来进行模型的训练.整个训练流程包括数据接口准备.模型定义.结果保存与分析. 1. ...

  5. 深度学习基础:基于人脸的常见表情识别(2)—数据获取与整理

    项目背景 数据获取 2.1 数据爬取 数据整理 3.1 图片格式统一 3.2 数据清洗 3.3 提取嘴唇区域 该 Task 就是本训练营的实战部分了,这一部分我们会讲解如何获取数据集,并对数据集进行整 ...

  6. 基于人脸的常见表情识别——模型搭建、训练与测试¶

    整个训练流程包括数据接口准备.模型定义.结果保存与分析. 数据接口一般使用torchvision.Dataset定义数据的读取.torch.utils.data.Dataloader定义数据的加载. ...

  7. Opencv基于改进VGG19的表情识别系统(源码&Fer2013&教程)

    1.研究背景 在深度学习中,传统的卷积神经网络对面部表情特征的提取不充分以及计算参数量较大的问题,导致分类准确率偏低.因此,提出了一种基于改进的VGG19网络的人脸表情识别算法.首先,对数据进行增强如 ...

  8. [实训题目EmoProfo]基于深度学习的表情识别服务搭建(一)

    基于深度学习的表情识别服务搭建(一) 文章目录 基于深度学习的表情识别服务搭建(一) 背景 识别服务设计 实现方式的选择 dlib性能验证 功能实现 小结 背景 之前我完成了终端和服务端之间交流的全部 ...

  9. python+opencv+dlib实现人脸检测与表情识别

    python+opencv+dlib实现人脸检测与表情识别 一,dlib简单介绍:Dlib包含广泛的机器学习算法.所有的设计都是高度模块化的,快速执行,并且通过一个干净而现代的C ++ API,使用起 ...

最新文章

  1. 【习题3】数字和数学计算【第4天】
  2. 利用Spring-Boot解析Excel、用Java分析Excel、告别手动输入用程序读取Excel
  3. HDU 1394 Minimum Inversion Number(线段树的单点更新)
  4. JAVA拾遗--关于SPI机制
  5. 笔记-信息化与系统集成技术-信息系统的特点
  6. 关于静态局部全局变量
  7. 在Python中升级灰度图像
  8. 百度声明:从未答应屏蔽三鹿负面
  9. 自学it18大数据笔记-第二阶段hadoop-day11——会持续更新……
  10. markdown方式测试图片2
  11. 将房子卖了五百万,存在银行,靠利息能过日子吗?
  12. 借助邮件网关,为企业实现节源开流
  13. Java后台开发一:环境搭建
  14. VMware 配置虚拟机固定IP指南
  15. mysql的软件卸载不了,控制台也卸载不了的问题
  16. 【机器学习】生成模型与判别模型详解
  17. 苹果系统忘记登录密码
  18. python安装卸载及查看python版本/第三方包版本
  19. mac-怎么查询mac苹果电脑的保修期
  20. Web-网上在线支付

热门文章

  1. 基于QT开发的开源局域网联机UNO卡牌游戏报告(附github仓库地址)
  2. NOIp2016 题解
  3. TCP原理,Socket与网络编程入门
  4. RuntimeWarning: Glyph 19979 missing from current font.
  5. 对Request.parameter中参数进行添加或修改
  6. elasticsearch搭建遇到的问题整理合集1
  7. 来自GDPR的警示:是时候关注您的客户数据了
  8. 一线OA品牌之心通达OA综合指标
  9. 中职教资证计算机应用,中职计算机教师资格证只能教中职学校的吗
  10. 十万条评论告诉你,给《流浪地球》评1星的都是什么心态? | Alfred数据室