• 数据库:MNIST,与这里对比
  • tf.nn.depthwise_conv2d的理解看这里,主要是对卷积核参数的理解,即(高度,宽度,输入通道,每个通道得到的输出通道数)
  • 训练速度慢,收敛也慢,刚开始就像没训练的样子,只将一个卷积层改成深度可分离卷积就增加了12次迭代
  • 20211213:将第一个卷积层改成深度可分离卷积后,训练始终不收敛。所以改成下列代码,多训练了16次,模型减小了6.85%。使用sp_conv也会不收敛,可能是batch_normalization的原因
  • 20211214:破案了,是relu的锅,注释以后,多训练了28次,看网上说好像要减小学习率
import tensorflow as tf
import numpy as np
import random
import cv2,sys,os
import MyDatadef sp_conv(name, data, kernel_size, input_num, output_num, padding, data_format='NHWC'):with tf.variable_scope(name):weight = tf.get_variable(name='weight', dtype=tf.float32, trainable=True, shape=[kernel_size,kernel_size,input_num,1], initializer=tf.random_normal_initializer(stddev=0.01))conv = tf.nn.depthwise_conv2d(data, weight, [1,1,1,1], padding, data_format=data_format)conv = tf.layers.batch_normalization(conv, momentum=0.9)#conv = tf.nn.leaky_relu(conv, alpha=0.1)point_weight = tf.get_variable(name='point_weight', dtype=tf.float32, trainable=True, shape=[1,1,input_num,output_num], initializer=tf.random_normal_initializer(stddev=0.01))conv = tf.nn.conv2d(conv, point_weight, [1,1,1,1], padding, data_format=data_format)conv = tf.layers.batch_normalization(conv, momentum=0.9)#conv = tf.nn.leaky_relu(conv, alpha=0.1)return convdata=tf.placeholder(tf.float32, [None, 28, 28, 3],name='data')
label=tf.placeholder(tf.float32, [None, 10], name='label')with tf.variable_scope('conv1'): # output is 28x28weight = tf.get_variable(name='weight', dtype=tf.float32, trainable=True, shape=[5,5,1,6], initializer=tf.random_normal_initializer(stddev=0.01))conv1 = tf.nn.conv2d(data, weight, [1,1,1,1], 'SAME')bias = tf.get_variable(name='bias', shape=6, trainable=True, dtype=tf.float32, initializer=tf.constant_initializer(0.0))conv1 = tf.nn.bias_add(conv1, bias)
with tf.variable_scope('pool1'):pool1 = tf.nn.max_pool(conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID')
with tf.variable_scope('conv2'): # output is 10x10weight2 = tf.get_variable(name='weight', dtype=tf.float32, trainable=True, shape=[5,5,6,1], initializer=tf.random_normal_initializer(stddev=0.01))conv2 = tf.nn.depthwise_conv2d(pool1, weight2, [1,1,1,1], 'VALID')point_weight2 = tf.get_variable(name='point_weight', dtype=tf.float32, trainable=True, shape=[1,1,6,16], initializer=tf.random_normal_initializer(stddev=0.01))conv2 = tf.nn.conv2d(conv2, point_weight2, [1,1,1,1], 'VALID')bias2 = tf.get_variable(name='bias', shape=16, trainable=True, dtype=tf.float32, initializer=tf.constant_initializer(0.0))conv2 = tf.nn.bias_add(conv2, bias2)
with tf.variable_scope('pool2'):pool2 = tf.nn.max_pool(conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID')
with tf.variable_scope('dense1'):flat = tf.reshape(pool2, [-1, 5*5*16])dense1 = tf.layers.dense(inputs=flat, units=80, activation=tf.nn.relu, use_bias=True)
with tf.variable_scope('dense2'):dense2 = tf.layers.dense(inputs=dense1, units=10, activation=None, use_bias=True)
y = tf.nn.softmax(dense2)
# loss
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels=label, logits=dense2)
loss=tf.reduce_sum(cross_entropy)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):train_step = tf.train.GradientDescentOptimizer(1e-4).minimize(loss)
# accuracy
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(label, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.per_process_gpu_memory_fraction = 0.1
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer())
input_data = MyData.Dataset('/home/lwd/data/mnist/train.txt', True, 32)
test_data = MyData.Dataset('/home/lwd/data/mnist/test.txt', False, 32)
saver = tf.train.Saver()
summary_writer = tf.summary.FileWriter('./log/', sess.graph)for i in range(100000):total = 0cnt = 0tl = 0for item in input_data:_, acc, lo = sess.run([train_step, accuracy, loss], feed_dict={data:item[0], label:item[1]})total += acccnt += 1.0tl += loprint(i, total/cnt, tl / cnt)if total/cnt > 0.88:saver.save(sess, './checkpoint/mb')xh = 0acc = 0for item in test_data:yy = sess.run(y, feed_dict={data:item[0]})for k in range(yy.shape[0]):if(np.argmax(yy[k]) == np.argmax(item[1][k])) : acc += 1xh += 1print(acc * 1.0 / xh)sys.exit(0)

tensorflow实现深度可分离卷积相关推荐

  1. 深度可分离卷积(Xception 与 MobileNet)

    前言 从卷积神经网络登上历史舞台开始,经过不断的改进和优化,卷积早已不是当年的卷积,诞生了分组卷积(Group convolution).空洞卷积(Dilated convolution 或 À tr ...

  2. 【Tensorflow】tf.nn.depthwise_conv2d如何实现深度卷积?+深度可分离卷积详解

    目录 常规卷积操作 深度可分离卷积 = 逐通道卷积+逐点卷积 1.逐通道卷积 2.逐点卷积 参数对比 介绍 实验 代码清单 一些轻量级的网络,如mobilenet中,会有深度可分离卷积depthwis ...

  3. Lesson 16.1016.1116.1216.13 卷积层的参数量计算,1x1卷积核分组卷积与深度可分离卷积全连接层 nn.Sequential全局平均池化,NiN网络复现

    二 架构对参数量/计算量的影响 在自建架构的时候,除了模型效果之外,我们还需要关注模型整体的计算效率.深度学习模型天生就需要大量数据进行训练,因此每次训练中的参数量和计算量就格外关键,因此在设计卷积网 ...

  4. 【CV】MobileNet:使用深度可分离卷积实现用于嵌入式设备的 CNN 架构

    论文名称:MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications 论文下载:https:/ ...

  5. 深度学习中的depthwise convolution,pointwise convolution,SeparableConv2D深度可分离卷积

    DepthwiseConv2D深度方向的空间卷积 pointwise convolution, SeparableConv2D深度可分离卷积 SeparableConv2D实现整个深度分离卷积过程,即 ...

  6. Xception深度可分离卷积-论文笔记

    Xception Xception: Deep Learning with Depthwise Separable Convolutions 角度:卷积的空间相关性和通道相关性 . 笔记还是手写好,都 ...

  7. 【深度学习】利用深度可分离卷积减小计算量及提升网络性能

    [深度学习]利用深度可分离卷积减小计算量及提升网络性能 文章目录 1 深度可分离卷积 2 一个深度可分离卷积层的代码示例(keras) 3 优势与创新3.1 Depthwise 过程3.2 Point ...

  8. 2d 蓝图_“蓝图”卷积--对深度可分离卷积的再思考

    论文:Rethinking Depthwise Separable Convolutions: How Intra-Kernel Correlations Lead to Improved Mobil ...

  9. 深度可分离卷积Depthwise Separable Convolution

    从卷积神经网络登上历史舞台开始,经过不断的改进和优化,卷积早已不是当年的卷积,诞生了分组卷积(Group convolution).空洞卷积(Dilated convolution 或 À trous ...

  10. 深度学习自学(十九):caffe添加深度可分离卷积

    下面是两种不同的深度可分离卷积的实现方式,自己在训练关键点模型,采用MobileNet 添加深度可分离卷积,发现有两种不同的可分离卷积的实现,名字不相同,但是内部都是深度可分离.DepthwiseCo ...

最新文章

  1. Java常用命令及Java Dump
  2. 给Python代码加上酷炫进度条的几种姿势
  3. 智慧解析第12集:老板心理学
  4. 李松南:智能全真时代的多媒体技术——关于8K、沉浸式和人工智能的思考
  5. SmartTemplate学习入门一
  6. jq ajax提交评论,织梦评论怎么改成自己的jq ajax评论
  7. XCODE---个人常用快捷键整理
  8. 在powerdesigner 中出现Could not Initialize JavaVM! 应该怎么解决
  9. php图片生成缩略图_php实现根据url自动生成缩略图的方法
  10. input输入框大小设置_Qualtrics调查问卷设计1-如何在输入框前后添加辅助文字
  11. 【电驱动】驱动电机系统讲解
  12. php吧输出结果进行分割,[判断题] 呼叫处理程序按照一定的逻辑对呼叫进行处理,对呼叫的处理结果与局数据、用户数据的内容无关。...
  13. 服务器跳过系统自检,win7 64位旗舰版跳过开机自检功能直接进入系统的方法
  14. 【控制篇 / 策略】(5.4) ❀ 03. Explicit Web Proxy 显式web代理 ❀ FortiGate 防火墙
  15. Spring Boot 开发微信公众号
  16. margin-left:-100%理解
  17. 类拼多多砍价业务总结
  18. 【文学文娱】《屌丝逆袭》-出任CEO、迎娶白富美、走上人生巅峰
  19. pandas指定从第一行读取正文数据
  20. Apache MINA简介

热门文章

  1. 【HTML/JS】百度地图javascriptAPI点击地图得到坐标(拾取坐标) 标签: 百度地图坐标
  2. 国内主要安全产品及厂商
  3. 单词毕业设计,微信小程序毕设,小程序毕设源码,单词天天斗 (毕业设计/实战小程序学习/微信小程序完整项目)
  4. ENSP路由交换机配置
  5. CAD选择时会卡一下的解决办法
  6. 使用Animate制作汽车广告动画
  7. 30+的华为,也在乘风破浪
  8. ALSA音频架构 -- aplay播放流程分析
  9. css修改图标字体大小,css-更改AngularJS材质图标的图标大小
  10. 2022年起重机械指挥考试题库及模拟考试