cpp的例子

device_malloc

  • cpp没有用具体数值初始化 float *d_from_tensor = NULL;device_malloc(&d_from_tensor, batch_size * seq_len * hidden_dim);
  • https://github1s.com/NVIDIA/FasterTransformer/blob/v1.0/sample/cpp/transformer_fp32.cc#L35-L38 直接用的cudaMalloc
void device_malloc(float** ptr, int size) // cudaMalloc函数为什么是二级指针的解释https://blog.csdn.net/CaiYuxingzzz/article/details/121112273
{cudaMalloc((void**)ptr, sizeof(float) * size);
}

allocator

  • allocator用于分配attr_out_buf_
https://github1s.com/NVIDIA/FasterTransformer/blob/v1.0/fastertransformer/bert_encoder_transformer.h#L131-L135
buf_ = reinterpret_cast<DataType_*>(allocator_.malloc(sizeof(DataType_) * buf_size * 6));
  • 然后将这些参数和encoder_param打包成multi_head_init_param
    在初始化(encoder_transformer_->initialize)时传给attention_->initialize(multi_head_init_param);
    attention_->initialize则只需将传入的参数初始化给attention对象的参数,等forward时调用自己的参数
接口包含两个方法malloc,free
class IAllocator{public:virtual void* malloc(size_t size) const = 0;virtual void free(void* ptr) const = 0;
};
//AllocatorTypeyouenum class AllocatorType{CUDA, TF}; 用的应该是CUDA的
template<>
class Allocator<AllocatorType::CUDA> : public IAllocator{const int device_id_;public:Allocator(int device_id): device_id_(device_id){}void* malloc(size_t size) const {void* ptr = nullptr;int o_device = 0;check_cuda_error(get_set_device(device_id_, &o_device));check_cuda_error(cudaMalloc(&ptr, size));check_cuda_error(get_set_device(o_device));return ptr;}void free(void* ptr) const {int o_device = 0;check_cuda_error(get_set_device(device_id_, &o_device));check_cuda_error(cudaFree(ptr));check_cuda_error(get_set_device(o_device));return;}
};
fastertransformer::Allocator<AllocatorType::CUDA> allocator(0); // 0是device_id_

encoder_param

  • EncoderInitParam encoder_param; //init param here 包含参数的结构体,成员记录了GPU数据的地址

initialize

  BertEncoderTransformer<EncoderTraits_> *encoder_transformer_ = new BertEncoderTransformer<EncoderTraits_>(allocator, batch_size, from_seq_len, to_seq_len, head_num, size_per_head);encoder_transformer_->initialize(encoder_param);

trt_plugin的例子

将数值放入vector

  • https://github1s.com/NVIDIA/FasterTransformer/blob/v1.0/sample/tensorRT/transformer_trt.cc#L108-L136
  • 先分配地址
    host_malloc(&h_attr_kernel_Q, hidden_dim * hidden_dim);
  • 然后进行赋值
    h_attr_kernel_Q[i] = 0.001f;
   std::vector<T* > layer_param;layer_param.push_back(h_attr_kernel_Q);将值打包params.push_back(layer_param);}cudaStream_t stream;cudaStreamCreate(&stream);TRT_Transformer<T>* trt_transformer = new TRT_Transformer<T>(batch_size, seq_len, head_num, hidden_dim, layers);trt_transformer->build_engine(params);trt_transformer->do_inference(batch_size, h_from_tensor, h_attr_mask, h_transformer_out, stream);delete trt_transformer;
  • 构建TRT_Transformer时会调用算子插件,权重在void build_engine(std::vector<std::vector<T* > > &weights)时传入
    https://github1s.com/NVIDIA/FasterTransformer/blob/v1.0/fastertransformer/trt_plugin/trt_model.h#L75-L77
auto plugin = new TransformerPlugin<T>(hidden_dim_, head_num_, seq_len_, batch_size_, point2weight(weights[i][0], hidden_dim_ * hidden_dim_),
  • 创建TransformerPlugin实例时会传入权重
TransformerPlugin(int hidden_dim, int head_num, int seq_len, int max_batch_size,const nvinfer1::Weights &w_attr_kernel_Q,...
  • 这里就是和cpp例子的不同了,其使用权重w_attr_kernel_Q
  • https://github1s.com/NVIDIA/FasterTransformer/blob/v1.0/fastertransformer/trt_plugin/bert_transformer_plugin.h#L103
cudaMallocAndCopy(d_attr_kernel_Q_, w_attr_kernel_Q, hidden_dim * hidden_dim);
  • cudaMallocAndCopy定义在https://github1s.com/NVIDIA/FasterTransformer/blob/v1.0/fastertransformer/trt_plugin/bert_transformer_plugin.h#L338-L352
    static void cudaMallocAndCopy(T *&dpWeight, const nvinfer1::Weights &w, int nValue) {assert(w.count == nValue);check_cuda_error(cudaMalloc(&dpWeight, nValue * sizeof(T)));check_cuda_error(cudaMemcpy(dpWeight, w.values, nValue * sizeof(T), cudaMemcpyHostToDevice));T* data = (T*)malloc(sizeof(T) * nValue);cudaMemcpy(data, dpWeight, sizeof(T) * nValue, cudaMemcpyDeviceToHost);}static void cudaMallocAndCopy(T*&dpWeight, const T *&dpWeightOld, int nValue) {check_cuda_error(cudaMalloc(&dpWeight, nValue * sizeof(T)));check_cuda_error(cudaMemcpy(dpWeight, dpWeightOld, nValue * sizeof(T), cudaMemcpyDeviceToDevice));}

cg

  • https://github.com/NVIDIA/TensorRT/blob/release/8.5/demo/Diffusion/models.py

FasterTransformer 005 初始化:如何将参数传给模型?相关推荐

  1. 通过BeanShell获取UUID并将参数传递给Jmeter

    有些HTTPS请求报文的报文体中包含由客户端生成的UUID,在用Jmeter做接口自动化测试的时候,因为越过了客户端,直接向服务器端发送报文,所以,需要在Jmeter中通过beanshell获取UUI ...

  2. 数组作为函数的参数传参时,数组名会退化为指针

    1.数组作为函数的参数传参时,数组名会退化为指针 数组作为函数的参数传参时,数组名会退化为指针,数值传参时,需要把数值的长度一起传过去,另外,sizeof()运算符包含字符串的哨兵'/0',而strl ...

  3. java+hadoop配置参数_将Hadoop参数传递给Java代码

    我有一个Uber jar执行一些级联ETL任务. jar的执行方式如下: hadoop jar munge-data.jar 我希望在作业启动时将参数传递给jar,例如 hadoop jar mung ...

  4. 如何将命令行参数传递给Node.js程序?

    我有一个用Node.js编写的Web服务器,我想使用一个特定的文件夹启动. 我不确定如何在JavaScript中访问参数. 我正在像这样运行节点: $ node server.js folder 这是 ...

  5. GoJS超详细入门(插件使用无非:引包、初始化、配参数(json)、引数据(json)四步)...

    GoJS超详细入门(插件使用无非:引包.初始化.配参数(json).引数据(json)四步) 一.总结 一句话总结:插件使用无非:引包.初始化.配参数(json).引数据(json)四步. 1.goj ...

  6. DL之DNN优化技术:采用三种激活函数(sigmoid、relu、tanh)构建5层神经网络,权重初始值(He参数初始化和Xavier参数初始化)影响隐藏层的激活值分布的直方图可视化

    DL之DNN优化技术:采用三种激活函数(sigmoid.relu.tanh)构建5层神经网络,权重初始值(He参数初始化和Xavier参数初始化)影响隐藏层的激活值分布的直方图可视化 目录

  7. DL之DNN优化技术:自定义MultiLayerNet【5*100+ReLU】对MNIST数据集训练进而比较三种权重初始值(Xavier参数初始化、He参数初始化)性能差异

    DL之DNN优化技术:自定义MultiLayerNet[5*100+ReLU]对MNIST数据集训练进而比较三种权重初始值(Xavier参数初始化.He参数初始化)性能差异 导读 #思路:观察不同的权 ...

  8. DL之DNN优化技术:DNN中参数初始化【Lecun参数初始化、He参数初始化和Xavier参数初始化】的简介、使用方法详细攻略

    DL之DNN优化技术:DNN中参数初始化[Lecun参数初始化.He参数初始化和Xavier参数初始化]的简介.使用方法详细攻略 导读:现在有很多学者认为,随着BN层的提出,权重初始化可能已不再那么紧 ...

  9. python get请求 url传参_用Python-get方法向页面发起请求,参数传不进去是怎么回事...

    源自:4-1 接口测试工具-python-get接口实战 用Python-get方法向页面发起请求,参数传不进去是怎么回事 #-*-coding:utf-8-*- import urllib impo ...

最新文章

  1. Basic Level 1023. 组个最小数 (20)
  2. 修改pip的源repository
  3. 开源的pop3和smtp组件(支持中文及SSL)
  4. inside-the-linux-kernel-full
  5. linux下u盘病毒msdos,浅谈U盘病毒——MS-DOS.com 以及做最便民的杀毒软件
  6. android get width单位是什么意思,浅析Android中getWidth()和getMeasuredWidth()的区别
  7. java访问hdfs_HDFS的java访问接口
  8. 编码人员和美工的配合问题
  9. arcgis 合并名字相同的要素_【转】ArcGIS中各种合并要素(Union、Merge、Append、Dissolve)的异同点分析...
  10. uos系统安装教程_统一操作系统UOS下载&安装图文教程:尝鲜记(一)
  11. SaltStack之数据系统
  12. Lightweight OpenPose
  13. 触发器referencing old as old new as new
  14. CryEngine 动态添加模型
  15. linux---finger命令
  16. voicewo在线语音识别转换jQuery插件
  17. 概率论————思维导图(上岸必备)(多维随机变量及其分布)
  18. Ubuntu 16.04 LTS安装sogou输入法详解
  19. 云端身份证识别OCr
  20. 管家婆云辉煌快速实现远程云打印

热门文章

  1. IM即时通讯开发数据库用NoSQL还是SQL?
  2. 性能为王,科视 1DLP 激光投影机以“质”服人
  3. 【音视频基础】(十三):YUV颜色空间之YUV和YCbCr
  4. 单片机应用系统设计技术——可预设电压的数控电源
  5. 漏洞复现|(CVE-2019-3396)Confluence文件读取远程命令执行
  6. Hbase学习笔记(一)
  7. SpringBoot2.1.4整合log4j2保存日志到MySQL中
  8. php+yii框架,【Yii框架 1 】PHP框架,Yii概述
  9. [英语阅读]乌克兰一村庄拟改名为“杰克逊”
  10. 为什么会出来山寨版春晚呢|春晚的缺点