1、代码模块书写规范

1.1、导入包

导入包需要注意分段

import os
os.environ['MKL_NUM_THREADS'] = '1'from functools import partial
import random
import wandb
import sys
import collections# Local imports
from data_loaders.data_manager import DataManager
from utils.utils import *

注:os.environ[‘MKL_NUM_THREADS’] = ‘1’
pytorch以及tensorflow的多线程输入设定过大,一般推荐较大数据流4线程,较小2线程。具体问题具体分析,要看数据输入是否是训练速度优化的瓶颈。
numpy或者opencv等的多线程操作或者tensorflow以及pytorch在cpu运行上的op。这些模块使用OMP或者MKL进行多线程加速,一般默认为cpu线程总数的一半,十分浪费计算力,推荐使用4线程,详见下表。

1.2、简化config

通过合并多个config,使用startswith识别可以减少config个数。

config['MODEL_NAME'].lower().startswith('stare'):

1.3、@staticmethod

常用在数据读取部分,可以省略调用前初始化类的过程。

class DataManager(object):""" Give me your args I'll give you a path to load the dataset with my superawesome AI """@staticmethoddef load(config: Union[dict, FancyDict]) -> Callable:

1.4、函数声明规范

函数声明部分一般需要包含Decisions函数简介、:param *:参数介绍、:return:返回数据介绍。其中参数介绍与返回数据介绍部分均需说明数据类型。

    @staticmethoddef get_alternative_graph_repr(raw: Union[List[List[int]], np.ndarray], config: dict) \-> Dict[str, np.ndarray]:"""Decisions:Quals are represented differently here, i.e., more as a coo matrixs1 p1 o1 qr1 qe1 qr2 qe2    [edge index column 0]s2 p2 o2 qr3 qe3            [edge index column 1]edge index:[ [s1, s2],[o1, o2] ]edge type:[ p1, p2 ]quals will looks like[ [qr1, qr2, qr3],[qe1, qr2, qe3],[0  , 0  , 1  ]       <- obtained from the edge index columns:param raw: [[s, p, o, qr1, qe1, qr2, qe3...], ..., [...]](already have a max qualifier length padded data):param config: the config dict:return: output dict"""

1.5、代码模块声明

在代码中常需要在下一个新的模块前声明新模块功能。

    """Make the model."""

1.6、传入参数设置

import argparse
def parse_config():parser = argparse.ArgumentParser()parser.add_argument('--max_len', type=int, default=128)parser.add_argument('--ckpt_path', type=str)parser.add_argument('--test_data',type=str)parser.add_argument('--out_path',type=str)parser.add_argument('--gpu_id',type=int, default=0)return parser.parse_args()if __name__ == '__main__':args = parse_config()ckpt_path = args.ckpt_pathtest_data = args.test_dataout_path = args.out_pathgpu_id = args.gpu_id

2、pytorch常用函数熟悉规范

2.1、声明优化器

    if config['OPTIMIZER'] == 'sgd':optimizer = torch.optim.SGD(model.parameters(), lr=config['LEARNING_RATE'])elif config['OPTIMIZER'] == 'adam':optimizer = torch.optim.Adam(model.parameters(), lr=config['LEARNING_RATE'])

1.6、train规范

训练部分规范,使用tqdm显示训练进度,使用torch.nn.utils.clip_grad_norm_预防梯度爆照

from tqdm.autonotebook import tqdm
def training_loop_gcn(epochs: int):train_loss = []train_acc = []for e in range(epochs):per_epoch_loss = []per_epoch_tr_acc = []with Timer() as timer:trn_dl = data_fn(data['train'])train_fn.train()for batch in tqdm(trn_dl, desc='Training'):opt.zero_grad()triples, labels = batchsub, rel = triples[:, 0], triples[:, 1]if qualifier_aware:quals = triples[:, 2:]_quals = torch.tensor(quals, dtype=torch.long, device=device)_sub = torch.tensor(sub, dtype=torch.long, device=device)_rel = torch.tensor(rel, dtype=torch.long, device=device)_labels = torch.tensor(labels, dtype=torch.float, device=device)pred = train_fn(_sub, _rel, _quals)loss = train_fn.loss(pred, _labels)per_epoch_loss.append(loss.item())loss.backward()if grad_clipping:torch.nn.utils.clip_grad_norm_(train_fn.parameters(), 1.0)opt.step()print(f"[Epoch: {e} ] Loss: {np.mean(per_epoch_loss)}")train_loss.append(np.mean(per_epoch_loss))train_loss.append(np.mean(per_epoch_loss))if e % eval_every == 0 and e >= 1:with torch.no_grad():summary_val = val_testbench()per_epoch_vl_acc = summary_val['metrics']['hits_at 1']

2.2、model书写规范

2.2.1、model包装规范

在model部分需要注意做好包装,做好类的继承。

class StarE_Transformer(StarEEncoder):def __init__(self, kg_graph_repr: Dict[str, np.ndarray], config: dict, id2e: tuple = None):if id2e is not None:super(self.__class__, self).__init__(kg_graph_repr, config, id2e[1])else:super(self.__class__, self).__init__(kg_graph_repr, config)self.model_name = 'StarE_Transformer_Statement'self.hid_drop2 = config['STAREARGS']['HID_DROP2']

2.2.2、model常用函数

2.2.2.1、torch.view()

对torch进行重新整型。

rel_embed = rel_embed.view(-1, 1, self.emb_dim)

2.2.2.2、torch.transpose()

交换一个tensor的两个维度。
注:transpose()一次只能在两个维度间进行转置。

#先转置0维和1维,之后在第2,3维间转置,之后在第1,3间转置
y=x.transpose(0,1).transpose(3,2).transpose(1,3)

2.2.2.3、torch.reshape()

变换张量tensor的形状,注意两个数据类型都是张量。
注:reshape是按照行来进行reshape(变形)的。

c=torch.randn((2,5))
# tensor([[ 1.0559, -0.3533,  0.5194,  0.9526, -0.2483],
#         [-0.1293,  0.4809, -0.5268, -0.3673,  0.0666]])
d=torch.reshape(c,(5,2))
# tensor([[ 1.0559, -0.3533],
#         [ 0.5194,  0.9526],
#         [-0.2483, -0.1293],
#         [ 0.4809, -0.5268],
#         [-0.3673,  0.0666]])

2.2.2.4、torch.mean(x,dim=0,keepdim=True)

在指定纬度进行求均值。

x=torch.arange(15).view(5,3)
#   0   1   2
#   3   4   5
#   6   7   8
#   9  10  11
#  12  13  14
x_mean=torch.mean(x,dim=0,keepdim=True)
#   6  7  8
x_mean0=torch.mean(x,dim=1,keepdim=True)
#   1
#   4
#   7
#  10
#  13

2.2.2.5、torch.min(x,dim=0,keepdim=True)

在指定纬度进行取最小值。

2.2.2.6、torch.mm与torch.mul

torch.mul(a, b)是矩阵a和b对应位相乘,a和b的维度必须相等,比如a的维度是(1, 2),b的维度是(1, 2),返回的仍是(1, 2)的矩阵;
torch.mm(a, b)是矩阵a和b矩阵相乘,比如a的维度是(1, 2),b的维度是(2, 3),返回的就是(1, 3)的矩阵。

2.2.2.7、torch.sigmoid(x)与torch.nn.Sigmoid(x)

torch.sigmoid(x)是一个函数,torch.nn.Sigmoid(x)是一个类,需要在init进行声明。

2.2.2.8、torch.nn.Linear(self.emb_dim, self.emb_dim)

声明一个线性变换层

self.fc = torch.nn.Linear(self.emb_dim, self.emb_dim)

2.2.2.9、torch.nn.Dropout(self.hid_drop)

声明一个dropout

self.hidden_drop = torch.nn.Dropout(self.hid_drop)

2.2.2.10、nn.BatchNorm1d(num_features)

对小批量(mini-batch)的2d或3d输入进行批标准化(Batch Normalization)操作
来自期望输入的特征数,该期望输入的大小为’batch_size x num_features [x width]’

nn.BatchNorm1d(num_features)

对小批量(mini-batch)3d数据组成的4d输入进行批标准化(Batch Normalization)操作
来自期望输入的特征数,该期望输入的大小为’batch_size x num_features x height x width’

nn.BatchNorm2d(num_features)

对小批量(mini-batch)4d数据组成的5d输入进行批标准化(Batch Normalization)操作
来自期望输入的特征数,该期望输入的大小为’batch_size x num_features depth x height x width’

nn.BatchNorm3d(num_features)

2.2.2.11、激活层

import torch.nn.functional as F
x = F.relu(x)

3、bert(torch)

from google_bert import BasicTokenizerdef extract_parameters(ckpt_path):model_ckpt = torch.load(ckpt_path)bert_args = model_ckpt['bert_args']model_args = model_ckpt['args']bert_vocab = model_ckpt['bert_vocab']model_parameters = model_ckpt['model']tree_args = model_ckpt['tree_args']tree_vocab = model_ckpt['tree_vocab']return bert_args, model_args, bert_vocab, model_parameters, tree_args, tree_vocabdef init_empty_bert_model(bert_args, bert_vocab, gpu_id, approx = 'none'):bert_model = BERTLM(gpu_id, bert_vocab, bert_args.embed_dim, bert_args.ff_embed_dim, bert_args.num_heads, \bert_args.dropout, bert_args.layers, approx)return bert_modeldef init_empty_tree_model(t_args, tree_vocab, gpuid):tree_model = TreeLSTM(tree_vocab.size(), t_args.input_dim, t_args.mem_dim, t_args.hidden_dim, t_args.num_classes, t_args.freeze_embed)tree_model = tree_model.cuda(gpuid)return tree_modeldef init_sequence_classification_model(empty_bert_model, args, bert_args, gpu_id, bert_vocab, model_parameters, empty_tree_model, tree_args):number_class = args.number_classnumber_category = 3embedding_size = bert_args.embed_dimbatch_size = args.batch_sizedropout = args.dropouttree_hidden_dim = tree_args.hidden_dimdevice = gpu_idvocab = bert_vocabseq_tagging_model = myModel(empty_bert_model, number_class, number_category, embedding_size, batch_size, dropout, device, vocab, empty_tree_model, tree_hidden_dim)seq_tagging_model.load_state_dict(model_parameters)return seq_tagging_modelbert_args, model_args, bert_vocab, model_parameters, tree_args, tree_vocab = extract_parameters(ckpt_path)
empty_bert_model = init_empty_bert_model(bert_args, bert_vocab, gpu_id, approx='none')
empty_tree_model = init_empty_tree_model(tree_args, tree_vocab, gpu_id)
seq_classification_model = init_sequence_classification_model(empty_bert_model, model_args, bert_args, gpu_id, bert_vocab, model_parameters, empty_tree_model, tree_args)
seq_classification_model.cuda(gpu_id)tokenizer = BasicTokenizer()

【DeepLearning笔记】python规范书写相关推荐

  1. PEP8 python规范神器和jupyter notebook主题更改--Jupyter Notebook 快速入门

    PEP8 python规范神器和jupyter notebook主题更改--Jupyter Notebook 快速入门 原文: https://www.cnblogs.com/xxtalhr/p/10 ...

  2. LEETCODE-刷题个人笔记 Python(1-400)

    按tag分类,250/400的重点题目 LEETCODE-刷题个人笔记 Python(1-400)-TAG标签版本 1.Two Sum(easy) 给定一个整型数组,找出能相加起来等于一个特定目标数字 ...

  3. LEETCODE-刷题个人笔记 Python(1-400)-TAG标签版本(二)

    前面一篇由于文字太多,不给编辑,遂此篇出炉 LEETCODE-刷题个人笔记 Python(1-400)-TAG标签版本(一) DFS&BFS (262)200. Number of Islan ...

  4. scrapy笔记——python的时间转换

    1 import datetime 2 GMT_FORMAT = '%M %H %d %m %w' 3 datetime.datetime.utcnow().strftime(GMT_FORMAT) ...

  5. 机器学习实战笔记(Python实现)-04-Logistic回归

    转自:机器学习实战笔记(Python实现)-04-Logistic回归 转自:简单多元线性回归(梯度下降算法与矩阵法) 转自:人工神经网络(从原理到代码) Step 01 感知器 梯度下降

  6. Google代码规范书写格式,告别丑陋代码

    今天在做pat的题时,遇到了困难,让学长给我看看代码,学长看了我的代码,哭笑着说,你的代码风格一直是这样吗,哈哈哈.他当时推荐我去学习一下Google代码的书写方式,看了后,感觉自己代码写的真是太丑了 ...

  7. Python笔记 - Python切片

    Python笔记 - Python切片 Python切片是对一个列表取其部分元素获得一个子序列的常见操作,切片操作的返回结果类型与被切片的对象一致.要创建一个已有列表的切片,通过指定切片的第一个列表元 ...

  8. python语言的33个保留字的基本含义_Python学习笔记——Python的33个保留字及其意义,python,pythone33,含义...

    Python学习笔记--Python的33个保留字及其意义,python,pythone33,含义 发表时间:2020-03-27 笔记走起 正文 序号 保留字 含义 1 and 用于表达式运算,逻辑 ...

  9. 小学生如何用计算机写字,巧用信息技术培养小学生规范书写汉字的能力

    一.借助信息技术,增强学生写好汉字的意识 作为小学语文教师,我注重通过思想教育增强学生写好汉字的意识.我利用互联网,精心选择了一些图片和资料,通过多媒体平台展示给学生看,并对学生说:"汉字是 ...

最新文章

  1. 张杰和机器人_《80后脱口秀》吐槽高考 张杰化身“机器人”
  2. C++ leetcode 给你一个数组 nums 和一个值 val,你需要 原地 移除所有数值等于 val 的元素,并返回移除后数组的新长度。 不要使用额外的数组空间,你必须仅使用 O(1) 额外
  3. linux配置redis服务,记一次linux下安装redis, 设置redis服务, 及添加环境变量
  4. 2015,工作几年的心得
  5. webpack4.0各个击破(5)—— Module篇
  6. OpenFile基于浏览器的免费网络存储管理
  7. 使用ef查询有缓存的问题
  8. GO语言学习之路23
  9. python训练mask rcnn模型C++调用训练好的模型--基于opencv4.0(干货满满)
  10. BZOJ4072[Wf2014] baggage
  11. TakeColor鼠标位置不对/取色不准
  12. mysql数据库编程第六章试题_2016年计算机二级MySQL数据库试题及答案
  13. 巴厘岛游记:风吹又日晒,自由又自在
  14. 全球及中国26二氟苯磺酰氯行业发展状况与前景趋势分析报告2022-2028年
  15. Win10安装fliqlo时钟屏保教程
  16. 怎样设置一个函数C语言,C语言中怎样编写一个函数 如何在C语言中定义一个函数?...
  17. 最新ChatGPT商业运营版网站源码+支持AI绘画+支持用户会员套餐+邀请分佣功能+支持后台一键更新+网站后台管理+永久更新!
  18. C++switch语句详解
  19. Linux-如何查看进程和关闭进程
  20. java 处理word,excel,pdf -javacode

热门文章

  1. 深度学习mindspore --- win10系统cpu下安装mindspore
  2. 毁掉一个孩子的几个方法 有多少家长正在这么做?
  3. ppt如何查看加载宏
  4. Unity实现瞄准镜效果
  5. Unity 瞄准镜实现
  6. Unity 汉诺塔Hannota笔记
  7. 《数学建模》知识点总结
  8. 神经网络自动布局技术,神经网络自动布局原理
  9. CSDN日报20170727——《想提高团队技术,来试试这个套路!》
  10. 把uTorrent做成绿色版