深度卷积神经网络模型由于其层数多,需要训练的参数多,导致从零开始训练很深的卷积神经网络非常困难,同时训练很深的网络通畅需要大量的数据集,这对于设备算力不够的使用者非常不友好。幸运的是Pytorch已经提供了使用ImageNet数据集与与训练好的流行的深度学习网络,我们可以针对自己的需求,对与训练好的网络进行微调,从而快速完成自己的任务。

下面将会基于与训练好的VGG16网络,对其网络结构进行微调,使用自己的分类数据集,训练一个图像分类器。使用的数据集来自kaggle数据集中的10类猴子数据库,数据地址为https://www.kaggle.com/slothkong/10-money-species。在该数据集中包含训练数据集合验证数据集,其中训练数据集中每类约140张RGB图像,验证数据集中每类30张图像。针对该数据集使用VGG16的卷积层和池化层的预训练好的权重,提取数据特征,然后定义新的全连接层,用于图像的分类。

首先导入所需要的库和模块。

# import numpy as np
# import pandas as pd
# from sklearn.metrics import accuracy_score,confusion_matrix,classification_report
# import matplotlib.pyplot as plt
# import seaborn as sns
# import hiddenlayer as hl
# import torch
import torch.nn as nn
# from torch.optim import SGD,Adam
# import torch.utils.data as Data
from torchvision import models
# from torchvision import transforms
# from torchvision.datasets import ImageFolder

对于已经训练好的VGG16网络,需要首先导入网络。

vgg16=models.vgg16(pretrained=True)
vgg=vgg16.features
# for param in vgg.parameters():
#     param.requires_grad_(False)

在上面的程序中,使用models.vgg16(pretrained=True)导入网络,其中参数pretrained=True表示导入的网络是使用ImageNet数据集预训练好的网络(如果第一次使用该程序,需要一定时间从网络上下载模型)。在得到的VGG16网络中,使用vgg16.features获取VGG16网络的特征提取模块,即前面的卷积池化层,不包括全连接层。为了提升网络的训练速度,只是用VGG16提取图像的特征,需要将VGG16的特征提取层参数冻结,不更新其权重,通过for循环和param.requires_grad_(False)即可。

VGG特征提取层预处理结束后,可在VGG16特征提取层之后添加新的全连接层,用于图像分类,程序定义网络结构如下:

class MyVggModel(nn.Module):def __init__(self):super(MyVggModel,self).__init__()self.vgg=vggself.classifier=nn.Sequential(nn.Linear(25088,512),nn.ReLU(),nn.Dropout(p=0.5),nn.Linear(512,256),nn.ReLU(),nn.Dropout(p=0.5),nn.Linear(256,10),nn.Softmax(dim=1))def forward(self,x):x=self.vgg(x)x=x.view(x.size(0),-1)output=self.classifier(x)return output

在上面的程序中,定义了一个卷积神经网络类MyVggModel,在该网络中,包含两个大的结构,一个是self.vgg,使用预训练好的VGG16的特征提取并且其参数的权重已经冻结;另一个是self.classifier,由三个全连接层组成,并且神经元的个数分别为512,256,和10.在全连接层中使用ReLU函数作为激活函数,并通过nn.Dropout()层防止过拟合。在网络的前向传播函数中,有self.classifier得到输出。

可以通过下面的程序查看网络的详细结构。

Myvggc=MyVggModel()
print(Myvggc)

输出结果为:

在定义好卷积神经网络Myvggc后,下面需要对数据集进行准备。首先定义训练集和验证集的预处

理过程,程序如下:

train_data_transforms=transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
val_data_transforms=transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

上面的程序定义了对训练集的预处理过程train_data_transforms,从而对训练集进行数据增强,对验证集的预处理过程val_data_trabsforms与train_data_transforms会有一些差异,其不需要对图像进行随机翻转与随机裁剪操作。在对读入的单张图像进行预处理时,通过RandomResizedCrop()对图像进行随机裁剪,使用RandomHorizontalFlip()将图像依概率p=0.5水平翻转,通过Resize()充值图像分辨率,通过CenterCrop()将图像按照给定的尺寸从中心裁剪,通过Normalize()将他徐昂的像素值进行标准化处理等。

因为每类图像都分别保存在一个单独的文件夹中,所以可以使用ImageFolder()函数从文件中读取训练集和验证集,数据读取的程序如下:

train_data_dir="data/chap6/10-monkey-species/training"
train_data=ImageFolder(train_data_dir,transform=train_data_transforms)
train_data_loader=Data.DataLoader(train_data,batch_size=32,shuffle=True,num_workers=2)
val_data_dir="data/chap6/10-monkey-species/validation"
val_data=ImageFolder(val_data_dir,transform=val_data_transforms)
val_data_loader=Data.DataLoader(val_data,batch_size=32,shuffle=True,num_workers=2)
print("训练集样本数:",len(train_data.targets))
print("验证集样本数:",len(val_data.targets))

输出结果如下:

 上面的程序在读取图像后,分别使用Data.DataLoader()函数,将训练集和测试集处理为数据加载起train_data_loader和val_data_loader,并且每个batch包含32张图像。从输出结果发现,训练集有1097个样本,验证集有272个样本。下面我们获取训练集的一个batch图像,然后将获取的32张图像进行可视化,观察数据中图像的内容。

    for step,(b_x,b_y) in enumerate(train_data_loader):if step>0:breakmean=np.array([0.485,0.456,0.406])std=np.array([0.229,0.224,0.225])plt.figure(figsize=(12,6))for ii in np.arange(len(b_y)):plt.subplot(4,8,ii+1)image=b_x[ii,:,:,:].numpy().transpose((1,2,0))image=std*image+meanimage=np.clip(image,0,1)plt.imshow(image)plt.title(b_y[ii].data.numpy())plt.axis("off")plt.subplots_adjust(hspace=0.3)plt.show()

上面的程序在获取了一个batch图像后,再可视化前,需要将图像每个通道的像素值乘以对应的标准差并加上对应的均值。最后的图像如下:

卷积神经网络(CNN)(下)相关推荐

  1. 文本分类(下) | 卷积神经网络(CNN)在文本分类上的应用

    正文共3758张图,4张图,预计阅读时间18分钟. 1.简介 原先写过两篇文章,分别介绍了传统机器学习方法在文本分类上的应用以及CNN原理,然后本篇文章结合两篇论文展开,主要讲述下CNN在文本分类上的 ...

  2. 卷积神经网络(CNN,ConvNet)

    卷积神经网络(CNN,ConvNet) 卷积神经网络(CNN,有时被称为 ConvNet)是很吸引人的.在短时间内,变成了一种颠覆性的技术,打破了从文本.视频到语音等多个领域所有最先进的算法,远远超出 ...

  3. 卷积神经网络(CNN)的简单实现(MNIST)

    卷积神经网络(CNN)的基础介绍见http://blog.csdn.net/fengbingchun/article/details/50529500,这里主要以代码实现为主. CNN是一个多层的神经 ...

  4. 一文看懂卷积神经网络-CNN(基本原理+独特价值+实际应用)

    http://blog.itpub.net/29829936/viewspace-2648775/ 2019-06-25 21:31:18 卷积神经网络 – CNN 最擅长的就是图片的处理.它受到人类 ...

  5. 卷积神经网络(CNN)前向传播算法

    在卷积神经网络(CNN)模型结构中,我们对CNN的模型结构做了总结,这里我们就在CNN的模型基础上,看看CNN的前向传播算法是什么样子的.重点会和传统的DNN比较讨论. 1. 回顾CNN的结构 在上一 ...

  6. 深度学习(DL)与卷积神经网络(CNN)学习笔记随笔-03-基于Python的LeNet之LR

    原地址可以查看更多信息 本文主要参考于:Classifying MNIST digits using Logistic Regression  python源代码(GitHub下载 CSDN免费下载) ...

  7. 卷积神经网络CNN总结

    从神经网络到卷积神经网络(CNN) 我们知道神经网络的结构是这样的: 那卷积神经网络跟它是什么关系呢? 其实卷积神经网络依旧是层级网络,只是层的功能和形式做了变化,可以说是传统神经网络的一个改进.比如 ...

  8. 深度学习方法(五):卷积神经网络CNN经典模型整理Lenet,Alexnet,Googlenet,VGG,Deep Residual Learning

    欢迎转载,转载请注明:本文出自Bin的专栏blog.csdn.net/xbinworld.  技术交流QQ群:433250724,欢迎对算法.技术感兴趣的同学加入. 关于卷积神经网络CNN,网络和文献 ...

  9. python卷积神经网络cnn的训练算法_【深度学习系列】卷积神经网络CNN原理详解(一)——基本原理...

    上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...

  10. keras中文文档_【DL项目实战02】图像识别分类——Keras框架+卷积神经网络CNN(使用VGGNet)

    版权声明:小博主水平有限,希望大家多多指导. 目录: [使用传统DNN] BG大龍:[DL项目实战02]图像分类--Keras框架+使用传统神经网络DNN​zhuanlan.zhihu.com [使用 ...

最新文章

  1. 关于C语言中的预处理器的简单笔记
  2. cakephp 安装mysql_CakePHP的安装的简单方法
  3. 四边形可以分为几类_学习知识:四边形有几种类型
  4. html5--3.16 button元素
  5. Python爬虫 ---(1)爬虫基础知识
  6. python __setattr__
  7. Ambari--主机管理
  8. java 自定义形状按钮_制作自定义背景Button按钮、自定义形状Button的全攻略
  9. MEF: MSDN 杂志上的文章(6) 一个部件可以有多个导出 !!!
  10. 拓端tecdat|R语言空间可视化:绘制英国脱欧投票地图
  11. expdp导出表结构_Oracle用exp导出部分表和expdp
  12. python3实现校园网认证客户端
  13. 芒果移动广告优化平台
  14. 前端pdf禁止下载功能
  15. 众为兴SCARA四轴机械臂编程(二)——基于硬接线替代Modbus通讯
  16. 一篇工作调动时的旧文
  17. JavaWeb开发 —— Maven
  18. 为什么重写HashCode和Equals
  19. 燕十八ajax笔记,燕十八商城笔记资料.doc
  20. Cocos Creater 敏捷开发之插件开发

热门文章

  1. 如何使Android录音实现内录功能
  2. JAVA Json-Schema接口校验利器
  3. 最大比例(辗转相除)
  4. 2022年最新宠物十二生肖区块系统源码
  5. [译] Python 的打包现状(写于 2019 年)
  6. Python爬虫实战(四) :下载煎蛋网所有妹子照片
  7. 人工智能-三连子游戏设计和实现
  8. 舔狗舔到最后一无所有 (线性dp)
  9. 力创eda 画布大小_立创EDA快速上手经验指南
  10. 部署 Elastic