前言

最近打算重新跟着官方教程学习一下caffe,顺便也自己翻译了一下官方的文档。自己也做了一些标注,都用斜体标记出来了。中间可能额外还加了自己遇到的问题或是运行结果之类的。欢迎交流指正,拒绝喷子!
官方教程的原文链接:
http://nbviewer.ipython.org/github/BVLC/caffe/blob/master/examples/pascal-multilabel-with-datalayer.ipynb

事先提醒一下,这个例子最好还是用GPU来跑,像我用笔记本CPU跑花了6个多小时,说多了都是泪。整个例子都是围绕着多标签问题进行的,概念不多,不难理解。
另外,实验中还会使用到PASCAL VOC2012数据集,请事先到他们的官网下载好数据集,不是很大,差不多2G。
官网链接:http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html

Multilabel classification on PASCAL using python data-layers

在这个教程中我们会在PASCAL VOC2012数据集上做多标签分类任务。
多标签分类是多分类的一个推广,其中每一个实例(图像)可以属于很多类。例如,一幅图片可能同时属于“海滩”类别和“度假图片“类别。另一方面,在多分类中,每幅图像只能属于一个单独的类别。
Caffe通过SigmoidCrossEntropyLoss层支持进行多标签分类,我们将使用Python的data数据层加载数据。当然,数据也可以通过HDF5或者LMDB数据层提供,但是Python的data数据层具有更大的灵活性,这也正是我们选择它的原因。

1.准备

  • 第一,确保你在编译caffe时使用了WITH_PYTHON_LAYER := 1
  • 第二,下载好PASCAL VOC 2012。官网链接:http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html
  • 第三,导入模块
import sys
import osimport numpy as np
import os.path as osp
import matplotlib.pyplot as plt
from copy import copy% matplotlib inline
plt.rcParams['figure.figsize'] = (6, 6)# 改成自己的caffe路径
caffe_root = '/home/xhb/caffe/caffe/'  # this file is expected to be in {caffe_root}/examples
sys.path.append(caffe_root + 'python')
import caffe # If you get "No module named _caffe", either you have not built pycaffe or you have the wrong path.from caffe import layers as L, params as P # Shortcuts to define the net prototxt.# 修改一下当前路径,修改到caffe/examples路径下,才能找到pycaffe文件夹
os.chdir(os.path.join(caffe_root, 'examples'))sys.path.append("pycaffe/layers") # the datalayers we will use are in this directory.
sys.path.append("pycaffe") # the tools file is in this folderimport tools #this contains some tools that we need
  • 第四,设置数据集的路径并初始化caffe
# 设置数据集的路径
pascal_root = osp.join(caffe_root, 'data/pascal/VOC2012')# 定义好PASCAL VOC2012数据集中所以的类,后面会用到
classes = np.asarray(['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'])# 确保我们已经下载好权重文件
if not os.path.isfile(caffe_root + 'models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel'):print("Downloading pre-trained CaffeNet model...")!../scripts/download_model_binary.py ../models/bvlc_reference_caffenet# 我们使用CPU模式
caffe.set_mode_cpu()
# # initialize caffe for gpu mode
# caffe.set_mode_gpu()
# caffe.set_device(0)

2.在prototxt文件中定义网络

  • 刚开始,我们用caffe.NetSpec来定义网络。注意一下我们怎么使用SigmoidCrossEntropyLoss层的。还要注意一下数据层是如何定义的。
# 卷积层 + relu单元
def conv_relu(bottom, ks, nout, stride=1, pad=0, group=1):conv = L.Convolution(bottom, kernel_size=ks, stride=stride,num_output=nout, pad=pad, group=group)return conv, L.ReLU(conv, in_place=True)# 全连接层 + relu单元
def fc_relu(bottom, nout):fc = L.InnerProduct(bottom, num_output=nout)return fc, L.ReLU(fc, in_place=True)# 最大池化
def max_pool(bottom, ks, stride=1):return L.Pooling(bottom, pool=P.Pooling.MAX, kernel_size=ks, stride=stride)# 主要网络
def caffenet_multilabel(data_layer_params, datalayer):# 设置python数据层n = caffe.NetSpec()n.data, n.label = L.Python(module='pascal_multilabel_datalayers', layer=datalayer,ntop=2, param_str=str(data_layer_params))# 网络结构n.conv1, n.relu1 = conv_relu(n.data, 11, 96, stride=4)n.pool1 = max_pool(n.relu1, 3, stride=2)n.norm1 = L.LRN(n.pool1, local_size=5, alpha=1e-4, beta=0.75)n.conv2, n.relu2 = conv_relu(n.norm1, 5, 256, pad=2, group=2)n.pool2 = max_pool(n.relu2, 3, stride=2)n.norm2 = L.LRN(n.pool2, local_size=5, alpha=1e-4, beta=0.75)n.conv3, n.relu3 = conv_relu(n.norm2, 3, 384, pad=1)n.conv4, n.relu4 = conv_relu(n.relu3, 3, 384, pad=1, group=2)n.conv5, n.relu5 = conv_relu(n.relu4, 3, 256, pad=1, group=2)n.pool5 = max_pool(n.relu5, 3, stride=2)n.fc6, n.relu6 = fc_relu(n.pool5, 4096)n.drop6 = L.Dropout(n.relu6, in_place=True)n.fc7, n.relu7 = fc_relu(n.drop6, 4096)n.drop7 = L.Dropout(n.relu7, in_place=True)n.score = L.InnerProduct(n.drop7, num_output=20)n.loss = L.SigmoidCrossEntropyLoss(n.score, n.label)return str(n.to_proto())

3.编写网络和solver文件

  • 现在我们可以创建网络和solver的prototxt文件了。我们使用前面导入的tools模块中的CaffeSolver类来定义solver。
workdir = './pascal_multilabel_with_datalayer'
if not os.path.isdir(workdir):os.makedirs(workdir)
solverprototxt = tools.CaffeSolver(trainnet_prototxt_path=os.path.join(workdir, "trainnet.prototxt"), testnet_prototxt_path=os.path.join(workdir, "valnet.prototxt"))
solverprototxt.sp['display'] = "1"
solverprototxt.sp['base_lr'] = "0.0001"
solverprototxt.write(os.path.join(workdir, 'solver.prototxt'))# 训练网络
with open(os.path.join(workdir, 'trainnet.prototxt'), 'w') as f:# provide parameters to the data layer as a python dictionary. Easy as pie!data_layer_params = dict(batch_size=128, im_shape=[227,227], split='train', pascal_root=pascal_root)f.write(caffenet_multilabel(data_layer_params, 'PascalMultilabelDataLayerSync'))# 测试网络
with open(os.path.join(workdir, 'valnet.prototxt'), 'w') as f:data_layer_params = dict(batch_size=128, im_shape=[227,227], split='val', pascal_root=pascal_root)f.write(caffenet_multilabel(data_layer_params, 'PascalMultilabelDataLayerSync'))
  • 这个网络使用了Python的数据层PascalMultilabelDataLayerSync,定义在./pycaffe/layers/pascal_multilabel_datalayers.py
  • 再看看代码,它很直白,你也可以很容易地控制标签和数据。
  • 现在我们可以像往常一样导入caffe的solver了。

注:如果在运行时遇到错误请注意:在这一步之前请确保编译caffe时,使用了WITH_PYTHON_LAYER := 1,否则运行后,会报错,提示说找不到layer类型:python;如果是在notebook中运行,并不会提示这些信息,程序会直接停止然后重启python内核。

# print os.path.join(workdir, 'solver.prototxt')
# solver_path = os.path.join(caffe_root, 'examples', 'pascal_multilabel_with_datalayer', 'solver.prototxt')
solver = caffe.SGDSolver(os.path.join(workdir, 'solver.prototxt'))
# solver = caffe.SGDSolver(solver_path)
solver.net.copy_from(caffe_root + 'models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel')
solver.test_nets[0].share_with(solver.net)
solver.step(1)
BatchLoader initialized with 5717 images
PascalMultilabelDataLayerSync initialized for split: train, with bs: 128, im_shape: [227, 227].
BatchLoader initialized with 5823 images
PascalMultilabelDataLayerSync initialized for split: val, with bs: 128, im_shape: [227, 227].
  • 我们再来看看导入的数据。
transformer = tools.SimpleTransformer()# This is simply to add back the bias, re-shuffle the color channels to RGB, and so on...
image_index = 0
plt.figure()
plt.imshow(transformer.deprocess(copy(solver.net.blobs['data'].data[image_index, ...])))
gtlist = solver.net.blobs['label'].data[image_index, ...].astype(np.int)
plt.title('GT: {}'.format(classes[np.where(gtlist)]))
plt.axis('off')
(-0.5, 226.5, 226.5, -0.5)

注:我们是直接从网络的data层读取出来的图像,由于经过了预处理操作,所以这里图片的分辨率是比原始的PASCAL VOC数据集的图片更低的

4.训练网络

  • 接下来开始训练网络。首先,我们需要一些方法来测量准确率。汉明距离经常用于多标签问题。我们还需要一个循环来测试网络性能。开始敲代码吧。
def hamming_distance(gt, est):return sum([1 for (g, e) in zip(gt, est) if g == e]) / float(len(gt))def check_accuracy(net, num_batches, batch_size=128):acc = 0.0for t in range(num_batches):net.forward()gts = net.blobs['label'].dataests = net.blobs['score'].data > 0for gt, est in zip(gts, ests):#for each ground truth and estimated label vectoracc += hamming_distance(gt, est)return acc / (num_batches * batch_size)
  • 好的,接下来训练一段时间。
for it in range(6):solver.step(100)print 'iter:{:3d}'.format((it+1) * 100), 'accuracy:{0:.4f}'.format(check_accuracy(solver.test_nets[0], 50))
iter:100 accuracy:0.9523
iter:200 accuracy:0.9569
iter:300 accuracy:0.9580
iter:400 accuracy:0.9586
iter:500 accuracy:0.9591
iter:600 accuracy:0.9593
  • 很棒,看起来准确率在增长,而且看起来也很快地收敛了。看起来很奇怪,准确率一开始就如此之高,这是因为,它的ground truth分布很稀疏。PASCAL数据集有20个类,通常每幅图片就只会属于一个或者两个类而已,因此预测所有输出为0的区域会有很高的准确率。下面来确认一下。
def check_baseline_accuracy(net, num_batches, batch_size=128):acc = 0.0for t in range(num_batches):net.forward()gts = net.blobs['label'].dataests = np.zeros((batch_size, len(gts)))for gt, est in zip(gts, ests):acc += hamming_distance(gt, est)return acc / (num_batches * batch_size)print 'Baseline accuracy:{0:.4f}'.format(check_baseline_accuracy(solver.test_nets[0], 5823/128))
Baseline accuracy:0.9240

6.一些预测结果

test_net = solver.test_nets[0]
for image_index in range(5):plt.figure()plt.imshow(transformer.deprocess(copy(test_net.blobs['data'].data[image_index, ...])))gtlist = test_net.blobs['label'].data[image_index, ...].astype(np.int)estlist = test_net.blobs['score'].data[image_index, ...] > 0plt.title('GT: {} \n EST: {}'.format(classes[np.where(gtlist)], classes[np.where(estlist)]))plt.axis('off')

Caffe官方教程翻译(9):Multilabel Classification with Python Data Layer相关推荐

  1. Caffe官方教程翻译(5):Classification: Instant Recognition with Caffe

    前言 最近打算重新跟着官方教程学习一下caffe,顺便也自己翻译了一下官方的文档.自己也做了一些标注,都用斜体标记出来了.中间可能额外还加了自己遇到的问题或是运行结果之类的.欢迎交流指正,拒绝喷子! ...

  2. Caffe官方教程翻译(10):Editing model parameters

    前言 最近打算重新跟着官方教程学习一下caffe,顺便也自己翻译了一下官方的文档.自己也做了一些标注,都用斜体标记出来了.中间可能额外还加了自己遇到的问题或是运行结果之类的.欢迎交流指正,拒绝喷子! ...

  3. Caffe官方教程翻译(8):Brewing Logistic Regression then Going Deeper

    前言 最近打算重新跟着官方教程学习一下caffe,顺便也自己翻译了一下官方的文档.自己也做了一些标注,都用斜体标记出来了.中间可能额外还加了自己遇到的问题或是运行结果之类的.欢迎交流指正,拒绝喷子! ...

  4. Caffe官方教程翻译(7):Fine-tuning for Style Recognition

    前言 最近打算重新跟着官方教程学习一下caffe,顺便也自己翻译了一下官方的文档.自己也做了一些标注,都用斜体标记出来了.中间可能额外还加了自己遇到的问题或是运行结果之类的.欢迎交流指正,拒绝喷子! ...

  5. Caffe官方教程翻译(6):Learning LeNet

    前言 最近打算重新跟着官方教程学习一下caffe,顺便也自己翻译了一下官方的文档.自己也做了一些标注,都用斜体标记出来了.中间可能额外还加了自己遇到的问题或是运行结果之类的.欢迎交流指正,拒绝喷子! ...

  6. Caffe官方教程翻译(4):CIFAR-10 turorial

    前言 最近打算重新跟着官方教程学习一下caffe,顺便也自己翻译了一下官方的文档.自己也做了一些标注,都用斜体标记出来了.中间可能额外还加了自己遇到的问题或是运行结果之类的.欢迎交流指正,拒绝喷子! ...

  7. Caffe官方教程翻译(3):Siamese Network Training with Caffe

    前言 最近打算重新跟着官方教程学习一下caffe,顺便也自己翻译了一下官方的文档.自己也做了一些标注,都用斜体标记出来了.中间可能额外还加了自己遇到的问题或是运行结果之类的.欢迎交流指正,拒绝喷子! ...

  8. Caffe官方教程翻译(2):Web demo

    前言 最近打算重新跟着官方教程学习一下caffe,顺便也自己翻译了一下官方的文档.自己也做了一些标注,都用斜体标记出来了.中间可能额外还加了自己遇到的问题或是运行结果之类的.欢迎交流指正,拒绝喷子! ...

  9. Caffe官方教程翻译(1):LeNet MNIST Tutorial

    前言 最近打算重新跟着官方教程学习一下caffe,顺便也自己翻译了一下官方的文档.自己也做了一些标注,都用斜体标记出来了.中间可能额外还加了自己遇到的问题或是运行结果之类的.欢迎交流指正,拒绝喷子! ...

最新文章

  1. 如何自学Python?这本技术大咖推荐Python书籍,就是你的第一选择
  2. Struts2后期(这框架目前正处于淘汰状态)
  3. 几款最新的解谜单机小游戏
  4. 盘点2020中国IT上市企业100强,贵司上榜了吗?
  5. [BUUCTF-pwn]——ciscn_2019_es_2(内涵peak小知识)
  6. 编程兴趣真的是由“热情”驱动的吗?
  7. Tapioca:linux上同gtalk语音通信
  8. PHP 基于 SW-X 框架,搭建RPC微服务支持
  9. java 调用soapui_利用soapui和jdk API访问webservice
  10. 斐讯N1救砖指南!值得收藏
  11. mac上的android模拟器下载安装,Mac电脑上安装安卓模拟器,Mac如何安装Android模拟器...
  12. 用HTML写一首绝句古诗,唐诗七绝绝句经典50首:唐诗七绝悲伤的句子让人心醉
  13. mediawiki mysql配置_安装MediaWiki
  14. 【概率论基础进阶】随机事件和概率-古典概型与伯努利概型
  15. win10、win7系统64位oracle11g安装教程以及32位plsql连接教程
  16. 拜仁超越自我终成夙愿-记2013欧冠决赛
  17. linux 打开大文件命令,linux下大文件的读取
  18. 『2021语言与智能技术竞赛』-机器阅读理解任务基线系统详解
  19. Dev Board---将摄像机连接到开发板
  20. tomcat jdbc连接池配置属性详解之参数说明

热门文章

  1. powerbuilder 保存图表图像_数据可视化/统计图表循序渐进指南
  2. win7系统升服务器版本,WIN7专业版可update补丁,WIN7旗舰版无法update补丁,WSUS服务器是按windows类型还是版本区别updata的还是其他什么方式...
  3. Dubbo 源码分析 - 集群容错之 Cluster
  4. Java直接内存与非直接内存性能测试
  5. 论面向组合子程序设计方法 之 重构2
  6. 使用强大的 Mockito 测试框架来测试你的代码
  7. JavaWeb学习总结(十三)——使用Session防止表单重复提交
  8. 从Python中readline()函数读取的一行内容中去掉换行符\n
  9. Python numpy生成矩阵、串联矩阵
  10. Python list 数据类型:列表