本例主要测试Tensorflow C++ API中的ops::BatchMatMul算子。
整体来说这个算子比较简单。但是难在官网没有例子。Tensorflow的单测也写得不到位。
话不多说,上代码。
代码结构如下,

conanfile.txt

 [requires]gtest/1.10.0glog/0.4.0protobuf/3.9.1eigen/3.4.0dataframe/1.20.0opencv/3.4.17boost/1.76.0abseil/20210324.0xtensor/0.23.10[generators]cmake

CMakeLists.txt

cmake_minimum_required(VERSION 3.3)project(test_math_ops)set(ENV{PKG_CONFIG_PATH} "$ENV{PKG_CONFIG_PATH}:/usr/local/lib/pkgconfig/")set(CMAKE_CXX_STANDARD 17)
add_definitions(-g)include(${CMAKE_BINARY_DIR}/conanbuildinfo.cmake)
conan_basic_setup()find_package(TensorflowCC REQUIRED)
find_package(PkgConfig REQUIRED)
pkg_search_module(PKG_PARQUET REQUIRED IMPORTED_TARGET parquet)
pkg_search_module(PKG_ARROW REQUIRED IMPORTED_TARGET arrow)
pkg_search_module(PKG_ARROW_COMPUTE REQUIRED IMPORTED_TARGET arrow-compute)
pkg_search_module(PKG_ARROW_CSV REQUIRED IMPORTED_TARGET arrow-csv)
pkg_search_module(PKG_ARROW_DATASET REQUIRED IMPORTED_TARGET arrow-dataset)
pkg_search_module(PKG_ARROW_FS REQUIRED IMPORTED_TARGET arrow-filesystem)
pkg_search_module(PKG_ARROW_JSON REQUIRED IMPORTED_TARGET arrow-json)set(ARROW_INCLUDE_DIRS ${PKG_PARQUET_INCLUDE_DIRS} ${PKG_ARROW_INCLUDE_DIRS} ${PKG_ARROW_COMPUTE_INCLUDE_DIRS} ${PKG_ARROW_CSV_INCLUDE_DIRS} ${PKG_ARROW_DATASET_INCLUDE_DIRS} ${PKG_ARROW_FS_INCLUDE_DIRS} ${PKG_ARROW_JSON_INCLUDE_DIRS})set(INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../include ${ARROW_INCLUDE_DIRS})set(ARROW_LIBS PkgConfig::PKG_PARQUET PkgConfig::PKG_ARROW PkgConfig::PKG_ARROW_COMPUTE PkgConfig::PKG_ARROW_CSV PkgConfig::PKG_ARROW_DATASET PkgConfig::PKG_ARROW_FS PkgConfig::PKG_ARROW_JSON)include_directories(${INCLUDE_DIRS})file( GLOB test_file_list ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) file( GLOB APP_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/tensor_testutil.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/queue_runner.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/coordinator.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/status.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/death_handler/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/df/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/arr_/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/img_util/impl/*.cpp)add_library(${PROJECT_NAME}_lib SHARED ${APP_SOURCES})
target_link_libraries(${PROJECT_NAME}_lib PUBLIC ${CONAN_LIBS} TensorflowCC::TensorflowCC ${ARROW_LIBS})foreach( test_file ${test_file_list} )file(RELATIVE_PATH filename ${CMAKE_CURRENT_SOURCE_DIR} ${test_file})string(REPLACE ".cpp" "" file ${filename})add_executable(${file}  ${test_file})target_link_libraries(${file} PUBLIC ${PROJECT_NAME}_lib)
endforeach( test_file ${test_file_list})

tf_math2_test.cpp

#include <string>
#include <vector>
#include <glog/logging.h>
#include "death_handler/death_handler.h"
#include "tf_/tensor_testutil.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/cc/training/coordinator.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/core/protobuf/queue_runner.pb.h"
#include "tensorflow/core/public/session.h"using namespace tensorflow;int main(int argc, char** argv) {FLAGS_log_dir = "./";FLAGS_alsologtostderr = true;// 日志级别 INFO, WARNING, ERROR, FATAL 的值分别为0、1、2、3FLAGS_minloglevel = 0;Debug::DeathHandler dh;google::InitGoogleLogging("./logs.log");::testing::InitGoogleTest(&argc, argv);int ret = RUN_ALL_TESTS();return ret;
}TEST(TfArthimaticTests, BatchMatMul) {// BatchMatMul  测试// Refers to: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-mat-mul// 2 * 1 * 2// 2 * 2 * 3// = // 2 * 1 * 3Scope root = Scope::NewRootScope();auto left_ = test::AsTensor<int>({1, 2, 3, 4}, {2, 1, 2});/*** @brief Left param* {{1, 2},*  {3, 4}}*/auto right_ = test::AsTensor<int>({1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8}, {2, 2, 3});/*** @brief Right param*  {{{1, 2, 3}, {4, 1, 2}},*   {{3, 4, 5}, {6, 7, 8}}}*//*** @brief Result* {{9, 4, 7},*  {33, 40, 47}}*/auto batch_op = ops::BatchMatMul(root, left_, right_);ClientSession session(root);std::vector<Tensor> outputs;session.Run({batch_op.output}, &outputs);test::PrintTensorValue<int>(std::cout, outputs[0]);test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({9, 4, 7, 33, 40, 47}, {2, 1, 3}));
}TEST(TfArthimaticTests, BatchMatMulAdjXY) {// BatchMatMul  测试// Refers to: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-mat-mul// 2 * 1 * 2// 2 * 2 * 3// = // 2 * 1 * 3Scope root = Scope::NewRootScope();auto left_ = test::AsTensor<int>({1, 2, 3, 4}, {2, 2, 1});/*** @brief Left param* {{{1}, *   {2}},*  {{3},*   {4}}}*/auto right_ = test::AsTensor<int>({1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8}, {2, 3, 2});/*** @brief Right param*  {{{1, 2}, *   {3, 4}, *   {1, 2}}, **   {{3, 4}, *   {5, 6}, *   {7, 8}}  * }*//*** @brief Result* {{5, 11, 5},*  {25, 39, 53}}*/auto attrs = ops::BatchMatMul::AdjX(true).AdjY(true);auto batch_op = ops::BatchMatMul(root, left_, right_, attrs);ClientSession session(root);std::vector<Tensor> outputs;session.Run({batch_op.output}, &outputs);test::PrintTensorValue<int>(std::cout, outputs[0]);test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({5, 11, 5, 25, 39, 53}, {2, 1, 3}));
}

程序输出如下,代表两个算子均测试通过。

Tensorflow C++使用ops::BatchMatMul实现特征批量乘法相关推荐

  1. 针对AttributeError: ‘tensorflow.python.framework.ops.EagerTensor‘ ....no attribute ‘reshape‘问题的解决办法。

    操作系统:Win10,编译工具:notebook,语言:python 在学习Mnist手写数据集的时候,遇到这种问题.使用Tensorflow2.2-gpu版本 plt.figure(figsize= ...

  2. ‘tensorflow.python.framework.ops.EagerTensor‘ object has no attribute ‘reshape‘

    'tensorflow.python.framework.ops.EagerTensor' object has no attribute 'reshape' 可以将其用numpy读取后再reshap ...

  3. excel批量处理php,excel批量乘法怎么操作

    excel批量乘法怎么操作? 1.如下图,C.D.E列批量乘以10: 2.找一个单元格,输入被乘数值: 3.选定被乘数所在单元格,点右键,点复制: 4.选定需要批量进行乘法数据所在区域: 5.点右键, ...

  4. tensorflow中的ops(或者说op)的理解

    转自:https://www.cnblogs.com/tsiangleo/p/6145112.html 本文是在阅读官方文档后的一些个人理解. 官方文档地址:https://www.tensorflo ...

  5. 基于Python3+Scapy的数据包流量特征批量分析工具

    基于Python3+Scapy的网络数据包批量分析工具 项目源码 适用范围以及使用说明 背景 环境准备及运行说明 常见协议分析识别 TCP协议识别 UDP协议识别 输出TXT文档信息 SSL NAME ...

  6. TensorFlow 特征列介绍

    文 / TensorFlow 团队 欢迎阅读介绍 TensorFlow 数据集和估算器系列的第 2 部分(第一部分戳这里).我们将在这篇文章中介绍特征列 (Feature Column) - 一种说明 ...

  7. TensorFlow数据读取方式:Dataset API,以及如何查看dataset:DatasetV1Adapter的方法

    TensorFlow数据读取方式:Dataset API Datasets:一种为TensorFlow 模型创建输入管道的新方式.把数组.元组.张量等转换成DatasetV1Adapter格式 Dat ...

  8. TensorFlow 1.x 深度学习秘籍:1~5

    原文:TensorFlow 1.x Deep Learning Cookbook 协议:CC BY-NC-SA 4.0 译者:飞龙 本文来自[ApacheCN 深度学习 译文集],采用译后编辑(MTP ...

  9. TensorFlow 强化学习:1~5

    原文:Reinforcement Learning With TensorFlow 协议:CC BY-NC-SA 4.0 译者:飞龙 本文来自[ApacheCN 深度学习 译文集],采用译后编辑(MT ...

最新文章

  1. Linux C编程之一:Linux下c语言的开发环境
  2. 进阶必备:素数筛法(欧拉,埃氏筛法)
  3. 温州大学《深度学习》课程课件(九、目标检测)
  4. PowerDesigner15官方正式版+注册补丁
  5. Spring Boot定时任务-Quartz基本使用
  6. 2017年12月英语四级翻译预测
  7. Finding Structure in Time论文解读
  8. 怎么找到项目中所有同名的类_26岁转行程序员的成长历程--Day03从内存层面理解类和对象...
  9. 连载四:Oracle升级文章大全(完结篇)
  10. git忽略文件或者文件夹
  11. 20145326蔡馨熤《网络对抗》——信息搜集与漏洞扫描
  12. Merge Two Sorted Lists Leetcode
  13. 小程序 ---- (获取手机号码)
  14. springboot的配置文件加载的顺序,以及在不同位置配置下,加载的顺序
  15. matlab随机欠采样,欠采样技术
  16. 关于iPhone 5的适配
  17. XPDL学习与分享 二 XPDL整体结构
  18. TensorFlow2.0教程-使用keras训练模型
  19. 禁用键盘快捷键_如何在Windows中使用键盘快捷键临时禁用键盘
  20. MCS51延时程序分析

热门文章

  1. 一碗阳春面的故事--你还记得吗?
  2. nlp算法工程师英语
  3. 关于 ChatGPT 必看的 10 篇论文
  4. 厦门故事(三):枫叶随风飘落,重重地摔在了地面上
  5. [AcWing] 2058. 笨拙的手指(C++实现)秦九韶算法
  6. 2873-36-1,Gancidin W,CYCLO(L-LEU-L-PRO),cyclo-L-Leu-L-Pro,环(L-脯氨酰-L-亮氨酰)
  7. 2020-11-10总结
  8. IPTV与DTV:竞争还是共存?
  9. buck降压斩波电路
  10. 可执行的移动端网站seo技术