Distilling the Knowledge in a Neural Network 论文笔记蒸馏
</div><!--一个博主专栏付费入口--><!--一个博主专栏付费入口结束--><link rel="stylesheet" href="https://csdnimg.cn/release/phoenix/template/css/ck_htmledit_views-833878f763.css"><div id="content_views" class="markdown_views prism-github-gist"><!-- flowchart 箭头图标 勿删 --><svg xmlns="http://www.w3.org/2000/svg" style="display: none;"><path stroke-linecap="round" d="M5,0 0,2.5 5,5z" id="raphael-marker-block" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></path></svg><p><img src="https://img-blog.csdnimg.cn/20181119143732887.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L2JyeWFudF9tZW5n,size_16,color_FFFFFF,t_70#pic_center" alt="在这里插入图片描述"></p>
arXiv-2015
In NIPS Deep Learning Workshop, 2014
文章目录
- 1 Background and Motivation
- 2 Conceptual block
- 3 Knowledge Distilling
- 3.1 hard target
- 3.2 soft target
- 3.3 学 logits 和学 softmax+T 的区别
- 3.4 softmax+T 相比 logits 的优势
- 3.5 Cost function
- 4 Dataset
- 5 Experiments
- 6 References
- 7 Appendix
- A. softmax 加 temperature 后的变化
- B. knowledge distilling (MNIST)代码
- B.1 teacher network
- B.2 student network
本文只涉及《Distilling the Knowledge in a Neural Network》有关分类的部分,更多相关paper可以参考《Paper》
1 Background and Motivation
提高模型的 performance 一个很简单的思路是
train many different models on the same data and then to average their predictions
缺点
- 用 ensemble 来预测结果太 cumbersome
- 可能由于计算成本太高而无法部署到大量用户中,特别是如果单个模型是大型神经网络的话
Caruana 证实了 ensemble model to single model 的可行性
(demonstrate convincingly that the knowledge acquired by a large ensemble of models can be transferred to a single small model)
作者采用 knowledge distilling 的方法(全新的压缩方法)来实现这个过程(ensemble model to single model)
2 Conceptual block
1)对于模型学到的知识有个思想误区,这些知识常被认为是模型中已经训练好的参数。这种狭隘的思想曾一度阻碍了灌输学习的发展,因为一旦网络模型的结构发生变化,其所谓的知识/参数便无法得到有效利用。文中作者提出了对知识的更加宏观、抽象的理解,知识即为网络学习到的从输入vectors 到输出 vectors 之间的一种映射关系。
这样理解的话就不局限于模型的具体结构,使得小网络学习大网络成为可能!
2)另外一个思想误区是训练的目标函数应该尽量贴近真实值。尽管如此,尽管如此,模型训练的目的是让模型在训练数据集上表现尽可能好,而实际的目的却是模型在新数据上的泛化能力。显然,如果我们能够训练模型,从而使之具有优越的泛化性能,那真真是极好的!可是这几乎是不可能的因为关于泛化的信息难以获取。然而,在进行知识灌输时,大模学到的泛化能力可以很自然地传输给小模,由于大模体型庞大泛化能力出色,由他带出来的小模的泛化能力肯定比从头训练小模效果要好很多。
那么大模型的泛化性能是怎么传给小模型的呢? 通过 soft target,大网络 softmax 输出(传统 softmax 加上 temperature) 作为 label,这就是 soft target ,用小网络的 softmax 输出去逼近大网络的 softmax 的输出。对应 hard target 就是原数据集的标签。soft target 比 hard target 好的地方如上面的 PPT。
为什么说 soft target 就包含了模型泛化性能的信息呢? 个人理解是,soft target 相对 hard target 有更多的类类关系
3 Knowledge Distilling
3.1 hard target
我们先看一下 hard target (softmax)的计算
更形象一点(来自知乎)
3.2 soft target
再看下 soft target (softmax + T)的效果
横坐标是温度 T,纵坐标是 soft target 的输出 qiqiqiqiqi q_iqiqiqi∂zi∂21⋅(zi−vi)2=zi−vi
3.4 softmax+T 相比 logits 的优势
既然学 logits 和学 softmax+T 的一种特例,那么 学 softmax + T 相比之下,有哪些优势呢?
作者做出了如下总结
- logits are almost completely unconstrained by the cost function used for training the cumbersome model so they could be very noisy
- very negative logits may convey useful information about the knowledge acquired by the cumbersome model
3.5 Cost function
小网络的损失函数如下
从大网络学泛化性能的时候,用比较大的T(T越大,越不自信,如果在这种不自信的情况下还能辨认类别,当测试的时候T=1,就会表现的更好,类比负重训练)训练,学真实数据的时候,用T = 1
将真实标签与soft target结合起来,采用二者的加权和作为目标标签可以获得更好的效果。从而目标函数转化为下式,其中,λ取小于1的数值时效果较好。
4 Dataset
MNIST
5 Experiments
网络结构:
- 大网络:2个隐含层,每层1200个单元,55000训练样本。用dropout训练。
- 小网络1(常规):2个隐含层,每层800个单元,无正则化。采用常规方式直接训练。
- 小网络1(soft):2个隐含层,每层800个单元,无正则化。采用知识灌输法,师从大模进行训练。T=20。
错误个数对比:
- 大网络:67
- 小网络1(常规):146
- 小网络1(soft):74
泛化性能的实验
为了研究小网络的泛化能力,作者将所有数字3的图片从transfer set 数据(训练小网络的数据集,可以比训练大网络的数据集小,也可以为空)集中删除,也就是说小网络在训练过程中从未见过3这个数字。尽管如何,在测试中发现,小网络对于数字3取得了高达98.6%的准确率。另外,即使transfer set数据集仅包含数字7和数字8的图片,小模的错误率仅有13.2%。说明,小网络从大网络那里继承了泛化性能!
Q1:论文中第三节,调整实验的时候改变 bias 怎么理解?
6 References
【1】【论文导读】Hinton - Distilling the Knowledge in a Neural Network
【2】手打例子一步一步带你看懂softmax函数以及相关求导过程
【3】知识蒸馏(Distillation)相关论文阅读(1)——Distilling the Knowledge in a Neural Network(以及代码复现)
7 Appendix
A. softmax 加 temperature 后的变化
import math
import numpy as np
import matplotlib.pyplot as plt
T = np.arange(1,20,1)
y1 = (math.e**(0.9/T)/(math.e**(0.07/T)+math.e**(0.9/T)+math.e**(0.03/T)))
y2 = (math.e**(0.07/T)/(math.e**(0.07/T)+math.e**(0.9/T)+math.e**(0.03/T)))
y3 = (math.e**(0.03/T)/(math.e**(0.07/T)+math.e**(0.9/T)+math.e**(0.03/T)))
plt.plot(T,y1)
plt.plot(T,y2)
plt.plot(T,y3)
plt.legend(["0.9", "0.07","0.03"])# 图例
plt.grid()#网格
#plt.savefig('1.png')
plt.show()
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
B. knowledge distilling (MNIST)代码
代码来源(TensorFlow版本):
akimach/tensorflow-distillation-examples
也可以通过如下方式下载:链接:https://pan.baidu.com/s/1vDud4Iws_xnDxRqRnpyR-g 提取码:cemy知识补充:
《Tensorflow | 莫烦 》learning notes
【Keras-MLP】MNIST
【TensorFlow-MLP】MNIST
MNIST training data is 60000,为什么这里是 55000,还有 5000 是 validation data
B.1 teacher network
2个隐含层,每层1200个单元,55000训练样本。用dropout = 0.5 训练
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from sklearn.utils import shuffle
import matplotlib.pyplot as plt
%matplotlib inline
random.seed(123)
np.random.seed(123)
tf.set_random_seed(123)
# 载入数据集
mnist = input_data.read_data_sets("/root/userfolder/Experiment/MNIST_data/", one_hot=True)
mnist = input_data.read_data_sets("/root/userfolder/Experiment/MNIST_data/", one_hot=True)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
网络层的定义
# hyper parameters
n_epochs = 50
batch_size = 50
num_nodes_h1 = 1200
num_nodes_h2 = 1200
learning_rate = 0.001
# number of batches
n_batches = len(mnist.train.images) // batch_size # 55000
# 定义 W
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial)
# 定义 b
def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)
# 定义 soft max with T
def softmax_with_temperature(logits, temp=1.0, axis=1, name=None):
logits_with_temp = logits / temp
_softmax = tf.exp(logits_with_temp) / tf.reduce_sum(tf.exp(logits_with_temp), axis=axis, keep_dims=True)
return _softmax
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
网络结构的设计
# data
x = tf.placeholder(tf.float32, [None, 784])
y_ = tf.placeholder(tf.float32, [None, 10])
keep_prob = tf.placeholder(tf.float32)# drop out
# input to hidden layer 1
W_h1 = weight_variable([784, num_nodes_h1])# 784,1200
b_h1 = bias_variable([num_nodes_h1])# 1200
h1 = tf.nn.relu(tf.matmul(x, W_h1) + b_h1) # relu(wx+b)
h1_drop = tf.nn.dropout(h1, keep_prob) # drop out
# hidden layer 1 to hidden layer 2
W_h2 = weight_variable([num_nodes_h1, num_nodes_h2])# 1200,1200
b_h2 = bias_variable([num_nodes_h2])# 1200
h2 = tf.nn.relu(tf.matmul(h1_drop, W_h2) + b_h2)# relu(wx+b)
h2_drop = tf.nn.dropout(h2, keep_prob) # drop out
# hidden layer 2 to output layer
W_output = tf.Variable(tf.zeros([num_nodes_h2, 10]))
b_output = tf.Variable(tf.zeros([10]))
logits = tf.matmul(h2_drop, W_output) + b_output
y = tf.nn.softmax(logits) # hard target
y_soft_target = softmax_with_temperature(logits, temp=2.0) # soft target
loss = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y), reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(loss)
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
用 mini-batch 开始训练,并把训练的模型保留下来,训练的 loss,训练测试的 accuracy 记录下来
saver = tf.train.Saver()
losses = []
accs = []
test_accs = []
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for epoch in range(n_epochs):# epoch
x_shuffle, y_shuffle = shuffle(mnist.train.images, mnist.train.labels)
for i in range(n_batches):# batches
start = i * batch_size
end = start + batch_size
batch_x, batch_y = x_shuffle[start:end], y_shuffle[start:end]
sess.run(train_step, feed_dict={
x: batch_x, y_: batch_y, keep_prob:0.5})
train_loss = sess.run(loss, feed_dict={x: batch_x, y_: batch_y, keep_prob:0.5})
train_accuracy = sess.run(accuracy, feed_dict={
x: batch_x, y_: batch_y, keep_prob:1.0})
test_accuracy = sess.run(accuracy, feed_dict={
x: mnist.test.images, y_: mnist.test.labels, keep_prob:1.0})
print(“Epoch : %i, train loss : %f, Accuracy: %f, Test accuracy: %f” % (
epoch+1, train_loss, train_accuracy, test_accuracy))
saver.save(sess, “/root/userfolder/Experiment/tensorflow-distillation-examples/model_teacher/”,
global_step=epoch+1)# 只保留最新的几个 epoch
losses.append(train_loss)
accs.append(train_accuracy)
test_accs.append(test_accuracy)
print("… completed!")
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
output
Epoch : 1, train loss : 0.737658, Accuracy: 0.880000, Test accuracy: 0.870400
Epoch : 2, train loss : 0.761208, Accuracy: 0.900000, Test accuracy: 0.877700
Epoch : 3, train loss : 0.589437, Accuracy: 0.920000, Test accuracy: 0.890600
Epoch : 4, train loss : 0.643363, Accuracy: 0.900000, Test accuracy: 0.899900
Epoch : 5, train loss : 0.616038, Accuracy: 0.900000, Test accuracy: 0.900900
Epoch : 6, train loss : 0.611822, Accuracy: 0.860000, Test accuracy: 0.907100
Epoch : 7, train loss : 0.644078, Accuracy: 0.860000, Test accuracy: 0.909100
Epoch : 8, train loss : 0.402896, Accuracy: 0.960000, Test accuracy: 0.911100
Epoch : 9, train loss : 0.572901, Accuracy: 0.960000, Test accuracy: 0.907900
Epoch : 10, train loss : 0.517088, Accuracy: 0.900000, Test accuracy: 0.914600
Epoch : 11, train loss : 0.410240, Accuracy: 0.960000, Test accuracy: 0.914300
Epoch : 12, train loss : 0.945823, Accuracy: 0.800000, Test accuracy: 0.916200
Epoch : 13, train loss : 0.579927, Accuracy: 0.900000, Test accuracy: 0.917000
Epoch : 14, train loss : 0.503660, Accuracy: 0.860000, Test accuracy: 0.918300
Epoch : 15, train loss : 0.532867, Accuracy: 0.940000, Test accuracy: 0.918600
Epoch : 16, train loss : 0.430909, Accuracy: 0.940000, Test accuracy: 0.920300
Epoch : 17, train loss : 0.507866, Accuracy: 0.920000, Test accuracy: 0.920600
Epoch : 18, train loss : 0.453426, Accuracy: 0.920000, Test accuracy: 0.925200
Epoch : 19, train loss : 0.689311, Accuracy: 0.920000, Test accuracy: 0.926600
Epoch : 20, train loss : 0.379545, Accuracy: 0.940000, Test accuracy: 0.926100
Epoch : 21, train loss : 0.431786, Accuracy: 0.920000, Test accuracy: 0.926800
Epoch : 22, train loss : 0.401257, Accuracy: 0.960000, Test accuracy: 0.927300
Epoch : 23, train loss : 0.587902, Accuracy: 0.960000, Test accuracy: 0.928600
Epoch : 24, train loss : 0.620417, Accuracy: 0.880000, Test accuracy: 0.927400
Epoch : 25, train loss : 0.365211, Accuracy: 0.940000, Test accuracy: 0.929500
Epoch : 26, train loss : 0.427130, Accuracy: 0.960000, Test accuracy: 0.930300
Epoch : 27, train loss : 0.253452, Accuracy: 0.900000, Test accuracy: 0.930800
Epoch : 28, train loss : 0.427312, Accuracy: 0.920000, Test accuracy: 0.930900
Epoch : 29, train loss : 0.419188, Accuracy: 0.900000, Test accuracy: 0.933100
Epoch : 30, train loss : 0.268312, Accuracy: 0.940000, Test accuracy: 0.933800
Epoch : 31, train loss : 0.346375, Accuracy: 0.920000, Test accuracy: 0.933500
Epoch : 32, train loss : 0.292108, Accuracy: 0.960000, Test accuracy: 0.933000
Epoch : 33, train loss : 0.436444, Accuracy: 0.960000, Test accuracy: 0.935100
Epoch : 34, train loss : 0.278850, Accuracy: 0.940000, Test accuracy: 0.934900
Epoch : 35, train loss : 0.277737, Accuracy: 0.940000, Test accuracy: 0.937300
Epoch : 36, train loss : 0.425431, Accuracy: 0.940000, Test accuracy: 0.937300
Epoch : 37, train loss : 0.359413, Accuracy: 0.940000, Test accuracy: 0.937800
Epoch : 38, train loss : 0.338502, Accuracy: 0.960000, Test accuracy: 0.937600
Epoch : 39, train loss : 0.433313, Accuracy: 0.880000, Test accuracy: 0.937100
Epoch : 40, train loss : 0.529199, Accuracy: 0.860000, Test accuracy: 0.938700
Epoch : 41, train loss : 0.657401, Accuracy: 0.920000, Test accuracy: 0.938500
Epoch : 42, train loss : 0.491150, Accuracy: 0.920000, Test accuracy: 0.938600
Epoch : 43, train loss : 0.334091, Accuracy: 0.940000, Test accuracy: 0.940200
Epoch : 44, train loss : 0.298908, Accuracy: 0.940000, Test accuracy: 0.941000
Epoch : 45, train loss : 0.303939, Accuracy: 0.940000, Test accuracy: 0.939800
Epoch : 46, train loss : 0.378838, Accuracy: 0.940000, Test accuracy: 0.939500
Epoch : 47, train loss : 0.323622, Accuracy: 0.920000, Test accuracy: 0.941700
Epoch : 48, train loss : 0.280403, Accuracy: 0.940000, Test accuracy: 0.943000
Epoch : 49, train loss : 0.390651, Accuracy: 0.920000, Test accuracy: 0.942800
Epoch : 50, train loss : 0.614632, Accuracy: 0.900000, Test accuracy: 0.941700
... completed!
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
可视化训练的loss
#查看训练的损失变化
plt.title("Loss of teacher")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.plot(range(1, len(losses)+1), losses, label='train_loss')
plt.legend()
plt.show()
- 1
- 2
- 3
- 4
- 5
- 6
- 7
可视化训练和测试的 accuracy
# 查看训练精度和测试精度的变化
plt.title("Accuracy of teacher")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.plot(range(1, len(accs)+1), accs, label='Training')
plt.plot(range(1, len(test_accs)+1), test_accs, label='Test')
plt.legend()
plt.show()
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
把 训练的 loss,训练测试的 accuracy 保存下来
# 保存训练loss 和 accuracy 以及测试的 accuracy
np.save("loss_teacher.npy", np.array(losses))
np.save("acc_train_teacher.npy", np.array(accs))
np.save("acc_test_teacher.npy", np.array(test_accs))
- 1
- 2
- 3
- 4
保存 teacher network 的soft target,我们选择表现好一点 epoch 训练结果,下面的保存的 第48个 epoch
# 保存 第48个 epoch 的soft target
_soft_targets = []
with tf.Session() as sess:saver.restore(sess, "./model_teacher/-48")print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob:1.0}))for i in range(n_batches):start = i * batch_sizeend = start + batch_sizebatch_x = mnist.train.images[start:end]soft_target = sess.run(y_soft_target, feed_dict={x: batch_x, keep_prob:1.0})_soft_targets.append(soft_target)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
看下 _sotf_targets
的形式,reshape一下
np.shape(_soft_targets)# (1100, 50, 10) = (batch,batch_size,classes)
soft_targets = np.c_[_soft_targets].reshape(55000, 10)# reshape (5500,10)
- 1
- 2
对比下 soft target 和 hard target
print(soft_targets[:2])
print(mnist.train.labels[:2]) # label 可以和上面的softmax 预测结果对比一下
- 1
- 2
output
[[5.2621812e-03 6.1693429e-03 1.5207376e-01 6.1155759e-02 1.4845385e-024.8464271e-03 3.6828788e-03 6.0641229e-01 2.9818511e-02 1.1573344e-01][2.4089564e-03 2.6752956e-03 1.8253580e-02 8.5861373e-01 3.0618338e-041.7423177e-02 9.3506598e-05 3.6187540e-03 8.3464541e-02 1.3142269e-02]]
[[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.][0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]]
- 1
- 2
- 3
- 4
- 5
- 6
保存 teacher network 的 soft target,方便 student network learning
np.save('soft-targets.npy', soft_targets)
- 1
查看其 shape
np.load(file="soft-targets.npy").shape
- 1
output
(55000, 10)
- 1
B.2 student network
和 teacher network 的区别是 hidden layer 的大小(1200 to 600,论文中是800),以及loss的变化,其它一样
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from sklearn.utils import shuffle
import matplotlib.pyplot as plt
%matplotlib inline
random.seed(123)
np.random.seed(123)
tf.set_random_seed(123)
mnist = input_data.read_data_sets("/root/userfolder/Experiment/MNIST_data/", one_hot=True)
mnist = input_data.read_data_sets("/root/userfolder/Experiment/MNIST_data/", one_hot=True)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
载入 teacher network 的 soft target
soft_targets = np.load(file="soft-targets.npy")
print(np.shape(soft_targets))
- 1
- 2
output
(55000, 10)
- 1
hyper parameters 设置,W,b ,soft target 的定义
n_epochs = 50
batch_size = 50
num_nodes_h1 = 600 # Before 800
num_nodes_h2 = 600 # Before 800
learning_rate = 0.001
n_batches = len(mnist.train.images) // batch_size
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial)
def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)
def softmax_with_temperature(logits, temp=1.0, axis=1, name=None):
logits_with_temp = logits / temp
_softmax = tf.exp(logits_with_temp) / tf.reduce_sum(tf.exp(logits_with_temp),
axis=axis, keep_dims=True)
return _softmax
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
网络的设计
x = tf.placeholder(tf.float32, [None, 784])
y_ = tf.placeholder(tf.float32, [None, 10])
soft_target_ = tf.placeholder(tf.float32, [None, 10])
keep_prob = tf.placeholder(tf.float32)
T = tf.placeholder(tf.float32)
W_h1 = weight_variable([784, num_nodes_h1])
b_h1 = bias_variable([num_nodes_h1])
h1 = tf.nn.relu(tf.matmul(x, W_h1) + b_h1)
h1_drop = tf.nn.dropout(h1, keep_prob)
W_h2 = weight_variable([num_nodes_h1, num_nodes_h2])
b_h2 = bias_variable([num_nodes_h2])
h2 = tf.nn.relu(tf.matmul(h1_drop, W_h2) + b_h2)
h2_drop = tf.nn.dropout(h2, keep_prob)# 还是用了drop out
W_output = tf.Variable(tf.zeros([num_nodes_h2, 10]))
b_output = tf.Variable(tf.zeros([10]))
logits = tf.matmul(h2_drop, W_output) + b_output
y = tf.nn.softmax(logits)
y_soft_target = softmax_with_temperature(logits, temp=T)
loss_hard_target = -tf.reduce_sum(y_ tf.log(y), reduction_indices=[1])
loss_soft_target = -tf.reduce_sum(soft_target_ tf.log(y_soft_target),
reduction_indices=[1])
loss = tf.reduce_mean(tf.square(T) loss_hard_target + tf.square(T) loss_soft_target)
train_step = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(loss)
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
开始训练,高温训练,低温测试
saver = tf.train.Saver()
losses = []
accs = []
test_accs = []
with tf.Session() as sess:sess.run(tf.global_variables_initializer())for epoch in range(n_epochs):x_shuffle, y_shuffle, soft_targets_shuffle \= shuffle(mnist.train.images, mnist.train.labels, soft_targets)for i in range(n_batches):start = i * batch_sizeend = start + batch_sizebatch_x, batch_y, batch_soft_targets \= x_shuffle[start:end], y_shuffle[start:end], soft_targets_shuffle[start:end]sess.run(train_step, feed_dict={x: batch_x, y_: batch_y, soft_target_:batch_soft_targets, keep_prob:0.5, T:2.0})train_loss = sess.run(loss, feed_dict={x: batch_x, y_: batch_y, soft_target_:batch_soft_targets, keep_prob:0.5, T:2.0})# 高温训练train_accuracy = sess.run(accuracy, feed_dict={x: batch_x, y_: batch_y, keep_prob:1.0, T:1.0})test_accuracy = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob:1.0, T:1.0})# 低温测试print("Epoch : %i, Loss : %f, Accuracy: %f, Test accuracy: %f" % (epoch+1, train_loss, train_accuracy, test_accuracy))saver.save(sess, "/root/userfolder/Experiment/tensorflow-distillation-examples/model_student/", global_step=epoch+1)losses.append(train_loss)accs.append(train_accuracy)test_accs.append(test_accuracy)print("... completed!")
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
output,可以看出,结果青出于蓝
Epoch : 1, Loss : 7.137307, Accuracy: 0.860000, Test accuracy: 0.868200
Epoch : 2, Loss : 5.926404, Accuracy: 0.940000, Test accuracy: 0.892200
Epoch : 3, Loss : 5.597841, Accuracy: 0.920000, Test accuracy: 0.901400
Epoch : 4, Loss : 5.938632, Accuracy: 0.920000, Test accuracy: 0.913000
Epoch : 5, Loss : 5.872798, Accuracy: 0.920000, Test accuracy: 0.915800
Epoch : 6, Loss : 5.436497, Accuracy: 0.920000, Test accuracy: 0.919300
Epoch : 7, Loss : 5.455486, Accuracy: 0.880000, Test accuracy: 0.924100
Epoch : 8, Loss : 4.402141, Accuracy: 0.980000, Test accuracy: 0.927100
Epoch : 9, Loss : 5.413333, Accuracy: 0.960000, Test accuracy: 0.929700
Epoch : 10, Loss : 4.503023, Accuracy: 0.960000, Test accuracy: 0.931900
Epoch : 11, Loss : 4.971416, Accuracy: 0.960000, Test accuracy: 0.934800
Epoch : 12, Loss : 6.448879, Accuracy: 0.880000, Test accuracy: 0.937300
Epoch : 13, Loss : 6.164934, Accuracy: 0.920000, Test accuracy: 0.939000
Epoch : 14, Loss : 5.904130, Accuracy: 0.880000, Test accuracy: 0.940200
Epoch : 15, Loss : 5.206109, Accuracy: 0.940000, Test accuracy: 0.941200
Epoch : 16, Loss : 4.704682, Accuracy: 0.960000, Test accuracy: 0.942000
Epoch : 17, Loss : 4.707399, Accuracy: 0.940000, Test accuracy: 0.943000
Epoch : 18, Loss : 4.608377, Accuracy: 0.940000, Test accuracy: 0.944000
Epoch : 19, Loss : 6.394137, Accuracy: 0.900000, Test accuracy: 0.944600
Epoch : 20, Loss : 4.419221, Accuracy: 0.980000, Test accuracy: 0.944900
Epoch : 21, Loss : 4.322970, Accuracy: 0.960000, Test accuracy: 0.946800
Epoch : 22, Loss : 3.958002, Accuracy: 0.960000, Test accuracy: 0.946400
Epoch : 23, Loss : 4.949951, Accuracy: 0.960000, Test accuracy: 0.947600
Epoch : 24, Loss : 5.640293, Accuracy: 0.900000, Test accuracy: 0.947100
Epoch : 25, Loss : 4.615621, Accuracy: 0.940000, Test accuracy: 0.948300
Epoch : 26, Loss : 4.853579, Accuracy: 0.940000, Test accuracy: 0.948600
Epoch : 27, Loss : 4.839081, Accuracy: 0.960000, Test accuracy: 0.949700
Epoch : 28, Loss : 4.525964, Accuracy: 0.940000, Test accuracy: 0.950600
Epoch : 29, Loss : 5.636992, Accuracy: 0.940000, Test accuracy: 0.950700
Epoch : 30, Loss : 4.566214, Accuracy: 0.980000, Test accuracy: 0.951200
Epoch : 31, Loss : 4.846083, Accuracy: 0.960000, Test accuracy: 0.951300
Epoch : 32, Loss : 4.274162, Accuracy: 0.980000, Test accuracy: 0.951700
Epoch : 33, Loss : 4.423202, Accuracy: 0.960000, Test accuracy: 0.951800
Epoch : 34, Loss : 4.516046, Accuracy: 0.940000, Test accuracy: 0.952200
Epoch : 35, Loss : 3.987510, Accuracy: 0.940000, Test accuracy: 0.952900
Epoch : 36, Loss : 4.587525, Accuracy: 0.940000, Test accuracy: 0.953200
Epoch : 37, Loss : 4.149089, Accuracy: 0.960000, Test accuracy: 0.953300
Epoch : 38, Loss : 4.955534, Accuracy: 0.940000, Test accuracy: 0.953900
Epoch : 39, Loss : 5.080862, Accuracy: 0.960000, Test accuracy: 0.954700
Epoch : 40, Loss : 5.033619, Accuracy: 0.900000, Test accuracy: 0.954500
Epoch : 41, Loss : 5.110637, Accuracy: 0.940000, Test accuracy: 0.954100
Epoch : 42, Loss : 5.486012, Accuracy: 0.940000, Test accuracy: 0.954300
Epoch : 43, Loss : 4.117889, Accuracy: 0.980000, Test accuracy: 0.955800
Epoch : 44, Loss : 3.833005, Accuracy: 0.940000, Test accuracy: 0.955900
Epoch : 45, Loss : 4.636988, Accuracy: 0.960000, Test accuracy: 0.954500
Epoch : 46, Loss : 5.074997, Accuracy: 0.940000, Test accuracy: 0.955700
Epoch : 47, Loss : 4.291631, Accuracy: 0.960000, Test accuracy: 0.954800
Epoch : 48, Loss : 4.045475, Accuracy: 0.960000, Test accuracy: 0.956500
Epoch : 49, Loss : 4.960283, Accuracy: 0.920000, Test accuracy: 0.957400
Epoch : 50, Loss : 5.411842, Accuracy: 0.940000, Test accuracy: 0.956300
... completed!
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
可视化 training loss
plt.title("Loss of student")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.plot(range(1, len(losses)+1), losses, label='train_loss')
plt.legend()
plt.show()
- 1
- 2
- 3
- 4
- 5
- 6
可视化一下训练和测试的 accuracy
plt.title("Accuracy of teacher")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.plot(range(1, len(accs)+1), accs, label='Training')
plt.plot(range(1, len(test_accs)+1), test_accs, label='Test')
plt.legend()
plt.show()
- 1
- 2
- 3
- 4
- 5
- 6
- 7
看一下某个模型的精度
with tf.Session() as sess:saver.restore(sess, "./model_student/-49")print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob:1.0}))
- 1
- 2
- 3
output
INFO:tensorflow:Restoring parameters from ./model_student/-49
0.9574
- 1
- 2
保存一下 精度和损失
np.save("loss_student.npy", np.array(losses))
np.save("acc_student.npy", np.array(accs))
np.save("acc_test_student.npy", np.array(test_accs))
- 1
- 2
- 3
</div><link href="https://csdnimg.cn/release/phoenix/mdeditor/markdown_views-b6c3c6d139.css" rel="stylesheet"></div>
Distilling the Knowledge in a Neural Network 论文笔记蒸馏相关推荐
- Distilling the Knowledge in a Neural Network阅读笔记
文章目录 Abstract Introduction Distillation Preliminary experiments on MNIST Experiments on speech recog ...
- 【Distilling】《Distilling the Knowledge in a Neural Network》
arXiv-2015 In NIPS Deep Learning Workshop, 2014 文章目录 1 Background and Motivation 2 Conceptual block ...
- Paper:《Distilling the Knowledge in a Neural Network神经网络中的知识蒸馏》翻译与解读
Paper:<Distilling the Knowledge in a Neural Network神经网络中的知识蒸馏>翻译与解读 目录 <Distilling the Know ...
- 《Distilling the Knowledge in a Neural Network》 论文阅读笔记
原文链接:https://arxiv.org/abs/1503.02531 第一次接触这篇文章是在做网络结构的时候,对于神经网络加速,知识蒸馏也算是一种方法,当时连同剪纸等都是网络压缩的内容,觉得 ...
- Procedural Noise Adversarial Examples for Black-Box Attacks on Deep Neural Networks论文笔记
Procedural Noise Adversarial Examples for Black-Box Attacks on Deep Neural Networks论文笔记 0. 概述 如今一些深度 ...
- Identifying drug–target interactions based on graph convolutional network and deep neural network 论文
Identifying drug–target interactions based on graph convolutional network and deep neural network 文章 ...
- 蒸馏神经网络(Distill the Knowledge in a Neural Network)
本文是阅读Hinton 大神在2014年NIPS上一篇论文:蒸馏神经网络的笔记,特此说明.此文读起来很抽象,大篇的论述,鲜有公式和图表.但是鉴于和我的研究方向:神经网络的压缩十分相关,因此决定花气力好 ...
- Reinforcement Learning Enhanced Heterogeneous Graph Neural Network阅读笔记
强化学习增强异质图神经网络 代码源:https://github.com/zhiqiangzhongddu/RL-HGNN 摘要 异构信息网络(HINs)涉及多种节点类型和关系类型,在许多实际应用中非 ...
- GNN金融应用之Classifying and Understanding Financial Data Using Graph Neural Network学习笔记
Classifying and Understanding Financial Data Using Graph Neural Network 摘要 1. 概述 2. 数据表示-加权图 3. GNN利 ...
最新文章
- Spring Cloud Ribbon(服务消费者)
- mybatis的面试一对一,一对多,多对多的mapper.xml配置
- java.security.InvalidKeyException:illegal Key Size
- wxWidgets:wxHeaderCtrl类用法
- 编写python程序、创建名为class的数据库_Python中的元类(metaclass)以及元类实现单例模式...
- Flatten Binary Tree to Linked List (DFS)
- 重学java基础第十八课:卸载jdk和安装jdk
- 电感检测_几种常用的电流检测方式
- spark读写Oracle、hive的艰辛之路(一)
- 使用PYQT5打开海康威视工业相机并获取图像进行显示
- Postgresql13之FETCH FIRST ROWS … WITH TIES展示打结的行
- 关于实现某宝或某妈永久登录
- Photoshop——多变量+文字数据组替换+批处理详细操作
- java游戏演示ppt_java项目开发实战──五子棋游戏.ppt
- python输入生日输出星座_python输入日期输出星座?
- 微信小程序获取用户收货地址列表wx.chooseAddress
- QGIS如何将高程DEM统一增加数值
- java中的actionlistener_JAVA事件监听器之BUTTON类中的ADDACTIONLISTENER(ACTIONLISTENER L)方法...
- 查看并杀死Tomcat进程
- iOS -- OpenSSL生成RSA双密匙+签名证书(流程)