声学模型训练(前向-后向算法)

前文讲述了语音识别声学模型训练算法,主要基于Viterbi-EM算法来估计模型中参数,但是该方法对于计算语料中帧对应状态的弧号存在计算复杂度指数级增加的问题,为解决上述问题,有学者提出用前向后向算法来估计模型中参数,其可以解决复杂度指数级增长的问题,主要理论及工程实现如下。

本文语音识别算法主要参考哥伦比亚大学语音识别课程提供的源码

首先给出整体模型训练流程,其代码如下:

#ifndef NO_MAIN_LOOP
void main_loop(const char** argv) {map<string, string> params;process_cmd_line(argv, params);Lab2FbMain mainObj(params);GmmStats gmmStats(mainObj.get_gmm_set(), params);while (mainObj.init_iter()) {gmmStats.clear();while (mainObj.init_utt()) {double logProb = forward_backward(mainObj.get_graph(), mainObj.get_gmm_probs(), mainObj.get_chart(),mainObj.get_gmm_counts(), mainObj.get_trans_counts());mainObj.finish_utt(logProb);gmmStats.update(mainObj.get_gmm_counts(), mainObj.get_feats());}mainObj.finish_iter();gmmStats.reestimate();}mainObj.finish();
}
#endif

上述代码给出了整体训练流程,现对其进行逐个讲解:

  1.代码初始化

map<string, string> params;
process_cmd_line(argv, params);

上述代码主要读取输入超参数,并对其进行处理,最后结果如下所示:

上述参数以此表示输入语音位置,chart图,解码图,输入gmm模型,迭代次数以及经过更新后gmm模型参数。

2. 前后向算法初始化

Lab2FbMain mainObj(params); 表示实例化出前后向算法,其中前后向算法构造函数输入为:

Lab2FbMain::Lab2FbMain(const map<string, string>& params): m_params(params),m_frontEnd(m_params),m_gmmSet(get_required_string_param(m_params, "in_gmm")),m_outGmmFile(get_required_string_param(m_params, "out_gmm")),m_transCountsFile(get_string_param(params, "trans_counts")),m_iterCnt(get_int_param(m_params, "iters", 1)),m_iterIdx(1),m_totFrmCnt(0),m_totLogProb(0.0) {if (!m_transCountsFile.empty()) {m_graph.read_word_sym_table(get_required_string_param(params, "trans_syms"));}
}

上述构造函数主要对传入参数进行处理并对其进行赋值操作;

m_frontEnd(m_params),主要用于对语音数据进行特征提取;

m_gmmSet(get_required_string_param(m_params, "in_gmm"))用于读取原始GMM参数,其中GMM可以对其进行初始化,也可以不使用初始化操作,通常情况下,我们对GMM中均值与方差进行初始化有以下策略:因为语音识别中常用对角矩阵表示均值与方差矩阵,然而方差矩阵对角线元素均为非负,均值元素可正可负,因此可以尝试使用单位对角阵初始化方差矩阵,使用零对角阵初始化均值矩阵,至于为什么使用不同的元素初始化该参数:个人理解主要是调试过程中便于对其进行区分且符合上述对角矩阵理论部分;

其他参数依次表示经过参数训练后gmm模型参数存储位置、转移概率存储位置、迭代次数、总帧数以及总似然值,至此实例化对象参数初始化完毕。

3.gmm模型状态初始化

GmmStats gmmStats(mainObj.get_gmm_set(), params);类主要对前向后向算法实例化对象与函数参数作为输入,将前后向算法结果输入至GmmStats类实例化对象gmmStats中,因为该类对于状态计算很重要,因此将该类主要函数与参数展示如下:

class GmmStats {
public:GmmStats(GmmSet& gmmSet, const map<string, string>& params = ParamsType());void clear();double update(const vector<GmmCount>& gmmCountList,const matrix<double>& feats);double add_gmm_count(unsigned gmmIdx, double posterior,const vector<double>& feats);void reestimate() const;private:map<string, string> m_params;/** Reference to associated GmmSet. **/GmmSet& m_gmmSet;/** Total counts of each Gaussian. **/vector<double> m_gaussCounts;/** First-order stats for each dim of each Gaussian. **/matrix<double> m_gaussStats1;/** Second-order stats for each dim of each Gaussian. **/matrix<double> m_gaussStats2;
};

前文关于声学模型训练部分对该类中三个主要参数进行了说明,其分别统计语料库中所有状态对应出现次数,均值统计以及方差参数进行统计,其中对gmmStats实例化对象初始化如下:

GmmStats::GmmStats(GmmSet& gmmSet, const map<string, string>& params) : m_params(params),m_gmmSet(gmmSet),m_gaussCounts(m_gmmSet.get_gaussian_count()),m_gaussStats1(m_gmmSet.get_gaussian_count(), m_gmmSet.get_dim_count()),m_gaussStats2(m_gmmSet.get_gaussian_count(), m_gmmSet.get_dim_count()) {clear();
}void GmmStats::clear() {fill(m_gaussCounts.begin(), m_gaussCounts.end(), 0.0);fill(m_gaussStats1.data().begin(), m_gaussStats1.data().end(), 0.0);fill(m_gaussStats2.data().begin(), m_gaussStats2.data().end(), 0.0);
}

上述gmm状态初始化主要用输入gmm模型统计量结果对其进行复制与初始化。

4.开始迭代

用while循环开始迭代更新gmm模型中参数,其初始化迭代代码如下:

bool Lab2FbMain::init_iter() {if (m_iterIdx > m_iterCnt) return false;m_transCounts.clear();m_audioStrm.clear();m_audioStrm.open(get_required_string_param(m_params, "audio_file").c_str());m_graphStrm.clear();m_graphStrm.open(get_required_string_param(m_params, "graph_file").c_str());m_totFrmCnt = 0;m_totLogProb = 0.0;return true;
}

该部分主要是对音频文件、解码图文件以及一些超参数进行初始化操作。

5.开始处理语料

同理用while循环遍历语料,并对其进行处理,其初始化操作代码如下:

bool Lab2FbMain::init_utt() {if (m_audioStrm.peek() == EOF) return false;m_idStr = read_float_matrix(m_audioStrm, m_inAudio);cout << "Processing utterance ID: " << m_idStr << endl;m_frontEnd.get_feats(m_inAudio, m_feats);if (m_feats.size2() != m_gmmSet.get_dim_count())throw runtime_error("Mismatch in GMM and feat dim.");if (m_graphStrm.peek() == EOF)throw runtime_error("Mismatch in number of audio files and FSM's.");m_graph.read(m_graphStrm, m_idStr);if (m_graph.get_gmm_count() > m_gmmSet.get_gmm_count())throw runtime_error("Mismatch in number of GMM's between ""FSM and GmmSet.");m_gmmSet.calc_gmm_probs(m_feats, m_gmmProbs);m_chart.resize(m_feats.size1() + 1, m_graph.get_state_count());m_chart.clear();if (m_graph.get_start_state() < 0)throw runtime_error("Graph has no start state.");m_gmmCountList.clear();return true;
}

上述代码核心部分前文已经有介绍,主要是用于语料库中语句的特征提取、解码图读取、计算当前帧属于各个状态的概率密度函数以及对chart格子图进行初始化操作,不懂的可以看前文博客对计算pdf部分与初始化格子图为什么多一帧的介绍。

6.前向后向算法
  接下来则是前向后向算法核心部分,先将算法代码列出如下所示:

double forward_backward(const Graph& graph, const matrix<double>& gmmProbs,matrix<FbCell>& chart, vector<GmmCount>& gmmCountList,map<int, double>& transCounts) {int frmCnt = chart.size1() - 1;int stateCnt = chart.size2();{for (int frmIdx = 0; frmIdx < (int)chart.size1(); ++frmIdx) {for (int stateIdx = 0; stateIdx < (int)chart.size2(); ++stateIdx) {chart(frmIdx, stateIdx).set_forw_log_prob(g_zeroLogProb);chart(frmIdx, stateIdx).set_back_log_prob(g_zeroLogProb);}}}int startState = graph.get_start_state();chart(0, startState).set_forw_log_prob(0);for (int frmIdx = 1; frmIdx <= frmCnt; ++frmIdx) {for (int stateIdx = 0; stateIdx < stateCnt; ++stateIdx) {int arcCnt = graph.get_arc_count(stateIdx);int arcId = graph.get_first_arc_id(stateIdx);for (int arcIdx = 0; arcIdx < arcCnt; ++arcIdx) {Arc arc;arcId = graph.get_arc(arcId, arc);int dstState = arc.get_dst_state();//arc.get_log_prob(),chart(frmIdx - 1, stateIdx).get_forw_log_prob(),//gmmProbs(frmIdx - 1, arc.get_gmm())三者分别表示为状态转移概率,//子图初始概率以及状态发射概率double logProb = arc.get_log_prob() +chart(frmIdx - 1, stateIdx).get_forw_log_prob() +gmmProbs(frmIdx - 1, arc.get_gmm());logProb = add_log_probs(vector<double>{logProb, chart(frmIdx, dstState).get_forw_log_prob()});chart(frmIdx, dstState).set_forw_log_prob(logProb);}}}//for (int frmidx = 0; frmidx <= frmCnt; ++frmidx) {//    for (int srcidx = 0; srcidx < stateCnt; ++srcidx) {//        cout << format(" %d") % chart(frmidx, srcidx).get_forw_log_prob();//    }//    cout << endl;//}//得到概率最大的终止状态的似然值及其终止状态序号;double uttLogProb = init_backward_pass(graph, chart);if (uttLogProb == g_zeroLogProb) return uttLogProb;for (int frmIdx = frmCnt - 1; frmIdx >= 0; --frmIdx) {for (int stateIdx = 0; stateIdx < stateCnt; ++stateIdx) {int arcCnt = graph.get_arc_count(stateIdx);int arcId = graph.get_first_arc_id(stateIdx);for (int arcIdx = 0; arcIdx < arcCnt; ++arcIdx) {Arc arc;arcId = graph.get_arc(arcId, arc);int dstState = arc.get_dst_state();double logProb = arc.get_log_prob() + gmmProbs(frmIdx, arc.get_gmm()) +chart(frmIdx + 1, dstState).get_back_log_prob();// NOTE!!! They are log prob but not regular prob, so use add_log_probs// but not +.// logProb += chart(frmIdx, stateIdx).get_back_log_prob();logProb = add_log_probs(vector<double>{logProb, chart(frmIdx, stateIdx).get_back_log_prob()});chart(frmIdx, stateIdx).set_back_log_prob(logProb);}}}//for (int frmIdx = 0; frmIdx <= frmCnt; ++frmIdx) {//    for (int srcIdx = 0; srcIdx < stateCnt; ++srcIdx) {//        cout << format(" %d") % chart(frmIdx, srcIdx).get_back_log_prob();//    }//    cout << endl;//}for (int frmIdx = frmCnt; frmIdx > 0; --frmIdx) {for (int stateIdx = 0; stateIdx < stateCnt; ++stateIdx) {int arcCnt = graph.get_arc_count(stateIdx);int arcId = graph.get_first_arc_id(stateIdx);for (int arcIdx = 0; arcIdx < arcCnt; ++arcIdx) {Arc arc;arcId = graph.get_arc(arcId, arc);int dstState = arc.get_dst_state();//logProb表示任意时刻到达某一帧某条弧上的概率,其采用前向后向算法进行计算;double logProb =chart(frmIdx - 1, stateIdx).get_forw_log_prob() +  // alpha_t-1_iarc.get_log_prob() +                               // a_i_jgmmProbs(frmIdx - 1, arc.get_gmm()) +              // b_j_(ot)chart(frmIdx, dstState).get_back_log_prob();       // beta_t_j//exp(logProb - uttLogProb)表示状态转移至终止状态时概率,即为弧上概率;gmmCountList.push_back(GmmCount(arc.get_gmm(), frmIdx - 1, exp(logProb - uttLogProb)));}}}return uttLogProb;
}

笔者对于前向后向算法的理解如下:

(1)该算法可以称之为评估问题,即已知声学模型参数(gmm参数)以及其观测序列(解码图),如何基于此计算该模型产生该序列的产产生的概率,即为对该声学模型结果进行打分;

(2)此时的chart格子图与之前Viterbi-EM不一致,其每个元素对应的数据类型为自定义格式类型,主要用于存储以及读取前向概率与后向概率,该变量具体参数如下:

class FbCell {
public:FbCell() : m_forwLogProb(g_zeroLogProb), m_backLogProb(g_zeroLogProb) {}explicit FbCell(int): m_forwLogProb(g_zeroLogProb), m_backLogProb(g_zeroLogProb) {}void set_forw_log_prob(double logProb) { m_forwLogProb = logProb; }void set_back_log_prob(double logProb) { m_backLogProb = logProb; }double get_forw_log_prob() const { return m_forwLogProb; }double get_back_log_prob() const { return m_backLogProb; }void printLogprobs() {cout << m_forwLogProb << " "<< m_backLogProb << endl;}
private:double m_forwLogProb;double m_backLogProb;
};

(3)add_log_probs()是存储前向参数与向参数的核心所在,虽然笔者在前期讲解过该参数的具体实现方法,其主要用于计算到此状态时最大的似然概率,不管是前向计算还是后向计算,均是如此,计算的结果均为到此状态时最大的概率,其存储的分别为前向概率与后巷概率,其变量类型如上所述;

(4)实际上前向算法计算至最后与前文Viterbi-EM算法是一致的,但是降低模型的复杂度进而引入了后向算法

7.终止语料读取

通过mainObj.finish_utt(logProb);函数终止前后向算法的语料计算,其具体代码实现如下:

void Lab2FbMain::finish_utt(double logProb) {m_totFrmCnt += m_feats.size1();m_totLogProb += logProb;double minPosterior = get_float_param(m_params, "min_posterior", 0.001);//m_gmmCountList存储结果为所有帧对应弧,包括弧序号,弧所属帧以及对应转移概率,转移概率由前后向算法进行计算;if (minPosterior > 0.0) {m_gmmCountListThresh.clear();for (int cntIdx = 0; cntIdx < (int)m_gmmCountList.size(); ++cntIdx) {if (m_gmmCountList[cntIdx].get_count() >= minPosterior) {m_gmmCountListThresh.push_back(m_gmmCountList[cntIdx]);}//m_gmmCountListThresh过滤掉m_gmmCountList中转移概率太小的弧单元,即为对弧进行剪枝;}m_gmmCountList.swap(m_gmmCountListThresh);}//sort(m_gmmCountList.begin(), m_gmmCountList.end());string chartFile = get_string_param(m_params, "chart_file");if (!chartFile.empty()) {ofstream chartStrm(chartFile.c_str());int frmCnt = m_feats.size1();int stateCnt = m_graph.get_state_count();matrix<double> matForwProbs(frmCnt + 1, stateCnt);matrix<double> matBackProbs(frmCnt + 1, stateCnt);for (int frmIdx = 0; frmIdx <= frmCnt; ++frmIdx) {for (int srcIdx = 0; srcIdx < stateCnt; ++srcIdx) {matForwProbs(frmIdx, srcIdx) =m_chart(frmIdx, srcIdx).get_forw_log_prob();matBackProbs(frmIdx, srcIdx) =m_chart(frmIdx, srcIdx).get_back_log_prob();}}write_float_matrix(chartStrm, matForwProbs, m_idStr + "_forw");write_float_matrix(chartStrm, matBackProbs, m_idStr + "_back");matrix<double> matPost(frmCnt, m_gmmSet.get_gmm_count());matPost.clear();int gmmCountCnt = m_gmmCountList.size();for (int cntIdx = 0; cntIdx < gmmCountCnt; ++cntIdx) {const GmmCount& gmmCount = m_gmmCountList[cntIdx];matPost(gmmCount.get_frame_index(), gmmCount.get_gmm_index()) +=gmmCount.get_count();}write_float_matrix(chartStrm, matPost, m_idStr + "_post");chartStrm.close();}//for (int i = 0; i < m_gmmCountList.size(); i++) {//    m_gmmCountList[i].printGmmCount();//}
}

必须说明的是为了剪枝gmm模型中弧上概率较小的部分,设置阈值(0.001)来控制gmm模型的权重,对于权重较小的gmm模型不计算统计量,这样可以大幅度降低模型中的参数数量,笔者曾对此进行测试过(未剪枝是2516条弧,剪枝后仅为168条弧),这样可以最大幅度降低模型中参数数量,而且对最后结果影响很小。

最终将前向概率与后向概率以及gmm模型的权重写入值chart图中,其中gmm权重在代码中表示为后验概率,最后后验概率表示如下图所示:

从上图可知,gmm权重大部分为0,这样大大减少了模型计算量且便于参数计算,非常值得推荐。

8. 声学模型参数赋值

用gmmStats.update(mainObj.get_gmm_counts(), mainObj.get_feats());函数基于特征用gmm统计量对gmm模型进行状态个数、均值以及方差进行复制,前文声学模型Viterbi-EM对此进行详细的说明,不懂的读者可以参考下。

9.终止迭代

mainObj.finish_iter();函数控制迭代次数,其具体实现如下:

void Lab2FbMain::finish_iter() {m_audioStrm.close();m_graphStrm.close();cout << format("Iteration %d: %.6f logprob/frame (%d frames)") % m_iterIdx %(m_totFrmCnt ? m_totLogProb / m_totFrmCnt : 0.0) % m_totFrmCnt<< endl;++m_iterIdx;
}

该函数主要用于打印总体似然概率以及文件流关闭。

10.参数更新

前文对此亦进行介绍,现给出参数更新代码如下:

void GmmStats::reestimate() const {int gaussCnt = m_gmmSet.get_gaussian_count();int dimCnt = m_gmmSet.get_dim_count();double occupancy, mean, var;for (int gaussIdx = 0; gaussIdx < gaussCnt; ++gaussIdx) {occupancy = m_gaussCounts[gaussIdx];for (int dimIdx = 0; dimIdx < dimCnt; ++dimIdx) {//均值与方差重新估计,mean = m_gaussStats1(gaussIdx, dimIdx) / occupancy;var = m_gaussStats2(gaussIdx, dimIdx) / occupancy - mean * mean;m_gmmSet.set_gaussian_mean(gaussIdx, dimIdx, mean);m_gmmSet.set_gaussian_var(gaussIdx, dimIdx, var);}}
}

11.终止参数估计

最终将更新的参数存储至输出的gmm模型中,其代码如下:

void Lab2FbMain::finish() {m_gmmSet.write(m_outGmmFile);if (!m_transCountsFile.empty()) {ofstream countStrm(m_transCountsFile.c_str());for (map<int, double>::const_iterator elemIter = m_transCounts.begin();elemIter != m_transCounts.end(); ++elemIter)countStrm << format("%s %.3f\n") %m_graph.get_word_sym_table().get_str(elemIter->first) %elemIter->second;countStrm.close();}
}

前文对此都介绍过,本文对此进行代码说明,

  至此语音识别基于前向后向算法介绍完毕。

语音识别—声学模型训练(前向-后向算法)相关推荐

  1. 隐马尔科夫模型(前向后向算法、鲍姆-韦尔奇算法、维特比算法)

    隐马尔科夫模型(前向后向算法.鲍姆-韦尔奇算法.维特比算法) 概率图模型是一类用图来表达变量相关关系的概率模型.它以图为表示工具,最常见的是用一个结点表示一个或一组随机变量,结点之间的变表是变量间的概 ...

  2. 机器学习算法 10 —— HMM模型(马尔科夫链、前向后向算法、维特比算法解码、hmmlearn)

    文章目录 系列文章 隐马尔科夫模型 HMM 1 马尔科夫链 1.1 简介 1.2 经典举例 2 HMM简介 2.1 简单案例 2.2 案例进阶 问题二解决 问题一解决 问题三解决 3 HMM模型基础 ...

  3. 机器学习算法总结(七)——隐马尔科夫模型(前向后向算法、鲍姆-韦尔奇算法、维特比算法)...

    概率图模型是一类用图来表达变量相关关系的概率模型.它以图为表示工具,最常见的是用一个结点表示一个或一组随机变量,结点之间的变表是变量间的概率相关关系.根据边的性质不同,可以将概率图模型分为两类:一类是 ...

  4. 隐马尔科夫模型(HMMs)之五:维特比算法及前向后向算法

    维特比算法(Viterbi Algorithm) 找到可能性最大的隐藏序列 通常我们都有一个特定的HMM,然后根据一个可观察序列去找到最可能生成这个可观察序列的隐藏序列. 1.穷举搜索 我们可以在下图 ...

  5. HMM——前向后向算法

    1. 前言 解决HMM的第二个问题:学习问题, 已知观测序列,需要估计模型参数,使得在该模型下观测序列 P(观测序列 | 模型参数)最大,用的是极大似然估计方法估计参数. 根据已知观测序列和对应的状态 ...

  6. HMM前向算法,维比特算法,后向算法,前向后向算法代码

    typedef struct { int N; /* 隐藏状态数目;Q={1,2,-,N} */ int M; /* 观察符号数目; V={1,2,-,M}*/ double **A; /* 状态转移 ...

  7. 机器学习算法拾遗:(七)隐马尔科夫模型(前向后向算法、鲍姆-韦尔奇算法、维特比算法)

    1.隐马尔科夫模型HMM 隐马尔科夫模型的图结构如下 从上图中主要有两个信息:一是观测变量xi 仅仅与与之对应的状态变量yi 有关:二是当前的状态变量yi 仅仅与它的前一个状态变量yi-1 有关. 隐 ...

  8. 图像处理去噪方法的c语言实验,基于一阶前向后向算法的全变分图像去噪方法与流程...

    本发明涉及图像处理技术领域,具体涉及一种基于一阶前向后向算法的全变分图像去噪方法. 背景技术: 图像复原(imagerestoration)即利用退化过程的先验知识,去恢复已被退化图像的本来面目.图像 ...

  9. 机器学习笔记(十四)——HMM估计问题和前向后向算法

    一.隐马尔科夫链的第一个基本问题 估计问题:给定一个观察序列O=O1O2-OTO=O_1O_2\dots O_T和模型u=(A,B,π)u = (\boldsymbol{A,B,\pi}),如何快速地 ...

  10. 隐马尔可夫(HMM)、前/后向算法、Viterbi算法

    HMM的模型  图1 如上图所示,白色那一行描述由一个隐藏的马尔科夫链生成不可观测的状态随机序列,蓝紫色那一行是各个状态生成可观测的随机序列 话说,上面也是个贝叶斯网络,而贝叶斯网络中有这么一种,如下 ...

最新文章

  1. 肠·道 | 刘洋彧:重建肠道菌群生态网络
  2. 如何快速将微信公众号留言嵌入到CSDN博文中?
  3. swift 组件化_京东商城订单模块基于 Swift 的改造方案与实践
  4. mysql连接报错Access denied for user ‘root‘@‘localhost‘
  5. 神秘的subsys_initcall【转】
  6. Transient关键字的使用
  7. 互联网巨头们的「中台战事」
  8. linux7 重新开始udev,Redhat Linux 7 创建UDEV设备(示例代码)
  9. 【数据结构】B树的理解
  10. (2)连续存储数组的方法
  11. mysql为用户部分授权,MYSQL为用户授权
  12. python计算出nan_学习笔记0522:Tensorflow训练模型出现loss是nan的问题排查
  13. STVP烧录HEX文件方法
  14. 微位科技李子阳:哈耶克—未来的价值单位
  15. 辽宁电网容载比问题及合理取值研究
  16. 双系统安装 Ubuntu 18.04 以及删除双系统中的 Ubuntu 的方法
  17. spring-test部分翻译
  18. 不打开Wifi获取Mac地址
  19. 苹果App Store审核指南中文翻译(更新)
  20. 四、青龙面板 Nvjdc(诺兰)安装教程

热门文章

  1. Scala学习笔记(1)-基本类型归纳
  2. http://trans.godict.com/index.php
  3. 设置qgraphicsitem原点_QT QGraphicsScene设置原点左下角
  4. 【Qt】警告Missing reference in range-for with non trivial type
  5. python 表情包爬虫
  6. 开源框架Banner实现图片轮播
  7. 苹果手机突然闪退的7个原因及修复方法
  8. 模电学习笔记(十三)——控制直流偏执电路
  9. Unity中采用二进制存档与读档
  10. 校园网连不上,火绒检测dns错误但修复不了,360直接搞定,nice!