内容

try:  #不用多言, 获得该模块下的model_name函数Model = getattr(importlib.import_module(f"model.{model_name}"), model_name)config = getattr(importlib.import_module('config'), f"{model_name}Config")
except AttributeError:print(f"{model_name}not included!")exit()device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class EarlyStopping

class EarlyStopping:def __init__(self, patience=5):self.patience = patience   self.counter = 0self.best_loss = np.Infdef __call__(self, val_loss):"""if you use other metrics where a higher value is better, e.g. accuracy,call this with its corresponding negative value"""# 如果你使用的其他指标值越高越好,例如准确性,用它对应的负数来调用它if val_loss < self.best_loss:   #如果评测的损失小于最好的损失,那么就是最好的损失early_stop = Falseget_better = Trueself.counter = 0self.best_loss = val_loss  # 最好的损失else:get_better = False         #self.counter += 1if self.counter >= self.patience:early_stop = Trueelse:early_stop = Falsereturn early_stop, get_better

def latest_checkpoint(directory):

看一看存储的模型路径名称:

def latest_checkpoint(directory):   #最新的检查点!if not os.path.exists(directory):  #该路径在不在return Noneall_checkpoints = {   #{10000 : ckpt-10000.pth, 11000: ckpt-11000.pth}  这就是最终的结果int(x.split('.')[-2].split('-')[-1]): xfor x in os.listdir(directory)}if not all_checkpoints:   #如果没有checkpoint,就返回空return Nonereturn os.path.join(directory,   #我们选择keys最大的选择all_checkpoints[max(all_checkpoints.keys())])

def train()

log_dir:

def train():writer = SummaryWriter(  #这里的路径!  runs/DKN/.....log_dir=f"./runs/{model_name}/{datetime.datetime.now().replace(microsecond=0).isoformat()}{'-' + os.environ['REMARK'] if 'REMARK' in os.environ else ''}")if not os.path.exists('checkpoint'):  #如果没有checkpoint,那么就需要在当前目录下创建checkpointos.makedirs('checkpoint')try:pretrained_word_embedding = torch.from_numpy(  #读入预训练单词嵌入np.load('./data/train/pretrained_word_embedding.npy')).float()except FileNotFoundError:pretrained_word_embedding = Noneif model_name == 'DKN':   #如果是DKN模型try:pretrained_entity_embedding = torch.from_numpy(   #如果是DKN,嵌入实体np.load('./data/train/pretrained_entity_embedding.npy')).float()except FileNotFoundError:pretrained_entity_embedding = Nonetry:pretrained_context_embedding = torch.from_numpy(  #预训练上下文嵌入  但是numpy是在CPU上的!np.load('./data/train/pretrained_context_embedding.npy')).float()except FileNotFoundError:pretrained_context_embedding = Nonemodel = Model(config, pretrained_word_embedding,   #创建模型pretrained_entity_embedding,pretrained_context_embedding)print(torch.cuda.device_count())   #这里是自己加的,想要实现并行操作!if torch.cuda.device_count() > 1:   #如果设备数目大于1,那么就并行操作# model.to(device)device_ids = [0, 1]model = torch.nn.DataParallel(model, device_ids=device_ids)model.to(device)# for param in next(model.parameters()):#     print(param, param.device)# print(next(model.parameters()).device)if model_name != 'Exp1':print(model)else:print(models[0])dataset = BaseDataset('data/train/behaviors_parsed.tsv','data/train/news_parsed.tsv', 'data/train/roberta')#获得原数据集print(f"Load training dataset with size{len(dataset)}.")dataloader = iter(   #改成dataloader,并被迭代器包装,使得每次访问只需要next()即可DataLoader(dataset,   #由于自己原来接触过dataloader所以这里是懂点的,不再解释batch_size=config.batch_size,shuffle=True,num_workers=config.num_workers,drop_last=True,pin_memory=True))if model_name != 'Exp1':criterion = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(),lr=config.learning_rate)else:criterion = nn.NLLLoss()  #最大似然函数optimizers = [             #定义优化器torch.optim.Adam(model.parameters(), lr=config.learning_rate)for model in models]start_time = time.time()    #定义开始的时间loss_full = []       #全部损失exhaustion_count = 0  #竭尽全力_count???step = 0   early_stopping = EarlyStopping()  #早点结束,看上面的函数定义checkpoint_dir = os.path.join('./checkpoint', model_name)  #检查点/model_namePath(checkpoint_dir).mkdir(parents=True, exist_ok=True)    #创建checkpoint目录checkpoint_path = latest_checkpoint(checkpoint_dir)  #获得最新的检查点if checkpoint_path is not None:          #开始带入checkpointprint(f"Load saved parameters in{checkpoint_path}")checkpoint = torch.load(checkpoint_path)   #加载检查点,里面的格式是字典类型的early_stopping(checkpoint['early_stop_value'])   #step = checkpoint['step']     #if model_name != 'Exp1':model.load_state_dict(checkpoint['model_state_dict'])optimizer.load_state_dict(checkpoint['optimizer_state_dict'])model.train()else:for model in models:   model.load_state_dict(checkpoint['model_state_dict'])  #直接加载模型参数model.train()  for optimizer in optimizers:   #直接加载优化器参数optimizer.load_state_dict(checkpoint['optimizer_state_dict'])for i in tqdm(range(    # epochs * (len(dataset) // config.batch_size + 1)这么多次迭代1,config.num_epochs * len(dataset) // config.batch_size + 1),desc="Training"):try:   #获取小dataloader中的batchminibatch = next(dataloader)# if torch.cuda.device_count() > 1:#     minibatch = torch.nn.DataParallel(minibatch)#     minibatch.to(device)# minibatch.to(device)except StopIteration:  #如果迭代出问题了exhaustion_count += 1tqdm.write(f"Training data exhausted for{exhaustion_count}times after{i}batches, reuse the dataset.")dataloader = iter(DataLoader(dataset,batch_size=config.batch_size,shuffle=True,num_workers=config.num_workers,drop_last=True,pin_memory=True))minibatch = next(dataloader)step += 1y_pred = model(minibatch["candidate_news"],  #结算损失, 候选新闻是预测得到的!minibatch["clicked_news"])y = torch.zeros(len(y_pred)).long().to(device)loss = criterion(y_pred, y)loss_full.append(loss.item())  #要保存损失的if model_name != 'Exp1':optimizer.zero_grad()else:for optimizer in optimizers:  #优化器更新权重optimizer.zero_grad()loss.backward()if model_name != 'Exp1':optimizer.step()else:for optimizer in optimizers:optimizer.step()if i % 10 == 0:   #如果10次计算了,那么就写入我们的损失writer.add_scalar('Train/Loss', loss.item(), step)if i % config.num_batches_show_loss == 0:  #写出结果tqdm.write(f"Time{time_since(start_time)}, batches{i}, current loss{loss.item():.4f}, average loss:{np.mean(loss_full):.4f}, latest average loss:{np.mean(loss_full[-256:]):.4f}")if i % config.num_batches_validate == 0:   #(model if model_name != 'Exp1' else models[0]).eval()val_auc, val_mrr, val_ndcg5, val_ndcg10 = evaluate(model if model_name != 'Exp1' else models[0], './data/val',200000)(model if model_name != 'Exp1' else models[0]).train()writer.add_scalar('Validation/AUC', val_auc, step)writer.add_scalar('Validation/MRR', val_mrr, step)writer.add_scalar('Validation/nDCG@5', val_ndcg5, step)writer.add_scalar('Validation/nDCG@10', val_ndcg10, step)tqdm.write(f"Time{time_since(start_time)}, batches{i}, validation AUC:{val_auc:.4f}, validation MRR:{val_mrr:.4f}, validation nDCG@5:{val_ndcg5:.4f}, validation nDCG@10:{val_ndcg10:.4f}, ")#后面的都是如果是最好的效果,就保存模型参数early_stop, get_better = early_stopping(-val_auc)if early_stop:tqdm.write('Early stop.')breakelif get_better:try:torch.save({'model_state_dict': (model if model_name != 'Exp1'else models[0]).state_dict(),'optimizer_state_dict':(optimizer if model_name != 'Exp1' elseoptimizers[0]).state_dict(),'step':step,'early_stop_value':-val_auc}, f"./checkpoint/{model_name}/ckpt-{step}.pth")except OSError as error:print(f"OS error:{error}")

def time_since(since)

def time_since(since):   #运行了多长时间"""Format elapsed time string."""now = time.time()elapsed_time = now - since  #return time.strftime("%H:%M:%S", time.gmtime(elapsed_time))if __name__ == '__main__':# print('Using device:', device)print(f'Training model{model_name}')train()

补充

1. os.listdir() 方法

概述

os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表。(是该文件夹下所有的文件名)

它不包括 . 和 … 即使它在文件夹中。

只支持在 Unix, Windows 下使用。

语法

listdir()方法语法格式如下:

os.listdir(path)

参数

path – 需要列出的目录路径

返回值

返回指定路径下的文件和文件夹列表。

实例

#!/usr/bin/python
# -*- coding: UTF-8 -*-import os, sys# 打开文件
path = "/var/www/html/"
dirs = os.listdir( path )# 输出所有文件和文件夹
for file in dirs:print (file)

2. Python replace()方法

描述

Python replace() 方法把字符串中的 old(旧字符串) 替换成 new(新字符串),如果指定第三个参数max,则替换不超过 max 次。

语法

replace()方法语法:

str.replace(old, new[, max])

参数

  • old – 将被替换的子字符串。
  • new – 新字符串,用于替换old子字符串。
  • max – 可选字符串, 替换不超过 max 次

返回值

返回字符串中的 old(旧字符串) 替换成 new(新字符串)后生成的新字符串,如果指定第三个参数max,则替换不超过 max 次。

实例

str = "this is string example....wow!!! this is really string";
print str.replace("is", "was");
print str.replace("is", "was", 3);thwas was string example....wow!!! thwas was really string
thwas was string example....wow!!! thwas is really string

3. datetime测试

print(datetime.datetime.now())   #2021-08-27 09:47:48.748545
print(datetime.datetime.now().replace(microsecond=0))  #2021-08-27 09:48:26
print(datetime.datetime.now().replace(microsecond=0).isoformat())  #2021-08-27T09:49:18

4. NLLLoss 和 CrossEntropyLoss

https://blog.csdn.net/qq_22210253/article/details/85229988

NLLLoss的全称是Negative Log Likelihood Loss,也就是最大似然函数

在图片进行单标签分类时,【注意NLLLoss和CrossEntropyLoss都是用于单标签分类,而BCELoss和BECWithLogitsLoss都是使用与多标签分类。这里的多标签是指一个样本对应多个label.】

【DKN】(四)train.py相关推荐

  1. pytorch YoLOV3 源码解析 train.py

    train.py 总体分为三部分(不算import 库) 初始的一些设定 + train函数 + main函数 源码地址: https://github.com/ultralytics/yolov3 ...

  2. 图像分割套件PaddleSeg全面解析(一)train.py代码解读

    首先祝贺百度团队百度斩获NeurIPS2020挑战赛冠军,https://www.jiqizhixin.com/articles/2020-12-09-2. 在此次比赛中使用的是基于飞桨深度学习框架开 ...

  3. YOLOv5源码逐行超详细注释与解读(3)——训练部分train.py

    前言 本篇文章主要是对YOLOv5项目的训练部分train.py.通常这个文件主要是用来读取用户自己的数据集,加载模型并训练. 文章代码逐行手打注释,每个模块都有对应讲解,一文帮你梳理整个代码逻辑! ...

  4. 【YOLOV5-5.x 源码解读】train.py

    目录 前言 0.导入需要的包和基本配置 1.设置opt参数 2.main函数 2.1.logging和wandb初始化 2.2.判断是否使用断点续训resume, 读取参数 2.3.DDP mode设 ...

  5. YOLOV5训练代码train.py注释与解析

    YOLOv5代码注释版更新啦,注释的是最近的2021.07.14的版本,且注释更全 github: https://github.com/Laughing-q/yolov5_annotations Y ...

  6. Using TensorFlow backend. Traceback (most recent call last): File train.py, line 9, in module

    yolo程序里报错 Using TensorFlow backend. Traceback (most recent call last): File "train.py", li ...

  7. YOLOv3源码阅读之六:train.py

    一.YOLO简介   YOLO(You Only Look Once)是一个高效的目标检测算法,属于One-Stage大家族,针对于Two-Stage目标检测算法普遍存在的运算速度慢的缺点,YOLO创 ...

  8. Pointnet(part_seg)train.py,test.py代码随记

    train.py 我将代码全部简化,将关键步骤全部列出 hdf5_data_dir = 数据集路径 #读取数据集的路径创建os.mkdir(train_result) #创建train_result文 ...

  9. File “./tools/train.py“, line 124 log_file = osp.join(cfg.work_dir, f‘{timestamp}.log‘)

    问题: 在使用mmdetection做训练时出现这样的问题. File "./tools/train.py", line 124log_file = osp.join(cfg.wo ...

  10. YOLOV7跑通demo与train.py

    1. 下载yolov7(github)与testing里面的yolov7.pt(预训练权重文件)文件 2.新建项目,选中解压文件yolov7 3.新建虚拟环境: file- setting - Pro ...

最新文章

  1. 零起点学算法07——复杂一点的表达式计算
  2. PHP如何添加变量 $_SERVER
  3. lr模型和dnn模型_建立ML或DNN模型的技巧
  4. boost线程(二)
  5. mui框架下监听返回按钮
  6. requests session
  7. webrtc java api_java – 使用WebSockets实现WebRTC信令
  8. 云信api_服务端API文档-音视频通话-网易云信开发文档
  9. eXosip注册函数与使用说明
  10. 论文笔记之RL优化——高斯平滑的Q函数
  11. 四十六、Stata离散选择模型,时间序列和面板数据
  12. An动画基础之元件的影片剪辑动画与传统补间
  13. K8s问题【flannel一直重启问题,CrashLoopBackOff】
  14. 计算机如何磁盘整理,磁盘碎片整理,教您磁盘碎片怎么整理
  15. itunes下载的软件所在目录
  16. Cloudflare 远程浏览器隔离
  17. GB/T 10707 橡胶燃烧性能
  18. Continuous Integration 对 ABAP 技术栈来说意味着什么
  19. CentOS7 安装cellranger-4.0.0
  20. pymysql 向MySQL 插入数据无故报错

热门文章

  1. java jit_Java的JIT
  2. 29-高级路由:BGP清除
  3. UR机器人编译错误收集
  4. Java入门概念回炉重造
  5. PS导出灰度图到Unity内并生成地形
  6. Python爬虫以及数据可视化分析
  7. 暴雪战网下载各个版本与修改默认登陆地点方法
  8. python造数取值方法 random与faker
  9. gcc for arm 工具链使用(一)
  10. 计算机网络验证性试验