【项目实战】WaveNet 代码解析 —— train.py 【更新中】
WaveNet 代码解析 —— train.py
文章目录
- WaveNet 代码解析 —— train.py
- 简介
- 代码解析
- 全局变量解析
- 函数解析
- main()
- get_arguments()
- validate_directories(args)
- get_default_logdir(logdir_root)
- save(saver, sess, logdir, step)
- load(saver, sess, logdir)
简介
本项目是一个基于 WaveNet 生成神经网络体系结构的语音合成项目,它是使用 TensorFlow 实现的(项目地址)。
WaveNet神经网络体系结构能直接生成原始音频波形,在文本到语音和一般音频生成方面显示了出色的结果(详情请参阅 WaveNet 的详细介绍)。
由于 WaveNet 项目较大,代码较多。为了方便学习与整理,将按照工程文件的结构依次介绍。
本文将介绍项目中的 train.py 文件:基于VCTK语料库的小波网络训练脚本。
本脚本使用来自VCTK语料库的数据,用WaveNet训练网络(下载地址)
代码解析
全局变量解析
以下变量主要作为各功能参数的默认值,辅助开发人员对训练过程进行配置。
BATCH_SIZE = 1 # 一批训练集中,样本音频的数量DATA_DIRECTORY = './VCTK-Corpus' # 下载的VCTK数据集的路径LOGDIR_ROOT = './logdir' # 训练日志的路径CHECKPOINT_EVERY = 50 # 保存训练模型的检查点数量NUM_STEPS = int(1e5) # 训练的总次数LEARNING_RATE = 1e-3 # 学习率WAVENET_PARAMS = './wavenet_params.json' # WaveNet 模型的相关参数路径STARTED_DATESTRING = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now()) # 当前日期格式化SAMPLE_SIZE = 100000 # 样本数量大小L2_REGULARIZATION_STRENGTH = 0 # L2正则化中的系数SILENCE_THRESHOLD = 0.3 # 音量阈值大小EPSILON = 0.001 # 精度设置MOMENTUM = 0.9 # 优化器动量MAX_TO_KEEP = 5 # 保存的最大检查点数量METADATA = False # 高级调试信息存储标志
函数解析
main()
下面这段代码是 train.py 的主函数,主要作用是提取样本进行预处理、创建网络、训练模型、存取模型以及记录日志。
def main():# 解析命令行功能参数args = get_arguments()try:# 验证并整理与目录有关的参数directories = validate_directories(args)except ValueError as e:print("Some arguments are wrong:")print(str(e))return# 将整理好的文件路径赋给相应变量logdir = directories['logdir']restore_from = directories['restore_from']# 即使我们恢复了模型,如果训练的模型被写入到任意位置,我们也会把它当作新的训练is_overwritten_training = logdir != restore_from# 使用 josn 库的 load 函数读取 WaveNet 模型相关参数,将 json 格式的字符转换为 dictwith open(args.wavenet_params, 'r') as f:wavenet_params = json.load(f)# 创建线程协调器,多线程协调器相关知识可参考文章地址如下:# https://blog.csdn.net/weixin_42721167/article/details/112795491coord = tf.train.Coordinator()# 从VCTK数据集中加载原始波形with tf.name_scope('create_inputs'):# 允许通过指定接近零的阈值跳过静默修剪silence_threshold = args.silence_threshold if args.silence_threshold > \EPSILON else Nonegc_enabled = args.gc_channels is not None# 通用的后台音频读取器,对音频文件进行预处理并将它们排队到TensorFlow队列中reader = AudioReader(args.data_dir,coord,sample_rate=wavenet_params['sample_rate'],gc_enabled=gc_enabled,receptive_field=WaveNetModel.calculate_receptive_field(wavenet_params["filter_width"],wavenet_params["dilations"],wavenet_params["scalar_input"],wavenet_params["initial_filter_width"]),sample_size=args.sample_size,silence_threshold=silence_threshold)# 准备好的音频出队列audio_batch = reader.dequeue(args.batch_size)if gc_enabled:gc_id_batch = reader.dequeue_gc(args.batch_size)else:gc_id_batch = None# 创建 WaveNet 网络net = WaveNetModel(batch_size=args.batch_size,dilations=wavenet_params["dilations"],filter_width=wavenet_params["filter_width"],residual_channels=wavenet_params["residual_channels"],dilation_channels=wavenet_params["dilation_channels"],skip_channels=wavenet_params["skip_channels"],quantization_channels=wavenet_params["quantization_channels"],use_biases=wavenet_params["use_biases"],scalar_input=wavenet_params["scalar_input"],initial_filter_width=wavenet_params["initial_filter_width"],histograms=args.histograms,global_condition_channels=args.gc_channels,global_condition_cardinality=reader.gc_category_cardinality)# 验证 l2 正则化系数if args.l2_regularization_strength == 0:args.l2_regularization_strength = None# 创建一个 WaveNet 网络并返回自动编码损耗loss = net.loss(input_batch=audio_batch,global_condition_batch=gc_id_batch,l2_regularization_strength=args.l2_regularization_strength)# 创建对应的优化器optimizer = optimizer_factory[args.optimizer](learning_rate=args.learning_rate,momentum=args.momentum)# 返回使用 trainable=True 创建的所有变量trainable = tf.trainable_variables()optim = optimizer.minimize(loss, var_list=trainable)# 设置TensorBoard的日志记录writer = tf.summary.FileWriter(logdir)writer.add_graph(tf.get_default_graph())# 收集关于训练的元信息run_metadata = tf.RunMetadata()summaries = tf.summary.merge_all()# 建立会话sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))# 初始化变量init = tf.global_variables_initializer()sess.run(init)# 存储模型检查点的保护程序# 在创建这个 Saver 对象的时候, max_to_keep 参数表示要保留的最近检查点文件的最大数量,创建新文件时,将删除旧文件saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=args.max_checkpoints)try:# 恢复训练模型,获取训练步数saved_global_step = load(saver, sess, restore_from)if is_overwritten_training or saved_global_step is None:# 第一个训练步骤将是 saved_global_step + 1,因此我们在这里输入-1表示新的或覆盖的训练saved_global_step = -1except:print("Something went wrong while restoring checkpoint. ""We will terminate training to avoid accidentally overwriting ""the previous model.")raise# 开启入队线程启动器,详细介绍可参考这篇博客:# https://blog.csdn.net/weixin_42721167/article/details/112795491threads = tf.train.start_queue_runners(sess=sess, coord=coord)reader.start_threads(sess)step = Nonelast_saved_step = saved_global_steptry:# 从恢复模型的节点处开始训练for step in range(saved_global_step + 1, args.num_steps):# 获取当前时间start_time = time.time()# 当存储标志为 true 且训练次数为50的倍数时存储调试信息if args.store_metadata and step % 50 == 0:# 缓慢运行,存储额外的调试信息print('Storing metadata')# RunOptions提供配置参数,供SessionRun调用时使用run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)# 计算日志与自动编码的损失summary, loss_value, _ = sess.run([summaries, loss, optim],options=run_options,run_metadata=run_metadata)# 调用train_writer的add_summary方法将训练过程以及训练步数保存 writer.add_summary(summary, step)# 记录CPU/内存使用情况writer.add_run_metadata(run_metadata,'step_{:04d}'.format(step))# Tensorflow的Timeline模块是用于描述张量图一个工具,可以记录在会话中每个操作执行时间和资源分配及消耗的情况tl = timeline.Timeline(run_metadata.step_stats)# 加载文件路径,打开文件,写入日志timeline_path = os.path.join(logdir, 'timeline.trace')with open(timeline_path, 'w') as f:f.write(tl.generate_chrome_trace_format(show_memory=True))else:# 在不保存模型的训练步数里,保存训练日志到 Tensorboardsummary, loss_value, _ = sess.run([summaries, loss, optim])writer.add_summary(summary, step)# 计算并打印训练一次的时间与结果duration = time.time() - start_timeprint('step {:d} - loss = {:.3f}, ({:.3f} sec/step)'.format(step, loss_value, duration))# 每隔输入的检查点间隔存储一次训练模型if step % args.checkpoint_every == 0:save(saver, sess, logdir, step)last_saved_step = stepexcept KeyboardInterrupt:# 在 ctrl+C 显示之后引入一个换行符,这样保存消息就在它自己的行上了print()finally:# 若训练到了更多步if step > last_saved_step:save(saver, sess, logdir, step)coord.request_stop()coord.join(threads)
get_arguments()
下面这段代码主要是获取命令行参数。
运用 python 中的 argparse 模块对我们输入的命令行进行解析。
def get_arguments():def _str_to_bool(s):""" 将string转换为bool """""" 传入的字符串被限制为'true'或'false' """if s.lower() not in ['true', 'false']:raise ValueError('Argument needs to be a ''boolean, got {}'.format(s))return {'true': True, 'false': False}[s.lower()]# 创建解析器,解析的功能参数作为 WaveNet 的实例parser = argparse.ArgumentParser(description='WaveNet example network')# 添加可选功能参数: --batch_size; 该参数含义为: 一次要处理的 wav 文件数量parser.add_argument('--batch_size', type=int, default=BATCH_SIZE,help='How many wav files to process at once. Default: ' + str(BATCH_SIZE) + '.')# 添加可选功能参数: --data_dir; 该参数含义为: VCTK数据集的文件路径parser.add_argument('--data_dir', type=str, default=DATA_DIRECTORY,help='The directory containing the VCTK corpus.')# 添加可选功能参数: --store_metadata; 该参数含义为: 高级调试信息存储标志parser.add_argument('--store_metadata', type=bool, default=METADATA,help='Whether to store advanced debugging information ''(execution time, memory consumption) for use with ''TensorBoard. Default: ' + str(METADATA) + '.')# 添加可选功能参数: --logdir; 该参数含义为: 存储 TensorBoard 日志信息的文件路径;# 需要注意: 该参数不能与'--logdir_root'或'--restore_from'一起使用parser.add_argument('--logdir', type=str, default=None,help='Directory in which to store the logging ''information for TensorBoard. ''If the model already exists, it will restore ''the state and will continue training. ''Cannot use with --logdir_root and --restore_from.')# 添加可选功能参数: --logdir_root; 该参数含义为: 放置日志输出和生成模型的文件路径,存放在带有日期的子目录下# 需要注意: 该参数不能与'--logdir'一起使用parser.add_argument('--logdir_root', type=str, default=None,help='Root directory to place the logging ''output and generated model. These are stored ''under the dated subdirectory of --logdir_root. ''Cannot use with --logdir.')# 添加可选功能参数: --restore_from; 该参数含义为: 恢复模型的目录,能创建带有日期的子目录# 需要注意: 该参数不能与'--logdir'一起使用parser.add_argument('--restore_from', type=str, default=None,help='Directory in which to restore the model from. ''This creates the new model under the dated directory ''in --logdir_root. ''Cannot use with --logdir.')# 添加可选功能参数: --checkpoint_every; 该参数含义为: 存放训练模型的检查点间隔parser.add_argument('--checkpoint_every', type=int,default=CHECKPOINT_EVERY,help='How many steps to save each checkpoint after. Default: ' + str(CHECKPOINT_EVERY) + '.')# 添加可选功能参数: --num_steps; 该参数含义为: 训练的次数parser.add_argument('--num_steps', type=int, default=NUM_STEPS,help='Number of training steps. Default: ' + str(NUM_STEPS) + '.')# 添加可选功能参数: --learning_rate; 该参数含义为: 训练的学习率parser.add_argument('--learning_rate', type=float, default=LEARNING_RATE,help='Learning rate for training. Default: ' + str(LEARNING_RATE) + '.')# 添加可选功能参数: --wavenet_params; 该参数含义为: WaveNet 模型的相关参数parser.add_argument('--wavenet_params', type=str, default=WAVENET_PARAMS,help='JSON file with the network parameters. Default: ' + WAVENET_PARAMS + '.')# 添加可选功能参数: --sample_size; 该参数含义为: 使用的样本数量parser.add_argument('--sample_size', type=int, default=SAMPLE_SIZE,help='Concatenate and cut audio samples to this many ''samples. Default: ' + str(SAMPLE_SIZE) + '.')# 添加可选功能参数: --l2_regularization_strength; 该参数含义为: L2正则化的系数parser.add_argument('--l2_regularization_strength', type=float,default=L2_REGULARIZATION_STRENGTH,help='Coefficient in the L2 regularization. ''Default: False')# 添加可选功能参数: --silence_threshold; 该参数含义为: 音量阈值限制parser.add_argument('--silence_threshold', type=float,default=SILENCE_THRESHOLD,help='Volume threshold below which to trim the start ''and the end from the training set samples. Default: ' + str(SILENCE_THRESHOLD) + '.')# 添加可选功能参数: --optimizer; 该参数含义为: 优化器选择parser.add_argument('--optimizer', type=str, default='adam',choices=optimizer_factory.keys(),help='Select the optimizer specified by this option. Default: adam.')# 添加可选功能参数: --momentum; 该参数含义为: 优化器动量大小parser.add_argument('--momentum', type=float,default=MOMENTUM, help='Specify the momentum to be ''used by sgd or rmsprop optimizer. Ignored by the ''adam optimizer. Default: ' + str(MOMENTUM) + '.')# 添加可选功能参数: --histograms; 该参数含义为: 直方图汇总存储标志parser.add_argument('--histograms', type=_str_to_bool, default=False,help='Whether to store histogram summaries. Default: False')# 添加可选功能参数: --gc_channels; 该参数含义为: 全局条件通道数量parser.add_argument('--gc_channels', type=int, default=None,help='Number of global condition channels. Default: None. Expecting: Int')# 添加可选功能参数: --max_checkpoints; 该参数含义为: 最大训练模型保存检查点数parser.add_argument('--max_checkpoints', type=int, default=MAX_TO_KEEP,help='Maximum amount of checkpoints that will be kept alive. Default: '+ str(MAX_TO_KEEP) + '.')# 把parser中设置的所有"add_argument"给返回到args子类实例中并返回return parser.parse_args()
validate_directories(args)
下面这段代码主要工作是:验证当前的几个目录是否冲突,将输入的目录参数规范化。
def validate_directories(args):""" 验证和整理与目录相关的参数 """# 验证接断# logdir 与 logdir_root 参数不能同时存在if args.logdir and args.logdir_root:raise ValueError("--logdir and --logdir_root cannot be ""specified at the same time.")# logdir 与 restore_from 参数不能同时存在if args.logdir and args.restore_from:raise ValueError("--logdir and --restore_from cannot be specified at the same ""time. This is to keep your previous model from unexpected ""overwrites.\n""Use --logdir_root to specify the root of the directory which ""will be automatically created with current date and time, or use ""only --logdir to just continue the training from the last ""checkpoint.")# 整理阶段# 为 logdir_root 参数赋予给定的值或是默认值logdir_root = args.logdir_rootif logdir_root is None:logdir_root = LOGDIR_ROOT# 为 logdir 参数赋予给定的值或是 logdir_root 参数的默认值logdir = args.logdirif logdir is None:logdir = get_default_logdir(logdir_root)print('Using default logdir: {}'.format(logdir))# 为 restore_from 参数赋予给定的值或是 logdir 参数的值restore_from = args.restore_fromif restore_from is None:# args.logdir and args.restore_from are exclusive,# so it is guaranteed the logdir here is newly created.restore_from = logdir# 将验证并整理好的目录参数打包返回return {'logdir': logdir,'logdir_root': args.logdir_root,'restore_from': restore_from}
get_default_logdir(logdir_root)
下面这段代码主要工作是:在给定的日志目录下,创建训练文件夹,再创建以带有当前日期时间的文件路径,并将该路径返回
def get_default_logdir(logdir_root):# 使用路径拼接函数 os.path.join() 在给定的目录下创建'train'目录# 进而创建以当前日期时间为名的子目录,格式为:{0:%Y-%m-%dT%H-%M-%S}logdir = os.path.join(logdir_root, 'train', STARTED_DATESTRING)return logdir
save(saver, sess, logdir, step)
这段代码主要工作是:将给定的训练结果、会话以及检查点保存到指定的文件路径下
def save(saver, sess, logdir, step):# 设置保存的模型文件名,将文件路径进行拼接model_name = 'model.ckpt'checkpoint_path = os.path.join(logdir, model_name)print('Storing checkpoint to {} ...'.format(logdir), end="")# 刷新缓冲区,保证正常输出sys.stdout.flush()# 若文件不存在则先创造文件if not os.path.exists(logdir):os.makedirs(logdir)# 保存模型saver.save(sess, checkpoint_path, global_step=step)print(' Done.')
load(saver, sess, logdir)
这段代码主要工作是:将指定路径下的模型训练结果恢复到当前会话
def load(saver, sess, logdir):print("Trying to restore saved checkpoints from {} ...".format(logdir),end="")# 从指定路径下返回训练模型以及检查点ckpt = tf.train.get_checkpoint_state(logdir)if ckpt:print(" Checkpoint found: {}".format(ckpt.model_checkpoint_path))# 找到模型,获取检查点global_step = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])print(" Global step was: {}".format(global_step))print(" Restoring...", end="")# 恢复最新检查点训练情况saver.restore(sess, ckpt.model_checkpoint_path)print(" Done.")# 返回检查点return global_stepelse:# 未找到模型,返回空值print(" No checkpoint found.")return None
本文还在持续更新中!
欢迎各位大佬交流讨论!
【项目实战】WaveNet 代码解析 —— train.py 【更新中】相关推荐
- Baidu Apollo代码解析之EM Planner中的QP Speed Optimizer 1
大家好,我已经把CSDN上的博客迁移到了知乎上,欢迎大家在知乎关注我的专栏慢慢悠悠小马车(https://zhuanlan.zhihu.com/duangduangduang).希望大家可以多多交流, ...
- ECharts数据可视化项目-大屏数据可视化【持续更新中】
ECharts数据可视化项目-大屏数据可视化[持续更新中] 文章目录 ECharts数据可视化项目-大屏数据可视化[持续更新中] 一. 数据可视化ECharts使用 二.技术栈 三.数据可视化 四.可 ...
- 【Android 逆向】使用 Python 代码解析 ELF 文件 ( PyCharm 中进行断点调试 | ELFFile 实例对象分析 )
文章目录 一.PyCharm 中进行断点调试 二.ELFFile 实例对象分析 一.PyCharm 中进行断点调试 在上一篇博客 [Android 逆向]使用 Python 代码解析 ELF 文件 ( ...
- pytorch YoLOV3 源码解析 train.py
train.py 总体分为三部分(不算import 库) 初始的一些设定 + train函数 + main函数 源码地址: https://github.com/ultralytics/yolov3 ...
- 连连看项目实战之三(解析配置表)
推荐阅读: 我的CSDN 我的博客园 QQ群:704621321 这是一款连连看,如果只能连连,那估计大家看看就没兴趣了.如果在里面加上一些场景式对话,那么可能就会有意思许多. 今天就带大家学习Uni ...
- 爱咖喱炸鸡店订餐系统--代码超细节解析版(更新中)
爱咖喱炸鸡店订餐系统 (一)需求说明 (二)操作步骤 Ⅰ.项目功能分析 Ⅱ.搭建基本框架 Ⅲ依次完善各个条件 Ⅳ运行,检查,是否符合业务 (三)业务描述 (一)需求说明 现今已进入网络时代,网上购物. ...
- 手游开发神器 cocos2d-x editor 教程聚合和代码下载(持续更新中)
--------------游戏基础教程篇-------------已完成--------- 一 cocos2d-x editor工具下载和基础教程JS篇: 一 手游开发神器 cocos2d-x ed ...
- DeamNet||训练代码学习train.py注释与解析
目录 1. 导入各种库,设置运行环境 2. 训练设置,各种参数 3.测试 1. 导入各种库,设置运行环境 from __future__ import print_function import os ...
- 机器学习完整项目实战附代码(二):探索型数据分析+特征工程+建模+报告
1. 项目背景: 1.1 项目目标: 使用提供的波士顿房屋租赁价格数据开发一个模型,该模型可以预测房屋租赁价格, 然后解释结果以找到最能预测的变量. 这是一个受监督的回归机器学习任务:给定一组包含目标 ...
最新文章
- 基于SSH实现的农家乐管理系统
- 数据结构的堆栈与内存中堆栈的区别
- PHP实现菱形与杨辉三角形【php趣味案例】
- 李沐老师的PyTorch 版《动手学深度学习》PDF 开源了(全中文,支持 Jupyter 运行)
- Acrobat Pro DC 教程,如何将文件合并为 PDF?
- 多种方法让网络共享资源自动映射
- CSDNITeye招贤榜
- 华为智能体发布,智能联接火了
- GetTickCount64的使用
- 公司与公司保密协议范本
- FastStone Capture7.0注册码
- 皖能合肥电厂电能量计量管理系统设计方案
- 正离子计算机扫描检测,扫描电子显微镜
- 颜值评分,图像识别,植物、动物、车型、菜品、logo识别
- Matlab多项式基本运算(1)( polyval和polyvalm的区别)
- MySQL学习第三弹——约束与多表查询详解
- 刀塔自走棋上线不到十分钟就被功击,几十万玩家登录不上
- GUTI,Globally Unique Temporary UE Identity,全球唯一临时UE标识。
- jQuery mouseover与mouseenter,mouseout与mouseleave的区别
- 异地恋的自愈系小故事:企鹅先生和北极熊小姐