Learning Without Forgetting

  • LWF简介
  • 方法对比
  • LWF算法流程
  • 总结
  • 实现

LWF简介

LWF是结合知识蒸馏(KD)避免灾难性遗忘的经典持续学习方法。本质上是通过旧网络指导的输出对在新任务训练的网络参数进行平衡,从而得到在新旧任务网络上都表现较好的性能。

方法对比


a.从头开始训练
b.微调:在旧任务的网络基础上以较小的学习率学习新任务 另一种意义上的initialization?
c.联合训练:使用所有任务的数据一起训练
d.特征提取:将旧任务的参数固定作为特征提取器,添加新的层训练新任务

LWF算法流程


θ s \theta_s θs​为在old task上pretrained网络CNN的共享参数
θ o \theta_o θo​为每个old task的特定参数(可理解为网络的i最后的classifier head)
( X n , Y n ) (X_n,Y_n) (Xn​,Yn​) new task的数据

初始化:
1.将新数据 ( X n , Y n ) (X_n,Y_n) (Xn​,Yn​) 输入在旧任务pretrained网络中得到一组respond Y o Y_o Yo​
2.将new task对应的classifier head参数随机初始化(加快训练的常见手段)

训练:
Y o ^ \hat{Y_o} Yo​^​ 为待训练网络CNN 对应old task的输出,最开始 θ o \theta_o θo​= θ o ^ \hat{\theta_o} θo​^​ , θ s \theta_s θs​= θ s ^ \hat{\theta_s} θs​^​
Y n ^ \hat{Y_n} Yn​^​ 为待训练网络对应new task的输出,最开始 θ n \theta_n θn​= θ n ^ \hat{\theta_n} θn​^​ , θ s \theta_s θs​= θ s ^ \hat{\theta_s} θs​^​
优化目标为
θ s ∗ , θ o ∗ , θ n ∗ ← argmin ⁡ θ ^ s , θ ^ o , θ ^ n ( λ o L o l d ( Y o , Y ^ o ) + L n e w ( Y n , Y ^ n ) + R ( θ ^ s , θ ^ o , θ ^ n ) ) \theta_{s}^{*}, \theta_{o}^{*}, \theta_{n}^{*} \leftarrow \underset{\hat{\theta}_{s}, \hat{\theta}_{o}, \hat{\theta}_{n}}{\operatorname{argmin}}\left(\lambda_{o} \mathcal{L}_{o l d}\left(Y_{o}, \hat{Y}_{o}\right)+\mathcal{L}_{n e w}\left(Y_{n}, \hat{Y}_{n}\right)+\mathcal{R}\left(\hat{\theta}_{s}, \hat{\theta}_{o}, \hat{\theta}_{n}\right)\right) θs∗​,θo∗​,θn∗​←θ^s​,θ^o​,θ^n​argmin​(λo​Lold​(Yo​,Y^o​)+Lnew​(Yn​,Y^n​)+R(θ^s​,θ^o​,θ^n​))
第一项可以理解为old task的子优化目标,第二项为new task的优化目标,第三项为正则化项。
可以发现整个训练过程和joint training很相似,但是最大的不同是LWF没有用到old task data,而是巧妙地用KD损失去平衡old task的性能。至于KD则体现在以下公式:
L o l d ( y o , y ^ o ) = − H ( y o ′ , y ^ o ′ ) = − ∑ i = 1 l y o ′ ( i ) log ⁡ y ^ o ′ ( i ) \begin{aligned} \mathcal{L}_{o l d}\left(\mathbf{y}_{o}, \hat{\mathbf{y}}_{o}\right) &=-H\left(\mathbf{y}_{o}^{\prime}, \hat{\mathbf{y}}_{o}^{\prime}\right) \\ &=-\sum_{i=1}^{l} y_{o}^{\prime(i)} \log \hat{y}_{o}^{\prime(i)} \end{aligned} Lold​(yo​,y^​o​)​=−H(yo′​,y^​o′​)=−i=1∑l​yo′(i)​logy^​o′(i)​​
l l l 是label的数量,而 y ^ o ′ ( i ) \hat{y}_{o}^{\prime(i)} y^​o′(i)​ 和 y o ′ ( i ) y_{o}^{\prime(i)} yo′(i)​ 是 y ^ o ( i ) \hat{y}_{o}^{(i)} y^​o(i)​ 和 y o ( i ) {y}_{o}^{(i)} yo(i)​ 的修正版本,也就是这里体现了KD的概念, y o ′ ( i ) y_{o}^{\prime(i)} yo′(i)​ 是soft target,而 y ^ o ′ ( i ) \hat{y}_{o}^{\prime(i)} y^​o′(i)​ 为网络预测概率值。

y o ′ ( i ) = ( y o ( i ) ) 1 / T ∑ j ( y o ( j ) ) 1 / T , y ^ o ′ ( i ) = ( y ^ o ( i ) ) 1 / T ∑ j ( y ^ o ( j ) ) 1 / T y_{o}^{\prime(i)}=\frac{\left(y_{o}^{(i)}\right)^{1 / T}}{\sum_{j}\left(y_{o}^{(j)}\right)^{1 / T}}, \quad \hat{y}_{o}^{\prime(i)}=\frac{\left(\hat{y}_{o}^{(i)}\right)^{1 / T}}{\sum_{j}\left(\hat{y}_{o}^{(j)}\right)^{1 / T}} yo′(i)​=∑j​(yo(j)​)1/T(yo(i)​)1/T​,y^​o′(i)​=∑j​(y^​o(j)​)1/T(y^​o(i)​)1/T​
所以网络在训练时,第一部分的loss使得网络的输出概率值一定程度上贴近old task

总结

LWF其实质上是结合了KD和微调,优势在于训练相比joint training更快,且不需要访问先前的数据。但连续学习多个任务仍然避免不了灾难性遗忘

实现

以下是基于pytorch的简单复现,废话不多说贴上code

# 准备数据集
n_classes = 10 # split_mnist数据集一共10类
n_tasks = 5
per_classes_task = int(n_classes / n_tasks)
split_mnist = SplitMNIST(n_experiences=n_tasks, seed=0, return_task_id=True, shuffle=False)train_dataset = split_mnist.train_stream[0].dataset
test_dataset = split_mnist.test_stream[0].datasettrain_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=True)

训练旧任务

def kaiming_normal_init(m):if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, nonlinearity='relu')elif isinstance(m, nn.Linear):nn.init.kaiming_normal_(m.weight, nonlinearity='sigmoid')def train(epoch, model, optimizer, criterion):print('\nEpoch: %d' % epoch)model.train()train_loss = 0.0correct = 0total = 0for batch_id, (x, y, t) in enumerate(train_loader):x, y = x.to(device), y.to(device)optimizer.zero_grad()y_pred = model(x)loss = criterion(y_pred, y)loss.backward()optimizer.step()train_loss += loss.item()_, y_pred = y_pred.max(1)total += len(y)correct += y_pred.eq(y).sum().item()progress_bar(batch_id, len(train_loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'% (train_loss / (batch_id + 1), 100. * correct / total, correct, total))return train_loss / (batch_id + 1)def test(epoch, model, criterion):global best_accmodel.eval()test_loss = 0.0correct = 0total = 0with torch.no_grad():for batch_id, (x, y, t) in enumerate(test_loader):x, y = x.to(device), y.to(device)y_pred = model(x)loss = criterion(y_pred, y)test_loss = loss.item()_, y_pred = y_pred.max(1)total += len(y)correct += y_pred.eq(y).sum().item()progress_bar(batch_id, len(test_loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'% (test_loss / (batch_id + 1), 100. * correct / total, correct, total))acc = 100. * correct / totalif acc > best_acc:print('Saving..')state = {'model': model.state_dict(),'acc': acc,'epoch': epoch,}if not os.path.isdir('checkpoint'):os.mkdir('checkpoint')torch.save(state, './checkpoint/ckpt_mnist.pth')best_acc = accreturn acc# 训练和测试
device = 'cuda' if torch.cuda.is_available() else 'cpu'
epochs = 10
best_acc = 0.0
lr = 0.01
pre_model = SimpleMLP(num_classes=per_classes_task, hidden_size=256).to(device)
print(pre_model)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(pre_model.parameters(), lr=lr,momentum=0.9, weight_decay=5e-4)
for epoch in range(epochs):train(epoch, pre_model, optimizer, criterion)test(epoch, pre_model, criterion)

LWF

split_mnist = SplitMNIST(n_experiences=n_classes, seed=0, return_task_id=True, shuffle=False)
# 取第2个2分类任务
train_dataset = split_mnist.train_stream[1].dataset
test_dataset = split_mnist.test_stream[1].dataset
# 取第1个2分类任务测试LWF在旧任务上的性能
val_dataset = split_mnist.test_stream[0].dataset
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1000, shuffle=True)
net_new = SimpleMLP(num_classes=per_classes_task, hidden_size=256).to(device)
net_old = SimpleMLP(num_classes=per_classes_task, hidden_size=256).to(device)
oor = torch.load('checkpoint/ckpt_mnist.pth')
net_new.load_state_dict(oor['model'])
net_old.load_state_dict(oor['model'])
incremental_class = per_classes_task
# 获取前一个任务模型的分类头数量
in_features = net_old.classifier.in_features
out_features = net_old.classifier.out_features
# 提取分类头中参数
weight = net_old.classifier.weight.data
bias = net_old.classifier.bias.data
# 新头数量
new_out_features = incremental_class + out_features
# 构建新分类器
new_fc = nn.Linear(in_features, new_out_features)
kaiming_normal_init(new_fc.weight)
# 新任务模型的前两个头被替换,剩余头用来学习新类
new_fc.weight.data[:out_features] = weight
new_fc.bias.data[:out_features] = bias
net_new.classifier = new_fc
net_new = net_new.to(device)
print('new head numbers:', net_new.classifier.out_features)
# 确保前一个任务模型不参与反向传播
for param in net_old.parameters():param.requires_grad = False

改变训练,测试方法

def train(alpha, T, epoch):print('\nEpoch: %d' % epoch)net_new.eval()train_loss = 0correct = 0total = 0for batch_idx, (x, y, t) in enumerate(train_loader):x, y = x.to(device), y.to(device)optimizer.zero_grad()outputs = net_new(x)soft_y = net_old(x)# 新类的Lossloss1 = criterion(outputs, y)outputs_S = F.softmax(outputs[:, :out_features] / T, dim=1)outputs_T = F.softmax(soft_y[:, :out_features] / T, dim=1)loss2 = outputs_T.mul(-1 * torch.log(outputs_S))loss2 = loss2.sum(1)loss2 = loss2.mean() * T * T# loss = loss1 * alpha + loss2 * (1 - alpha)loss = loss1 + alpha * loss2loss.backward(retain_graph=True)# loss.backward()optimizer.step()train_loss += loss.item()_, y_pred = outputs.max(1)total += len(y)correct += y_pred.eq(y).sum().item()progress_bar(batch_idx, len(train_loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'% (train_loss / (batch_idx + 1), 100. * correct / total, correct, total))return train_loss / (batch_idx + 1)def test(alpha, T, epoch):global best_accnet_new.eval()test_loss = 0correct = 0total = 0with torch.no_grad():for batch_idx, (inputs, targets, t) in enumerate(test_loader):inputs, targets = inputs.to(device), targets.to(device)outputs = net_new(inputs)soft_target = net_old(inputs)loss1 = criterion(outputs, targets)outputs_S = F.softmax(outputs[:, :out_features] / T, dim=1)outputs_T = F.softmax(soft_target[:, :out_features] / T, dim=1)loss2 = outputs_T.mul(-1 * torch.log(outputs_S))loss2 = loss2.sum(1)loss2 = loss2.mean() * T * Tloss = loss1 * alpha + loss2 * (1 - alpha)test_loss += loss.item()_, predicted = outputs.max(1)total += len(targets)correct += predicted.eq(targets).sum().item()progress_bar(batch_idx, len(test_loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'% (test_loss / (batch_idx + 1), 100. * correct / total, correct, total)) acc = 100. * correct / totalif acc > best_acc:print('Saving..')state = {'model': net_new.state_dict(),'acc': acc,'epoch': epoch,}if not os.path.isdir('checkpoint'):os.mkdir('checkpoint')torch.save(state, './checkpoint/LWF_ckpt.pth')best_acc = accreturn accdef val(epoch): # 用于测试旧任务net_new.eval()correct = 0total = 0with torch.no_grad():for batch_idx, (inputs, targets, t) in enumerate(val_loader):inputs, targets = inputs.to(device), targets.to(device)outputs = net_new(inputs)_, predicted_old = outputs.max(1)total += len(targets)correct += predicted_old.eq(targets).sum().item()progress_bar(batch_idx, len(val_loader), 'Acc: %.3f%% (%d/%d)'% (100. * correct / total, correct, total))return 100. * correct / total

训练和测试

# 简单实现,超参非最佳参数
T = 2
alpha = 0.5  #
criterion = nn.CrossEntropyLoss()
best_acc = 0.0optimizer = optim.SGD(filter(lambda p: p.requires_grad, net_new.parameters()), lr=0.01,momentum=0.9, weight_decay=5e-4)
for epoch in range(epochs):train_loss = train(alpha, T, epoch)acc_new = test(alpha, T, epoch)acc_old = val(epoch)
torch.save(net_new.state_dict(), 'model.pth')

论文地址

Learning Without Forgetting 笔记及实现相关推荐

  1. 《Learning without Forgetting》 论文阅读笔记

    文章目录 1. 引言 2. 相关工作 2.1 方法对比 2.2 局部相关方法 2.3 同时期的方法 3. 不遗忘学习 原文链接: Learning without Forgetting 1. 引言 在 ...

  2. 《Learning Without Forgetting》(LWF)阅读笔记

    Learning Without Forgetting Abstract   这篇文章仍然从最简单的分类任务入手,LWF是结合知识蒸馏(Knowledge Distilling)避免灾难性遗忘的经典持 ...

  3. 博客 | Machine Learning Yearning 要点笔记

    博客 | Machine Learning Yearning 要点笔记 https://mp.weixin.qq.com/s/vsNnuLerxpgFy1NiCA5rsg Andrew Ng. 的新书 ...

  4. 吴恩达《Machine Learning》精炼笔记 12:大规模机器学习和图片文字识别 OCR

    作者 | Peter 编辑 | AI有道 系列文章: 吴恩达<Machine Learning>精炼笔记 1:监督学习与非监督学习 吴恩达<Machine Learning>精 ...

  5. 吴恩达《Machine Learning》精炼笔记 11:推荐系统

    作者 | Peter 编辑 | AI有道 系列文章: 吴恩达<Machine Learning>精炼笔记 1:监督学习与非监督学习 吴恩达<Machine Learning>精 ...

  6. 吴恩达《Machine Learning》精炼笔记 10:异常检测

    作者 | Peter 编辑 | AI有道 系列文章: 吴恩达<Machine Learning>精炼笔记 1:监督学习与非监督学习 吴恩达<Machine Learning>精 ...

  7. 吴恩达《Machine Learning》精炼笔记 9:PCA 及其 Python 实现

    作者 | Peter 编辑 | AI有道 系列文章: 吴恩达<Machine Learning>精炼笔记 1:监督学习与非监督学习 吴恩达<Machine Learning>精 ...

  8. 吴恩达《Machine Learning》精炼笔记 7:支持向量机 SVM

    作者 | Peter 编辑 | AI有道 系列文章: 吴恩达<Machine Learning>精炼笔记 1:监督学习与非监督学习 吴恩达<Machine Learning>精 ...

  9. 吴恩达《Machine Learning》精炼笔记 6:关于机器学习的建议

    作者 | Peter 编辑 | AI有道 系列文章: 吴恩达<Machine Learning>精炼笔记 1:监督学习与非监督学习 吴恩达<Machine Learning>精 ...

最新文章

  1. 8月第3周回顾:四巨头发三大新闻 一报告引多家争议
  2. 多因子降维法(MDR,Multifactor Dimensionality Reduction )
  3. 转 Oracle 删除表,oracle 中删除表 drop delete truncate 的区别
  4. 如何把复杂单体应用快速迁移到微服务
  5. JavaScript之document.cookie
  6. 大学c语言编程模板,c语言编程模板
  7. Linux中ACL权限设置
  8. matlab中LMI工具箱函数feasp的用法
  9. 《深度学习》李宏毅 -- task1机器学习介绍
  10. 谷歌翻译突然用不了了
  11. 鸢尾花数据集的数据可视化
  12. C语言的良好编程习惯
  13. python登录面向对象_python基础 面向对象一
  14. 数据库系统:第二章关系数据库
  15. 怎么对接个人收款支付接口(扫码支付)
  16. 语音识别数据集及性能评测指标WER
  17. 数据库介绍与数据库安装
  18. 淘宝网及新浪网等几大官方IP查询API接口地址库的调用及使用方法教程
  19. 打破金属打印性能世界纪录,这家中国公司开发纳米改性超级金属-1
  20. python画三瓣树叶_python教你画一棵树

热门文章

  1. 虚拟机连接不上网络,解决办法
  2. 前端笔记知识点整合之JavaScript(五)关于数组和字符串那点事
  3. 一步一步的使用C++和OPENGL实现COLLADA骨骼动画 第一部分
  4. Kotlin鱿鱼游戏大奖赛
  5. Python 简单数据提取
  6. ecstore网站换服务器,Nginx下ecstore伪静态开启后的后台跳转问题
  7. EasyExcel结合Springboot的将excel导入数据库
  8. javaMail发送邮件实例
  9. High Dimensional Continuous Control Using Generalized Advantage Estimation
  10. macOS iTerm2 简单使用