版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。

本文链接:https://blog.csdn.net/bryant_meng/article/details/79260165
            </div><!--一个博主专栏付费入口--><!--一个博主专栏付费入口结束--><link rel="stylesheet" href="https://csdnimg.cn/release/phoenix/template/css/ck_htmledit_views-833878f763.css"><div id="content_views" class="markdown_views prism-github-gist"><!-- flowchart 箭头图标 勿删 --><svg xmlns="http://www.w3.org/2000/svg" style="display: none;"><path stroke-linecap="round" d="M5,0 0,2.5 5,5z" id="raphael-marker-block" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></path></svg><p><img src="https://img-blog.csdnimg.cn/20181119143732887.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L2JyeWFudF9tZW5n,size_16,color_FFFFFF,t_70#pic_center" alt="在这里插入图片描述"></p>

arXiv-2015
In NIPS Deep Learning Workshop, 2014


文章目录

  • 1 Background and Motivation
  • 2 Conceptual block
  • 3 Knowledge Distilling
    • 3.1 hard target
    • 3.2 soft target
    • 3.3 学 logits 和学 softmax+T 的区别
    • 3.4 softmax+T 相比 logits 的优势
    • 3.5 Cost function
  • 4 Dataset
  • 5 Experiments
  • 6 References
  • 7 Appendix
    • A. softmax 加 temperature 后的变化
    • B. knowledge distilling (MNIST)代码
      • B.1 teacher network
      • B.2 student network

本文只涉及《Distilling the Knowledge in a Neural Network》有关分类的部分,更多相关paper可以参考《Paper》

1 Background and Motivation

提高模型的 performance 一个很简单的思路是
train many different models on the same data and then to average their predictions

缺点

  • 用 ensemble 来预测结果太 cumbersome
  • 可能由于计算成本太高而无法部署到大量用户中,特别是如果单个模型是大型神经网络的话

Caruana 证实了 ensemble model to single model 的可行性
(demonstrate convincingly that the knowledge acquired by a large ensemble of models can be transferred to a single small model)

作者采用 knowledge distilling 的方法(全新的压缩方法)来实现这个过程(ensemble model to single model)

2 Conceptual block


1)对于模型学到的知识有个思想误区,这些知识常被认为是模型中已经训练好的参数。这种狭隘的思想曾一度阻碍了灌输学习的发展,因为一旦网络模型的结构发生变化,其所谓的知识/参数便无法得到有效利用。文中作者提出了对知识的更加宏观、抽象的理解,知识即为网络学习到的从输入vectors 到输出 vectors 之间的一种映射关系

这样理解的话就不局限于模型的具体结构,使得小网络学习大网络成为可能!

2)另外一个思想误区是训练的目标函数应该尽量贴近真实值。尽管如此,尽管如此,模型训练的目的是让模型在训练数据集上表现尽可能好,而实际的目的却是模型在新数据上的泛化能力。显然,如果我们能够训练模型,从而使之具有优越的泛化性能,那真真是极好的!可是这几乎是不可能的因为关于泛化的信息难以获取。然而,在进行知识灌输时,大模学到的泛化能力可以很自然地传输给小模,由于大模体型庞大泛化能力出色,由他带出来的小模的泛化能力肯定比从头训练小模效果要好很多。

那么大模型的泛化性能是怎么传给小模型的呢? 通过 soft target,大网络 softmax 输出(传统 softmax 加上 temperature) 作为 label,这就是 soft target ,用小网络的 softmax 输出去逼近大网络的 softmax 的输出。对应 hard target 就是原数据集的标签。soft target 比 hard target 好的地方如上面的 PPT。

为什么说 soft target 就包含了模型泛化性能的信息呢? 个人理解是,soft target 相对 hard target 有更多的类类关系

3 Knowledge Distilling

3.1 hard target

我们先看一下 hard target (softmax)的计算

更形象一点(来自知乎)


3.2 soft target

再看下 soft target (softmax + T)的效果


横坐标是温度 T,纵坐标是 soft target 的输出 qiqiqiqiqi q_iqiqiqi​∂zi​∂21​⋅(zi​−vi​)2​=zi​−vi​

3.4 softmax+T 相比 logits 的优势

既然学 logits 和学 softmax+T 的一种特例,那么 学 softmax + T 相比之下,有哪些优势呢?
作者做出了如下总结

  • logits are almost completely unconstrained by the cost function used for training the cumbersome model so they could be very noisy
  • very negative logits may convey useful information about the knowledge acquired by the cumbersome model

3.5 Cost function

小网络的损失函数如下

从大网络学泛化性能的时候,用比较大的T(T越大,越不自信,如果在这种不自信的情况下还能辨认类别,当测试的时候T=1,就会表现的更好,类比负重训练)训练,学真实数据的时候,用T = 1

将真实标签与soft target结合起来,采用二者的加权和作为目标标签可以获得更好的效果。从而目标函数转化为下式,其中,λ取小于1的数值时效果较好。

4 Dataset

MNIST

5 Experiments

网络结构:

  • 大网络:2个隐含层,每层1200个单元,55000训练样本。用dropout训练。
  • 小网络1(常规):2个隐含层,每层800个单元,无正则化。采用常规方式直接训练。
  • 小网络1(soft):2个隐含层,每层800个单元,无正则化。采用知识灌输法,师从大模进行训练。T=20。

错误个数对比:

  • 大网络:67
  • 小网络1(常规):146
  • 小网络1(soft):74

泛化性能的实验
为了研究小网络的泛化能力,作者将所有数字3的图片从transfer set 数据(训练小网络的数据集,可以比训练大网络的数据集小,也可以为空)集中删除,也就是说小网络在训练过程中从未见过3这个数字。尽管如何,在测试中发现,小网络对于数字3取得了高达98.6%的准确率。另外,即使transfer set数据集仅包含数字7和数字8的图片,小模的错误率仅有13.2%。说明,小网络从大网络那里继承了泛化性能!

Q1:论文中第三节,调整实验的时候改变 bias 怎么理解?

6 References

【1】【论文导读】Hinton - Distilling the Knowledge in a Neural Network

【2】手打例子一步一步带你看懂softmax函数以及相关求导过程

【3】知识蒸馏(Distillation)相关论文阅读(1)——Distilling the Knowledge in a Neural Network(以及代码复现)

7 Appendix

A. softmax 加 temperature 后的变化

import math
import numpy as np
import matplotlib.pyplot as plt
T = np.arange(1,20,1)
y1 = (math.e**(0.9/T)/(math.e**(0.07/T)+math.e**(0.9/T)+math.e**(0.03/T)))
y2 = (math.e**(0.07/T)/(math.e**(0.07/T)+math.e**(0.9/T)+math.e**(0.03/T)))
y3 = (math.e**(0.03/T)/(math.e**(0.07/T)+math.e**(0.9/T)+math.e**(0.03/T)))
plt.plot(T,y1)
plt.plot(T,y2)
plt.plot(T,y3)
plt.legend(["0.9", "0.07","0.03"])# 图例
plt.grid()#网格
#plt.savefig('1.png')
plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

B. knowledge distilling (MNIST)代码

  • 代码来源(TensorFlow版本):
    akimach/tensorflow-distillation-examples
    也可以通过如下方式下载:链接:https://pan.baidu.com/s/1vDud4Iws_xnDxRqRnpyR-g 提取码:cemy

  • 知识补充:
    《Tensorflow | 莫烦 》learning notes
    【Keras-MLP】MNIST
    【TensorFlow-MLP】MNIST

MNIST training data is 60000,为什么这里是 55000,还有 5000 是 validation data

B.1 teacher network

2个隐含层,每层1200个单元,55000训练样本。用dropout = 0.5 训练

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import random
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from sklearn.utils import shuffle
import matplotlib.pyplot as plt
%matplotlib inline

random.seed(123)
np.random.seed(123)
tf.set_random_seed(123)

# 载入数据集
mnist = input_data.read_data_sets("/root/userfolder/Experiment/MNIST_data/", one_hot=True)
mnist = input_data.read_data_sets("/root/userfolder/Experiment/MNIST_data/", one_hot=True)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

网络层的定义

# hyper parameters
n_epochs = 50
batch_size = 50
num_nodes_h1 = 1200
num_nodes_h2 = 1200
learning_rate = 0.001

# number of batches
n_batches = len(mnist.train.images) // batch_size # 55000

# 定义 W
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial)

# 定义 b
def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)

# 定义 soft max with T
def softmax_with_temperature(logits, temp=1.0, axis=1, name=None):
logits_with_temp = logits / temp
_softmax = tf.exp(logits_with_temp) / tf.reduce_sum(tf.exp(logits_with_temp), axis=axis, keep_dims=True)
return _softmax

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

网络结构的设计

# data
x = tf.placeholder(tf.float32, [None, 784])
y_ = tf.placeholder(tf.float32, [None, 10])
keep_prob = tf.placeholder(tf.float32)# drop out

# input to hidden layer 1
W_h1 = weight_variable([784, num_nodes_h1])# 784,1200
b_h1 = bias_variable([num_nodes_h1])# 1200
h1 = tf.nn.relu(tf.matmul(x, W_h1) + b_h1) # relu(wx+b)
h1_drop = tf.nn.dropout(h1, keep_prob) # drop out

# hidden layer 1 to hidden layer 2
W_h2 = weight_variable([num_nodes_h1, num_nodes_h2])# 1200,1200
b_h2 = bias_variable([num_nodes_h2])# 1200
h2 = tf.nn.relu(tf.matmul(h1_drop, W_h2) + b_h2)# relu(wx+b)
h2_drop = tf.nn.dropout(h2, keep_prob) # drop out

# hidden layer 2 to output layer
W_output = tf.Variable(tf.zeros([num_nodes_h2, 10]))
b_output = tf.Variable(tf.zeros([10]))
logits = tf.matmul(h2_drop, W_output) + b_output

y = tf.nn.softmax(logits) # hard target
y_soft_target = softmax_with_temperature(logits, temp=2.0) # soft target
loss = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y), reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(loss)

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29

用 mini-batch 开始训练,并把训练的模型保留下来,训练的 loss,训练测试的 accuracy 记录下来

saver = tf.train.Saver()
losses = []
accs = []
test_accs = []

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for epoch in range(n_epochs):# epoch
x_shuffle, y_shuffle = shuffle(mnist.train.images, mnist.train.labels)
for i in range(n_batches):# batches
start = i * batch_size
end = start + batch_size
batch_x, batch_y = x_shuffle[start:end], y_shuffle[start:end]
sess.run(train_step, feed_dict={
x: batch_x, y_: batch_y, keep_prob:0.5})
train_loss = sess.run(loss, feed_dict={x: batch_x, y_: batch_y, keep_prob:0.5})
train_accuracy = sess.run(accuracy, feed_dict={
x: batch_x, y_: batch_y, keep_prob:1.0})
test_accuracy = sess.run(accuracy, feed_dict={
x: mnist.test.images, y_: mnist.test.labels, keep_prob:1.0})
print(“Epoch : %i, train loss : %f, Accuracy: %f, Test accuracy: %f” % (
epoch+1, train_loss, train_accuracy, test_accuracy))
saver.save(sess, “/root/userfolder/Experiment/tensorflow-distillation-examples/model_teacher/”,
global_step=epoch+1)# 只保留最新的几个 epoch
losses.append(train_loss)
accs.append(train_accuracy)
test_accs.append(test_accuracy)
print("… completed!")

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28

output

Epoch : 1, train loss : 0.737658, Accuracy: 0.880000, Test accuracy: 0.870400
Epoch : 2, train loss : 0.761208, Accuracy: 0.900000, Test accuracy: 0.877700
Epoch : 3, train loss : 0.589437, Accuracy: 0.920000, Test accuracy: 0.890600
Epoch : 4, train loss : 0.643363, Accuracy: 0.900000, Test accuracy: 0.899900
Epoch : 5, train loss : 0.616038, Accuracy: 0.900000, Test accuracy: 0.900900
Epoch : 6, train loss : 0.611822, Accuracy: 0.860000, Test accuracy: 0.907100
Epoch : 7, train loss : 0.644078, Accuracy: 0.860000, Test accuracy: 0.909100
Epoch : 8, train loss : 0.402896, Accuracy: 0.960000, Test accuracy: 0.911100
Epoch : 9, train loss : 0.572901, Accuracy: 0.960000, Test accuracy: 0.907900
Epoch : 10, train loss : 0.517088, Accuracy: 0.900000, Test accuracy: 0.914600
Epoch : 11, train loss : 0.410240, Accuracy: 0.960000, Test accuracy: 0.914300
Epoch : 12, train loss : 0.945823, Accuracy: 0.800000, Test accuracy: 0.916200
Epoch : 13, train loss : 0.579927, Accuracy: 0.900000, Test accuracy: 0.917000
Epoch : 14, train loss : 0.503660, Accuracy: 0.860000, Test accuracy: 0.918300
Epoch : 15, train loss : 0.532867, Accuracy: 0.940000, Test accuracy: 0.918600
Epoch : 16, train loss : 0.430909, Accuracy: 0.940000, Test accuracy: 0.920300
Epoch : 17, train loss : 0.507866, Accuracy: 0.920000, Test accuracy: 0.920600
Epoch : 18, train loss : 0.453426, Accuracy: 0.920000, Test accuracy: 0.925200
Epoch : 19, train loss : 0.689311, Accuracy: 0.920000, Test accuracy: 0.926600
Epoch : 20, train loss : 0.379545, Accuracy: 0.940000, Test accuracy: 0.926100
Epoch : 21, train loss : 0.431786, Accuracy: 0.920000, Test accuracy: 0.926800
Epoch : 22, train loss : 0.401257, Accuracy: 0.960000, Test accuracy: 0.927300
Epoch : 23, train loss : 0.587902, Accuracy: 0.960000, Test accuracy: 0.928600
Epoch : 24, train loss : 0.620417, Accuracy: 0.880000, Test accuracy: 0.927400
Epoch : 25, train loss : 0.365211, Accuracy: 0.940000, Test accuracy: 0.929500
Epoch : 26, train loss : 0.427130, Accuracy: 0.960000, Test accuracy: 0.930300
Epoch : 27, train loss : 0.253452, Accuracy: 0.900000, Test accuracy: 0.930800
Epoch : 28, train loss : 0.427312, Accuracy: 0.920000, Test accuracy: 0.930900
Epoch : 29, train loss : 0.419188, Accuracy: 0.900000, Test accuracy: 0.933100
Epoch : 30, train loss : 0.268312, Accuracy: 0.940000, Test accuracy: 0.933800
Epoch : 31, train loss : 0.346375, Accuracy: 0.920000, Test accuracy: 0.933500
Epoch : 32, train loss : 0.292108, Accuracy: 0.960000, Test accuracy: 0.933000
Epoch : 33, train loss : 0.436444, Accuracy: 0.960000, Test accuracy: 0.935100
Epoch : 34, train loss : 0.278850, Accuracy: 0.940000, Test accuracy: 0.934900
Epoch : 35, train loss : 0.277737, Accuracy: 0.940000, Test accuracy: 0.937300
Epoch : 36, train loss : 0.425431, Accuracy: 0.940000, Test accuracy: 0.937300
Epoch : 37, train loss : 0.359413, Accuracy: 0.940000, Test accuracy: 0.937800
Epoch : 38, train loss : 0.338502, Accuracy: 0.960000, Test accuracy: 0.937600
Epoch : 39, train loss : 0.433313, Accuracy: 0.880000, Test accuracy: 0.937100
Epoch : 40, train loss : 0.529199, Accuracy: 0.860000, Test accuracy: 0.938700
Epoch : 41, train loss : 0.657401, Accuracy: 0.920000, Test accuracy: 0.938500
Epoch : 42, train loss : 0.491150, Accuracy: 0.920000, Test accuracy: 0.938600
Epoch : 43, train loss : 0.334091, Accuracy: 0.940000, Test accuracy: 0.940200
Epoch : 44, train loss : 0.298908, Accuracy: 0.940000, Test accuracy: 0.941000
Epoch : 45, train loss : 0.303939, Accuracy: 0.940000, Test accuracy: 0.939800
Epoch : 46, train loss : 0.378838, Accuracy: 0.940000, Test accuracy: 0.939500
Epoch : 47, train loss : 0.323622, Accuracy: 0.920000, Test accuracy: 0.941700
Epoch : 48, train loss : 0.280403, Accuracy: 0.940000, Test accuracy: 0.943000
Epoch : 49, train loss : 0.390651, Accuracy: 0.920000, Test accuracy: 0.942800
Epoch : 50, train loss : 0.614632, Accuracy: 0.900000, Test accuracy: 0.941700
... completed!
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51

可视化训练的loss

#查看训练的损失变化
plt.title("Loss of teacher")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.plot(range(1, len(losses)+1), losses, label='train_loss')
plt.legend()
plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7


可视化训练和测试的 accuracy

# 查看训练精度和测试精度的变化
plt.title("Accuracy of teacher")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.plot(range(1, len(accs)+1), accs, label='Training')
plt.plot(range(1, len(test_accs)+1), test_accs, label='Test')
plt.legend()
plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

把 训练的 loss,训练测试的 accuracy 保存下来

# 保存训练loss 和 accuracy 以及测试的 accuracy
np.save("loss_teacher.npy", np.array(losses))
np.save("acc_train_teacher.npy", np.array(accs))
np.save("acc_test_teacher.npy", np.array(test_accs))
  • 1
  • 2
  • 3
  • 4

保存 teacher network 的soft target,我们选择表现好一点 epoch 训练结果,下面的保存的 第48个 epoch

# 保存 第48个 epoch 的soft target
_soft_targets = []
with tf.Session() as sess:saver.restore(sess, "./model_teacher/-48")print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob:1.0}))for i in range(n_batches):start = i * batch_sizeend = start + batch_sizebatch_x = mnist.train.images[start:end]soft_target = sess.run(y_soft_target, feed_dict={x: batch_x, keep_prob:1.0})_soft_targets.append(soft_target)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

看下 _sotf_targets 的形式,reshape一下

np.shape(_soft_targets)# (1100, 50, 10) = (batch,batch_size,classes)
soft_targets  = np.c_[_soft_targets].reshape(55000, 10)# reshape (5500,10)
  • 1
  • 2

对比下 soft target 和 hard target

print(soft_targets[:2])
print(mnist.train.labels[:2]) # label 可以和上面的softmax 预测结果对比一下
  • 1
  • 2

output

[[5.2621812e-03 6.1693429e-03 1.5207376e-01 6.1155759e-02 1.4845385e-024.8464271e-03 3.6828788e-03 6.0641229e-01 2.9818511e-02 1.1573344e-01][2.4089564e-03 2.6752956e-03 1.8253580e-02 8.5861373e-01 3.0618338e-041.7423177e-02 9.3506598e-05 3.6187540e-03 8.3464541e-02 1.3142269e-02]]
[[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.][0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

保存 teacher network 的 soft target,方便 student network learning

np.save('soft-targets.npy', soft_targets)
  • 1

查看其 shape

np.load(file="soft-targets.npy").shape
  • 1

output

(55000, 10)
  • 1

B.2 student network

和 teacher network 的区别是 hidden layer 的大小(1200 to 600,论文中是800),以及loss的变化,其它一样

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import random
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from sklearn.utils import shuffle
import matplotlib.pyplot as plt

%matplotlib inline

random.seed(123)
np.random.seed(123)
tf.set_random_seed(123)

mnist = input_data.read_data_sets("/root/userfolder/Experiment/MNIST_data/", one_hot=True)
mnist = input_data.read_data_sets("/root/userfolder/Experiment/MNIST_data/", one_hot=True)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

载入 teacher network 的 soft target

soft_targets = np.load(file="soft-targets.npy")
print(np.shape(soft_targets))
  • 1
  • 2

output

(55000, 10)
  • 1

hyper parameters 设置,W,b ,soft target 的定义

n_epochs = 50
batch_size = 50
num_nodes_h1 = 600 # Before 800
num_nodes_h2 = 600 # Before 800
learning_rate = 0.001

n_batches = len(mnist.train.images) // batch_size

def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial)

def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)

def softmax_with_temperature(logits, temp=1.0, axis=1, name=None):
logits_with_temp = logits / temp
_softmax = tf.exp(logits_with_temp) / tf.reduce_sum(tf.exp(logits_with_temp),
axis=axis, keep_dims=True)
return _softmax

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

网络的设计

x = tf.placeholder(tf.float32, [None, 784])
y_ = tf.placeholder(tf.float32, [None, 10])
soft_target_ = tf.placeholder(tf.float32, [None, 10])

keep_prob = tf.placeholder(tf.float32)
T = tf.placeholder(tf.float32)

W_h1 = weight_variable([784, num_nodes_h1])
b_h1 = bias_variable([num_nodes_h1])
h1 = tf.nn.relu(tf.matmul(x, W_h1) + b_h1)
h1_drop = tf.nn.dropout(h1, keep_prob)

W_h2 = weight_variable([num_nodes_h1, num_nodes_h2])
b_h2 = bias_variable([num_nodes_h2])
h2 = tf.nn.relu(tf.matmul(h1_drop, W_h2) + b_h2)
h2_drop = tf.nn.dropout(h2, keep_prob)# 还是用了drop out

W_output = tf.Variable(tf.zeros([num_nodes_h2, 10]))
b_output = tf.Variable(tf.zeros([10]))
logits = tf.matmul(h2_drop, W_output) + b_output

y = tf.nn.softmax(logits)
y_soft_target = softmax_with_temperature(logits, temp=T)

loss_hard_target = -tf.reduce_sum(y_ tf.log(y), reduction_indices=[1])
loss_soft_target = -tf.reduce_sum(soft_target_ tf.log(y_soft_target),
reduction_indices=[1])

loss = tf.reduce_mean(tf.square(T) loss_hard_target + tf.square(T) loss_soft_target)

train_step = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(loss)

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34

开始训练,高温训练,低温测试

saver = tf.train.Saver()
losses = []
accs = []
test_accs = []
with tf.Session() as sess:sess.run(tf.global_variables_initializer())for epoch in range(n_epochs):x_shuffle, y_shuffle, soft_targets_shuffle \= shuffle(mnist.train.images, mnist.train.labels, soft_targets)for i in range(n_batches):start = i * batch_sizeend = start + batch_sizebatch_x, batch_y, batch_soft_targets \= x_shuffle[start:end], y_shuffle[start:end], soft_targets_shuffle[start:end]sess.run(train_step, feed_dict={x: batch_x, y_: batch_y, soft_target_:batch_soft_targets, keep_prob:0.5, T:2.0})train_loss = sess.run(loss, feed_dict={x: batch_x, y_: batch_y, soft_target_:batch_soft_targets, keep_prob:0.5, T:2.0})# 高温训练train_accuracy = sess.run(accuracy, feed_dict={x: batch_x, y_: batch_y, keep_prob:1.0, T:1.0})test_accuracy = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob:1.0, T:1.0})# 低温测试print("Epoch : %i, Loss : %f, Accuracy: %f, Test accuracy: %f" % (epoch+1, train_loss, train_accuracy, test_accuracy))saver.save(sess, "/root/userfolder/Experiment/tensorflow-distillation-examples/model_student/", global_step=epoch+1)losses.append(train_loss)accs.append(train_accuracy)test_accs.append(test_accuracy)print("... completed!")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33

output,可以看出,结果青出于蓝

Epoch : 1, Loss : 7.137307, Accuracy: 0.860000, Test accuracy: 0.868200
Epoch : 2, Loss : 5.926404, Accuracy: 0.940000, Test accuracy: 0.892200
Epoch : 3, Loss : 5.597841, Accuracy: 0.920000, Test accuracy: 0.901400
Epoch : 4, Loss : 5.938632, Accuracy: 0.920000, Test accuracy: 0.913000
Epoch : 5, Loss : 5.872798, Accuracy: 0.920000, Test accuracy: 0.915800
Epoch : 6, Loss : 5.436497, Accuracy: 0.920000, Test accuracy: 0.919300
Epoch : 7, Loss : 5.455486, Accuracy: 0.880000, Test accuracy: 0.924100
Epoch : 8, Loss : 4.402141, Accuracy: 0.980000, Test accuracy: 0.927100
Epoch : 9, Loss : 5.413333, Accuracy: 0.960000, Test accuracy: 0.929700
Epoch : 10, Loss : 4.503023, Accuracy: 0.960000, Test accuracy: 0.931900
Epoch : 11, Loss : 4.971416, Accuracy: 0.960000, Test accuracy: 0.934800
Epoch : 12, Loss : 6.448879, Accuracy: 0.880000, Test accuracy: 0.937300
Epoch : 13, Loss : 6.164934, Accuracy: 0.920000, Test accuracy: 0.939000
Epoch : 14, Loss : 5.904130, Accuracy: 0.880000, Test accuracy: 0.940200
Epoch : 15, Loss : 5.206109, Accuracy: 0.940000, Test accuracy: 0.941200
Epoch : 16, Loss : 4.704682, Accuracy: 0.960000, Test accuracy: 0.942000
Epoch : 17, Loss : 4.707399, Accuracy: 0.940000, Test accuracy: 0.943000
Epoch : 18, Loss : 4.608377, Accuracy: 0.940000, Test accuracy: 0.944000
Epoch : 19, Loss : 6.394137, Accuracy: 0.900000, Test accuracy: 0.944600
Epoch : 20, Loss : 4.419221, Accuracy: 0.980000, Test accuracy: 0.944900
Epoch : 21, Loss : 4.322970, Accuracy: 0.960000, Test accuracy: 0.946800
Epoch : 22, Loss : 3.958002, Accuracy: 0.960000, Test accuracy: 0.946400
Epoch : 23, Loss : 4.949951, Accuracy: 0.960000, Test accuracy: 0.947600
Epoch : 24, Loss : 5.640293, Accuracy: 0.900000, Test accuracy: 0.947100
Epoch : 25, Loss : 4.615621, Accuracy: 0.940000, Test accuracy: 0.948300
Epoch : 26, Loss : 4.853579, Accuracy: 0.940000, Test accuracy: 0.948600
Epoch : 27, Loss : 4.839081, Accuracy: 0.960000, Test accuracy: 0.949700
Epoch : 28, Loss : 4.525964, Accuracy: 0.940000, Test accuracy: 0.950600
Epoch : 29, Loss : 5.636992, Accuracy: 0.940000, Test accuracy: 0.950700
Epoch : 30, Loss : 4.566214, Accuracy: 0.980000, Test accuracy: 0.951200
Epoch : 31, Loss : 4.846083, Accuracy: 0.960000, Test accuracy: 0.951300
Epoch : 32, Loss : 4.274162, Accuracy: 0.980000, Test accuracy: 0.951700
Epoch : 33, Loss : 4.423202, Accuracy: 0.960000, Test accuracy: 0.951800
Epoch : 34, Loss : 4.516046, Accuracy: 0.940000, Test accuracy: 0.952200
Epoch : 35, Loss : 3.987510, Accuracy: 0.940000, Test accuracy: 0.952900
Epoch : 36, Loss : 4.587525, Accuracy: 0.940000, Test accuracy: 0.953200
Epoch : 37, Loss : 4.149089, Accuracy: 0.960000, Test accuracy: 0.953300
Epoch : 38, Loss : 4.955534, Accuracy: 0.940000, Test accuracy: 0.953900
Epoch : 39, Loss : 5.080862, Accuracy: 0.960000, Test accuracy: 0.954700
Epoch : 40, Loss : 5.033619, Accuracy: 0.900000, Test accuracy: 0.954500
Epoch : 41, Loss : 5.110637, Accuracy: 0.940000, Test accuracy: 0.954100
Epoch : 42, Loss : 5.486012, Accuracy: 0.940000, Test accuracy: 0.954300
Epoch : 43, Loss : 4.117889, Accuracy: 0.980000, Test accuracy: 0.955800
Epoch : 44, Loss : 3.833005, Accuracy: 0.940000, Test accuracy: 0.955900
Epoch : 45, Loss : 4.636988, Accuracy: 0.960000, Test accuracy: 0.954500
Epoch : 46, Loss : 5.074997, Accuracy: 0.940000, Test accuracy: 0.955700
Epoch : 47, Loss : 4.291631, Accuracy: 0.960000, Test accuracy: 0.954800
Epoch : 48, Loss : 4.045475, Accuracy: 0.960000, Test accuracy: 0.956500
Epoch : 49, Loss : 4.960283, Accuracy: 0.920000, Test accuracy: 0.957400
Epoch : 50, Loss : 5.411842, Accuracy: 0.940000, Test accuracy: 0.956300
... completed!
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51

可视化 training loss

plt.title("Loss of student")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.plot(range(1, len(losses)+1), losses, label='train_loss')
plt.legend()
plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

可视化一下训练和测试的 accuracy

plt.title("Accuracy of teacher")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.plot(range(1, len(accs)+1), accs, label='Training')
plt.plot(range(1, len(test_accs)+1), test_accs, label='Test')
plt.legend()
plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7


看一下某个模型的精度

with tf.Session() as sess:saver.restore(sess, "./model_student/-49")print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob:1.0}))
  • 1
  • 2
  • 3

output

INFO:tensorflow:Restoring parameters from ./model_student/-49
0.9574
  • 1
  • 2

保存一下 精度和损失

np.save("loss_student.npy", np.array(losses))
np.save("acc_student.npy", np.array(accs))
np.save("acc_test_student.npy", np.array(test_accs))
  • 1
  • 2
  • 3
                                </div><link href="https://csdnimg.cn/release/phoenix/mdeditor/markdown_views-b6c3c6d139.css" rel="stylesheet"></div>

Distilling the Knowledge in a Neural Network 论文笔记蒸馏相关推荐

  1. Distilling the Knowledge in a Neural Network阅读笔记

    文章目录 Abstract Introduction Distillation Preliminary experiments on MNIST Experiments on speech recog ...

  2. 【Distilling】《Distilling the Knowledge in a Neural Network》

    arXiv-2015 In NIPS Deep Learning Workshop, 2014 文章目录 1 Background and Motivation 2 Conceptual block ...

  3. Paper:《Distilling the Knowledge in a Neural Network神经网络中的知识蒸馏》翻译与解读

    Paper:<Distilling the Knowledge in a Neural Network神经网络中的知识蒸馏>翻译与解读 目录 <Distilling the Know ...

  4. 《Distilling the Knowledge in a Neural Network》 论文阅读笔记

    原文链接:https://arxiv.org/abs/1503.02531   第一次接触这篇文章是在做网络结构的时候,对于神经网络加速,知识蒸馏也算是一种方法,当时连同剪纸等都是网络压缩的内容,觉得 ...

  5. Procedural Noise Adversarial Examples for Black-Box Attacks on Deep Neural Networks论文笔记

    Procedural Noise Adversarial Examples for Black-Box Attacks on Deep Neural Networks论文笔记 0. 概述 如今一些深度 ...

  6. Identifying drug–target interactions based on graph convolutional network and deep neural network 论文

    Identifying drug–target interactions based on graph convolutional network and deep neural network 文章 ...

  7. 蒸馏神经网络(Distill the Knowledge in a Neural Network)

    本文是阅读Hinton 大神在2014年NIPS上一篇论文:蒸馏神经网络的笔记,特此说明.此文读起来很抽象,大篇的论述,鲜有公式和图表.但是鉴于和我的研究方向:神经网络的压缩十分相关,因此决定花气力好 ...

  8. Reinforcement Learning Enhanced Heterogeneous Graph Neural Network阅读笔记

    强化学习增强异质图神经网络 代码源:https://github.com/zhiqiangzhongddu/RL-HGNN 摘要 异构信息网络(HINs)涉及多种节点类型和关系类型,在许多实际应用中非 ...

  9. GNN金融应用之Classifying and Understanding Financial Data Using Graph Neural Network学习笔记

    Classifying and Understanding Financial Data Using Graph Neural Network 摘要 1. 概述 2. 数据表示-加权图 3. GNN利 ...

最新文章

  1. Spring Cloud Ribbon(服务消费者)
  2. mybatis的面试一对一,一对多,多对多的mapper.xml配置
  3. java.security.InvalidKeyException:illegal Key Size
  4. wxWidgets:wxHeaderCtrl类用法
  5. 编写python程序、创建名为class的数据库_Python中的元类(metaclass)以及元类实现单例模式...
  6. Flatten Binary Tree to Linked List (DFS)
  7. 重学java基础第十八课:卸载jdk和安装jdk
  8. 电感检测_几种常用的电流检测方式
  9. spark读写Oracle、hive的艰辛之路(一)
  10. 使用PYQT5打开海康威视工业相机并获取图像进行显示
  11. Postgresql13之FETCH FIRST ROWS … WITH TIES展示打结的行
  12. 关于实现某宝或某妈永久登录
  13. Photoshop——多变量+文字数据组替换+批处理详细操作
  14. java游戏演示ppt_java项目开发实战──五子棋游戏.ppt
  15. python输入生日输出星座_python输入日期输出星座?
  16. 微信小程序获取用户收货地址列表wx.chooseAddress
  17. QGIS如何将高程DEM统一增加数值
  18. java中的actionlistener_JAVA事件监听器之BUTTON类中的ADDACTIONLISTENER(ACTIONLISTENER L)方法...
  19. 查看并杀死Tomcat进程
  20. iOS -- OpenSSL生成RSA双密匙+签名证书(流程)

热门文章

  1. 力扣每日一题——独一无二出现的次数
  2. 2022-2028年中国可生物降解农用薄膜产业竞争现状及投资决策建议报告
  3. LeetCode简单题之至少是其他数字两倍的最大数
  4. 企业如何选择音视频会议系统分析
  5. Linux下Flash-LED的处理
  6. 大规模数据处理Apache Spark开发
  7. 用户自定义协议client/server代码示例
  8. 2020年人工智能汽车将出台多项标准
  9. 利用硅光子学的移动心脏监护仪
  10. arm,asic,dsp,fpga,mcu,soc各自的特点