文章目录

  • 一、简介
  • 二、slim的子模块及功能介绍
  • 三、slim 定义模型
    • 1. slim中定义一个变量的示例
    • 2. slim中实现一个层
    • 3. slim 中的 argscope
  • 四、训练模型
  • 五、读取保存模型变量

一、简介

slim被放在tensorflow.contrib这个库下面,引入的方法如下:

import tensorflow.contrib.slim as slim
或
import tensorflow as tf
slim = tf.contrib.slim

那么什么是 slim ?slim到底有什么用 ?

它可以消除原生tensorflow里面很多重复的模板性的代码,让代码更精确,更容易替换性。另外slim提供了很多计算机视觉方面的著名模型(VGG,AlexNet等),我们不仅可以直接使用,甚至能以各种方式进行扩展。

二、slim的子模块及功能介绍

arg_scope:提供了一个名为arg_scope的新范围,该范围允许用户为该范围内的特定操作定义默认参数。
除了基本的 namescope,variabelscope外,又加了argscope,它是用来控制每一层的替代超参数的。data(数据):包含TF-slim的数据集定义,数据提供程序,parallel_reader和解码实用程序。evaluation(评估):包含评估模型的例程,评估模型的一些方法,用的不多layers(层):包含用于使用张量流构建模型的高层。这个比较重要,slim的核心和精髓,一些复杂层的定义learning(学习):包含用于训练模型的模型loss:包含常用的损失函数。matris(指标):包含流行的评估指标,评估模型的准则标准nets(网络):包含流行的网络定义,例如VGG和AlexNet模型。queues(队列):提供上下文管理器,可轻松安全地启动和关闭QueueRunner。variables(变量):为变量创建和操作提供方便的包装。

三、slim 定义模型

1. slim中定义一个变量的示例

# Model Variables
weights = slim.model_variable('weights',shape=[10, 10, 3 , 3],initializer=tf.truncated_normal_initializer(stddev=0.1),regularizer=slim.l2_regularizer(0.05),device='/CPU:0')
model_variables = slim.get_model_variables()# Regular variables
my_var = slim.variable('my_var',shape=[20, 1],initializer=tf.zeros_initializer())
regular_variables_and_model_variables = slim.get_variables()

如上,变量分为两类:模型变量和局部变量。
局部变量是不作为模型参数保存的,而模型变量会在保存的时候保存下来。这个玩过tensorflow的人都会明白,诸如global_step之类的就是局部变量。
slim中可以写明变量存放的设备,正则和初始化规则。
还有获取变量的函数也需要注意一下,get_variables是返回所有的变量。

2. slim中实现一个层

首先让我们看看tensorflow怎么实现一个层,例如卷积层:

input = ...
with tf.name_scope('conv1_1') as scope:kernel = tf.Variable(tf.truncated_normal([3, 3, 64, 128], dtype=tf.float32,stddev=1e-1), name='weights')conv = tf.nn.conv2d(input, kernel, [1, 1, 1, 1], padding='SAME')biases = tf.Variable(tf.constant(0.0, shape=[128], dtype=tf.float32),trainable=True, name='biases')bias = tf.nn.bias_add(conv, biases)conv1 = tf.nn.relu(bias, name=scope)

然后slim的实现:

input = ...
net = slim.conv2d(input, 128, [3, 3], scope='conv1_1')

但这个不是重要的,因为 tenorflow 目前也有大部分层的简单实现,这里比较吸引人的是 slim 中的 repeat 和 stack 操作:

假设定义三个相同的卷积层:

net = ...
net = slim.conv2d(net, 256, [3, 3], scope='conv3_1')
net = slim.conv2d(net, 256, [3, 3], scope='conv3_2')
net = slim.conv2d(net, 256, [3, 3], scope='conv3_3')
net = slim.max_pool2d(net, [2, 2], scope='pool2')

在slim中的repeat操作可以减少代码量:

net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3')
net = slim.max_pool2d(net, [2, 2], scope='pool2')

假设定义三层FC:

# Verbose way:
x = slim.fully_connected(x, 32, scope='fc/fc_1')
x = slim.fully_connected(x, 64, scope='fc/fc_2')
x = slim.fully_connected(x, 128, scope='fc/fc_3')

使用stack操作:

slim.stack(x, slim.fully_connected, [32, 64, 128], scope='fc')

同理卷积层也一样:

# 普通方法:
x = slim.conv2d(x, 32, [3, 3], scope='core/core_1')
x = slim.conv2d(x, 32, [1, 1], scope='core/core_2')
x = slim.conv2d(x, 64, [3, 3], scope='core/core_3')
x = slim.conv2d(x, 64, [1, 1], scope='core/core_4')# 简便方法:
slim.stack(x, slim.conv2d, [(32, [3, 3]), (32, [1, 1]), (64, [3, 3]), (64, [1, 1])], scope='core')

3. slim 中的 argscope

如果你的网络有大量相同的参数,如下:

net = slim.conv2d(inputs, 64, [11, 11], 4, padding='SAME',weights_initializer=tf.truncated_normal_initializer(stddev=0.01),weights_regularizer=slim.l2_regularizer(0.0005), scope='conv1')
net = slim.conv2d(net, 128, [11, 11], padding='VALID',weights_initializer=tf.truncated_normal_initializer(stddev=0.01),weights_regularizer=slim.l2_regularizer(0.0005), scope='conv2')
net = slim.conv2d(net, 256, [11, 11], padding='SAME',weights_initializer=tf.truncated_normal_initializer(stddev=0.01),weights_regularizer=slim.l2_regularizer(0.0005), scope='conv3')

然后我们用arg_scope处理一下:

with slim.arg_scope([slim.conv2d], padding='SAME',weights_initializer=tf.truncated_normal_initializer(stddev=0.01)weights_regularizer=slim.l2_regularizer(0.0005)):net = slim.conv2d(inputs, 64, [11, 11], scope='conv1')net = slim.conv2d(net, 128, [11, 11], padding='VALID', scope='conv2')net = slim.conv2d(net, 256, [11, 11], scope='conv3')

是不是一下子就变简洁了?
这里额外说明一点,arg_scope 的作用范围内,是定义了指定层的替代参数,若想特别指定某些层的参数,可以重新赋值(相当于重置),如上倒数第二行代码。
那如果除了卷积层还有其他层呢?那就要如下定义:

with slim.arg_scope([slim.conv2d, slim.fully_connected],activation_fn=tf.nn.relu,weights_initializer=tf.truncated_normal_initializer(stddev=0.01),weights_regularizer=slim.l2_regularizer(0.0005)):with slim.arg_scope([slim.conv2d], stride=1, padding='SAME'):net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID', scope='conv1')net = slim.conv2d(net, 256, [5, 5],weights_initializer=tf.truncated_normal_initializer(stddev=0.03),scope='conv2')net = slim.fully_connected(net, 1000, activation_fn=None, scope='fc')

写两个arg_scope就行了

采用如上方法,定义一个VGG也就十几行代码的事了

def vgg16(inputs):with slim.arg_scope([slim.conv2d, slim.fully_connected],activation_fn=tf.nn.relu,weights_initializer=tf.truncated_normal_initializer(0.0, 0.01),weights_regularizer=slim.l2_regularizer(0.0005)):net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1')net = slim.max_pool2d(net, [2, 2], scope='pool1')net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2')net = slim.max_pool2d(net, [2, 2], scope='pool2')net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3')net = slim.max_pool2d(net, [2, 2], scope='pool3')net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv4')net = slim.max_pool2d(net, [2, 2], scope='pool4')net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv5')net = slim.max_pool2d(net, [2, 2], scope='pool5')net = slim.fully_connected(net, 4096, scope='fc6')net = slim.dropout(net, 0.5, scope='dropout6')net = slim.fully_connected(net, 4096, scope='fc7')net = slim.dropout(net, 0.5, scope='dropout7')net = slim.fully_connected(net, 1000, activation_fn=None, scope='fc8')return net

四、训练模型

这个没什么好说的,说一下直接拿经典网络来训练吧。

import tensorflow as tf
vgg = tf.contrib.slim.nets.vgg# Load the images and labels.
images, labels = ...# Create the model.
predictions, _ = vgg.vgg_16(images)# Define the loss functions and get the total loss.
loss = slim.losses.softmax_cross_entropy(predictions, labels)

不是超级简单?

关于损耗,要说一下定义自己的损耗的方法,以及注意不要忘记加入到slim中让slim看到你的损耗。

还有正则项也是需要手动添加进损当中的,不然最后计算的时候就不优化正则目标了。

# Load the images and labels.
images, scene_labels, depth_labels, pose_labels = ...# Create the model.
scene_predictions, depth_predictions, pose_predictions = CreateMultiTaskModel(images)# Define the loss functions and get the total loss.
classification_loss = slim.losses.softmax_cross_entropy(scene_predictions, scene_labels)
sum_of_squares_loss = slim.losses.sum_of_squares(depth_predictions, depth_labels)
pose_loss = MyCustomLossFunction(pose_predictions, pose_labels)
slim.losses.add_loss(pose_loss) # Letting TF-Slim know about the additional loss.# The following two ways to compute the total loss are equivalent:
regularization_loss = tf.add_n(slim.losses.get_regularization_losses())
total_loss1 = classification_loss + sum_of_squares_loss + pose_loss + regularization_loss# (Regularization Loss is included in the total loss by default).
total_loss2 = slim.losses.get_total_loss()

五、读取保存模型变量

通过以下功能我们可以加载模型的部分变量:

# Create some variables.
v1 = slim.variable(name="v1", ...)
v2 = slim.variable(name="nested/v2", ...)
...# Get list of variables to restore (which contains only 'v2').
variables_to_restore = slim.get_variables_by_name("v2")# Create the saver which will be used to restore the variables.
restorer = tf.train.Saver(variables_to_restore)with tf.Session() as sess:# Restore variables from disk.restorer.restore(sess, "/tmp/model.ckpt")print("Model restored.")

除了这种部分变量加载的方法外,我们甚至还能加载到不同的名字的变量中。

假设我们定义的网络变量是conv1 / weights,而从VGG加载的变量称为vgg16 / conv1 / weights,正常负载肯定会报错(找不到变量名),但是可以这样:

def name_in_checkpoint(var):return 'vgg16/' + var.op.namevariables_to_restore = slim.get_model_variables()
variables_to_restore = {name_in_checkpoint(var):var for var in variables_to_restore}
restorer = tf.train.Saver(variables_to_restore)with tf.Session() as sess:# Restore variables from disk.restorer.restore(sess, "/tmp/model.ckpt")

通过这种方式我们可以加载不同变量名的变量!!

参考:
【1】https://blog.csdn.net/zhoujunr1/article/details/77131605
【2】https://www.2cto.com/kf/201706/649266.html
【3】https://blog.csdn.net/zsWang9/article/details/79965501

TensorFlow 之 slim(TF-Slim)介绍相关推荐

  1. CV之NS之VGG16:基于TF Slim(VGG16)利用七个不同的预训练模型实现快速NS风格

    CV之NS之VGG16:基于TF Slim(VGG16)利用七个不同的预训练模型实现快速NS风格 目录 实现结果 部分实例代码 实现结果 1.本博主,以前几天拍过的东方明珠照片,为例进行快速NS风格 ...

  2. CV之NS之VGG16:基于TF Slim库利用VGG16算法的预训练模型实现七种不同快速图像风格迁移设计(cubist/denoised_starry/mosaic/scream/wave)案例

    CV之NS之VGG16:基于TF Slim库利用VGG16算法的预训练模型实现七种不同快速图像风格迁移设计(cubist/denoised_starry/feathers/mosaic/scream/ ...

  3. TF:tensorflow框架中常用函数介绍—tf.Variable()和tf.get_variable()用法及其区别

    TF:tensorflow框架中常用函数介绍-tf.Variable()和tf.get_variable()用法及其区别 目录 tensorflow框架 tensorflow.Variable()函数 ...

  4. tensorflow兼容处理 tensorflow.compat.v1 tf.contrib

    20201130 问题提出: v1版本中tensorflow中contrib模块十分丰富,但是发展不可控,因此在v2版本中将这个模块集成到其他模块中去了.在学习tensorflow经常碰到tf.con ...

  5. 图片基础与tf.keras介绍

    图片基础与tf.keras介绍 1.1 图像基本知识 回忆:之前在特征抽取中如何将文本处理成数值. 思考:如何将图片文件转换成机器学习算法能够处理的数据? 我们经常接触到的图片有两种,一种是黑白图片( ...

  6. Tensorflow读取数据-tf.data.TFRecordDataset

    tensorflow TFRecords文件的生成和读取方法 文章目录 tensorflow TFRecords文件的生成和读取方法 1. TFRecords说明 2.关键API 2.1 tf.io. ...

  7. php slim 教程,Slim - 超轻量级PHP Restful API构建框架

    下载源码包: http://www.slimframework.com/ 基于Slim的Restful API Sample: require '/darjuan/Slim/Slim.php'; us ...

  8. tensorflow中的tf.summary.image

    tensorflow中的tf.summary.image tf.summary.image(name,#生成的节点的名称.也将作为TensorBoard中的系列名称tensor,#uint8或者flo ...

  9. 记录 之 tensorflow函数:tf.data.Dataset.from_tensor_slices

    tf.data.Dataset.from_tensor_slices(),是常见的数据处理函数,它的作用是将给定的元组(turple).列表(list).张量(tensor)等特征进行特征切片.切片的 ...

  10. tensorflow 启动Session(tf.Session(),tf.InteractivesSession(),tf.train.Supervisor().managed_session() )

    (1)tf.Session() 计算图构造完成后, 才能启动图. 启动图的第一步是创建一个 Session 对象. 示例程序: #coding:utf-8 import tensorflow as t ...

最新文章

  1. Java 事件适配器 Adapter
  2. 前端三十三:表单form
  3. att48数据集最优值10628的解
  4. Myeclipse+mysql出现中文乱码情况
  5. 重写了GD32VF103的启动脚本和链接脚本
  6. JVM的常用配置参数
  7. Thread.yield()和Thread.sleep(0)
  8. 【MySQL】mysql show操作简单示例
  9. android程序的建立,创建第一个Android程序 HelloWorld
  10. Bailian2686 打印完数【暴力】
  11. python之MD5加密
  12. 苹果airplay是什么 苹果手机投屏到电脑
  13. 智能生活雷达应用,微波雷达技术发展,微波雷达感应模块方案
  14. java-家庭作业2
  15. ConcurrenHashMap源码(JDK1.7)
  16. 天猫商城在线购物系统
  17. 视频直播秒开背后的技术与优化经验
  18. imp-00003:oracle error 959 encountered
  19. 新技術讓大數據“看得見”
  20. 企业获客的五种方式解读

热门文章

  1. python math类
  2. scheme语言编写运行
  3. 程序员如何让自己的技术能力突飞猛进?
  4. 彩扩机项目--开关滤波进阶,电机驱动桥,死区,三极管搭建反向电路
  5. 重仓金融股却遭“滑铁卢”
  6. 运筹帷幄之中,决胜千里之外——运筹学1-3章
  7. 致翔OA漏洞复现手册
  8. java课程设计 计算器_Java课程设计-计算器
  9. 分离LZY的字符串(循环)
  10. autojs之语音识别