fasttext源码学习(2)--模型压缩
fasttext源码学习(2)–模型压缩
前言
fasttext模型压缩的很明显,精度却降低不多,其网站上提供的语种识别模型,压缩前后的对比就是例证,压缩前126M,压缩后917K。太震惊了,必须学习一下。看文档介绍用到权重量化(weight quantization)和特征选择(feature selection),下面结合代码学习下。
说明:文章中代码皆为简化版,为突出重点,简化了逻辑,原版代码需到官方网页下载。
一 特征选择
一开始以为fasttext会用到比较复杂的特征选择算法,直到看到代码才差点闪了腰。。。fasttext用的就是kbest,剩下的全砍掉,就是这么简单直接。
void FastText::quantize(const Args& qargs, const TrainCallback& callback) {if (qargs.cutoff > 0 && qargs.cutoff < input->size(0)) {auto idx = selectEmbeddings(qargs.cutoff);dict_->prune(idx); // 剪枝(词典重新计算)if (qargs.retrain) { // 重新训练startThreads(callback);}}
}std::vector<int32_t> FastText::selectEmbeddings(int32_t cutoff) const {std::shared_ptr<DenseMatrix> input =std::dynamic_pointer_cast<DenseMatrix>(input_);Vector norms(input->size(0));input->l2NormRow(norms); // [1] 正则化std::vector<int32_t> idx(input->size(0), 0);std::iota(idx.begin(), idx.end(), 0);std::sort(idx.begin(), idx.end(), [&norms, eosid](size_t i1, size_t i2) {return (eosid != i2 && norms[i1] > norms[i2]);}); // [2] 按正则化值排序idx.erase(idx.begin() + cutoff, idx.end()); // [3] 保留指定数目的id return idx;
}
从上面代码可以看出,selectEmbeddings主要干的就是按正则化值排序,然后只保留指定数目的行。默认保留是5万,结合上一章的DensMatrix的行数是200万+词语个数,可以看出这一步的特征选择最少能压缩到原来的1/40,压缩比很可观。
值得注意的是dict_->prune(idx),这一步必不可少,因为fasttext是基于hash映射来计算矩阵下标的,特征被筛选后,相应的词典也需要清理压缩,对应关系也需要刷新。
void Dictionary::prune(std::vector<int32_t>& idx) {std::vector<int32_t> words, ngrams;// 按id选取对应的word、ngramfor (auto it = idx.cbegin(); it != idx.cend(); ++it) {if (*it < nwords_) {words.push_back(*it);} else {ngrams.push_back(*it);}}// 按id排序std::sort(words.begin(), words.end());idx = words;// 计算被筛选的ngram的hash与id的对应关系// ngram的对应关系原本不需存储,筛选后,由于对应关系的变化,导致需要存储pruneidx_if (ngrams.size() != 0) {int32_t j = 0;for (const auto ngram : ngrams) {pruneidx_[ngram - nwords_] = j;j++;}idx.insert(idx.end(), ngrams.begin(), ngrams.end());}pruneidx_size_ = pruneidx_.size();int32_t j = 0;// 筛选过的word往前移,为后续的清除做准备// 同时,重新计算hash与id的对应关系for (int32_t i = 0; i < words_.size(); i++) {if (getType(i) == entry_type::label ||(j < words.size() && words[j] == i)) {words_[j] = words_[i];word2int_[find(words_[j].word)] = j;j++;}}nwords_ = words.size();size_ = nwords_ + nlabels_;// 移除多余的wordwords_.erase(words_.begin() + size_, words_.end());// 重新初始化各word的ngrams,重新计算ngram的下标initNgrams();
}
要理解这段代码,必须先理解fasttext的数据存储即Dictionary和DenseMatrix那一部分,否则会非常晕。
二 权重量化
量化一般是将大的数值表示变为小的数值表示,比如从float变为byte。而fasttext采用了另外一种方法,product quantization。简单来说,就是将向量分割为更小的子向量,再使用kmeans算法,将子向量映射到中心点下标。这样, 假设子向量长度为2,则n*8的float类型的矩阵,被映射为n*4的byte矩阵,模型可以减小到原来的1/8.
void ProductQuantizer::train(int32_t n, const real* x) {std::vector<int32_t> perm(n, 0);std::iota(perm.begin(), perm.end(), 0);auto d = dsub_;auto np = std::min(n, max_points_);auto xslice = std::vector<real>(np * dsub_);// 划分为nsubq_个子向量,子向量长度为d. for (auto m = 0; m < nsubq_; m++) {std::shuffle(perm.begin(), perm.end(), rng);// 随机选取np行数据,每行长度为dfor (auto j = 0; j < np; j++) {memcpy(xslice.data() + j * d, x + perm[j] * dim_ + m * dsub_, d*sizeof(real));}// kmeans计算该子向量对应的中心点kmeans(xslice.data(), get_centroids(m, 0), np, d);}
}
void ProductQuantizer::kmeans(const real* x, real* c, int32_t n, int32_t d) {std::vector<int32_t> perm(n, 0);std::iota(perm.begin(), perm.end(), 0);std::shuffle(perm.begin(), perm.end(), rng);// 随机初始化中心点for (auto i = 0; i < ksub_; i++) {memcpy(&c[i * d], x + perm[i] * d, d * sizeof(real));}auto codes = std::vector<uint8_t>(n);// kmeans标准算法,具体可参考另一篇介绍kmeans的文章for (auto i = 0; i < niter_; i++) {Estep(x, c, codes.data(), d, n);MStep(x, c, codes.data(), d, n);}
}
// MStep中有一部分与Kmeans算法中不太一样的部分
// 对于中心点计数为0的部分做了修正,对中心点数值进行了调整
std::uniform_real_distribution<> runiform(0, 1);for (auto k = 0; k < ksub_; k++) {if (nelts[k] == 0) {int32_t m = 0;while (runiform(rng) * (n - ksub_) >= nelts[m] - 1) {m = (m + 1) % ksub_;}memcpy(centroids + k * d, centroids + m * d, sizeof(real) * d);for (auto j = 0; j < d; j++) {int32_t sign = (j % 2) * 2 - 1;centroids[k * d + j] += sign * eps_;centroids[m * d + j] -= sign * eps_;}nelts[k] = nelts[m] / 2;nelts[m] -= nelts[k];}}
源码中默认子矩阵长度是2,kmeans簇大小为256(不超过一个byte),默认会压缩至1/8大小。
总结
原本以为会写比较多内容,因为这部分代码确实花了点时间去看,尤其权重量化中kmeans算法前面那部分(子矩阵划分),不太明白,但是看懂代码逻辑之后,又回头看了文档,恍然大悟,一句话就能把逻辑说的很清楚。但是并没办法绕过看代码,因为正是看文档看的不明白才去看代码的。。。-_-||
所以,最终整篇文章变成了代码注释,以防细节部分忘掉。。。
话说特征选择,个人觉得对于大部分场景,压缩比非常可观。但是也需要看到,DenseMatrix矩阵起步就是200万行,所以对于小数据集,fasttext也会训出比较大的模型,这个是不足的一个方面。
附录
- fasttext language identification
- fast源码
- K-means学习总结
- fasttext源码学习(1)–dictionary
fasttext源码学习(2)--模型压缩相关推荐
- Opencascade源码学习之模型算法_TKO模块文件介绍
Opencascade源码学习之模型数据_TKO模块文件介绍 1.TKO 1.BOPAlgo 2.BOPDS 3.BOPTools 4.BRepAlgoAPI 5.IntTools 1.TKO 1.B ...
- Opencascade源码学习之模型数据
Opencascade源码学习之模型数据 1.模型数据 2.几何工具 1.插值和拟合 1.分析一组点 2.基本插值和近似 3.2D 插值 4.3D 插值 5.2D 拟合 6.3D 拟合 7.曲面拟合 ...
- Opencascade源码学习之模型数据——TKGeomBase模块文件介绍
Opencascade源码学习之模型数据--TKGeomBase模块文件介绍 1.AdvApp2Var 2.AppCont 3.AppDef 4.AppParCurves 5.Approx 6.Bnd ...
- libevent源码学习-----Reactor模型
libevent内部采用了reactor模型 所谓reactor模型,其实就是一套事件注册机制,用来解决单线程的阻塞问题.reactor核心思想是将事件和相应事件发生时想要调用的函数都记录下来,在事件 ...
- ERNIE源码学习与实践:为超越ChatGPT打下技术基础!
★★★ 本文源自AlStudio社区精品项目,[点击此处]查看更多精品内容 >>> ERNIE学习与实践:为超越ChatGPT打下技术基础! ERNIE是BERT相爱相杀的好基友,由 ...
- 文心ERNIE源码学习与实践:为超越ChatGPT打下技术基础!
ERNIE学习与实践:为超越ChatGPT打下技术基础! ERNIE是BERT相爱相杀的好基友,由ERNIE发展起来的文心大模型,是GPT3.0的强劲竞争对手,未来还会挑战ChatGPT的江湖地位! ...
- DotText源码学习——ASP.NET的工作机制
--本文是<项目驱动学习--DotText源码学习>系列的第一篇文章,在这之后会持续发表相关的文章. 概论 在阅读DotText源码之前,让我们首先了解一下ASP.NET的工作机制,可以使 ...
- 我的angularjs源码学习之旅2——依赖注入
依赖注入起源于实现控制反转的典型框架Spring框架,用来削减计算机程序的耦合问题.简单来说,在定义方法的时候,方法所依赖的对象就被隐性的注入到该方法中,在方法中可以直接使用,而不需要在执行该函数的时 ...
- ASP.NET Core MVC 源码学习:MVC 启动流程详解
前言 在 上一篇 文章中,我们学习了 ASP.NET Core MVC 的路由模块,那么在本篇文章中,主要是对 ASP.NET Core MVC 启动流程的一个学习. ASP.NET Core 是新一 ...
最新文章
- python学不会的表情包-python这么简单 为何这么多人学不会
- c++ 拷贝构造函数_禁止拷贝构造,禁止bug
- 【完整代码】Scala akka入门示例
- Arduino 与 SPI 结合使用 以及SPI 深层理解
- .iml文件_jetbrains误删maven 项目.iml文件后的处理方法
- 深入理解Magento – 第三章 – 布局,块和模板
- IOS开发UI控件UIScrollView和Delegate的使用
- 字节跳动副总裁喊话腾讯:停止无理由封杀飞书;Git服务器配置错误导致日产汽车源码泄露;Linux5.10.5 发布
- 物权法全文内容有哪些呢-广告外链_SEO优化的站外优化工作有哪些?
- 网易开源云原生日志系统!
- Android权限大全
- 自定义APPLEALC驱动APPLEHDA之整理codec
- oracle的权限授予,oracle权限命令
- TIC TAC TOE 井字游戏
- NOIP2015酱油记
- python爬取歌词生成词云图
- 海康设备通过SDK获取和设置设备网络参数
- 积木开发系列----Blockly初体验
- Java 中 Boolean 和 boolean的默认值和修改默认值
- 百度地图行政区优化卡顿问题
热门文章
- 苹果手机之间怎么传照片_相机与手机之间传送RAW格式照片问题!
- vue-pdf预览乱码问题、打印乱码多一页空白问题
- 难道是“写时拷贝”?
- HTML5CSS3前端入门教程---从0开始通过一个商城实例手把手教你学习PC端和移动端页面开发第10章有路网PC端主页实战整合
- Fzu 2202 犯罪嫌疑人【逻辑推理思维题】好题!
- 中南c语言作业答案,中南民族大学10套计算机C语言期末考试复习试题及答案.doc...
- Camtasia Studio录制屏幕字迹不清晰的原因
- 教师和计算机平面设计图谁更好一些,计算机平面设计专业教学浅析
- 文本域多行文本回显换行问题
- Spring-尚硅谷-学习笔记