Evaluating Student Writing
最近参加了一个ner序列标注比赛Evaluating Student Writing,感觉自己已经很难上分了,也学到了一些知识,在这里做个总结
比赛的大意是对于一波学生的文章进行序列标注,标注的标签有Claim,Evidence,Lead几种标签,对于一篇文章中的预测示例如下:
id,class,predictionstring
1,Claim,1 2
1,Claim,6 7 8
这里predictionstring的结果是以空格切分出来的单词,比如I love you中I就是1,love就是2,这里值得学习的内容如下:
pytorch中的softmax和crossentropy不能混用
这是在做这一道题目之中遇到的比较重大的bug,因为在pytorch之中的crossentropy自带softmax激活函数,所以网络层的定义只能为
class ClassificationModel(nn.Module):def __init__(self,model,config,n_labels):super(ClassificationModel,self).__init__()self.model = modelself.fc = nn.Linear(config.embedding_size,15)def forward(self,input_ids,attention_mask):#mask_ids = torch.not_equal(input_ids,1)#英文roberta padding=1output = self.model(input_ids,attention_mask=attention_mask)output = self.fc(output)return output
损失函数定义为
def compute_multilabel_loss(model,batch_token_ids,batch_attention_mask,batch_label):#print('compute_multilabel_loss')logit = model(input_ids=batch_token_ids,attention_mask=batch_attention_mask)#loss_fn = torch.nn.CrossEntropyLoss(reduce=True, size_average=True)#使用CrossEntropyLoss()时网络层不需要加上激活函数softmaxloss_fn = torch.nn.CrossEntropyLoss(reduce=True,size_average=True)logit = logit.view(logit.size()[0]*logit.size()[1],-1)batch_label = batch_label.view(batch_label.size()[0]*batch_label.size()[1],-1)batch_label = batch_label.squeeze()mseloss = loss_fn(logit,batch_label)if 12 in batch_label:mseloss = mseloss*1.2#增加一波标签为12和13的内容return mseloss
模型训练的epoch
对于模型训练的结束位置,可以等到效果反超的时候退出,而不是等到训练完若干个epoch再退出。
if bestpoint < point:bestpoint = pointprint('bestpoint = %s'%(str(bestpoint)))torch.save(model,'best_point='+str(bestpoint)+'_current_split='+str(current_split)+'.pth')best_point_result[current_split] = bestpoint
else:break
对于单词的划分
这里对于单词的划分使用的是字符划分的方式
for id_num in tqdm(range(len(IDS))):#if id_num%100==0: print(id_num,', ',end='')# READ TRAIN TEXT, TOKENIZE, AND SAVE IN TOKEN ARRAYSn = IDS[id_num]name = f'/home/xiaoguzai/数据/Evaluate student writing/train/{n}.txt'txt = open(name, 'r').read()train_lens.append(len(txt.split()))tokens = tokenizer.encode_plus(txt, max_length=MAX_LEN, padding='max_length',truncation=True, return_offsets_mapping=True)train_tokens[id_num,] = tokens['input_ids']train_attention[id_num,] = tokens['attention_mask']# FIND TARGETS IN TEXT AND SAVE IN TARGET ARRAYSoffsets = tokens['offset_mapping']#读取文件名称并切分,获取tokens['input_ids'],#tokens['attention_mask']以及tokens['offset_mapping']内容train_offset.append(offsets)train_text.append(txt)offset_index = 0df_current = df.loc[df.id==n]for index,row in df_current.iterrows():#!!!这里有bug!!!应该为df_current.iterrows,而不是df.iterrowsa = row.discourse_startb = row.discourse_end#discourse_start为开始的字母,从第8个字母开始,discourse_end为结束的字母,从第229个字母开始#这个if要放在循环外面,节约运行时间#这里的a记录的是每一项开始的字符,b记录的是每一项结束的字符c = offsets[offset_index][0]d = offsets[offset_index][1]#切分出来带空格的第几个字母,!!!利用了offsets也是一个列表,列表之中记录了单词的开始的字母和结束的字母beginning = Truewhile b>c:if (a<=c)&(d<=b):k = target_map[row.discourse_type] r"""这里a,b,c,d实际上的大小关系(a , b)(c<---------c,d------->d)(c,d)包含在(a,b)之中,证明这个字母所在的单词为这个标签的内容c,d为字母编号,id_num为划分的id编号(单词编号)!!!如果当前字母的起始和终止位置包含在总起始和终止位置之中"""#“Lead","Position","Evidence","Claim","Concluding Statement",if beginning:targets_b[k][id_num][offset_index] = 1beginning = Falseelse:targets_i[k][id_num][offset_index] = 1#如果对应字母在这个(a,d)的范围之中,标记当前分类类别#这里不需要offset_index+1,因为起始为(0,0)代表打头的标记了#标记当前的单词即可#offset_index代表着当前的单词offset_index += 1if offset_index>len(offsets)-1:breakc = offsets[offset_index][0]d = offsets[offset_index][1]#k是分类,id_num为对应的id,offset_index为对应的predictstring
可以看出,在训练的时候,利用当前字符的内容去标记当前单词的内容(longformer与roberta切词方法一样,都是按照空格一个一个切单词),如果当前单词在字符范围之内,则标记该单词
容易出现bug的地方:序列标注需要去除掉开头和结尾的位置
在序列标注的过程中,训练的时候需要加上开头的标志,预测完之后要去除掉开头的标志,否则容易发生错位现象
for index1 in range(batch_size):current_pred = [number_map_target[int(data)] for data in output_idx[index1]]pred_df.append({'id':batch_ids[index1],\'input_ids':batch_token[index1][1:],\'text':batch_text[index1],\'offset_mapping':batch_offset[index1][1:],\'preds':current_pred[1:],\'pred_scores':output_pred[index1][1:]})
设定界限以及连续的次数
for batch_ids,batch_token,batch_text,batch_offset,batch_attention_mask,batch_label in tqdm(valid_loader):batch_token = batch_token.to(device)batch_attention_mask = batch_attention_mask.to(device)batch_label = batch_label.to(device)with torch.no_grad():output = model(batch_token,attention_mask=batch_attention_mask)#output_label = np.argmax(output,axis=-1)output = torch.softmax(output,axis=-1)output_pred, output_idx = output.max(-1)batch_size = output.size()[0]r"""number_map_target = {0:'B-Lead',1:'I-Lead',2:'B-Position',3:'I-Position',4:'B-Evidence',5:'I-Evidence',\6:'B-Claim',7:'I-Claim',8:'B-Concluding Statement',9:'I-Concluding Statement',\10:'B-Counterclaim',11:'I-Counterclaim',12:'B-Rebuttal',13:'I-Rebuttal',14:'O'}"""for index1 in range(batch_size):current_pred = [number_map_target[int(data)] for data in output_idx[index1]]pred_df.append({'id':batch_ids[index1],\'input_ids':batch_token[index1][1:],\'text':batch_text[index1],\'offset_mapping':batch_offset[index1][1:],\'preds':current_pred[1:],\'pred_scores':output_pred[index1][1:]})#预测出来的结果去掉开头的[cls]标志,容易出现bug的地方:text不需要去除第一位的字母#其他的像offset_mapping,preds,pred_scores等可以#batch_ids不能带[1:],为文章的id
pred_df = change_id_to_predictstring(pred_df)
这里的设定界限非常的关键,预测出结果的转化方式与之前将输入的内容转为单词的转化方式有所不同,这里我们查看change_id_to_predictstring函数
def change_id_to_predictstring(test_samples):temp_df = []for sample_idx, sample in enumerate(test_samples):preds = sample["preds"]offset_mapping = sample["offset_mapping"]sample_id = sample["id"]sample_text = sample["text"]sample_input_ids = sample["input_ids"]sample_pred_scores = sample["pred_scores"]sample_preds = []if len(preds) < len(offset_mapping):#没有出现过这种情况preds = preds + ["O"] * (len(offset_mapping) - len(preds))sample_pred_scores = sample_pred_scores + [0] * (len(offset_mapping) - len(sample_pred_scores))#这里没看到下面有用到O来填充的idx = 0phrase_preds = []while idx < len(offset_mapping):start, end = offset_mapping[idx]if preds[idx] != "O":label = preds[idx][2:]else:label = "O"phrase_scores = []phrase_scores.append(sample_pred_scores[idx])#这里phrase_scores压入的是刚开始的得分#phrase_scores保存每一层概率的得分idx += 1while idx < len(offset_mapping):if label == "O":matching_label = "O"else:matching_label = f"I-{label}"if preds[idx] == matching_label:_, end = offset_mapping[idx]phrase_scores.append(sample_pred_scores[idx])idx += 1else:break#取出后面的label与当前label相同的标签内容#比如前面的为Lead,这里依次找寻I-Lead(I-Lead均为匹配直到不匹配为止)if "end" in locals():phrase = sample_text[start:end]phrase_preds.append((phrase, start, end, label, phrase_scores))for phrase_idx, (phrase, start, end, label, phrase_scores) in enumerate(phrase_preds):word_start = len(sample_text[:start].split())word_end = word_start + len(sample_text[start:end].split())word_end = min(word_end, len(sample_text.split()))#这个情况非常的严谨ps = " ".join([str(x) for x in range(word_start, word_end)])if label != "O":if sum(phrase_scores) / len(phrase_scores) >= proba_thresh[label]:#sum phrase_scores = 140.2255,len phrase_scores = 147,proba_thresh[label] = 0.7#求出每一个phrase的平均得分,判断是否大于proba_thresh[label]的得分,if len(ps.split()) >= min_thresh[label]:#平均得分大于proba_thres[label]并且连续单词数量大于min_thresh[label]时放入最终结果中temp_df.append((sample_id, label, ps))#if label != "O" and ps != "" and len(phrase_scores) > 2:# temp_df.append((sample_id, label, ps))#这里调用对于phrase_preds进行进一步处理#提交中需要sample_id,label,ps(用空格切分的内容)temp_df = pd.DataFrame(temp_df, columns=["id", "class", "predictionstring"])temp_df = temp_df.reset_index(drop=True)return temp_df
自己的一点想法
这里融合模型的时候可以使用之前kaggle使用的,循环找寻最优的参数进行融合的方法,比如模型1的权重设置为0~0.4,模型2的权重设置为0~0.4,模型3的权重为1-模型1的权重-模型2的权重,然后找寻在验证集上的最高分,这样来设定权重内容
Evaluating Student Writing相关推荐
- 2022-kaggle-nlp赛事:Feedback Prize - English Language Learning
文章目录 零.比赛介绍 0.1 比赛目标 0.2 数据集 0.3 注意事项 一.设置 1.1 导入相关库 1.2 设置超参数和随机种子 1.3 启动wandb 二. 数据预处理 2.1 定义前处理函数 ...
- 比赛推送 图像/表格/CV/NLP,多线程开启
H&M Personalized Fashion Recommendations 比赛任务:根据之前的购买行为提供产品推荐 比赛链接:https://www.kaggle.com/c/h-an ...
- 用faster-rcnn训练自己的数据集(VOC2007格式,python版)
用faster-rcnn训练自己的数据集(VOC2007格式,python版) 一. 配置caffe环境 ubunt16.04下caffe环境安装 二. 下载,编译及测试py-faster-rcnn源 ...
- tf-faster-rcnn代码学习.目标检测(Tensorflow版Faster R-CNN)
TF-Faster R-CNN 电脑配置 代码来源 环境配置 demo测试 参考博客 训练自己的数据集 测试阶段 Tensorboard查看收敛情况 电脑配置 系统:Ubuntu 16.04 GPU型 ...
- 【多目标跟踪】Tracktor++代码及调试过程
论文<Tracking without bells and whistles>(Philipp Bergmann, Tim Meinhardt, Laura Leal-Taixe) arx ...
- Applying Rhetorical Structure Theory to Student Essays for Providing Automated Writing Feedback
原论文 动机 作文结构方面的反馈可以帮助写作者建立一个清晰的结构,从而组织好作文中的句子和段落. 现有的作文评分的系统有的仅仅得到一个分数.有的只给出单个句子结构反馈,或者反馈不具有改进指导意义 论文 ...
- Angel Borja博士教你如何撰写科学论文一:Six things to do before writing your manuscript
Six things to do before writing your manuscript In this new series - "How to Prepare a Manuscri ...
- A Collection of 100+ Writing Task 2 Essays for IELTS
EDITION 2019 A Collection of 100+ Writing Task 2 Essays IELTS ESSAYS FROM EXAMINERS VERSION 3.0 OREM ...
- english writing sample for professional
Thank you for contacting our office. Students are responsible to bring their own extra long twin bed ...
- Writing a good grant proposal
1. 把问题想成熟,能够清晰地表达问题是什么. 2. 清晰地表达为什么这个问题是重要的. 3. 实现idea所用的技术,即不能全是已知成熟的技术,也不能是完全未知的技术.全已知的技术就像做一个appl ...
最新文章
- android 长按赋值功能,android实现WebView中长按选中复制文本操作
- 合并两个对象 java_在Java中合并两个对象列表8
- mysql or的效率_Mysql比较exists与in以及or的效率分析
- TensorFlow 最小二乘法拟合
- database disk image is malformed 问题解决
- 基于Spring Boot和Spring Cloud实现微服务架构学习
- IOS学习笔记十九NSArray和NSMutableArray
- v8引擎和v12引擎_为什么V8和V12发动机至今还存在,而V10发动机却早早被淘汰了?...
- 存储ic载板_延伸IC领域 崇达技术拟将持有普诺威55%股权
- 通过QEMU-GuestAgent实现从外部注入写文件到KVM虚拟机内部
- Acad::ErrorStatus
- GridView固定表头
- response.sendRedirect()和request.getRequestDispatcher().forward(request,reponse)的区别
- springboot2.4+nettyWebServerApplicationContext@15f51c50 has been closed already问题解决
- [灯哥开源—四足机器人]程序算法讲解与STM32移植——运行框架(两个主线程)
- 如何用教科书式的方法,着手分析一个行业?
- Python3网络爬虫实战-38、动态渲染页面抓取:Splash的使用
- 1024程序员节最新福利之2018最全java资料集合
- 相似图片搜索、算法、识别的原理解析(上)
- token的使用方法
热门文章
- 中国最大的IDC世纪互联是如何成为云计算时代的看客的
- Python 数据相关性分析
- [Android]按阶段编译Android kernel中的代码
- csv是什么意思中文_CSV文件是什么意思?
- javax.persistence.EntityNotFoundException: Unable to find 类 with id ?
- 云上架构和传统IT架构有什么区别及优势?
- excel表格横向纵向变换_Excel新手最容易给自己挖的几个坑,手把手教你完美避雷!...
- 【java1234】java学习路线图2018
- 技术天才米勒 oracle,奇迹中的奇迹 WW之功能炫技篇
- 又一大的技术站点域名被ClientHold了