李宏毅2022机器学习HW4解析
准备工作
作业四是speaker Identification(语者识别),需要将助教代码+数据集放置于同一目录下,记得解压数据集。关注本公众号,可获得代码和数据集(文末有方法)。
Kaggle提交地址
https://www.kaggle.com/competitions/ml2022spring-hw4,提交结果可能需要科学上网,想讨论的可进QQ群:156013866。
Simple Baseline (acc>0.60824)
方法:使用TransformerEncoder层,其中TransformerEncoderLayer的层数是2(num_layers=2),,运行代码出现output.csv文件,将其提交到kaggle上得到分数:0.65775。
# __init__中启用encoder
self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=2)
# forward中使用encoder
out = self.encoder(out)
Medium Baseline (acc>0.70375)
方法:维度更改+使用TransformerEncoder层+Dropout+全连接层修改+Train longer。助教代码中的d_model维度是40,而我们需要预测的n_spks维度是600,维度相差过大,需要将d_model调整为224,经测试d_model过大过小都不好。TransformerEncoder使用3层TransformerEncoderLayer,dropout=0.2。全连接层从2层改为1层,并加入BatchNorm。训练step由70000改为100000。运行代码,提交得到kaggle分数:0.74025。
def __init__(self, d_model=224, n_spks=600, dropout=0.2):
super().__init__()
self.prenet = nn.Linear(40, d_model)
self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model,
dim_feedforward=d_model*2, nhead=2, dropout=dropout)
self.encoder = nn.TransformerEncoder(self.encoder_layer,
num_layers=3)
self.pred_layer = nn.Sequential(
nn.BatchNorm1d(d_model),
nn.Linear(d_model, n_spks),
)
Strong Baseline (acc>0.7750)
方法:维度更改+使用ConformerBlock+Dropout+全连接层修改+Train longer。与medium baseline相比,将transoformerEncoder改为ConformerBlock,可以通过下载现成的Conformer代码实现。运行代码,提交后得到分数:0.77850。
!pip install conformer
from conformer import ConformerBlock
# 模型中的主要改动
def __init__(self, d_model=224, n_spks=600, dropout=0.25):
super().__init__()
self.prenet = nn.Linear(40, d_model)
self.encoder = ConformerBlock(
dim = d_model,
dim_head = 4,
heads = 4,
ff_mult = 4,
conv_expansion_factor = 2,
conv_kernel_size = 20,
attn_dropout = dropout,
ff_dropout = dropout,
conv_dropout = dropout,
)
self.pred_layer = nn.Sequential(
nn.BatchNorm1d(d_model),
nn.Linear(d_model, n_spks),
)
Boss Baseline (acc>0.86500)
方法:维度更改+ConformerBlock+Self-attention pooling+Additive margin softmax+Train longer。与strong baseline相比,将mean pooling 换成了self-attention pooling,另外使用了简单版的additive margin softmax,batch size从32到64,step改为200000步。运行代码,提交后得到分数:0.78725。提升效果有限,另外试了正常版的additive margin softmax,仍然效果有限,最后用ensemble+TTA,能到达boss baseline,不过这跟作业课件中的要求有违背,课件中的说明是单个模型到达boss baseline,以后有更多时间我会再做更多尝试,完成后再补充。
class SelfAttentionPooling(nn.Module):
def __init__(self, input_dim):
super().__init__()
self.W = nn.Linear(input_dim, 1)
def forward(self, batch_rep):
att_w = F.softmax(self.W(batch_rep).squeeze(-1), dim=-1).unsqueeze(-1)
utter_rep = torch.sum(batch_rep * att_w, dim=1)
return utter_rep
from torch.autograd import Variable
class AMSoftmax(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input, target, scale=5.0, margin=0.35):
cos_theta = input
target = target.view(-1, 1) # size=(B,1)
index = cos_theta.data * 0.0 # size=(B,Classnum)
index.scatter_(1, target.data.view(-1, 1), 1)
index = index.byte()
index = Variable(index).bool()
output = cos_theta * 1.0 # size=(B,Classnum)
output[index] -= margin
output = output * scale
logpt = F.log_softmax(output, dim=-1)
logpt = logpt.gather(1, target)
logpt = logpt.view(-1)
loss = -1 * logpt
loss = loss.mean()
return loss
作业四答案获得方式:
关注微信公众号 “机器学习手艺人”
后台回复关键词:202204
李宏毅2022机器学习HW4解析相关推荐
- 李宏毅2022机器学习HW2解析
准备工作:去课程github下载原始代码,kaggle下载数据集.或者关注本公众号,下载代码和数据集(文末有方法).解压数据集,出现libriphone文件夹,将文件和代码放到同一目录下. kaggl ...
- 李宏毅2022机器学习HW10解析
准备工作 作业十是黑箱攻击(Blackbox Attack),完成作业需要助教代码和数据集,运行代码过程中保持联网可以自动下载数据集,已经有数据集的情况可关闭助教代码中的下载数据部分.关注本公众号,可 ...
- 李宏毅2022机器学习HW5解析
准备工作 作业五是机器翻译,需要助教代码,运行代码过程中保持联网可以自动下载数据集,已经有数据集的情况可关闭助教代码中的下载数据部分.关注本公众号,可获得代码和数据集(文末有方法). 提交地址 这次作 ...
- 李宏毅老师机器学习选择题解析
机器学习选择题解析加整理 项目说明,本项目是李宏毅老师在飞桨授权课程的配套问题 课程 传送门 该项目AiStudio项目 传送门 仅供学习参考! 三岁出品必是精品! 整理内容源于李宏毅老师机器学习课程 ...
- 李宏毅2022机器学习hw6
目录 Machine Learning HW6 一.任务 二.数据集 Crypko: 三.结果 四.改进方法 4.
- 李宏毅2022机器学习HW1收获
colab的使用 把训练集取出一部分作为验证集 选择特征 tqdm,tensorboard使用 在训练和验证时要关闭梯度计算 要把模型和数据放在同一个device上 保证模型可复现性 pytorch和 ...
- 【李宏毅《机器学习》2022】作业1:COVID 19 Cases Prediction (Regression)
文章目录 [李宏毅<机器学习>2022]作业1:COVID 19 Cases Prediction (Regression) 作业内容 1.目标 2.任务描述 3.数据 4.评价指标 代码 ...
- 李宏毅《机器学习》国语课程(2022)来了
提起李宏毅老师,熟悉机器学习的读者朋友一定不会陌生.很多人选择的机器学习入门学习材料都是李宏毅老师的台大公开课视频.今年李宏毅老师开设一门新的机器学习机器学习课程,涵盖最新热门主题,非常值得关注! 李 ...
- 【千呼万唤】李宏毅《机器学习》国语课程(2022)终于来了
提起李宏毅老师,熟悉机器学习的读者朋友一定不会陌生.很多人选择的机器学习入门学习材料都是李宏毅老师的台大公开课视频.今年李宏毅老师开设一门新的机器学习机器学习课程,涵盖最新热门主题,非常值得关注! 李 ...
- 李宏毅2020机器学习作业2-Classification:年收入二分类
更多作业,请查看 李宏毅2020机器学习资料汇总 文章目录 0 作业链接 1 作业说明 环境 任务说明 数据说明 作业概述 2 原始代码 2.0 数据准备 导入数据 标准化(Normalization ...
最新文章
- 网页图表Highcharts实践教程之图表区
- app:compileDebugJavaWithJavac
- Java并发编程实践读书笔记(3)任务执行
- 在阿里干了 5 年招聘,这 10 条建议我必须分享给你!
- cuDNN兼容性问题造成的caffe/mnist,py-faster-rcnn/demo运行结果错误
- 1Python全栈之路系列Web框架介绍
- 学以致用十三-----Centos7.2+python3+YouCompleteMe成功历程
- 使用IPV6 ACL对telnet登陆进行限定
- flowable工作流_使用Bash Shell实现flowable配置文件修改定制
- 《王道》数据结构笔记整理2022
- AP6212认证_自适应测试
- 迅捷文字转语音软件v2.0.0官方免费版
- macOS上显示隐藏文件
- [转]stm32 sdio写入速度 SD卡【好文章】[F1开发板通用] 战舰STM32F103开发板 SDIO写入速度测试(使用FATFS)
- 数据中台架构与技术选型
- 工业互联网新发展:基于 HTML5 WebGL 的高炉炼铁厂可视化系统
- ERP与MRP、MRPⅡ的主要区别
- DateDiff 数据库时间差函数
- 请不要在Java项目中乱打印日志了,这才是正确姿势,非常实用!
- 甲板智慧-北京林业大学“林之心”项目
热门文章
- python读取海康视频流(rtsp格式)
- 惠普服务器硬盘指示灯不亮或显示蓝色
- Android 应用进程启动流程
- 【oracle11g,13】表空间管理2:undo表空间管理(调优) ,闪回原理
- NL驱动表错误导致的性能问题
- Unity 代码帧动画
- vsode 编译报错:main.c:4:10: fatal error: iostream: 没有那个文件或目录
- 北京三大春天赏花圣地
- 时间管理 android app推荐,干货星球 篇十三:【强烈安利】分享10个时间管理APP,每一个都堪称精品!...
- php 九宫格验证码,用php数字九宫格.