【DKN】(四)train.py
内容
try: #不用多言, 获得该模块下的model_name函数Model = getattr(importlib.import_module(f"model.{model_name}"), model_name)config = getattr(importlib.import_module('config'), f"{model_name}Config")
except AttributeError:print(f"{model_name}not included!")exit()device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class EarlyStopping
class EarlyStopping:def __init__(self, patience=5):self.patience = patience self.counter = 0self.best_loss = np.Infdef __call__(self, val_loss):"""if you use other metrics where a higher value is better, e.g. accuracy,call this with its corresponding negative value"""# 如果你使用的其他指标值越高越好,例如准确性,用它对应的负数来调用它if val_loss < self.best_loss: #如果评测的损失小于最好的损失,那么就是最好的损失early_stop = Falseget_better = Trueself.counter = 0self.best_loss = val_loss # 最好的损失else:get_better = False #self.counter += 1if self.counter >= self.patience:early_stop = Trueelse:early_stop = Falsereturn early_stop, get_better
def latest_checkpoint(directory):
看一看存储的模型路径名称:
def latest_checkpoint(directory): #最新的检查点!if not os.path.exists(directory): #该路径在不在return Noneall_checkpoints = { #{10000 : ckpt-10000.pth, 11000: ckpt-11000.pth} 这就是最终的结果int(x.split('.')[-2].split('-')[-1]): xfor x in os.listdir(directory)}if not all_checkpoints: #如果没有checkpoint,就返回空return Nonereturn os.path.join(directory, #我们选择keys最大的选择all_checkpoints[max(all_checkpoints.keys())])
def train()
log_dir:
def train():writer = SummaryWriter( #这里的路径! runs/DKN/.....log_dir=f"./runs/{model_name}/{datetime.datetime.now().replace(microsecond=0).isoformat()}{'-' + os.environ['REMARK'] if 'REMARK' in os.environ else ''}")if not os.path.exists('checkpoint'): #如果没有checkpoint,那么就需要在当前目录下创建checkpointos.makedirs('checkpoint')try:pretrained_word_embedding = torch.from_numpy( #读入预训练单词嵌入np.load('./data/train/pretrained_word_embedding.npy')).float()except FileNotFoundError:pretrained_word_embedding = Noneif model_name == 'DKN': #如果是DKN模型try:pretrained_entity_embedding = torch.from_numpy( #如果是DKN,嵌入实体np.load('./data/train/pretrained_entity_embedding.npy')).float()except FileNotFoundError:pretrained_entity_embedding = Nonetry:pretrained_context_embedding = torch.from_numpy( #预训练上下文嵌入 但是numpy是在CPU上的!np.load('./data/train/pretrained_context_embedding.npy')).float()except FileNotFoundError:pretrained_context_embedding = Nonemodel = Model(config, pretrained_word_embedding, #创建模型pretrained_entity_embedding,pretrained_context_embedding)print(torch.cuda.device_count()) #这里是自己加的,想要实现并行操作!if torch.cuda.device_count() > 1: #如果设备数目大于1,那么就并行操作# model.to(device)device_ids = [0, 1]model = torch.nn.DataParallel(model, device_ids=device_ids)model.to(device)# for param in next(model.parameters()):# print(param, param.device)# print(next(model.parameters()).device)if model_name != 'Exp1':print(model)else:print(models[0])dataset = BaseDataset('data/train/behaviors_parsed.tsv','data/train/news_parsed.tsv', 'data/train/roberta')#获得原数据集print(f"Load training dataset with size{len(dataset)}.")dataloader = iter( #改成dataloader,并被迭代器包装,使得每次访问只需要next()即可DataLoader(dataset, #由于自己原来接触过dataloader所以这里是懂点的,不再解释batch_size=config.batch_size,shuffle=True,num_workers=config.num_workers,drop_last=True,pin_memory=True))if model_name != 'Exp1':criterion = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(),lr=config.learning_rate)else:criterion = nn.NLLLoss() #最大似然函数optimizers = [ #定义优化器torch.optim.Adam(model.parameters(), lr=config.learning_rate)for model in models]start_time = time.time() #定义开始的时间loss_full = [] #全部损失exhaustion_count = 0 #竭尽全力_count???step = 0 early_stopping = EarlyStopping() #早点结束,看上面的函数定义checkpoint_dir = os.path.join('./checkpoint', model_name) #检查点/model_namePath(checkpoint_dir).mkdir(parents=True, exist_ok=True) #创建checkpoint目录checkpoint_path = latest_checkpoint(checkpoint_dir) #获得最新的检查点if checkpoint_path is not None: #开始带入checkpointprint(f"Load saved parameters in{checkpoint_path}")checkpoint = torch.load(checkpoint_path) #加载检查点,里面的格式是字典类型的early_stopping(checkpoint['early_stop_value']) #step = checkpoint['step'] #if model_name != 'Exp1':model.load_state_dict(checkpoint['model_state_dict'])optimizer.load_state_dict(checkpoint['optimizer_state_dict'])model.train()else:for model in models: model.load_state_dict(checkpoint['model_state_dict']) #直接加载模型参数model.train() for optimizer in optimizers: #直接加载优化器参数optimizer.load_state_dict(checkpoint['optimizer_state_dict'])for i in tqdm(range( # epochs * (len(dataset) // config.batch_size + 1)这么多次迭代1,config.num_epochs * len(dataset) // config.batch_size + 1),desc="Training"):try: #获取小dataloader中的batchminibatch = next(dataloader)# if torch.cuda.device_count() > 1:# minibatch = torch.nn.DataParallel(minibatch)# minibatch.to(device)# minibatch.to(device)except StopIteration: #如果迭代出问题了exhaustion_count += 1tqdm.write(f"Training data exhausted for{exhaustion_count}times after{i}batches, reuse the dataset.")dataloader = iter(DataLoader(dataset,batch_size=config.batch_size,shuffle=True,num_workers=config.num_workers,drop_last=True,pin_memory=True))minibatch = next(dataloader)step += 1y_pred = model(minibatch["candidate_news"], #结算损失, 候选新闻是预测得到的!minibatch["clicked_news"])y = torch.zeros(len(y_pred)).long().to(device)loss = criterion(y_pred, y)loss_full.append(loss.item()) #要保存损失的if model_name != 'Exp1':optimizer.zero_grad()else:for optimizer in optimizers: #优化器更新权重optimizer.zero_grad()loss.backward()if model_name != 'Exp1':optimizer.step()else:for optimizer in optimizers:optimizer.step()if i % 10 == 0: #如果10次计算了,那么就写入我们的损失writer.add_scalar('Train/Loss', loss.item(), step)if i % config.num_batches_show_loss == 0: #写出结果tqdm.write(f"Time{time_since(start_time)}, batches{i}, current loss{loss.item():.4f}, average loss:{np.mean(loss_full):.4f}, latest average loss:{np.mean(loss_full[-256:]):.4f}")if i % config.num_batches_validate == 0: #(model if model_name != 'Exp1' else models[0]).eval()val_auc, val_mrr, val_ndcg5, val_ndcg10 = evaluate(model if model_name != 'Exp1' else models[0], './data/val',200000)(model if model_name != 'Exp1' else models[0]).train()writer.add_scalar('Validation/AUC', val_auc, step)writer.add_scalar('Validation/MRR', val_mrr, step)writer.add_scalar('Validation/nDCG@5', val_ndcg5, step)writer.add_scalar('Validation/nDCG@10', val_ndcg10, step)tqdm.write(f"Time{time_since(start_time)}, batches{i}, validation AUC:{val_auc:.4f}, validation MRR:{val_mrr:.4f}, validation nDCG@5:{val_ndcg5:.4f}, validation nDCG@10:{val_ndcg10:.4f}, ")#后面的都是如果是最好的效果,就保存模型参数early_stop, get_better = early_stopping(-val_auc)if early_stop:tqdm.write('Early stop.')breakelif get_better:try:torch.save({'model_state_dict': (model if model_name != 'Exp1'else models[0]).state_dict(),'optimizer_state_dict':(optimizer if model_name != 'Exp1' elseoptimizers[0]).state_dict(),'step':step,'early_stop_value':-val_auc}, f"./checkpoint/{model_name}/ckpt-{step}.pth")except OSError as error:print(f"OS error:{error}")
def time_since(since)
def time_since(since): #运行了多长时间"""Format elapsed time string."""now = time.time()elapsed_time = now - since #return time.strftime("%H:%M:%S", time.gmtime(elapsed_time))if __name__ == '__main__':# print('Using device:', device)print(f'Training model{model_name}')train()
补充
1. os.listdir() 方法
概述
os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表。(是该文件夹下所有的文件名)
它不包括 . 和 … 即使它在文件夹中。
只支持在 Unix, Windows 下使用。
语法
listdir()方法语法格式如下:
os.listdir(path)
参数
path – 需要列出的目录路径
返回值
返回指定路径下的文件和文件夹列表。
实例
#!/usr/bin/python
# -*- coding: UTF-8 -*-import os, sys# 打开文件
path = "/var/www/html/"
dirs = os.listdir( path )# 输出所有文件和文件夹
for file in dirs:print (file)
2. Python replace()方法
描述
Python replace() 方法把字符串中的 old(旧字符串) 替换成 new(新字符串),如果指定第三个参数max,则替换不超过 max 次。
语法
replace()方法语法:
str.replace(old, new[, max])
参数
- old – 将被替换的子字符串。
- new – 新字符串,用于替换old子字符串。
- max – 可选字符串, 替换不超过 max 次
返回值
返回字符串中的 old(旧字符串) 替换成 new(新字符串)后生成的新字符串,如果指定第三个参数max,则替换不超过 max 次。
实例
str = "this is string example....wow!!! this is really string";
print str.replace("is", "was");
print str.replace("is", "was", 3);thwas was string example....wow!!! thwas was really string
thwas was string example....wow!!! thwas is really string
3. datetime测试
print(datetime.datetime.now()) #2021-08-27 09:47:48.748545
print(datetime.datetime.now().replace(microsecond=0)) #2021-08-27 09:48:26
print(datetime.datetime.now().replace(microsecond=0).isoformat()) #2021-08-27T09:49:18
4. NLLLoss 和 CrossEntropyLoss
https://blog.csdn.net/qq_22210253/article/details/85229988
NLLLoss的全称是Negative Log Likelihood Loss,也就是最大似然函数。
在图片进行单标签分类时,【注意NLLLoss和CrossEntropyLoss都是用于单标签分类,而BCELoss和BECWithLogitsLoss都是使用与多标签分类。这里的多标签是指一个样本对应多个label.】
【DKN】(四)train.py相关推荐
- pytorch YoLOV3 源码解析 train.py
train.py 总体分为三部分(不算import 库) 初始的一些设定 + train函数 + main函数 源码地址: https://github.com/ultralytics/yolov3 ...
- 图像分割套件PaddleSeg全面解析(一)train.py代码解读
首先祝贺百度团队百度斩获NeurIPS2020挑战赛冠军,https://www.jiqizhixin.com/articles/2020-12-09-2. 在此次比赛中使用的是基于飞桨深度学习框架开 ...
- YOLOv5源码逐行超详细注释与解读(3)——训练部分train.py
前言 本篇文章主要是对YOLOv5项目的训练部分train.py.通常这个文件主要是用来读取用户自己的数据集,加载模型并训练. 文章代码逐行手打注释,每个模块都有对应讲解,一文帮你梳理整个代码逻辑! ...
- 【YOLOV5-5.x 源码解读】train.py
目录 前言 0.导入需要的包和基本配置 1.设置opt参数 2.main函数 2.1.logging和wandb初始化 2.2.判断是否使用断点续训resume, 读取参数 2.3.DDP mode设 ...
- YOLOV5训练代码train.py注释与解析
YOLOv5代码注释版更新啦,注释的是最近的2021.07.14的版本,且注释更全 github: https://github.com/Laughing-q/yolov5_annotations Y ...
- Using TensorFlow backend. Traceback (most recent call last): File train.py, line 9, in module
yolo程序里报错 Using TensorFlow backend. Traceback (most recent call last): File "train.py", li ...
- YOLOv3源码阅读之六:train.py
一.YOLO简介 YOLO(You Only Look Once)是一个高效的目标检测算法,属于One-Stage大家族,针对于Two-Stage目标检测算法普遍存在的运算速度慢的缺点,YOLO创 ...
- Pointnet(part_seg)train.py,test.py代码随记
train.py 我将代码全部简化,将关键步骤全部列出 hdf5_data_dir = 数据集路径 #读取数据集的路径创建os.mkdir(train_result) #创建train_result文 ...
- File “./tools/train.py“, line 124 log_file = osp.join(cfg.work_dir, f‘{timestamp}.log‘)
问题: 在使用mmdetection做训练时出现这样的问题. File "./tools/train.py", line 124log_file = osp.join(cfg.wo ...
- YOLOV7跑通demo与train.py
1. 下载yolov7(github)与testing里面的yolov7.pt(预训练权重文件)文件 2.新建项目,选中解压文件yolov7 3.新建虚拟环境: file- setting - Pro ...
最新文章
- 零起点学算法07——复杂一点的表达式计算
- PHP如何添加变量 $_SERVER
- lr模型和dnn模型_建立ML或DNN模型的技巧
- boost线程(二)
- mui框架下监听返回按钮
- requests session
- webrtc java api_java – 使用WebSockets实现WebRTC信令
- 云信api_服务端API文档-音视频通话-网易云信开发文档
- eXosip注册函数与使用说明
- 论文笔记之RL优化——高斯平滑的Q函数
- 四十六、Stata离散选择模型,时间序列和面板数据
- An动画基础之元件的影片剪辑动画与传统补间
- K8s问题【flannel一直重启问题,CrashLoopBackOff】
- 计算机如何磁盘整理,磁盘碎片整理,教您磁盘碎片怎么整理
- itunes下载的软件所在目录
- Cloudflare 远程浏览器隔离
- GB/T 10707 橡胶燃烧性能
- Continuous Integration 对 ABAP 技术栈来说意味着什么
- CentOS7 安装cellranger-4.0.0
- pymysql 向MySQL 插入数据无故报错