上一篇文章讲解如何使用tf.keras 快速搭建网络,这篇讲解自定义一个神经网络结构。
使用 Sequential 可以快速搭建网络结构,但是如果网络包含跳连等其他复杂网络结构, Sequential 就无法表示了。 这就需要使用 class 来声明网络结构。。大部分代码一样,只是Sequential改为自定义的class。
使用 class 类封装网络结构,如下所示是一个 class 模板, MyModel 表示声明的神经网络的名字,括号中的 Model 表示创建的类需要继承 tensorflow 库中的 Model 类。 类中需要定义两个函数, init()函数为类的构造函数用于初始化类的参数, spuer(MyModel,self).init()这行表示初始化父类的参数。 之后便可初始化网络结构,搭建出神经网络所需的各种网络结构块。 call()函数中调用__init__()函数中完成初始化的网络块,实现前向传播并返回推理值。

class IrisModel(Model):def __init__(self):super(IrisModel, self).__init__()self.d1 = Dense(3, activation='sigmoid', kernel_regularizer=tf.keras.regularizers.l2())def call(self, x):y = self.d1(x)return ymodel = IrisModel()

完整代码:

import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras import Model
from sklearn import datasets
import numpy as npx_train = datasets.load_iris().data
y_train = datasets.load_iris().targetnp.random.seed(116)
np.random.shuffle(x_train)
np.random.seed(116)
np.random.shuffle(y_train)
tf.random.set_seed(116)class IrisModel(Model):def __init__(self):super(IrisModel, self).__init__()self.d1 = Dense(3, activation='sigmoid', kernel_regularizer=tf.keras.regularizers.l2())def call(self, x):y = self.d1(x)return ymodel = IrisModel()model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1),loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)
model.summary()

深度学习TensorFlow学习-自定义网络相关推荐

  1. 深度学习---TensorFlow学习笔记:搭建CNN模型

    转载自:http://jermmy.xyz/2017/02/16/2017-2-16-learn-tensorflow-build-cnn-model/ 最近跟着 Udacity 上的深度学习课程学了 ...

  2. 动手学深度学习(tensorflow)---学习笔记整理(一、预备知识篇)

    学习视频来源为b站动手学深度学习系列视频:https://space.bilibili.com/209599371/channel/detail?cid=23541 由于上述视频为MXNet/Gluo ...

  3. 【深度学习】(7) 交叉验证、正则化,自定义网络案例:图片分类,附python完整代码

    各位同学好,今天和大家分享一下TensorFlow2.0深度学习中的交叉验证法和正则化方法,最后展示一下自定义网络的小案例. 1. 交叉验证 交叉验证主要防止模型过于复杂而引起的过拟合,找到使模型泛化 ...

  4. 【深度学习】生成对抗网络(GAN)的tensorflow实现

    [深度学习]生成对抗网络(GAN)的tensorflow实现 一.GAN原理 二.GAN的应用 三.GAN的tensorflow实现 参考资料 GAN( Generative Adversarial ...

  5. 【深度学习】使用tensorflow实现VGG19网络

    转载注明出处:http://blog.csdn.net/accepthjp/article/details/70170217 接上一篇AlexNet,本文讲述使用tensorflow实现VGG19网络 ...

  6. 深度学习-Tensorflow2.2-预训练网络{7}-迁移学习基础针对小数据集-19

    使用预训练网络(迁移学习) 预训练网络是一个保存好的之前已在大型数据集(大规模图像分类任务)上训练好的卷积神经网络 如果这个原始数据集足够大且足够通用,那么预训练网络学到的特征的空间层次结构可以作为有 ...

  7. 深度学习之生成对抗网络(8)WGAN-GP实战

    深度学习之生成对抗网络(8)WGAN-GP实战 代码修改 完整代码 WGAN WGAN_train 代码修改  WGAN-GP模型可以在原来GAN代码实现的基础上仅做少量修改.WGAN-GP模型的判别 ...

  8. 深度学习框架tensorflow学习与应用——代码笔记11(未完成)

    11-1 第十周作业-验证码识别(未完成) #!/usr/bin/env python # coding: utf-8# In[1]:import os import tensorflow as tf ...

  9. [转] 介绍深度学习和长期记忆网络

    机器学习,深度学习 101 IBM Power Systems 入门 Beth Hoffman 和 Rupashree Bhattacharya 2017 年 7 月 04 日发布 WeiboGoog ...

最新文章

  1. StartSSL申请全过程 让网站拥有免费SSL证书
  2. 【 MATLAB 】filter 函数介绍(一维数字滤波器)
  3. BZOJ1975 [Sdoi2010]魔法猪学院 k短路
  4. mysql查询包含字符串的记录_MySQL查询字符串中包含字符的记录
  5. 用模板类实现shared_ptr和unique_ptr
  6. CentOS7默认安装PHP不支持mysql的办法
  7. 2021年最新C语言教程入门,C语言自学教程(最全整理)
  8. linux之vmlinux、vmlinuz、System.map和/proc/kallsyms简介
  9. 成功在fedora 13 上安装 了libfetion
  10. 2022-2028年中国长租公寓行业市场运行格局及发展策略分析报告
  11. Eclipse Mars2中Augular2开发环境的搭建过程记录
  12. 京东大图在服务器哪个文件夹,京东图片管理在哪里?怎么使用?
  13. 截止频率计算公式wc_计算截止频率Wc的快速方法
  14. 只需三步!使用3DCG软件Blender制作时尚图片
  15. SHELLEXECUTEINFO控制外部进程
  16. 图像形状及数量识别(matlab实现)
  17. 交换机连接控制器_干货丨FIT控制器与eMotion LV1的配置场景介绍
  18. 【独行秀才】macOS Big Sur 11.5.1 正式版(20G80)原版镜像
  19. 电子元器件的分类有哪些?
  20. 思博伦设备修改接口速率的三种方式

热门文章

  1. 微软云计算操作系统Windows Azure 平台——云+端全面攻略
  2. 第十一章——电子商务网站用户行为分析及服务推荐
  3. [引擎搭建记录] 分块/分簇延迟渲染
  4. CSS 文章段落样式
  5. 【办公-WORD】修改字母大小写的热键或快捷键
  6. 双11专栏 | EdgeRec:电商信息流的端上推荐系统
  7. 电脑监控软件应如何选择?
  8. 设计师常用的Windows软件合集
  9. 如何轻松地导出照片的EXIF信息
  10. php 图片压缩 保留exif,Android Bitmap小技巧 - 压缩时保留图片的Exif信息