自定义层函数需要继承layers.Layer,自定义网络需要继承keras.Model。
其内部需要定义两个函数:
1、__init__初始化函数,内部需要定义构造形式;
2、call函数,内部需要定义计算形式及返回值。

#self def layer
class MyDense(layers.Layer):#inherit layers.Layerdef __init__(self,input_dim,output_dim):#initsuper(MyDense,self).__init__()self.kernal = self.add_variable('w',[input_dim,output_dim])self.bias = self.add_variable('b',[output_dim])def call(self,inputs,training=None):#computeout = inputs @ self.kernal + self.biasreturn out
#self def network
class MyModel(keras.Model):#inherit keras.Modeldef __init__(self):#initsuper(MyModel,self).__init__()self.fc1 = MyDense(input_dim=28*28,output_dim=512)self.fc2 = MyDense(input_dim=512, output_dim=256)self.fc3 = MyDense(input_dim=256, output_dim=128)self.fc4 = MyDense(input_dim=128, output_dim=64)self.fc5 = MyDense(input_dim=64, output_dim=32)self.fc6 = MyDense(input_dim=32, output_dim=10)def call(self,inputs,training=None):#compute inputs.shape = [b,28*28]x = self.fc1(inputs)x = tf.nn.relu(x)x = self.fc2(x)x = tf.nn.relu(x)x = self.fc3(x)x = tf.nn.relu(x)x = self.fc4(x)x = tf.nn.relu(x)x = self.fc5(x)x = tf.nn.relu(x)x = self.fc6(x)return x

自定义的层和网络在使用上与正常一样,并无任何区别。

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers,Sequential,optimizers,datasets,metricsdef preprocess(x,y):x = tf.cast(tf.reshape(x,[-1]),dtype=tf.float32)/255.y = tf.cast(tf.one_hot(y,depth=10),dtype=tf.int32)return x,y#load_data
(x_train,y_train),(x_val,y_val) = datasets.mnist.load_data()
print('data: ',x_train.shape,y_train.shape,x_val.shape,y_val.shape)db = tf.data.Dataset.from_tensor_slices((x_train,y_train))
db = db.map(preprocess).shuffle(60000).batch(128)
db_val = tf.data.Dataset.from_tensor_slices((x_val,y_val))
db_val = db_val.map(preprocess).batch(128)#self def layer
class MyDense(layers.Layer):#inherit layers.Layerdef __init__(self,input_dim,output_dim):#initsuper(MyDense,self).__init__()self.kernal = self.add_variable('w',[input_dim,output_dim])self.bias = self.add_variable('b',[output_dim])def call(self,inputs,training=None):#computeout = inputs @ self.kernal + self.biasreturn out#self def network
class MyModel(keras.Model):#inherit keras.Modeldef __init__(self):#initsuper(MyModel,self).__init__()self.fc1 = MyDense(input_dim=28*28,output_dim=512)self.fc2 = MyDense(input_dim=512, output_dim=256)self.fc3 = MyDense(input_dim=256, output_dim=128)self.fc4 = MyDense(input_dim=128, output_dim=64)self.fc5 = MyDense(input_dim=64, output_dim=32)self.fc6 = MyDense(input_dim=32, output_dim=10)def call(self,inputs,training=None):#compute inputs.shape = [b,28*28]x = self.fc1(inputs)x = tf.nn.relu(x)x = self.fc2(x)x = tf.nn.relu(x)x = self.fc3(x)x = tf.nn.relu(x)x = self.fc4(x)x = tf.nn.relu(x)x = self.fc5(x)x = tf.nn.relu(x)x = self.fc6(x)return xnetwork = MyModel()
network.build(input_shape=[None,28*28])
network.summary()#build network
network = Sequential([layers.Dense(512,activation=tf.nn.relu),layers.Dense(256,activation=tf.nn.relu),layers.Dense(128,activation=tf.nn.relu),layers.Dense(64,activation=tf.nn.relu),layers.Dense(32,activation=tf.nn.relu),layers.Dense(10)
])
network.build(input_shape=[None,28*28])
network.summary()#input para
network.compile(optimizer=optimizers.Adam(lr=1e-2),loss = tf.losses.CategoricalCrossentropy(from_logits=True),metrics = ['accuracy'])#run network
network.fit(db,epochs=20,validation_data=db_val,validation_freq=1)

TensorFlow2.0:自定义层与自定义网络相关推荐

  1. Tensorflow2.0:使用Keras自定义网络实战

    tensorflow2.0建议使用tf.keras作为构建神经网络的高级API 接下来我就使用tensorflow实现VGG16去训练数据 背景介绍: 2012年 AlexNet 在 ImageNet ...

  2. 〖TensorFlow2.0笔记21〗自定义数据集(宝可精灵数据集)实现图像分类+补充:tf.where!

    自定义数据集(宝可精灵数据集)实现图像分类+补充:tf.where! 文章目录 一. 数据集介绍以及加载 1.1. 数据集简单描述 1.2. 程序实现步骤 1.3. 加载数据的格式 1.4. map函 ...

  3. Tensorflow2.0 之深度残差收缩网络 (DRSN)

    文章目录 DRSN 原理 残差网络 自注意力网络 软阈值化 代码实现 DRSN 原理 DRSN 由三部分组成:残差网络.自注意力网络和软阈值化. 残差网络 残差网络(或称深度残差网络.深度残差学习,英 ...

  4. tensorRT教程——tensor RT OP理解(实现自定义层,搭建网络)

    首先如果你的自定义操作可以通过一些矩阵操作来实现,那么你大可不必自己去通过plug in的方式实现,可以使用tensor RT 的OP来组合实现. 他的OP极其类似tensor flow的操作,如果看 ...

  5. 基于tensorflow2.0实现猫狗大战(搭建网络迁移学习)

    猫狗大战是kaggle平台上的一个比赛,用于实现猫和狗的二分类问题.最近在学卷积神经网络,所以自己动手搭建了几层网络进行训练,然后利用迁移学习把别人训练好的模型直接应用于猫狗分类这个数据集,比较一下实 ...

  6. Tensorflow2.0入门教程22:RNN网络实现文本分类

    RNN实现文本分类 import tensorflow as tf 下载数据集 imdb=tf.keras.datasets.imdb (train_x, train_y), (test_x, tes ...

  7. 【深度学习】(6) tensorflow2.0使用keras高层API

    各位同学好,今天和大家分享一下TensorFlow2.0深度学习中借助keras的接口减少神经网络代码量.主要内容有: 1. metrics指标:2. compile 模型配置:3. fit 模型训练 ...

  8. 【小白学PyTorch】扩展之Tensorflow2.0 | 21 Keras的API详解(下)池化、Normalization

    <<小白学PyTorch>> 扩展之Tensorflow2.0 | 21 Keras的API详解(上)卷积.激活.初始化.正则 扩展之Tensorflow2.0 | 20 TF ...

  9. 【小白学PyTorch】扩展之Tensorflow2.0 | 21 Keras的API详解(上)卷积、激活、初始化、正则...

    [机器学习炼丹术]的学习笔记分享 <<小白学PyTorch>> 扩展之Tensorflow2.0 | 20 TF2的eager模式与求导 扩展之Tensorflow2.0 | ...

最新文章

  1. 【VMCloud云平台】私有云门户第一朵Web云(一)
  2. php curl 404,PHP使用curl判断404(网页是否存在)方法
  3. iOS开发笔记 -- 推送证书的创建及合并
  4. 泛型委托 Predicate/Func/Action
  5. python变量类型是动态的_Python 学习 第四篇:动态类型模型
  6. vs2017c语言单元测试,vs2017单元测试没反应,检测出错误,有关详细信息,请查看“测试输出”窗口...
  7. windows下github 出现Permission denied (publickey)
  8. Flask练手项目之通讯录
  9. goaheadlinux移植_goahead
  10. Linux中变量$#,$@,$0,$1,$2,$*,$$,$?的含义
  11. eclipse配置java环境_java环境搭建(Eclipse)
  12. LCD12864(ST7565P)字符汉字显示(STM32F103)
  13. 酷家乐面试经历(图形引擎渲染工程师)
  14. 微信公众号授权登录(应用免登陆)
  15. 2020年10月计算机语言排名,最新!2020年10月编程语言排行榜出炉
  16. 计算机英语单词怎么巧背,怎么快速背记英语单词
  17. 安卓如何调出软键盘_Android软键盘显示模式及打开和关闭方式(推荐)
  18. linux连接校园网wifi,Linux/Ubuntu 16.04 使用校园网客户端Dr.com DrClient 有线连网,同时开启WiFi热点...
  19. 形式语言与自动机 Part.6 图灵机
  20. 没有网络电脑计算机还能用吗,电脑连不上公用网络怎么办

热门文章

  1. python2 与python3 区别的总结 持续更新中......
  2. 第18章 Redis数据结构常用命令
  3. 2017.12.1T19_B2_1zuoye
  4. 201506110135陈若倩词法分析实验报告
  5. 【MVC 过滤器的应用】ASP.NET MVC 如何统计 Action 方法的执行时间
  6. GridView网格布局
  7. poj 1270 Following Orders
  8. 用 Python 分析上网记录,发现了很多不可思议的事
  9. 8.11 NOIP模拟测试17 入阵曲+将军令+星空
  10. JavaScript基础知识(三个判断、三个循环)