fasttext源码学习(2)–模型压缩

前言

fasttext模型压缩的很明显,精度却降低不多,其网站上提供的语种识别模型,压缩前后的对比就是例证,压缩前126M,压缩后917K。太震惊了,必须学习一下。看文档介绍用到权重量化(weight quantization)和特征选择(feature selection),下面结合代码学习下。

说明:文章中代码皆为简化版,为突出重点,简化了逻辑,原版代码需到官方网页下载。

一 特征选择

一开始以为fasttext会用到比较复杂的特征选择算法,直到看到代码才差点闪了腰。。。fasttext用的就是kbest,剩下的全砍掉,就是这么简单直接。

void FastText::quantize(const Args& qargs, const TrainCallback& callback) {if (qargs.cutoff > 0 && qargs.cutoff < input->size(0)) {auto idx = selectEmbeddings(qargs.cutoff);dict_->prune(idx);  // 剪枝(词典重新计算)if (qargs.retrain) { // 重新训练startThreads(callback);}}
}std::vector<int32_t> FastText::selectEmbeddings(int32_t cutoff) const {std::shared_ptr<DenseMatrix> input =std::dynamic_pointer_cast<DenseMatrix>(input_);Vector norms(input->size(0));input->l2NormRow(norms); // [1] 正则化std::vector<int32_t> idx(input->size(0), 0);std::iota(idx.begin(), idx.end(), 0);std::sort(idx.begin(), idx.end(), [&norms, eosid](size_t i1, size_t i2) {return (eosid != i2 && norms[i1] > norms[i2]);}); // [2] 按正则化值排序idx.erase(idx.begin() + cutoff, idx.end()); // [3] 保留指定数目的id return idx;
}

从上面代码可以看出,selectEmbeddings主要干的就是按正则化值排序,然后只保留指定数目的行。默认保留是5万,结合上一章的DensMatrix的行数是200万+词语个数,可以看出这一步的特征选择最少能压缩到原来的1/40,压缩比很可观。

值得注意的是dict_->prune(idx),这一步必不可少,因为fasttext是基于hash映射来计算矩阵下标的,特征被筛选后,相应的词典也需要清理压缩,对应关系也需要刷新。

void Dictionary::prune(std::vector<int32_t>& idx) {std::vector<int32_t> words, ngrams;// 按id选取对应的word、ngramfor (auto it = idx.cbegin(); it != idx.cend(); ++it) {if (*it < nwords_) {words.push_back(*it);} else {ngrams.push_back(*it);}}// 按id排序std::sort(words.begin(), words.end());idx = words;// 计算被筛选的ngram的hash与id的对应关系// ngram的对应关系原本不需存储,筛选后,由于对应关系的变化,导致需要存储pruneidx_if (ngrams.size() != 0) {int32_t j = 0;for (const auto ngram : ngrams) {pruneidx_[ngram - nwords_] = j;j++;}idx.insert(idx.end(), ngrams.begin(), ngrams.end());}pruneidx_size_ = pruneidx_.size();int32_t j = 0;// 筛选过的word往前移,为后续的清除做准备// 同时,重新计算hash与id的对应关系for (int32_t i = 0; i < words_.size(); i++) {if (getType(i) == entry_type::label ||(j < words.size() && words[j] == i)) {words_[j] = words_[i];word2int_[find(words_[j].word)] = j;j++;}}nwords_ = words.size();size_ = nwords_ + nlabels_;// 移除多余的wordwords_.erase(words_.begin() + size_, words_.end());// 重新初始化各word的ngrams,重新计算ngram的下标initNgrams();
}

要理解这段代码,必须先理解fasttext的数据存储即Dictionary和DenseMatrix那一部分,否则会非常晕。

二 权重量化

量化一般是将大的数值表示变为小的数值表示,比如从float变为byte。而fasttext采用了另外一种方法,product quantization。简单来说,就是将向量分割为更小的子向量,再使用kmeans算法,将子向量映射到中心点下标。这样, 假设子向量长度为2,则n*8的float类型的矩阵,被映射为n*4的byte矩阵,模型可以减小到原来的1/8.

void ProductQuantizer::train(int32_t n, const real* x) {std::vector<int32_t> perm(n, 0);std::iota(perm.begin(), perm.end(), 0);auto d = dsub_;auto np = std::min(n, max_points_);auto xslice = std::vector<real>(np * dsub_);// 划分为nsubq_个子向量,子向量长度为d. for (auto m = 0; m < nsubq_; m++) {std::shuffle(perm.begin(), perm.end(), rng);// 随机选取np行数据,每行长度为dfor (auto j = 0; j < np; j++) {memcpy(xslice.data() + j * d, x + perm[j] * dim_ + m * dsub_,                          d*sizeof(real));}// kmeans计算该子向量对应的中心点kmeans(xslice.data(), get_centroids(m, 0), np, d);}
}
void ProductQuantizer::kmeans(const real* x, real* c, int32_t n, int32_t d) {std::vector<int32_t> perm(n, 0);std::iota(perm.begin(), perm.end(), 0);std::shuffle(perm.begin(), perm.end(), rng);// 随机初始化中心点for (auto i = 0; i < ksub_; i++) {memcpy(&c[i * d], x + perm[i] * d, d * sizeof(real));}auto codes = std::vector<uint8_t>(n);// kmeans标准算法,具体可参考另一篇介绍kmeans的文章for (auto i = 0; i < niter_; i++) {Estep(x, c, codes.data(), d, n);MStep(x, c, codes.data(), d, n);}
}
// MStep中有一部分与Kmeans算法中不太一样的部分
// 对于中心点计数为0的部分做了修正,对中心点数值进行了调整
std::uniform_real_distribution<> runiform(0, 1);for (auto k = 0; k < ksub_; k++) {if (nelts[k] == 0) {int32_t m = 0;while (runiform(rng) * (n - ksub_) >= nelts[m] - 1) {m = (m + 1) % ksub_;}memcpy(centroids + k * d, centroids + m * d, sizeof(real) * d);for (auto j = 0; j < d; j++) {int32_t sign = (j % 2) * 2 - 1;centroids[k * d + j] += sign * eps_;centroids[m * d + j] -= sign * eps_;}nelts[k] = nelts[m] / 2;nelts[m] -= nelts[k];}}

源码中默认子矩阵长度是2,kmeans簇大小为256(不超过一个byte),默认会压缩至1/8大小。

总结

原本以为会写比较多内容,因为这部分代码确实花了点时间去看,尤其权重量化中kmeans算法前面那部分(子矩阵划分),不太明白,但是看懂代码逻辑之后,又回头看了文档,恍然大悟,一句话就能把逻辑说的很清楚。但是并没办法绕过看代码,因为正是看文档看的不明白才去看代码的。。。-_-||

所以,最终整篇文章变成了代码注释,以防细节部分忘掉。。。

话说特征选择,个人觉得对于大部分场景,压缩比非常可观。但是也需要看到,DenseMatrix矩阵起步就是200万行,所以对于小数据集,fasttext也会训出比较大的模型,这个是不足的一个方面。

附录

  1. fasttext language identification
  2. fast源码
  3. K-means学习总结
  4. fasttext源码学习(1)–dictionary

fasttext源码学习(2)--模型压缩相关推荐

  1. Opencascade源码学习之模型算法_TKO模块文件介绍

    Opencascade源码学习之模型数据_TKO模块文件介绍 1.TKO 1.BOPAlgo 2.BOPDS 3.BOPTools 4.BRepAlgoAPI 5.IntTools 1.TKO 1.B ...

  2. Opencascade源码学习之模型数据

    Opencascade源码学习之模型数据 1.模型数据 2.几何工具 1.插值和拟合 1.分析一组点 2.基本插值和近似 3.2D 插值 4.3D 插值 5.2D 拟合 6.3D 拟合 7.曲面拟合 ...

  3. Opencascade源码学习之模型数据——TKGeomBase模块文件介绍

    Opencascade源码学习之模型数据--TKGeomBase模块文件介绍 1.AdvApp2Var 2.AppCont 3.AppDef 4.AppParCurves 5.Approx 6.Bnd ...

  4. libevent源码学习-----Reactor模型

    libevent内部采用了reactor模型 所谓reactor模型,其实就是一套事件注册机制,用来解决单线程的阻塞问题.reactor核心思想是将事件和相应事件发生时想要调用的函数都记录下来,在事件 ...

  5. ERNIE源码学习与实践:为超越ChatGPT打下技术基础!

    ★★★ 本文源自AlStudio社区精品项目,[点击此处]查看更多精品内容 >>> ERNIE学习与实践:为超越ChatGPT打下技术基础! ERNIE是BERT相爱相杀的好基友,由 ...

  6. 文心ERNIE源码学习与实践:为超越ChatGPT打下技术基础!

    ERNIE学习与实践:为超越ChatGPT打下技术基础! ERNIE是BERT相爱相杀的好基友,由ERNIE发展起来的文心大模型,是GPT3.0的强劲竞争对手,未来还会挑战ChatGPT的江湖地位! ...

  7. DotText源码学习——ASP.NET的工作机制

    --本文是<项目驱动学习--DotText源码学习>系列的第一篇文章,在这之后会持续发表相关的文章. 概论 在阅读DotText源码之前,让我们首先了解一下ASP.NET的工作机制,可以使 ...

  8. 我的angularjs源码学习之旅2——依赖注入

    依赖注入起源于实现控制反转的典型框架Spring框架,用来削减计算机程序的耦合问题.简单来说,在定义方法的时候,方法所依赖的对象就被隐性的注入到该方法中,在方法中可以直接使用,而不需要在执行该函数的时 ...

  9. ASP.NET Core MVC 源码学习:MVC 启动流程详解

    前言 在 上一篇 文章中,我们学习了 ASP.NET Core MVC 的路由模块,那么在本篇文章中,主要是对 ASP.NET Core MVC 启动流程的一个学习. ASP.NET Core 是新一 ...

最新文章

  1. python学不会的表情包-python这么简单 为何这么多人学不会
  2. c++ 拷贝构造函数_禁止拷贝构造,禁止bug
  3. 【完整代码】Scala akka入门示例
  4. Arduino 与 SPI 结合使用 以及SPI 深层理解
  5. .iml文件_jetbrains误删maven 项目.iml文件后的处理方法
  6. 深入理解Magento – 第三章 – 布局,块和模板
  7. IOS开发UI控件UIScrollView和Delegate的使用
  8. 字节跳动副总裁喊话腾讯:停止无理由封杀飞书;Git服务器配置错误导致日产汽车源码泄露;Linux5.10.5 发布
  9. 物权法全文内容有哪些呢-广告外链_SEO优化的站外优化工作有哪些?
  10. 网易开源云原生日志系统!
  11. Android权限大全
  12. 自定义APPLEALC驱动APPLEHDA之整理codec
  13. oracle的权限授予,oracle权限命令
  14. TIC TAC TOE 井字游戏
  15. NOIP2015酱油记
  16. python爬取歌词生成词云图
  17. 海康设备通过SDK获取和设置设备网络参数
  18. 积木开发系列----Blockly初体验
  19. Java 中 Boolean 和 boolean的默认值和修改默认值
  20. 百度地图行政区优化卡顿问题

热门文章

  1. 苹果手机之间怎么传照片_相机与手机之间传送RAW格式照片问题!
  2. vue-pdf预览乱码问题、打印乱码多一页空白问题
  3. 难道是“写时拷贝”?
  4. HTML5CSS3前端入门教程---从0开始通过一个商城实例手把手教你学习PC端和移动端页面开发第10章有路网PC端主页实战整合
  5. Fzu 2202 犯罪嫌疑人【逻辑推理思维题】好题!
  6. 中南c语言作业答案,中南民族大学10套计算机C语言期末考试复习试题及答案.doc...
  7. Camtasia Studio录制屏幕字迹不清晰的原因
  8. 教师和计算机平面设计图谁更好一些,计算机平面设计专业教学浅析
  9. 文本域多行文本回显换行问题
  10. Spring-尚硅谷-学习笔记