Network Compression ——Knowledge Distillation

  • 前言
  • 一、knowledge distillation是什么?
    • 1.原理
    • 2.KL散度
    • 3.Readme
  • 二、网络模型代码
    • 1.加载数据集,定义环境
    • 2.定义KL散度
    • 3.数据处理
    • 4.预处理
    • 5.开始训练

前言

知识蒸馏,实质上就是用训练好的网络告诉没训练的网络如何学习。

一、knowledge distillation是什么?

1.原理

知识蒸馏(暗知识提取)的概念,即通过引入与教师网络(teacher network:复杂、但推理性能优越)相关的软目标(soft-target)作为total loss的一部分,以诱导学生网络(student network:精简、低复杂度)的训练,实现知识迁移(knowledge transfer)。

源自
[https://blog.csdn.net/nature553863/article/details/80568658]

本博客不注重理论,以代码为主。
白话,利用二者网络参数的损失帮助学生网络的学习。其目的还是简化模型,简单的网络模型通过学习,达到复杂的网络模型的能力,自然在这一具体问题应用上可以取代复杂的模型。青出于蓝而胜于蓝吧。最终目的还是节省时间和节省算力。

2.KL散度


用教师网络的参数当作训练学生网络的标准。
為甚麼這會work?

  • 例如當data不是很乾淨的時候,對一般的model來說他是個noise,只會干擾學習。透過去學習其他大model預測的logits會比較好。
  • label和label之間可能有關連,這可以引導小model去學習。例如數字8可能就和6,9,0有關係。
  • 弱化已經學習不錯的target(?),避免讓其gradient干擾其他還沒學好的task。

KL散度公式
主要是二者网络参数的损失用到了KL散度公式
相对熵(relative entropy)又称为KL散度(Kullback-Leibler divergence),KL距离,是两个随机分布间距离的度量。记为DKL(p||q)。它度量当真实分布为p时,假设分布q的无效性。

  • Loss=αT2×KL(Teacher’s LogitsT∣∣Student’s LogitsT)+(1−α)(Original Loss)Loss = \alpha T^2 \times KL(\frac{\text{Teacher's Logits}}{T} || \frac{\text{Student's Logits}}{T}) + (1-\alpha)(\text{Original Loss})Loss=αT2×KL(TTeacher’s Logits​∣∣TStudent’s Logits​)+(1−α)(Original Loss)

PyTorch的KL散度损失(KLDivLos)的定义/文件要求输入是概率分布和对数概率分布,这就是为什么后面我们在老师/学生输出(原始分数)上使用softmax和log-softmax。

3.Readme

在這個notebook中我們會介紹Knowledge Distillation,
而我們有提供已經學習好的大model方便大家做Knowledge Distillation。
而我們使用的小model是"Architecture Design"過的model。

  • Architecute Design在同目錄中的hw7_Architecture_Design.ipynb。
  • 下載pretrained大model(47.2M):
  • [https://drive.google.com/file/d/1B8ljdrxYXJsZv2vmTequdPOofp3VF3NN/view?usp=sharing]
    • 請使用torchvision提供的ResNet18,把num_classes改成11後load進去即可。(後面有範例。)

二、网络模型代码

1.加载数据集,定义环境

# Download dataset
!gdown --id '19CzXudqN58R3D-1G8KeFWk8UDQwlb8is' --output food-11.zip
# Unzip the files
!unzip food-11.zip

结果

import torch
import os
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.models as models
# Load進我們的Model架構(在hw7_Architecture_Design.ipynb內)
!gdown --id '1lJS0ApIyi7eZ2b3GMyGxjPShI8jXM2UC' --output "hw7_Architecture_Design.ipynb"
%run "hw7_Architecture_Design.ipynb"

结果
Downloading…
From: https://drive.google.com/uc?id=1lJS0ApIyi7eZ2b3GMyGxjPShI8jXM2UC
To: /content/hw7_Architecture_Design.ipynb
100% 8.78k/8.78k [00:00<00:00, 8.13MB/s]

2.定义KL散度

这部分对着公式敲出来就行。

def loss_fn_kd(outputs, labels, teacher_outputs, T=20, alpha=0.5):# 一般的Cross Entropyhard_loss = F.cross_entropy(outputs, labels) * (1. - alpha)# 讓logits的log_softmax對目標機率(teacher的logits/T後softmax)做KL Divergence。soft_loss = nn.KLDivLoss(reduction='batchmean')(F.log_softmax(outputs/T, dim=1),F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T)return hard_loss + soft_loss

3.数据处理

这部分几乎所有代码在原有基础上我都注释过了,可以说很清楚每一步在干什么。

import re   #  使用python的re模块,尽管不能满足所有复杂的匹配情况,
#但足够在绝大多数情况下能够有效地实现对复杂字符串的分析并提取出相关信息。
#python 会将正则表达式转化为字节码,利用 C 语言的匹配引擎进行深度优先的匹配。
import torch
from glob import glob
#glob模块用来查找文件目录和文件
#glob支持 * ? [ ] 三种通配符。
#1) * 代表0个或多个字符
#2) ? 代表一个字符
#3) [ ]匹配指定范围内的字符,如[0-9]匹配数字
#1、import glob #导入整个glob模块
#2、from glob import golb #从glob模块导入glob函数
#glob.glob()可同时获取所有的匹配路径,而glob.iglob()一次只能获取一个匹配路径。from PIL import Image
import torchvision.transforms as transformsclass MyDataset(torch.utils.data.Dataset):def __init__(self, folderName, transform=None):self.transform = transformself.data = []self.label = []for img_path in sorted(glob(folderName + '/*.jpg')):#寻找照片路径try:# Get classIdx by parsing image path 通过解析图像路径获取classIdxclass_idx = int(re.findall(re.compile(r'\d+'), img_path)[1])#compile(pattern):创建模式对象#findall(pattern,string):列表形式返回匹配项except:# if inference mode (there's no answer), class_idx default 0# 如果推理模式(没有找到相应的照片),class_idx默认为0class_idx = 0image = Image.open(img_path)# Get File Descriptor获取文件描述image_fp = image.fpimage.load()# Close File Descriptor (or it'll reach OPEN_MAX)#如果推理模式(没有答案),class_idx默认为0image_fp.close()self.data.append(image)#数据扩增图片和索引(标签)self.label.append(class_idx)def __len__(self):#定义数据长度return len(self.data)def __getitem__(self, idx):if torch.is_tensor(idx):idx = idx.tolist()#tolist()作用:根据条件获取元素所在的位置(索引)image = self.data[idx]if self.transform:image = self.transform(image)return image, self.label[idx]#torchvision.transforms是pytorch中的图像预处理包。一般用Compose把多个步骤整合到一起
trainTransform = transforms.Compose([#transforms.RandomCrop (size, padding= None , pad_if_needed= False , fill= 0 , padding_mode= 'constant')#功能:从图片中随机裁剪出尺寸为size的图片 #size:所需裁剪图片尺寸#padding:设置填充大小 ;当为a时,上下左右均填充a个像素 ;当为(a, b)时,上下填充b个像素,左右填充a个像素 ;当为(a, b, c, d)时,左,上,右,下分别填充a, b, c, d#pad_if_need:若图像小于设定size,则填充#padding_mode:填充模式,有4种模式 #——1、constant:像素值由fill设定#——2、edge:像素值由图像边缘像素决定#——3、reflect:镜像填充,最后一个像素不镜像,eg:[1,2,3,4] → [3,2,1,2,3,4,3,2] #——4、symmetric:镜像填充,最后一个像素镜像,eg:[1,2,3,4] → [2,1,1,2,3,4,4,3]#fill:constant时,设置填充的像素值transforms.RandomCrop(256, pad_if_needed=True, padding_mode='symmetric'),transforms.RandomHorizontalFlip(),#水平翻转transforms.RandomRotation(15),#15°旋转transforms.ToTensor(),#转成张量
])
testTransform = transforms.Compose([#transforms.CenterCrop(size),#在图片的中间区域进行裁剪,size:所需裁剪图片尺寸transforms.CenterCrop(256),transforms.ToTensor(),
])def get_dataloader(mode='training', batch_size=32):assert mode in ['training', 'testing', 'validation']dataset = MyDataset(f'./food-11/{mode}',transform=trainTransform if mode == 'training' else testTransform)#封装数据集,包括打乱,批大小dataloader = torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=(mode == 'training'))return dataloader

4.预处理

架构和自动处理方法
模型架构是torchvision提供的ResNet18。
我們已經提供TeacherNet的state_dict,
至於StudentNet的架構則在hw7_Architecture_Design.ipynb中。
這裡我們使用的Optimizer為AdamW,沒有為甚麼,就純粹我想用。

# get dataloader
train_dataloader = get_dataloader('training', batch_size=32)
valid_dataloader = get_dataloader('validation', batch_size=32)
!gdown --id '1B8ljdrxYXJsZv2vmTequdPOofp3VF3NN' --output teacher_resnet18.bin#使用网络时,不预处理,设置输出类数目
teacher_net = models.resnet18(pretrained=False, num_classes=11).cuda()
student_net = StudentNet(base=16).cuda()teacher_net.load_state_dict(torch.load(f'./teacher_resnet18.bin'))
optimizer = optim.AdamW(student_net.parameters(), lr=1e-3)

5.开始训练

model.eval()與model.train()差在於Batchnorm要不要紀錄,以及要不要做Dropout。节省时间和内存。
因为 train需要更新梯度,因还在学习过程中。eval是在验证,学习结束,梯度已经固定。当然教师网络肯定用eval模式。

def run_epoch(dataloader, update=True, alpha=0.5):total_num, total_hit, total_loss = 0, 0, 0for now_step, batch_data in enumerate(dataloader):# 清空 optimizeroptimizer.zero_grad()# 處理 inputinputs, hard_labels = batch_datainputs = inputs.cuda()hard_labels = torch.LongTensor(hard_labels).cuda()# 因為Teacher沒有要backprop,所以我們使用torch.no_grad# 告訴torch不要暫存中間值(去做backprop)以浪費記憶體空間。with torch.no_grad():soft_labels = teacher_net(inputs)if update:logits = student_net(inputs)# 使用我們之前所寫的融合soft label&hard label的loss。# T=20是原始論文的參數設定。loss = loss_fn_kd(logits, hard_labels, soft_labels, 20, alpha)loss.backward()optimizer.step()    else:# 只是算validation acc的話,就開no_grad節省空間。with torch.no_grad():logits = student_net(inputs)loss = loss_fn_kd(logits, hard_labels, soft_labels, 20, alpha)#找到logits中的最大值看是否与困难数据集的标签是否相同,然后求和赋给hit    total_hit += torch.sum(torch.argmax(logits, dim=1) == hard_labels).item()total_num += len(inputs)total_loss += loss.item() * len(inputs)return total_loss / total_num, total_hit / total_num# TeacherNet永遠都是Eval mode.毕竟不改变教师网络的参数
teacher_net.eval()
now_best_acc = 0
for epoch in range(200):student_net.train()train_loss, train_acc = run_epoch(train_dataloader, update=True)student_net.eval()valid_loss, valid_acc = run_epoch(valid_dataloader, update=False)# 存下最好的model。if valid_acc > now_best_acc:now_best_acc = valid_acctorch.save(student_net.state_dict(), 'student_model.bin')print('epoch {:>3d}: train loss: {:6.4f}, acc {:6.4f} valid loss: {:6.4f}, acc {:6.4f}'.format(epoch, train_loss, train_acc, valid_loss, valid_acc))

李宏毅作业七其二 Network Compression (Knowledge Distillation)相关推荐

  1. 李宏毅作业七其三 Network Compression (Network Pruning)

    Network Compression --Network Pruning 前言 一.Network Pruning是什么? Weight & Neuron Pruning 对于修剪网络并不简 ...

  2. 【李宏毅2020 ML/DL】P51 Network Compression - Knowledge Distillation | 知识蒸馏两大流派

    我已经有两年 ML 经历,这系列课主要用来查缺补漏,会记录一些细节的.自己不知道的东西. 已经有人记了笔记(很用心,强烈推荐):https://github.com/Sakura-gh/ML-note ...

  3. 李宏毅作业十二 Transfer Learning(迁移学习)

    系列文章目录 李宏毅作业十 Generative Adversarial Network生成对抗网络(代码) 李宏毅作业九 Anomaly Detection异常检测 李宏毅作业八unsupervis ...

  4. Pytorch实战_神经网络的压缩(Network Compression)

    1. 神经网络的压缩 对于一些大型的神经网络,它的网络结构是十分复杂的(听说华为的一些神经网络有上亿的神经元组成),我们很难在很小的设备中(比如我们的apple watch)上面将这个这个神经网络放上 ...

  5. Knowledge Distillation | 知识蒸馏经典解读

    作者 | 小小 整理 | NewBeeNLP 写在前面 知识蒸馏是一种模型压缩方法,是一种基于"教师-学生网络思想"的训练方法,由于其简单,有效,在工业界被广泛应用.这一技术的理论 ...

  6. 【论文翻译】Few Sample Knowledge Distillation for Efficient Network Compression

    Few Sample Knowledge Distillation for Efficient Network Compression 用于高效网络压缩的少样本知识提取 论文地址:https://ar ...

  7. 【李宏毅2020 ML/DL】P45-50 Network Compression

    我已经有两年 ML 经历,这系列课主要用来查缺补漏,会记录一些细节的.自己不知道的东西. 已经有人记了笔记(很用心,强烈推荐):https://github.com/Sakura-gh/ML-note ...

  8. 机器学习笔记 network compression

    来自于李宏毅教授的ML课件,作业七部分 Hung-yi Lee (ntu.edu.tw) 0 前言 我们为什么要进行network compression呢? 因为在某些环境中(比如在手机,手环等设备 ...

  9. 【论文翻译】Highlight Every Step: Knowledge Distillation via Collaborative Teaching

    Highlight Every Step: Knowledge Distillation via Collaborative Teaching 强调每一步:通过协作教学提炼知识 摘要 High sto ...

最新文章

  1. 如何用最强模型BERT做NLP迁移学习?
  2. 云炬随笔20211016(2)
  3. *【POJ - 3659】Cell Phone Network (树形dp,最小支配集)
  4. 双继承_在Python中使用双下划线防止类属性被覆盖!
  5. EJB3.0学习笔记---MDBbean--区分P2P模式和Pub/Sub模式的示例
  6. Spring-MVC的配置文件及路径问题
  7. 前端架构 IMVC 实时热更新模式
  8. 【Tensorflow、Keras】关于Reshape层小结(部分问题未解决)
  9. 文件和参数一起上传_基于netty的文件上传下载组件
  10. 176.第二高的薪水
  11. matlab安装软件 Matlab视频教程李大勇 MATLAB程序开发入门课程 MATLAB神经网络30个案例分析及源程序
  12. 首发丨极课大数据完成1亿元B轮融资,用AI提升学生作业、考试管理效率
  13. 和画意思相近的字_有没有类似“鸢语慕君年青筏画卿颜”这种古风情侣网名啊...
  14. 将pyecharts生成的html转为图片的一些问题
  15. 在线生鲜订购配送系统,生鲜订购系统 生鲜配送系统 前端+后台 Android源码+SSH后台管理系统+MySQL数据库
  16. 什么是机器人的外部轴?
  17. python实现百万英雄答题神器
  18. 【数量称谓】祖宗十八代
  19. $refs 模拟点击
  20. 网络知识点之-关于web

热门文章

  1. Number Sequence/数字序列
  2. 监控摄像头清晰度(分辨率)介绍
  3. Product-based Neural Networks (PNN) - 改进特征交叉的方式
  4. Exception的处理
  5. php获取蓝凑云文件列表,蓝奏云网盘登录获取文件例程
  6. ansible register 之用法
  7. Linux服务.NO7——samba
  8. As Error:Failed to find configured root that contains /storage/emulated/0/xxx/xxx/xxx.png
  9. 三分钟入门大数据之用户画像标签的分类
  10. linux 中 查看防火墙开放端口号 命令