WaveNet 代码解析 —— train.py

文章目录

  • WaveNet 代码解析 —— train.py
    • 简介
    • 代码解析
      • 全局变量解析
    • 函数解析
      • main()
      • get_arguments()
      • validate_directories(args)
      • get_default_logdir(logdir_root)
      • save(saver, sess, logdir, step)
      • load(saver, sess, logdir)

简介

本项目是一个基于 WaveNet 生成神经网络体系结构的语音合成项目,它是使用 TensorFlow 实现的(项目地址)。
       
       WaveNet神经网络体系结构能直接生成原始音频波形,在文本到语音和一般音频生成方面显示了出色的结果(详情请参阅 WaveNet 的详细介绍)。
       
       由于 WaveNet 项目较大,代码较多。为了方便学习与整理,将按照工程文件的结构依次介绍。
       
       本文将介绍项目中的 train.py 文件:基于VCTK语料库的小波网络训练脚本。
       
       本脚本使用来自VCTK语料库的数据,用WaveNet训练网络(下载地址)

代码解析

全局变量解析

以下变量主要作为各功能参数的默认值,辅助开发人员对训练过程进行配置。

     BATCH_SIZE = 1                             # 一批训练集中,样本音频的数量DATA_DIRECTORY = './VCTK-Corpus'          # 下载的VCTK数据集的路径LOGDIR_ROOT = './logdir'                  # 训练日志的路径CHECKPOINT_EVERY = 50                     # 保存训练模型的检查点数量NUM_STEPS = int(1e5)                     # 训练的总次数LEARNING_RATE = 1e-3                       # 学习率WAVENET_PARAMS = './wavenet_params.json'    # WaveNet 模型的相关参数路径STARTED_DATESTRING = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())               # 当前日期格式化SAMPLE_SIZE = 100000                      # 样本数量大小L2_REGULARIZATION_STRENGTH = 0             # L2正则化中的系数SILENCE_THRESHOLD = 0.3                     # 音量阈值大小EPSILON = 0.001                                # 精度设置MOMENTUM = 0.9                               # 优化器动量MAX_TO_KEEP = 5                             # 保存的最大检查点数量METADATA = False                           # 高级调试信息存储标志

函数解析

main()

下面这段代码是 train.py 的主函数,主要作用是提取样本进行预处理、创建网络、训练模型、存取模型以及记录日志。

 def main():# 解析命令行功能参数args = get_arguments()try:# 验证并整理与目录有关的参数directories = validate_directories(args)except ValueError as e:print("Some arguments are wrong:")print(str(e))return# 将整理好的文件路径赋给相应变量logdir = directories['logdir']restore_from = directories['restore_from']# 即使我们恢复了模型,如果训练的模型被写入到任意位置,我们也会把它当作新的训练is_overwritten_training = logdir != restore_from# 使用 josn 库的 load 函数读取 WaveNet 模型相关参数,将 json 格式的字符转换为 dictwith open(args.wavenet_params, 'r') as f:wavenet_params = json.load(f)# 创建线程协调器,多线程协调器相关知识可参考文章地址如下:# https://blog.csdn.net/weixin_42721167/article/details/112795491coord = tf.train.Coordinator()# 从VCTK数据集中加载原始波形with tf.name_scope('create_inputs'):# 允许通过指定接近零的阈值跳过静默修剪silence_threshold = args.silence_threshold if args.silence_threshold > \EPSILON else Nonegc_enabled = args.gc_channels is not None# 通用的后台音频读取器,对音频文件进行预处理并将它们排队到TensorFlow队列中reader = AudioReader(args.data_dir,coord,sample_rate=wavenet_params['sample_rate'],gc_enabled=gc_enabled,receptive_field=WaveNetModel.calculate_receptive_field(wavenet_params["filter_width"],wavenet_params["dilations"],wavenet_params["scalar_input"],wavenet_params["initial_filter_width"]),sample_size=args.sample_size,silence_threshold=silence_threshold)# 准备好的音频出队列audio_batch = reader.dequeue(args.batch_size)if gc_enabled:gc_id_batch = reader.dequeue_gc(args.batch_size)else:gc_id_batch = None# 创建 WaveNet 网络net = WaveNetModel(batch_size=args.batch_size,dilations=wavenet_params["dilations"],filter_width=wavenet_params["filter_width"],residual_channels=wavenet_params["residual_channels"],dilation_channels=wavenet_params["dilation_channels"],skip_channels=wavenet_params["skip_channels"],quantization_channels=wavenet_params["quantization_channels"],use_biases=wavenet_params["use_biases"],scalar_input=wavenet_params["scalar_input"],initial_filter_width=wavenet_params["initial_filter_width"],histograms=args.histograms,global_condition_channels=args.gc_channels,global_condition_cardinality=reader.gc_category_cardinality)# 验证 l2 正则化系数if args.l2_regularization_strength == 0:args.l2_regularization_strength = None# 创建一个 WaveNet 网络并返回自动编码损耗loss = net.loss(input_batch=audio_batch,global_condition_batch=gc_id_batch,l2_regularization_strength=args.l2_regularization_strength)# 创建对应的优化器optimizer = optimizer_factory[args.optimizer](learning_rate=args.learning_rate,momentum=args.momentum)# 返回使用 trainable=True 创建的所有变量trainable = tf.trainable_variables()optim = optimizer.minimize(loss, var_list=trainable)# 设置TensorBoard的日志记录writer = tf.summary.FileWriter(logdir)writer.add_graph(tf.get_default_graph())# 收集关于训练的元信息run_metadata = tf.RunMetadata()summaries = tf.summary.merge_all()# 建立会话sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))# 初始化变量init = tf.global_variables_initializer()sess.run(init)# 存储模型检查点的保护程序# 在创建这个 Saver 对象的时候, max_to_keep 参数表示要保留的最近检查点文件的最大数量,创建新文件时,将删除旧文件saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=args.max_checkpoints)try:# 恢复训练模型,获取训练步数saved_global_step = load(saver, sess, restore_from)if is_overwritten_training or saved_global_step is None:# 第一个训练步骤将是 saved_global_step + 1,因此我们在这里输入-1表示新的或覆盖的训练saved_global_step = -1except:print("Something went wrong while restoring checkpoint. ""We will terminate training to avoid accidentally overwriting ""the previous model.")raise# 开启入队线程启动器,详细介绍可参考这篇博客:# https://blog.csdn.net/weixin_42721167/article/details/112795491threads = tf.train.start_queue_runners(sess=sess, coord=coord)reader.start_threads(sess)step = Nonelast_saved_step = saved_global_steptry:# 从恢复模型的节点处开始训练for step in range(saved_global_step + 1, args.num_steps):# 获取当前时间start_time = time.time()# 当存储标志为 true 且训练次数为50的倍数时存储调试信息if args.store_metadata and step % 50 == 0:# 缓慢运行,存储额外的调试信息print('Storing metadata')# RunOptions提供配置参数,供SessionRun调用时使用run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)# 计算日志与自动编码的损失summary, loss_value, _ = sess.run([summaries, loss, optim],options=run_options,run_metadata=run_metadata)# 调用train_writer的add_summary方法将训练过程以及训练步数保存 writer.add_summary(summary, step)# 记录CPU/内存使用情况writer.add_run_metadata(run_metadata,'step_{:04d}'.format(step))# Tensorflow的Timeline模块是用于描述张量图一个工具,可以记录在会话中每个操作执行时间和资源分配及消耗的情况tl = timeline.Timeline(run_metadata.step_stats)# 加载文件路径,打开文件,写入日志timeline_path = os.path.join(logdir, 'timeline.trace')with open(timeline_path, 'w') as f:f.write(tl.generate_chrome_trace_format(show_memory=True))else:# 在不保存模型的训练步数里,保存训练日志到 Tensorboardsummary, loss_value, _ = sess.run([summaries, loss, optim])writer.add_summary(summary, step)# 计算并打印训练一次的时间与结果duration = time.time() - start_timeprint('step {:d} - loss = {:.3f}, ({:.3f} sec/step)'.format(step, loss_value, duration))# 每隔输入的检查点间隔存储一次训练模型if step % args.checkpoint_every == 0:save(saver, sess, logdir, step)last_saved_step = stepexcept KeyboardInterrupt:# 在 ctrl+C 显示之后引入一个换行符,这样保存消息就在它自己的行上了print()finally:# 若训练到了更多步if step > last_saved_step:save(saver, sess, logdir, step)coord.request_stop()coord.join(threads)

get_arguments()

下面这段代码主要是获取命令行参数。
        运用 python 中的 argparse 模块对我们输入的命令行进行解析。

 def get_arguments():def _str_to_bool(s):""" 将string转换为bool """""" 传入的字符串被限制为'true'或'false' """if s.lower() not in ['true', 'false']:raise ValueError('Argument needs to be a ''boolean, got {}'.format(s))return {'true': True, 'false': False}[s.lower()]# 创建解析器,解析的功能参数作为 WaveNet 的实例parser = argparse.ArgumentParser(description='WaveNet example network')# 添加可选功能参数: --batch_size; 该参数含义为: 一次要处理的 wav 文件数量parser.add_argument('--batch_size', type=int, default=BATCH_SIZE,help='How many wav files to process at once. Default: ' + str(BATCH_SIZE) + '.')# 添加可选功能参数: --data_dir; 该参数含义为: VCTK数据集的文件路径parser.add_argument('--data_dir', type=str, default=DATA_DIRECTORY,help='The directory containing the VCTK corpus.')# 添加可选功能参数: --store_metadata; 该参数含义为: 高级调试信息存储标志parser.add_argument('--store_metadata', type=bool, default=METADATA,help='Whether to store advanced debugging information ''(execution time, memory consumption) for use with ''TensorBoard. Default: ' + str(METADATA) + '.')# 添加可选功能参数: --logdir; 该参数含义为: 存储 TensorBoard 日志信息的文件路径;# 需要注意: 该参数不能与'--logdir_root'或'--restore_from'一起使用parser.add_argument('--logdir', type=str, default=None,help='Directory in which to store the logging ''information for TensorBoard. ''If the model already exists, it will restore ''the state and will continue training. ''Cannot use with --logdir_root and --restore_from.')# 添加可选功能参数: --logdir_root; 该参数含义为: 放置日志输出和生成模型的文件路径,存放在带有日期的子目录下# 需要注意: 该参数不能与'--logdir'一起使用parser.add_argument('--logdir_root', type=str, default=None,help='Root directory to place the logging ''output and generated model. These are stored ''under the dated subdirectory of --logdir_root. ''Cannot use with --logdir.')# 添加可选功能参数: --restore_from; 该参数含义为: 恢复模型的目录,能创建带有日期的子目录# 需要注意: 该参数不能与'--logdir'一起使用parser.add_argument('--restore_from', type=str, default=None,help='Directory in which to restore the model from. ''This creates the new model under the dated directory ''in --logdir_root. ''Cannot use with --logdir.')# 添加可选功能参数: --checkpoint_every; 该参数含义为: 存放训练模型的检查点间隔parser.add_argument('--checkpoint_every', type=int,default=CHECKPOINT_EVERY,help='How many steps to save each checkpoint after. Default: ' + str(CHECKPOINT_EVERY) + '.')# 添加可选功能参数: --num_steps; 该参数含义为: 训练的次数parser.add_argument('--num_steps', type=int, default=NUM_STEPS,help='Number of training steps. Default: ' + str(NUM_STEPS) + '.')# 添加可选功能参数: --learning_rate; 该参数含义为: 训练的学习率parser.add_argument('--learning_rate', type=float, default=LEARNING_RATE,help='Learning rate for training. Default: ' + str(LEARNING_RATE) + '.')# 添加可选功能参数: --wavenet_params; 该参数含义为: WaveNet 模型的相关参数parser.add_argument('--wavenet_params', type=str, default=WAVENET_PARAMS,help='JSON file with the network parameters. Default: ' + WAVENET_PARAMS + '.')# 添加可选功能参数: --sample_size; 该参数含义为: 使用的样本数量parser.add_argument('--sample_size', type=int, default=SAMPLE_SIZE,help='Concatenate and cut audio samples to this many ''samples. Default: ' + str(SAMPLE_SIZE) + '.')# 添加可选功能参数: --l2_regularization_strength; 该参数含义为: L2正则化的系数parser.add_argument('--l2_regularization_strength', type=float,default=L2_REGULARIZATION_STRENGTH,help='Coefficient in the L2 regularization. ''Default: False')# 添加可选功能参数: --silence_threshold; 该参数含义为: 音量阈值限制parser.add_argument('--silence_threshold', type=float,default=SILENCE_THRESHOLD,help='Volume threshold below which to trim the start ''and the end from the training set samples. Default: ' + str(SILENCE_THRESHOLD) + '.')# 添加可选功能参数: --optimizer; 该参数含义为: 优化器选择parser.add_argument('--optimizer', type=str, default='adam',choices=optimizer_factory.keys(),help='Select the optimizer specified by this option. Default: adam.')# 添加可选功能参数: --momentum; 该参数含义为: 优化器动量大小parser.add_argument('--momentum', type=float,default=MOMENTUM, help='Specify the momentum to be ''used by sgd or rmsprop optimizer. Ignored by the ''adam optimizer. Default: ' + str(MOMENTUM) + '.')# 添加可选功能参数: --histograms; 该参数含义为: 直方图汇总存储标志parser.add_argument('--histograms', type=_str_to_bool, default=False,help='Whether to store histogram summaries. Default: False')# 添加可选功能参数: --gc_channels; 该参数含义为: 全局条件通道数量parser.add_argument('--gc_channels', type=int, default=None,help='Number of global condition channels. Default: None. Expecting: Int')# 添加可选功能参数: --max_checkpoints; 该参数含义为: 最大训练模型保存检查点数parser.add_argument('--max_checkpoints', type=int, default=MAX_TO_KEEP,help='Maximum amount of checkpoints that will be kept alive. Default: '+ str(MAX_TO_KEEP) + '.')# 把parser中设置的所有"add_argument"给返回到args子类实例中并返回return parser.parse_args()

validate_directories(args)

下面这段代码主要工作是:验证当前的几个目录是否冲突,将输入的目录参数规范化。

 def validate_directories(args):""" 验证和整理与目录相关的参数 """# 验证接断# logdir 与 logdir_root 参数不能同时存在if args.logdir and args.logdir_root:raise ValueError("--logdir and --logdir_root cannot be ""specified at the same time.")# logdir 与 restore_from 参数不能同时存在if args.logdir and args.restore_from:raise ValueError("--logdir and --restore_from cannot be specified at the same ""time. This is to keep your previous model from unexpected ""overwrites.\n""Use --logdir_root to specify the root of the directory which ""will be automatically created with current date and time, or use ""only --logdir to just continue the training from the last ""checkpoint.")# 整理阶段# 为 logdir_root 参数赋予给定的值或是默认值logdir_root = args.logdir_rootif logdir_root is None:logdir_root = LOGDIR_ROOT# 为 logdir 参数赋予给定的值或是 logdir_root 参数的默认值logdir = args.logdirif logdir is None:logdir = get_default_logdir(logdir_root)print('Using default logdir: {}'.format(logdir))# 为 restore_from 参数赋予给定的值或是 logdir 参数的值restore_from = args.restore_fromif restore_from is None:# args.logdir and args.restore_from are exclusive,# so it is guaranteed the logdir here is newly created.restore_from = logdir# 将验证并整理好的目录参数打包返回return {'logdir': logdir,'logdir_root': args.logdir_root,'restore_from': restore_from}

get_default_logdir(logdir_root)

下面这段代码主要工作是:在给定的日志目录下,创建训练文件夹,再创建以带有当前日期时间的文件路径,并将该路径返回

 def get_default_logdir(logdir_root):# 使用路径拼接函数 os.path.join() 在给定的目录下创建'train'目录# 进而创建以当前日期时间为名的子目录,格式为:{0:%Y-%m-%dT%H-%M-%S}logdir = os.path.join(logdir_root, 'train', STARTED_DATESTRING)return logdir

save(saver, sess, logdir, step)

这段代码主要工作是:将给定的训练结果、会话以及检查点保存到指定的文件路径下

 def save(saver, sess, logdir, step):# 设置保存的模型文件名,将文件路径进行拼接model_name = 'model.ckpt'checkpoint_path = os.path.join(logdir, model_name)print('Storing checkpoint to {} ...'.format(logdir), end="")# 刷新缓冲区,保证正常输出sys.stdout.flush()# 若文件不存在则先创造文件if not os.path.exists(logdir):os.makedirs(logdir)# 保存模型saver.save(sess, checkpoint_path, global_step=step)print(' Done.')

load(saver, sess, logdir)

这段代码主要工作是:将指定路径下的模型训练结果恢复到当前会话

 def load(saver, sess, logdir):print("Trying to restore saved checkpoints from {} ...".format(logdir),end="")# 从指定路径下返回训练模型以及检查点ckpt = tf.train.get_checkpoint_state(logdir)if ckpt:print("  Checkpoint found: {}".format(ckpt.model_checkpoint_path))# 找到模型,获取检查点global_step = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])print("  Global step was: {}".format(global_step))print("  Restoring...", end="")# 恢复最新检查点训练情况saver.restore(sess, ckpt.model_checkpoint_path)print(" Done.")# 返回检查点return global_stepelse:# 未找到模型,返回空值print(" No checkpoint found.")return None

本文还在持续更新中!
       欢迎各位大佬交流讨论!

【项目实战】WaveNet 代码解析 —— train.py 【更新中】相关推荐

  1. Baidu Apollo代码解析之EM Planner中的QP Speed Optimizer 1

    大家好,我已经把CSDN上的博客迁移到了知乎上,欢迎大家在知乎关注我的专栏慢慢悠悠小马车(https://zhuanlan.zhihu.com/duangduangduang).希望大家可以多多交流, ...

  2. ECharts数据可视化项目-大屏数据可视化【持续更新中】

    ECharts数据可视化项目-大屏数据可视化[持续更新中] 文章目录 ECharts数据可视化项目-大屏数据可视化[持续更新中] 一. 数据可视化ECharts使用 二.技术栈 三.数据可视化 四.可 ...

  3. 【Android 逆向】使用 Python 代码解析 ELF 文件 ( PyCharm 中进行断点调试 | ELFFile 实例对象分析 )

    文章目录 一.PyCharm 中进行断点调试 二.ELFFile 实例对象分析 一.PyCharm 中进行断点调试 在上一篇博客 [Android 逆向]使用 Python 代码解析 ELF 文件 ( ...

  4. pytorch YoLOV3 源码解析 train.py

    train.py 总体分为三部分(不算import 库) 初始的一些设定 + train函数 + main函数 源码地址: https://github.com/ultralytics/yolov3 ...

  5. 连连看项目实战之三(解析配置表)

    推荐阅读: 我的CSDN 我的博客园 QQ群:704621321 这是一款连连看,如果只能连连,那估计大家看看就没兴趣了.如果在里面加上一些场景式对话,那么可能就会有意思许多. 今天就带大家学习Uni ...

  6. 爱咖喱炸鸡店订餐系统--代码超细节解析版(更新中)

    爱咖喱炸鸡店订餐系统 (一)需求说明 (二)操作步骤 Ⅰ.项目功能分析 Ⅱ.搭建基本框架 Ⅲ依次完善各个条件 Ⅳ运行,检查,是否符合业务 (三)业务描述 (一)需求说明 现今已进入网络时代,网上购物. ...

  7. 手游开发神器 cocos2d-x editor 教程聚合和代码下载(持续更新中)

    --------------游戏基础教程篇-------------已完成--------- 一 cocos2d-x editor工具下载和基础教程JS篇: 一 手游开发神器 cocos2d-x ed ...

  8. DeamNet||训练代码学习train.py注释与解析

    目录 1. 导入各种库,设置运行环境 2. 训练设置,各种参数 3.测试 1. 导入各种库,设置运行环境 from __future__ import print_function import os ...

  9. 机器学习完整项目实战附代码(二):探索型数据分析+特征工程+建模+报告

    1. 项目背景: 1.1 项目目标: 使用提供的波士顿房屋租赁价格数据开发一个模型,该模型可以预测房屋租赁价格, 然后解释结果以找到最能预测的变量. 这是一个受监督的回归机器学习任务:给定一组包含目标 ...

最新文章

  1. 基于SSH实现的农家乐管理系统
  2. 数据结构的堆栈与内存中堆栈的区别
  3. PHP实现菱形与杨辉三角形【php趣味案例】
  4. 李沐老师的PyTorch 版《动手学深度学习》PDF 开源了(全中文,支持 Jupyter 运行)
  5. Acrobat Pro DC 教程,如何将文件合并为 PDF?
  6. 多种方法让网络共享资源自动映射
  7. CSDNITeye招贤榜
  8. 华为智能体发布,智能联接火了
  9. GetTickCount64的使用
  10. 公司与公司保密协议范本
  11. FastStone Capture7.0注册码
  12. 皖能合肥电厂电能量计量管理系统设计方案
  13. 正离子计算机扫描检测,扫描电子显微镜
  14. 颜值评分,图像识别,植物、动物、车型、菜品、logo识别
  15. Matlab多项式基本运算(1)( polyval和polyvalm的区别)
  16. MySQL学习第三弹——约束与多表查询详解
  17. 刀塔自走棋上线不到十分钟就被功击,几十万玩家登录不上
  18. GUTI,Globally Unique Temporary UE Identity,全球唯一临时UE标识。
  19. jQuery mouseover与mouseenter,mouseout与mouseleave的区别
  20. 异地恋的自愈系小故事:企鹅先生和北极熊小姐

热门文章

  1. vue cli5降级为4
  2. 大数据 用户画像基础
  3. XnView 批量调整大小 PNG 保持透明度
  4. Linux 块设备与字条设备
  5. dot.js嵌套html文件,doT.js实现混合布局,判断,数组,函数使用,取模,数组嵌套...
  6. 拼多多店铺怎么选择资源位,怎么报名活动,什么活动对店铺利益最大?
  7. Java语言连接MongoDB常用的方法
  8. 蓝牙 舵狗 openmv通信相关
  9. namomo 每日一题 207 拆方块
  10. 天纵智能软件快速开发中国地图统计分析插件