为了探究影响模型运行时间的变量,之前运用了参数量做标准

参数量在TF框架下还是很容易计算的

TF框架下运用

tf.keras.models.Model().summary()

就能一键生成包含模型的layers,output,parameters的报告

为了探究其他标准用于反映模型的运行时间,我们在网上找到了三个标准:

参数量(parameters)
浮点运算次数(FLOPs)
内存访问次数(MAC)

这次我们就来探究FLOPs对模型latency的影响

一上来先踩坑

TF2.X取消了Profiler接口对于FLOPs的统计

即使通过网上给的TF1.X接口再用TF2.X compat.v1接口依然不能成功返回FLOPs的值

于是只能写个程序硬算

FLOPs本质上是模型中的乘法和加法运算,

模型里面的layers有:

Input layer

Zero Padding layer

Conv2D

BatchNormalization

Activation

Depthwise conv2D

Dense

其中因为浮点运算次数很少而可以忽略不记的layers:

Input,

zero padding,

BatchNormalization,

Activation

剩下需要计算的layers就是:

Conv2D,

Depthwise Conv2D,

Dense

Conv2D:

FLOPs=Cin∗K∗K∗H∗W∗CoutFLOPs=Cin*K*K*H*W*CoutFLOPs=Cin∗K∗K∗H∗W∗Cout

Cin是输入的channel

K*K是kernel size

H*W是输出size

Cout是输出channel

Depthwise Conv2D:

FLOPs=Cin∗H∗W∗K∗K/S/SFLOPs=Cin*H*W*K*K/S/SFLOPs=Cin∗H∗W∗K∗K/S/S

Cin是输入的channel

K*K是kernel size

H*W是输入size

S*S是strides

Dense:

FLOPs=2∗Cin∗CoutFLOPs=2*Cin*CoutFLOPs=2∗Cin∗Cout

Cin是输入的channel

Cout是输出channel

基本层定义好了之后就按照模型的结构将运算FLOPs的结构搭起来

每一层输出的数据跟模型的也是一样的

#for conv2D layers in Depthwise Res Block
#padding=same stride=(1,1)
#request input:
# channel out(filter),
# kernel size(k[2]),
# input(cin[3])
#
#output:
# Flops of this layer(conv2dDRB_FLPOs)
# output(cout[3])
def conv2DDRB(filter,kernel,cin):H=cin[0]W=cin[1]conv2dDRB_FLOPs=cin[2]*kernel[0]*kernel[1]*filter*H*Wcout=[H,W,filter]return conv2dDRB_FLOPs,cout#for DepthwiseConv2D layers in Depthwise Res Block
#request input:
# kernel size(kernal[2]),
# stride(stride[2]),
# input(cin[3])
#output:
# FLOPs of this layer(DepConv2dDRB_FLOPs)
# output(out[3])
def DepthwiseConv2DDRB(kernel,stride,cin):out=[0,0,0]if stride[0]==1 :DepConv2dDRB_FLOPs=cin[0]*cin[1]*cin[2]*kernel[0]*kernel[1]out=cinelse:DepConv2dDRB_FLOPs=cin[0]*cin[1]*cin[2]*kernel[0]*kernel[1]/stride[0]/stride[1]out[0]=cin[0]/stride[0]out[1]=cin[1]/stride[1]out[2]=cin[2]return DepConv2dDRB_FLOPs, outdef DRB(cin, filter,kernel,stride,t):exp_channel=cin[2]*talpha=filterblock_counter=0conv_flop,cout=conv2DDRB(exp_channel,(1,1),cin)block_counter+=conv_flopdep_flop,cout=DepthwiseConv2DDRB(kernel,stride,cout)block_counter+=dep_flopconv_flop,cout=conv2DDRB(alpha,(1,1),cout)block_counter+=conv_flopprint('This Block FLOPs:',block_counter,'Output:',cout)return block_counter,coutdef conv(filter,kernel,stride,cin):H=cin[0]/stride[0]W=cin[1]/stride[1]conv_FLOPs=cin[2]*kernel[0]*kernel[1]*filter*H*Wcout=[H,W,filter]return conv_FLOPs, coutdef GolbalAvgPool(cin):features=cin[2]print('Output:',features)return featuresdef den(classes,fin):flop=2*fin*classesprint('This layer FLOPs:',flop,'Output:',classes)return flop, classes
def MNV2(classes):total_count=0img_input=(224,224,3)conv_flop,cout=conv(32,(3,3),(2,2),img_input)total_count+=conv_flopprint('This layer FLOPs:',conv_flop,'Output:',cout)Dep_flop,cout=DepthwiseConv2DDRB((3,3),(1,1),cout)total_count+=Dep_flopprint('This layer FLOPs:', Dep_flop, 'Output:', cout)conv_flop,cout=conv(16,(1,1),(1,1),cout)total_count += conv_flopprint('This layer FLOPs:', conv_flop, 'Output:', cout)DRB_flops,cout=DRB(cout,24,(3,3),(2,2),6)total_count+=DRB_flopsDRB_flops,cout=DRB(cout,24,(3,3),(1,1),6)total_count+=DRB_flopsDRB_flops,cout=DRB(cout,32,(3,3),(2,2),6)total_count+=DRB_flopsDRB_flops,cout=DRB(cout,32,(3,3),(1,1),6)total_count+=DRB_flopsDRB_flops,cout=DRB(cout,32,(3,3),(1,1),6)total_count+=DRB_flopsDRB_flops,cout=DRB(cout,64,(3,3),(2,2),6)total_count+=DRB_flopsDRB_flops,cout=DRB(cout,64,(3,3),(1,1),6)total_count+=DRB_flopsDRB_flops,cout=DRB(cout,64,(3,3),(1,1),6)total_count+=DRB_flopsDRB_flops,cout=DRB(cout,64,(3,3),(1,1),6)total_count+=DRB_flopsDRB_flops,cout=DRB(cout,96,(3,3),(1,1),6)total_count+=DRB_flopsDRB_flops,cout=DRB(cout,96,(3,3),(1,1),6)total_count+=DRB_flopsDRB_flops,cout=DRB(cout,96,(3,3),(1,1),6)total_count+=DRB_flopsDRB_flops,cout=DRB(cout,160,(3,3),(2,2),6)total_count+=DRB_flopsDRB_flops,cout=DRB(cout,160,(3,3),(1,1),6)total_count+=DRB_flopsDRB_flops,cout=DRB(cout,160,(3,3),(1,1),6)total_count+=DRB_flopsDRB_flops,cout=DRB(cout,320,(3,3),(1,1),6)total_count+=DRB_flopsconv_flop,cout=conv(1280,(1,1),(1,1),cout)total_count+=conv_flopcout=GolbalAvgPool(cout)Den_flops,classes=den(classes,cout)total_count+=Den_flopsprint('TOTAL FLOPs:',total_count,'Output:',classes)return total_count,classesdef main():MNV2(5)if __name__=='__main__':main()

计算模型的FLOPs相关推荐

  1. pytorch模型参数信息 计算模型的FLOPs

    参考链接:https://blog.csdn.net/tsq292978891/article/details/87918244 打印模型参数信息 在python3环境下安装torchsummary ...

  2. Keras | 计算模型的FLOPs、MACCs

    FLOPs全称是floating point operations的缩写,翻译过来是浮点运算数,理解为计算量,常用来衡量算法或深度学习模型的计算复杂度. 关于计算FLOPs值的函数,网上相关的博客很多 ...

  3. 使用thop库对yolo等深度学习模型的FLOPS进行计算

    据说yolov5原来的FLOPS计算脚本有bug,因此这个大神推荐使用thop库进行计算,代码如下: input = torch.randn(1, 3, 416, 416) flops, params ...

  4. 华为高级研究员谢凌曦:下一代人工智能计算模型探索

    2020-04-30 23:30:04 导读:下一代人工智能计算模型,主要是使用一些自动化技术帮助我们设计更好的深度学习网络结构,并在任务中提升性能.在深度学习如火如荼的当下,如何设计高效的神经网络架 ...

  5. pytorch计算模型参数量

    1. 安装 thop 1.1 常规安装 pip install thop 1.2 若上述安装方式错误,可以参考以下方式: pip install thop-i http://pypi.douban.c ...

  6. 测试方法介绍-计算模型复杂度(GMac)、模型大小(M)、计算速度(FPS)

    PRNet-V 计算复杂度为 48.76GMac 参数数量为34.73M (PRNet测试结果)(IEO在12345层) 参数数量为27.57M (PRNet测试结果)(IEO在345层) 计算图片读 ...

  7. (3)tesorflow 计算模型复杂度

    目录 1. 计算模型复杂度的衡量 2 . 典型层的复杂性计算原理 2.1 全连接层的复杂性计算 2.2 卷积层的复杂性计算 3. 全连接Tensorflow实现 4. GraphDef 5. Free ...

  8. Pytorch 获取模型 Params/FLOPS

    Note! 查看 model 参数值 model.state_dict() 1.自定义 Params Pytorch依据其内建接口自己写代码获取模型参数情况,大家可以参考Pytorch提供的model ...

  9. ECCV2020最佳论文解读之递归全对场变换(RAFT)光流计算模型

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 计算机视觉三大国际顶级会议之一的 ECCV 2020 已经召开.今年 ECCV 共收到有效投稿 502 ...

最新文章

  1. 13.Zookeeper的java客户端API使用方法
  2. Powershell 命令行泄漏下一个 Windows 10 更新内容
  3. pycharm 报错 out of memory 解决方法
  4. 在linux上使用yum安装JDK
  5. 【译】zkSNARKs in a nutshell
  6. MySQL双向主从复制
  7. ViewPager子类与父类滑动冲突的情况
  8. 导出excel写入公式_乱码、公式出错、效率低,这些excel“事故”的解决办法来了...
  9. Oracle插入时间
  10. mysql 数据增量备份_mysqlmysqldump数据备份和增量备份
  11. matlab常见函数表达式,MATLAB常用函数简介
  12. DHTMLXGantt and DHTMLXGantt pro
  13. IOS11.03越狱
  14. 小飞鱼通达二开 致远OA A8+ 设计工作流实例初体验(图文)
  15. php中调整图片大小,php 调整图片尺寸的简单示例
  16. 海外权威媒体好评连连,一加5T中国11月28号发布
  17. Linux下手动查杀木马
  18. java发送图片_Java发送邮件(图片、附件、HTML)
  19. ISLR读书笔记(1)统计学习简介
  20. 【leetcode】另一棵树的子树 c++

热门文章

  1. ArcGIS之字段计算器
  2. 《海外社交媒体营销》一一2.5 选择正确的工具和软件
  3. Linux网络设备驱动分析,以W5300以太网驱动为例
  4. [转]寒冬悟道者马云:阿里巴巴逢单出击
  5. APMServ 在 Win7 下出现“APMServ-Apache 服务因 函数不正确。 服务特定错误而停止。”
  6. 最佳队形(动态规划)
  7. SSH登录虚拟机慢的问题(等待很久才提示输入密码)
  8. System.UnauthorizedAccessException: 对路径“.......”的访问被拒绝的解决办法
  9. 阿里巴巴校招内推简历筛选方案(总结篇)
  10. 浅析平衡二叉树的四种旋转