本次暑期科研见习,我有机会初步了解了人工智能的深度学习和模型压缩的基本内容,并在移动设备(树莓派3B)上进行了一些简单的深度学习模型训练。在见习结束之际,总结一下这次学习的内容,也期待之后能够继续在相关领域进行更为深入的研究。

一、深度学习的模型剪枝初探

参考:https://jacobgil.github.io/deeplearning/pruning-deep-learning
一般来讲,随着深度学习的神经网络层数越来越多、网络越来越宽,深度学习模型得到的结果会越来越精细。但与此同时,模型的参数量和计算量也会呈现激增的态势。这不仅对硬件的性能形成了挑战,同时一些冗余的模型参数也会影响计算效率。因此,就需要对一些大模型进行压缩。模型压缩大体上分为知识蒸馏、模型剪枝和量化三类方法,这里重点介绍模型剪枝。
模型剪枝的原理就是希望通过剪除对输出结果贡献不大的参数,减小模型的规模、提升运行速度,同时可以保持模型性能基本不变。
基本步骤:首先,根据对结果贡献度(weight)的大小对神经元进行排序;然后,舍去那些贡献度低的神经元,使模型的规模更精简,模型运行速度更快。当然,这里面有几点问题需要进行说明:

1.剪枝后的再训练

进行剪枝的目的之一是希望剪枝带来的模型损失(cost)越小越好.因此,剪枝后的模型需要再进行反复训练直到呈现令人满意的性能。所以,模型的剪枝实际上是一个迭代的过程,这通常称为“迭代式剪枝”;迭代的过程就是剪枝和模型训练两者的交替重复。

2.几种剪枝技术的简单介绍

不同的剪枝技术不仅包括对神经网络卷积层的处理,也包括如何选取模型参数的权重函数。

论文1:《Pruning filters for effecient convnets
地址:https://arxiv.org/abs/1608.08710
本文提出对卷积层进行完全的剪枝。 作者提出了基于量级的裁剪方式,用weight值的大小来评判其重要性,对于一个filter,其中所有weight的绝对值求和,来作为该filter的评价指标,将一层中值低的filter裁掉,可以有效的降低模型的复杂度,并且不会给模型的性能带来很大的损失。
对卷积窗口剪枝的迭代过程中,每一轮迭代会将全部的卷积窗口进行排序(排序指标为卷积核中L1正则化的权重参数),舍弃排序后指标最低的m个卷积窗口以达到剪枝的目的,然后用剪枝后的卷积窗口进行模型训练,再不断地重复这个过程。

论文二:《Structured Pruning of Deep Convolutional Neural Networks》 地址:https://arxiv.org/abs/1512.08571
这篇论文与上一篇类似,不过在**排序上用了更加复杂的方法。论文采用了N个卷积单元过滤器 (Particle Filters)来对相应的N个卷积层进行剪枝操作。**每一个卷积单元会根据其影响模型在验证数据集上的准确率程度而被分配一个分值,分值低的卷积单元会被过滤掉以达到剪枝的目的。不过这种剪枝非常耗时。

论文三:《Pruning Convolutional Neural Networks for Resource Efficient Inference》地址:https://arxiv.org/abs/1611.06440
本文将剪枝问题当作是一个组合优化问题:从众多的权重参数中选择一个最优组合B,使得被剪枝的模型的代价函数损失最小。相应公式如下:

值得注意的是,论文用的是代价函数损失的绝对值,而不是单纯的差值。使用代价函数损失的绝对值作为优化目标,可以保证被剪枝的模型在性能上不会损失过多。

二、树莓派上简单深度学习模型的训练

当然,模型剪枝的一个设想就是能够把模型放在比较小的设备上能够运行。小型移动设备的配置和计算性能相对来讲都略逊一筹,不过作为一个训练小型深度学习模型的载体还是足够的。

树莓派系统和模块的配置

本次使用的是树莓派(Raspberry pi)3B。树莓派是一款基于ARM架构的微型电脑主板,以SD/MicroSD卡为内存硬盘,卡片主板周围有1/2/4个USB接口和一个10/100 以太网接口(A型没有网口),可连接键盘、鼠标和网线,同时拥有视频模拟信号的电视输出接口和HDMI高清视频输出接口,以上部件全部整合在一张仅比信用卡稍大的主板上,具备所有PC的基本功能只需接通显示屏、鼠标和键盘,就能执行一些简单的功能。树莓派自带的Raspbian系统基于Linux,系统默认的python版本是python2.7&3.7 。
同时,训练深度学习模型需要安装pytorch模块。PyTorch是美国互联网巨头Facebook在深度学习框架Torch的基础上使用Python重写的一个全新的深度学习框架,它更像NumPy的替代产物,不仅继承了NumPy的众多优点,还支持GPU计算,在计算效率上要比NumPy有更明显的优势;不仅如此,PyTorch还有许多高级功能,比如拥有丰富的API,可以快速完成深度神经网络模型的搭建和训练。
关于在树莓派上安装pytorch以及相关模块请参考前一篇文章:
https://blog.csdn.net/qq_44635669/article/details/96972336
当然你也可以尝试在anaconda上尝试安装pytorch模块进行模型训练。

opencv目标检测预训练(predict)模型演示

这个模型的基本原理就是把图片中的一些元素框出,和元素库中的标签元素进行比对识别。

import cv2
import time# Pretrained classes in the model
classNames = {0: 'background',1: 'person', 2: 'bicycle', 3: 'car', 4: 'motorcycle', 5: 'airplane', 6: 'bus',7: 'train', 8: 'truck', 9: 'boat', 10: 'traffic light', 11: 'fire hydrant',13: 'stop sign', 14: 'parking meter', 15: 'bench', 16: 'bird', 17: 'cat',18: 'dog', 19: 'horse', 20: 'sheep', 21: 'cow', 22: 'elephant', 23: 'bear',24: 'zebra', 25: 'giraffe', 27: 'backpack', 28: 'umbrella', 31: 'handbag',32: 'tie', 33: 'suitcase', 34: 'frisbee', 35: 'skis', 36: 'snowboard',37: 'sports ball', 38: 'kite', 39: 'baseball bat', 40: 'baseball glove',41: 'skateboard', 42: 'surfboard', 43: 'tennis racket', 44: 'bottle',46: 'wine glass', 47: 'cup', 48: 'fork', 49: 'knife', 50: 'spoon',51: 'bowl', 52: 'banana', 53: 'apple', 54: 'sandwich', 55: 'orange',56: 'broccoli', 57: 'carrot', 58: 'hot dog', 59: 'pizza', 60: 'donut',61: 'cake', 62: 'chair', 63: 'couch', 64: 'potted plant', 65: 'bed',67: 'dining table', 70: 'toilet', 72: 'tv', 73: 'laptop', 74: 'mouse',75: 'remote', 76: 'keyboard', 77: 'cell phone', 78: 'microwave', 79: 'oven',80: 'toaster', 81: 'sink', 82: 'refrigerator', 84: 'book', 85: 'clock',86: 'vase', 87: 'scissors', 88: 'teddy bear', 89: 'hair drier', 90: 'toothbrush'}def id_class_name(class_id, classes):for key, value in classes.items():if class_id == key:return value# Loading model
time_start=time.time()
model = cv2.dnn.readNetFromTensorflow('models/frozen_inference_graph.pb','models/ssd_mobilenet_v2_coco_2018_03_29.pbtxt')
image = cv2.imread("image.jpeg")model.setInput(cv2.dnn.blobFromImage(image, size=(300, 300), swapRB=True))
output = model.forward()
# print(output[0,0,:,:].shape)for detection in output[0, 0, :, :]:confidence = detection[2]if confidence > .5:class_id = detection[1]class_name=id_class_name(class_id,classNames)print(str(str(class_id) + " " + str(detection[2])  + " " + class_name))box_x = detection[3] * image_widthbox_y = detection[4] * image_heightbox_width = detection[5] * image_widthbox_height = detection[6] * image_heightcv2.rectangle(image, (int(box_x), int(box_y)), (int(box_width), int(box_height)), (23, 230, 210), thickness=1)cv2.putText(image,class_name ,(int(box_x), int(box_y+.05*image_height)),cv2.FONT_HERSHEY_SIMPLEX,(.005*image_width),(0, 0, 255))
time_end=time.time()time_run=time_end-time_start
time_predict=time_end-time_begincv2.imshow('image', image)
# cv2.imwrite("image_box_text.jpg",image)
print('predict time:',time_predict)
print('run time:',time_run)cv2.waitKey(0)
cv2.destroyAllWindows()

最后在jupyter notebook上的运行结果:
1.0 0.6592765 person
18.0 0.8725562 dog
predict time: 0.10614347457885742
run time: 0.526254415512085

分别显示了识别的元素key和value、accuracy rate以及代码运行时间和模型预测时间。
当然,代码运行完毕后会自动弹出一个窗口显示识别的结果。

关于此模型的详细内容请参考:
https://heartbeat.fritz.ai/real-time-object-detection-on-raspberry-pi-using-opencv-dnn-98827255fa60

暑期科研见习总结:移动设备上的深度学习与模型剪枝初探相关推荐

  1. 干货 | 如何使用 CNN 推理机在 IoT 设备上实现深度学习

    作者 | 唐洁 责编 | 何永灿 通过深度学习技术,物联网(IoT)设备能够得以解析非结构化的多媒体数据,智能地响应用户和环境事件,但是却伴随着苛刻的性能和功耗要求.本文作者探讨了两种方式以便将深度学 ...

  2. 如何用TensorFlow在安卓设备上实现深度学习推断

    在 Insight 任职期间,我用 TensorFlow 在安卓上部署了一个预训练的 WaveNet 模型.我的目标是探索将深度学习模型部署到设备上并使之工作的工程挑战!这篇文章简要介绍了如何用 Te ...

  3. 如何使用CNN推理机在IoT设备上实现深度学习

    作者简介: 唐洁,华南理工大学计算机科学与工程学院副教授.主要从事面向无人驾驶和机器人的大数据计算与存储平台.面向人工智能的计算体系架构.面向机器视觉的嵌入式系统研究. 责编:何永灿(heyc@csd ...

  4. 边缘计算 | 在移动设备上部署深度学习模型的思路与注意点

  5. OpenCV在Android设备上运行深度网络

    OpenCV在Android设备上运行深度网络 在Android设备上运行深度网络 介绍 要求 创建一个空的Android Studio项目 添加OpenCV依赖项 做一个样品 在Android设备上 ...

  6. 一种用于人脸检测的设备上的深度神经网络

    欢迎大家前往云加社区,获取更多腾讯海量技术实践干货哦~ 译者:QiqiHe 苹果公司开始在iOS 10中使用深度学习进行人脸检测.随着Vision框架的发布,开发人员现在可以在他们的应用程序中使用这种 ...

  7. 用TVM在硬件平台上部署深度学习工作负载的端到端 IR 堆栈

    用TVM在硬件平台上部署深度学习工作负载的端到端 IR 堆栈 深度学习已变得无处不在,不可或缺.这场革命的一部分是由可扩展的深度学习系统推动的,如滕索弗洛.MXNet.咖啡和皮托奇.大多数现有系统针对 ...

  8. NVIDIA GPUs上深度学习推荐模型的优化

    NVIDIA GPUs上深度学习推荐模型的优化 Optimizing the Deep Learning Recommendation Model on NVIDIA GPUs 推荐系统帮助人在成倍增 ...

  9. 点云上的深度学习及其在三维场景理解中的应用————PointNet(一)

    最近在学3D方向的语义分析. 师兄推荐了一个哔哩大学的将门创投 | 斯坦福大学在读博士生祁芮中台:点云上的深度学习及其在三维场景理解中的应用!的宝藏视频,我会多看几遍,并写下每次观看笔记. 下文的截图 ...

最新文章

  1. 第九十六题(编写strcpy 函数)
  2. MySQL数据库从入门到实战(四)
  3. ACL 2020 | 多编码器是否能够捕获篇章级信息?
  4. pycharm python 模板配置_windows下pycharm安装、创建文件、配置默认模板
  5. 图片标注尺寸_AutoCAD图纸与测量尺寸不一样怎么办
  6. 信息学奥赛C++语言:分糖果
  7. 五.开发记录之ubuntu系统安装各个软件
  8. 执行计算机查错程序,计算机 每次启动过程中总会执行磁盘检查CHKDSK,什么问题???怎么处理??...
  9. Q92:怎么对PLY文件对应的图形进行仿射变换
  10. Python让繁琐工作自动化——chapter16 发送电子邮件和短信
  11. smarty手册 分离php和html
  12. 游戏及相关CG行业知识分享大V全整合
  13. rar x64 5.50 linux,WinRAR 5.50简体中文注册版(已含Key文件和32位、64位)
  14. java 日期格式化工具类
  15. Exp外贸/出口英文商城系统在国际电商贸易中的角色扮演
  16. 计算机为什么不能装win7,i5 8400 cpu能装win7吗?为什么安装不了win7
  17. Oracle中的空值问题
  18. BZOJ 1631==USACO 2007== POJ 3268 Cow Party奶牛派对
  19. Python 内置模块tkinter —— 秒表计时器
  20. 涉案金额600万!微粒贷诈骗团伙被警方一锅端

热门文章

  1. 【四二学堂】认识网络爬虫
  2. 基于LDA的 职位描述JD 匹配
  3. lkwa-blind-rce盲打rce
  4. espressif中的sd库有问题
  5. 关于iphone手机升级与itunes升级包大小不同的问题
  6. js通过URL下载文件
  7. python函数的动态参数之一个星号和两个星号
  8. 本BLOG内所有文章的版权声明
  9. bootstrap富文本编辑器的使用
  10. 智能肉温度计的全球与中国市场2022-2028年:技术、参与者、趋势、市场规模及占有率研究报告