Keras自定义可训练参数是在自定义层中实现的,因此需要我们自己编写一个层来实现我们需要的功能。话不多说,直接上实例。

假设我们需要自定义一个可学习的权重矩阵来对某一层的数据进行转换,则可以通过下面代码实现:

from keras import backend as K
from keras.layers import Layerclass MyLayer(Layer):def __init__(self, output_dim, **kwargs):self.output_dim = output_dimsuper(MyLayer, self).__init__(**kwargs)def build(self, input_shape):# 创建一个可训练的权重变量矩阵self.kernel = self.add_weight(name='kernel',shape=(input_shape[1], self.output_dim),  # 假设输入tensor只有一个维度(不算batch的维度)initializer='uniform',trainable=True)      # 如果要定义可训练参数这里一定要选择Truesuper(MyLayer, self).build(input_shape)  # 这行代码一定要加上,super主要是调用MyLayer的父类(Layer)的build方法。def call(self, x):return K.dot(x, self.kernel)     # 该层要实现的功能def compute_output_shape(self, input_shape):return (input_shape[0], self.output_dim)   # 指定输出维度

其中,类MyLayer中各个方法的说明如下:

build(input_shape): 可学习变量在这里定义。

call(x): 你打算用这个层实现什么功能要在这里定义,参数x就是输入该层的tensor。一般是只需要考虑x就行。

compute_output_shape(input_shape): 如果该层修改了输入tensor的维度,则需要在这里指定要返回的tensor维度。

Keras自定义可训练参数相关推荐

  1. keras + tensorflow —— 训练参数数目的计算

    1. RNN 模型 Embedding Embedding(input_dim, output_dim,input_length) input_dim 表示字典的大小: outpu_dim 则表示嵌入 ...

  2. Keras自定义损失函数出现:ValueError: Unknown loss function: focal_loss

    Keras自定义损失函数出现:ValueError: Unknown loss function: focal_loss 1.软件环境 2.问题描述 3.解决方法 4.结果预览 1.软件环境 Wind ...

  3. keras自定义loss

    loss是model.compile编译时所需的参数之一,可以用损失函数名或者 TensorFlow 符号函数: #损失函数名 model.compile(loss='mean_squared_err ...

  4. Tensorflow |(5)模型保存与恢复、自定义命令行参数

    Tensorflow |(1)初识Tensorflow Tensorflow |(2)张量的阶和数据类型及张量操作 Tensorflow |(3)变量的的创建.初始化.保存和加载 Tensorflow ...

  5. 深度学习每层的通道数如何计算_深度学习基础系列(一)| 一文看懂用kersa构建模型的各层含义(掌握输出尺寸和可训练参数数量的计算方法)...

    我们在学习成熟网络模型时,如VGG.Inception.Resnet等,往往面临的第一个问题便是这些模型的各层参数是如何设置的呢?另外,我们如果要设计自己的网路模型时,又该如何设置各层参数呢?如果模型 ...

  6. keras 自定义评估函数和损失函数loss训练模型后加载模型出现ValueError: Unknown metric function:fbeta_score

    keras分类回归的损失函数与评价指标 目标函数 (1)mean_squared_error / mse 均方误差,常用的目标函数,公式为((y_pred-y_true)**2).mean() (2) ...

  7. pytorch Dataset, DataLoader产生自定义的训练数据

    pytorch Dataset, DataLoader产生自定义的训练数据 目录 pytorch Dataset, DataLoader产生自定义的训练数据 1. torch.utils.data.D ...

  8. YOLOv5自定义数据集训练

    YOLOv5自定义数据集训练 简介 本文介绍如何在自己的VOC格式数据集上训练YOLO5目标检测模型. VOC数据集格式 首先,先来了解一下Pascal VOC数据集的格式,该数据集油5个部分组成,文 ...

  9. 09_keras_Tuner使用keras Tuner调整超参数(超参数优化)

    """ 09_keras_Tuner使用keras Tuner调整超参数 """import tensorflow as tf from t ...

最新文章

  1. 关于:项“ConnectionString”已添加
  2. c语言搜索关键字吗,c语言-以关键字搜索程序
  3. 为什么要importmodulepython_python – 为什么“import”这样实现?
  4. 【原创】StreamInsight查询系列(十九)——查询模式之检测异常
  5. 常见的通配符_8、数据库常见操作
  6. Leetcode 124.二叉树中的最大路径和
  7. Xshell连接mysql数据库乱码问题解决思路总结
  8. opencv contourArea() 计算面积(转)
  9. 网络_简单实现远程唤醒与远程控制(Teamviewer)
  10. 图像传感器(智能相机技术)
  11. TeraTerm下载方法
  12. python处理xps文件_WFP: 读取XPS文件或将word、txt文件转化为XPS文件
  13. 数字孪生中的人工智能——技术现状、挑战和未来研究课题
  14. 金桔蓝牙LoRa主被动一体定位系统原理
  15. 都市调频广播 2009年节目广告运行表
  16. 平板游戏交互式设计的10大规则
  17. ctfshow终极考核(一键通关脚本)
  18. bom成本分析模型_如何计算一台汽车的BOM成本?
  19. 【仙女踩坑实录】Macbook修改文件创建时间
  20. TIA博途SCL编程学习16_歌德巴赫猜想验证

热门文章

  1. 服装行业ERP系统有哪些基本功能?
  2. 专业精神-希波克拉底的誓言(转载)
  3. android notification应用之自定义来电通知
  4. 山东不符合申报高新技术企业的条件
  5. 康蒂尼药业再次冲刺港股:9个月营收4.4亿 龙磐创投是股东
  6. 什么是集成测试?集成测试方法有哪些?
  7. 学生消费记录管理系统(C语言 结构体, 链表)
  8. 最后半天时间,支付宝等第三方支付机构备付金必须100%上交
  9. 音质好的蓝牙耳机有哪些?音质好的蓝牙耳机测评
  10. 全排列__正月点灯笼视频笔记