基于PaddlePaddle构建ResNet18残差神经网络的食物图片分类问题

Introduction

本项目是在李宏毅机器学习课程的作业3进行的工作,任务是手动搭建一个CNN模型进行食物图片分类(11种)。

项目要求

  • 请使用 CNN 搭建 model
  • 不能使用额外 dataset
  • 禁止使用 pre-trained model(只能自己手写CNN)
  • 请不要上网寻找 label

项目传送门

Abstract

本文的主要内容如下:

  • 1 PaddlePaddle的深度学习万能公式介绍:该万能公式其实是进行项目研究的常规性方法步骤,本项目就是按照该步骤进行的,能够很好的作为项目开展的DEMO。
  • 2 手动搭建基于标准ResNet18残差神经网络模型并改变其内部参数设置,使得提取的特征增多。最终在训练集和验证集上表现挺好,但模型存在过拟合现象,还需要进一步学习得到更好的模型。
  • 3 在测试集上全部进行了预测,由于没有标签,批量展示了部分预测,预测准确率挺高的。
  • 4 最后给出了简单的残差神经网络搭建的方法,通过调整模块内部Residual的数量和配置实现不同的ResNet网络。
    想看预测结果直接拉到快文章末尾处……

目录

  • 深度学习万能公式——PaddlePaddle
  • 1 问题定义
  • 2 数据准备
  • 3 模型选择和开发
  • 4 模型训练
  • 5 模型评估和测试
  • 6 模型部署
  • 7 残差神经网络搭建的方法
  • 8参考文献&文章&代码
  • 作者介绍
  • 附录

深度学习万能公式——PaddlePaddle

  • 1 问题定义
  • 2 数据准备
  • 3 模型选择和开发
  • 4 模型训练和调优
  • 5 模型评估和测试
  • 6 部署上线

1 问题定义

根据项目要求,搭建一个CNN模型实现11类食物图片的分类,属于分类问题。

2 数据准备

数据格式
下载 zip 档后解压缩会有三个资料夹,分别为training、validation 以及 testing
training 以及 validation 中的照片名称格式为 [类别]_[编号].jpg,例如 3_100.jpg 即为类别 3 的照片(编号不重要)

2.1 解压缩数据集

!unzip -d work data/data57075/food-11.zip # 解压缩food-11数据集
  inflating: work/food-11/training/6_12.jpg

2.2 数据标注

我们先看一下解压缩后的数据集长成什么样子。

.
├── training:
│   [类别]_[编号].jpg
│     .
│     .
│     .
│
├── validation:
│   [类别]_[编号].jpg
│     .
│     .
│     .
│
├── testing:
│   [编号].jpg
│     .
│     .
│     .
│

数据集共有三个资料夹,分别为training、validation 以及 testing。这三个文件夹里直接存放着照片,照片名称格式为 [类别]_[编号].jpg,例如 3_100.jpg 即为类别 3 的照片(编号不重要),每个文件夹里都有11类。对这些样本进行一个标注处理,最终生成train.txt/valid.txt/test.txt三个数据标注文件。

import io
import os
from PIL import Image
from config import get  # 配置函数文件包括了多种参数的设置,详细代码见附录
# 数据集根目录
DATA_ROOT = 'work/food-11'# 标注生成函数
def generate_annotation(mode):# 建立标注文件with open('{}/{}.txt'.format(DATA_ROOT, mode), 'w') as f:# 对应每个用途的数据文件夹,train/valid/testtrain_dir = '{}/{}'.format(DATA_ROOT, mode)# train_dir = work/food-11/training# 图像样本所在的路径image_path = '{}'.format(train_dir) # image_path = #'work/food-11/training'# 遍历所有图像for image in os.listdir(image_path):# 图像完整路径和名称image_file = '{}/{}'.format(image_path, image)for k in image:if k=='_':   # 如果图片名称有下划线‘—’stop = image.index(k)   # 下划线所在索引label_index = image[0:stop] # image的索引从0——下划线前的数字为为图片的标签label_index =int(label_index)            try:# 验证图片格式是否okwith open(image_file, 'rb') as f_img:image = Image.open(io.BytesIO(f_img.read()))image.load()if image.mode == 'RGB':f.write('{}\t{}\n'.format(image_file, label_index))except:continuegenerate_annotation('training')  # 生成训练集标注文件
generate_annotation('validation')  # 生成验证集标注文件

训练集和验证集标注文件


由于测试数据集没有标签,所以只生成其数据集的路径文件

# 数据集根目录
DATA_ROOT = 'work/food-11'def generate_annotation(mode):with open('{}/{}.txt'.format(DATA_ROOT, mode), 'w') as f:# 对应每个用途的数据文件夹,train/valid/testtrain_dir = '{}/{}'.format(DATA_ROOT, mode)# train_dir = work/food-11/training# 图像样本所在的路径image_path = '{}'.format(train_dir) # image_path = #'work/food-11/training'# 遍历所有图像for image in os.listdir(image_path):# 图像完整路径和名称image_file = '{}/{}'.format(image_path, image)try:# 验证图片格式是否okwith open(image_file, 'rb') as f_img:image = Image.open(io.BytesIO(f_img.read()))image.load()if image.mode == 'RGB':f.write('{}\n'.format(image_file))except:continue# 生成测试集
generate_annotation('testing')

测试集路径文件

2.3 数据集定义

接下来我们使用标注好的文件进行数据集类的定义,方便后续模型训练使用。

2.3.1 导入相关库

import paddle
import numpy as np
from config import get
print(paddle.__version__)
2.0.1

我们数据集的代码实现是在dataset.py中。

# data.py 文件包括了图片数据的预处理,详细代码见附录
from dataset import ZodiacDataset

2.3.2 实例化数据集类

根据所使用的数据集需求实例化数据集类,并查看总样本量。

training_dataset = ZodiacDataset(mode='training')
validation_dataset = ZodiacDataset(mode='validation')
print('训练数据集:{}张; 验证数据集:{}张'.format(len(training_dataset),len(validation_dataset)))
训练数据集:9866张; 验证数据集:3430张

2.3.3 数据集查看

print('图片:')
print(type(training_dataset[1][0]))
print(training_dataset[1][0])
print('标签:')
print(type(training_dataset[1][1]))
print(training_dataset[1][1])
图片:
<class 'paddle.VarBase'>
Tensor(shape=[3, 224, 224], dtype=float32, place=CPUPlace, stop_gradient=True,[[[-2.11790395, -2.11790395, -2.11790395, ..., -2.11790395, -2.11790395, -2.11790395],[-2.11790395, -2.11790395, -2.11790395, ..., -2.11790395, -2.11790395, -2.11790395],[-2.11790395, -2.11790395, -2.11790395, ..., -2.11790395, -2.11790395, -2.11790395],...,[-2.11790395, -2.11790395, -2.11790395, ..., -2.11790395, -2.11790395, -2.11790395],[-2.11790395, -2.11790395, -2.11790395, ..., -2.11790395, -2.11790395, -2.11790395],[-2.11790395, -2.11790395, -2.11790395, ..., -2.11790395, -2.11790395, -2.11790395]],[[-2.03571415, -2.03571415, -2.03571415, ..., -2.03571415, -2.03571415, -2.03571415],[-2.03571415, -2.03571415, -2.03571415, ..., -2.03571415, -2.03571415, -2.03571415],[-2.03571415, -2.03571415, -2.03571415, ..., -2.03571415, -2.03571415, -2.03571415],...,[-2.03571415, -2.03571415, -2.03571415, ..., -2.03571415, -2.03571415, -2.03571415],[-2.03571415, -2.03571415, -2.03571415, ..., -2.03571415, -2.03571415, -2.03571415],[-2.03571415, -2.03571415, -2.03571415, ..., -2.03571415, -2.03571415, -2.03571415]],[[-1.80444443, -1.80444443, -1.80444443, ..., -1.80444443, -1.80444443, -1.80444443],[-1.80444443, -1.80444443, -1.80444443, ..., -1.80444443, -1.80444443, -1.80444443],[-1.80444443, -1.80444443, -1.80444443, ..., -1.80444443, -1.80444443, -1.80444443],...,[-1.80444443, -1.80444443, -1.80444443, ..., -1.80444443, -1.80444443, -1.80444443],[-1.80444443, -1.80444443, -1.80444443, ..., -1.80444443, -1.80444443, -1.80444443],[-1.80444443, -1.80444443, -1.80444443, ..., -1.80444443, -1.80444443, -1.80444443]]])
标签:
<class 'numpy.ndarray'>
1

3 模型选择和开发

根据题目要求使用 CNN 搭建 model并且禁止使用 pre-trained model(只能自己手写CNN)。值得一提的是,模型组网一般共有三组方法,以PaddlePaddle框架为例:

  • (1)Sequential 组网
    顺序容器。子Layer将按构造函数参数的顺序添加到此容器中。传递给构造函数的参数可以Layers或可迭代的name Layer元组。
  • (2)SubClass 组网
    针对一些比较复杂的网络结构,就可以使用Layer子类定义的方式来进行模型代码编写,在__init__构造函数中进行组网Layer的声明,在forward中使用声明的Layer变量进行前向计算。
  • (3)飞桨框架内置模型
    飞桨框架内置的模型,路径为 paddle.vision.models。那么根据要求,只能使用前两种方法来搭建模型

3.1 网络构建

由与本次分类的类别较多,训练的数据为分辨率较大的彩色图片,因此选择SubClass 组网方法来搭建Resnet18网络来完成分类任务。

3.1.1深度残差网络介绍

2015 年,微软亚洲研究院何恺明等人发表了基于 Skip Connection 的深度残差网络(Residual Neural Network,简称 ResNet)算法,并提出了18层、34 层、50层、101层、152层的 ResNet-18、ResNet-34、ResNet-50、ResNet-101 和 ResNet-152 等模型,如表1所示,甚至成功训练出层数达到 1202 层的极深层神经网络。ResNet 在 ILSVRC 2015挑战赛ImageNet数据集上的分类、检测等任务上面均获得了最好性能,ResNet 论文至今已经获得超 25000的引用量,可见 ResNet 在人工智能行业的影响力。ResNet 通过在卷积层的输入和输出之间添加 Skip Connection 实现层数回退机制,如下图1所示,输入

基于PaddlePaddle构建ResNet18残差神经网络的食物图片分类问题相关推荐

  1. 基于双向 lstm 和残差神经网络的 rna 二级结构预测方法

    目录 结果 结论 背景 结果 学习结果和演示 预测结果和比较 讨论 结论 材料和方法 数据收集和处理 摘要 背景: 研究表明,rna 二级结构是由配对碱基构成的平面结构,在基本生命活动和复杂疾病中发挥 ...

  2. Paddle实现食物图片分类

    Paddle实现食物图片分类 食物图片分类 项目描述 数据集介绍 思路方法 读取文件 卷积神经网络示意图 作者简介 项目说明,本项目是李宏毅老师在飞桨授权课程的作业解析 课程 传送门 该项目AiStu ...

  3. 基于卷积神经网络CNN的图片分类实现——附代码

    目录 摘要: 1.卷积神经网络介绍: 2.卷积神经网络(CNN)构建与训练: 2.1 CNN的输入图像 2.2 构建CNN网络 2.3 训练CNN网络 3.卷积神经网络(CNN)的实际分类测试: 4. ...

  4. 李宏毅2020机器学习作业3-CNN:食物图片分类

    更多作业,请查看 李宏毅2020机器学习资料汇总 文章目录 0 作业链接 1 作业说明 环境 任务说明 任务要求 数据说明 作业概述 2 原始代码 导入需要的库 读取图片 定义Dataset 定义模型 ...

  5. 搭建神经网络实现简单图片分类

    数据来源于吴恩达L2HW3(SIGNS 数据集),训练集包含1800张64*64像素的彩色图片,图片内容为手势,表示从0到5的数字,所要做的是搭建较深的神经网络,以实现图片分类.测试集包含120张图片 ...

  6. 基于paddlepaddle构建resnet神经网络的蝴蝶分类

    一.序言 使用百度飞浆提供的paddle框架实现蝴蝶分类,环境:paddle 2.0.2,opencv 4.5.4.58,pycharm编译器. 目录结构: Butterfly20里有20个文件夹,分 ...

  7. 基于PaddlePaddle框架的BP神经网络的鲍鱼年龄的预测

    # 经典的线性回归模型主要用来预测一些存在着线性关系的数据集.回归模型可以理解为:存在一个点集,用一条曲线去拟合它分布的过程.如果拟合曲线是一条直线,则称为线性回归.如果是一条二次曲线,则被称为二次回 ...

  8. 【深度学习】利用tensorflow2.0卷积神经网络进行卫星图片分类实例操作详解

    本文的应用场景是对于卫星图片数据的分类,图片总共1400张,分为airplane和lake两类,也就是一个二分类的问题,所有的图片已经分别放置在2_class文件夹下的两个子文件夹中.下面将从这个实例 ...

  9. 使用Pytorch搭建CNN模型完成食物图片分类(李宏毅视频课2020作业3,附超详细代码讲解)

    文章目录 0 前言 1 任务描述 1.1 数据描述 1.2 作业提交 1.3 数据下载 1.3.1 完整数据集 1.3.2 部分数据集 2 过程讲解 2.1 读取数据 2.2 数据预处理 2.3 模型 ...

最新文章

  1. 你想知道的关于JavaScript作用域的一切(译)
  2. AMD 5XXX 系列显卡的 peak bandwidth计算
  3. SpringBoot 封装返回类以及session 添加获取
  4. Linux能适应不同的指令集,(转)linux常用指令集
  5. 对JavaScript解析JSON格式数据的理解
  6. 爬虫实例4 爬取网络小说
  7. VBA操作WORD(二):替换字符(含空格、全角字符、换行符等)
  8. 浅谈Java垃圾回收
  9. 2019年管理类MBA/MEM联考英语小作文范文
  10. Hexo NexT 评论系统 Valine 的使用
  11. 系统管理中的三大利刃
  12. Flutter集成个推推送-安卓原生篇
  13. Win10任务栏图标无法右键/取消固定
  14. 题目3:一个整数,它加上100后是一个完全平方数,再加上268又是一个完全平方数,请问该数是多少?
  15. java查询学号数据库_数据库SQL查询语句练习题 PDF 下载
  16. Web服务器的配置与应用
  17. Python判断奇偶的方法
  18. java 朗读_java下载安装 用Java实现简单的语音朗读
  19. 景区手绘地图的绘制流程
  20. 海康录像机RTSP取流路径

热门文章

  1. 群晖NAS与阿里云盘同步的方法
  2. 2021年中国智能仓储行业情况分析:电商快速发展,促进智能物流及仓储快速发展[图]
  3. 绝缘监测仪原理及各原理特点
  4. 【HDU】4411 Arrest 费用流
  5. 黑科技还是流氓应用?有些App,通知关不掉!
  6. Python numpy 开N次方
  7. 基于非合作博弈的风-光-氢微电网容量优化配置(Matlab代码实现)
  8. php管理系统申请著作权,php管理系统申请著作权-我有一套PHP源码系统,想修改网站底部版权信息,可......
  9. 回声消除(Echo Cancellation)理解
  10. 设备智能化开发,软硬件技术如何选型及上位机开发的注意事项