「@Author:Runsen」

上次微调了Alexnet,这次微调ResNet实现男人和女人图像分类。

ResNet是 Residual Networks 的缩写,是一种经典的神经网络,用作许多计算机视觉任务。

  • ResNet论文参见此处:

https://arxiv.org/abs/1512.03385

该模型是 2015 年 ImageNet 挑战赛的获胜者。ResNet 的根本性突破是它使我们能够成功训练 150 层以上的极深神经网络。

下面是resnet18的整个网络结构:

Resnet 18 是在 ImageNet 数据集上预训练的图像分类模型。

这次使用Resnet 18 实现分类性别数据集,

该性别分类数据集共有58,658 张图像。(train:47,009 / val:11,649)

female

male
  • Dataset: Kaggle Gender Classification Dataset

加载数据集

设置图像目录路径并初始化 PyTorch 数据加载器。和之前一样的模板套路

import torch
import torch.nn as nn
import torch.optim as optimimport torchvision
from torchvision import datasets, models, transformsimport numpy as np
import matplotlib.pyplot as pltimport time
import osdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # device objecttransforms_train = transforms.Compose([transforms.Resize((224, 224)),transforms.RandomHorizontalFlip(), # data augmentationtransforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # normalization
])transforms_val = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])data_dir = './gender_classification_dataset'
train_datasets = datasets.ImageFolder(os.path.join(data_dir, 'Training'), transforms_train)
val_datasets = datasets.ImageFolder(os.path.join(data_dir, 'Validation'), transforms_val)train_dataloader = torch.utils.data.DataLoader(train_datasets, batch_size=16, shuffle=True, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(val_datasets, batch_size=16, shuffle=True, num_workers=4)print('Train dataset size:', len(train_datasets))
print('Validation dataset size:', len(val_datasets))class_names = train_datasets.classes
print('Class names:', class_names)

plt.rcParams['figure.figsize'] = [12, 8]
plt.rcParams['figure.dpi'] = 60
plt.rcParams.update({'font.size': 20})def imshow(input, title):# torch.Tensor => numpyinput = input.numpy().transpose((1, 2, 0))# undo image normalizationmean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])input = std * input + meaninput = np.clip(input, 0, 1)# display imagesplt.imshow(input)plt.title(title)plt.show()# load a batch of train image
iterator = iter(train_dataloader)# visualize a batch of train image
inputs, classes = next(iterator)
out = torchvision.utils.make_grid(inputs[:4])
imshow(out, title=[class_names[x] for x in classes[:4]])

定义模型

我们使用迁移学习方法,只需要修改最后的输出即可。

model = models.resnet18(pretrained=True)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 2) # binary classification (num_of_class == 2)
model = model.to(device)criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

训练阶段

由于ResNet18网络非常复杂,深,这里只训练num_epochs = 3

num_epochs = 3
start_time = time.time()for epoch in range(num_epochs):""" Training  """model.train()running_loss = 0.running_corrects = 0# load a batch data of imagesfor i, (inputs, labels) in enumerate(train_dataloader):inputs = inputs.to(device)labels = labels.to(device)optimizer.zero_grad()outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)# get loss value and update the network weightsloss.backward()optimizer.step()running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / len(train_datasets)epoch_acc = running_corrects / len(train_datasets) * 100.print('[Train #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time))""" Validation"""model.eval()with torch.no_grad():running_loss = 0.running_corrects = 0for inputs, labels in val_dataloader:inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / len(val_datasets)epoch_acc = running_corrects / len(val_datasets) * 100.print('[Validation #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time))

「保存训练好的模型文件」

save_path = 'face_gender_classification_transfer_learning_with_ResNet18.pth'
torch.save(model.state_dict(), save_path)

「训练好的模型文件加载」

model = models.resnet18(pretrained=True)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 2)
model.load_state_dict(torch.load(save_path))
model.to(device)model.eval()
start_time = time.time()with torch.no_grad():running_loss = 0.running_corrects = 0for i, (inputs, labels) in enumerate(val_dataloader):inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)if i == 0:print('[Prediction Result Examples]')images = torchvision.utils.make_grid(inputs[:4])imshow(images.cpu(), title=[class_names[x] for x in labels[:4]])images = torchvision.utils.make_grid(inputs[4:8])imshow(images.cpu(), title=[class_names[x] for x in labels[4:8]])epoch_loss = running_loss / len(val_datasets)epoch_acc = running_corrects / len(val_datasets) * 100.print('[Validation #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time))

在最后的测试结果中,ACC达到了97,但是模型太复杂,运行太慢了,在项目中往往不可取。


往期精彩回顾适合初学者入门人工智能的路线及资料下载机器学习及深度学习笔记等资料打印机器学习在线手册深度学习笔记专辑《统计学习方法》的代码复现专辑
AI基础下载机器学习的数学基础专辑黄海广老师《机器学习课程》课件合集
本站qq群851320808,加入微信群请扫码:

【小白学习PyTorch教程】十四、迁移学习:微调ResNet实现男人和女人图像分类相关推荐

  1. 【小白学习PyTorch教程】四、基于nn.Module类实现线性回归模型

    「@Author:Runsen」 上次介绍了顺序模型,但是在大多数情况下,我们基本都是以类的形式实现神经网络. 大多数情况下创建一个继承自 Pytorch 中的 nn.Module 的类,这样可以使用 ...

  2. akka学习教程(十四) akka分布式实战

    akka系列文章目录 akka学习教程(十四) akka分布式实战 akka学习教程(十三) akka分布式 akka学习教程(十二) Spring与Akka的集成 akka学习教程(十一) akka ...

  3. 视觉SLAM十四讲学习笔记-第二讲-开发环境搭建

    专栏系列文章如下: 视觉SLAM十四讲学习笔记-第一讲_goldqiu的博客-CSDN博客 视觉SLAM十四讲学习笔记-第二讲-初识SLAM_goldqiu的博客-CSDN博客 ​​​​​​​ lin ...

  4. 视觉SLAM十四讲学习笔记-第三讲-旋转矩阵和Eigen库

    专栏系列文章如下: 视觉SLAM十四讲学习笔记-第一讲_goldqiu的博客-CSDN博客 视觉SLAM十四讲学习笔记-第二讲-初识SLAM_goldqiu的博客-CSDN博客 视觉SLAM十四讲学习 ...

  5. [视觉SLAM十四讲]学习笔记1-刚体运动之旋转矩阵与变换矩阵

    [视觉SLAM十四讲]学习笔记1-刚体运动之旋转矩阵与变换矩阵 1点.向量和坐标系 2 坐标系间的欧式变换 2.1 欧式变换之旋转 2.2 欧式变换之平移 3 变换矩阵与齐次坐标 4 Eigen库的简 ...

  6. pytorch与keras_Keras vs PyTorch:如何通过迁移学习区分外星人与掠食者

    pytorch与keras by Patryk Miziuła 通过PatrykMiziuła Keras vs PyTorch:如何通过迁移学习区分外星人与掠食者 (Keras vs PyTorch ...

  7. 虚拟内存——Windows核心编程学习手札之十四

    虚拟内存 --Windows核心编程学习手札之十四 系统信息 有些操作系统的值是根据主机而定的,如页面大小.分配粒度大小等,这些值不用硬编码形式,进程初始化时应检索这些值以使用.函数GetSystem ...

  8. Windows保护模式学习笔记(十四)—— 阶段测试

    Windows保护模式学习笔记(十四)-- 阶段测试 题目一 解题步骤 题目二 解题步骤 题目一 描述:给定一个线性地址,和长度,读取内容 int ReadMemory(OUT BYTE* buffe ...

  9. OpenCV学习笔记(十四):重映射:remap( )

    OpenCV学习笔记(十四):重映射:remap( ) 图像的坐标映射是通过原图像与目标图像之间建立一种映射关系,这种映射关系有两种,一种是计算原图像任意像素在映射后图像的坐标位置,另一种是计算变换后 ...

最新文章

  1. java nodelist 快速排序,【Leetcode】Sort List in java,你绝对想不到我是怎么做的^^我写完过了我自己都觉得好jian~...
  2. 使用docker搭建redis主从
  3. 第7篇:Flowable-Modeler集成之Flowable源码编译
  4. 限定概率抽奖_LOL:欧皇一次抽奖得16个永久皮肤 把老马亏得坐公交啦
  5. 东航期货穿透接口相关资料
  6. python有几种_Python常见的几种算法
  7. postman|接口测试 | pre-request script 场景应用
  8. 小程序设置发送验证码倒计时
  9. linux下dbf是什么文件,dbf是什么文件?dbf文件怎么读取
  10. 福利贴——爬取美女图片的Java爬虫小程序代码
  11. Helm — Chart介绍
  12. android 重力模拟,android的模拟器怎样仿真重力感应器
  13. 编写函数求整形数组a中存储的m个不重复的整数的第k大的整数(其中m=1,1=k=m)很简单的一个思路是酱紫的:管他辣么多干啥,上来一把排序然后直接得答案...
  14. mysql workbench pk_mysql workbench建表时PK,NN,UQ,BIN,UN,ZF,AI_MySQL - numeric
  15. 火狐浏览器截图整个网页截图 截取整个网页
  16. python中一个等于号和两个等于号_python中is与双等于号“==”的区别示例详解
  17. JDK8与JDK9新特性学习
  18. 东田纳西州立大学计算机排名,东田纳西州立大学如何
  19. 用串口(TFTP)给设备升级程序
  20. springBoot 用户头像的修改并及时显示

热门文章

  1. 【Beta阶段】M2事后分析
  2. Ajax请求Session超时的解决办法:拦截器 + 封装jquery的post方法
  3. CISSP的成长之路(七):复习信息安全管理(1)
  4. Android 通过代码改变控件的布局方式
  5. cocos2D(四)---- CCSprite
  6. c# treeView 取消选择事件
  7. GridView使用一些记录
  8. java 广播地址,根据ip地址跟子网掩码获取广播地址的java实现
  9. 临床观察性研究论文如何撰写“方法”?
  10. yabailv 运放_运放入门