深度学习TensorFlow学习-自定义网络
上一篇文章讲解如何使用tf.keras 快速搭建网络,这篇讲解自定义一个神经网络结构。
使用 Sequential 可以快速搭建网络结构,但是如果网络包含跳连等其他复杂网络结构, Sequential 就无法表示了。 这就需要使用 class 来声明网络结构。。大部分代码一样,只是Sequential改为自定义的class。
使用 class 类封装网络结构,如下所示是一个 class 模板, MyModel 表示声明的神经网络的名字,括号中的 Model 表示创建的类需要继承 tensorflow 库中的 Model 类。 类中需要定义两个函数, init()函数为类的构造函数用于初始化类的参数, spuer(MyModel,self).init()这行表示初始化父类的参数。 之后便可初始化网络结构,搭建出神经网络所需的各种网络结构块。 call()函数中调用__init__()函数中完成初始化的网络块,实现前向传播并返回推理值。
class IrisModel(Model):def __init__(self):super(IrisModel, self).__init__()self.d1 = Dense(3, activation='sigmoid', kernel_regularizer=tf.keras.regularizers.l2())def call(self, x):y = self.d1(x)return ymodel = IrisModel()
完整代码:
import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras import Model
from sklearn import datasets
import numpy as npx_train = datasets.load_iris().data
y_train = datasets.load_iris().targetnp.random.seed(116)
np.random.shuffle(x_train)
np.random.seed(116)
np.random.shuffle(y_train)
tf.random.set_seed(116)class IrisModel(Model):def __init__(self):super(IrisModel, self).__init__()self.d1 = Dense(3, activation='sigmoid', kernel_regularizer=tf.keras.regularizers.l2())def call(self, x):y = self.d1(x)return ymodel = IrisModel()model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1),loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)
model.summary()
深度学习TensorFlow学习-自定义网络相关推荐
- 深度学习---TensorFlow学习笔记:搭建CNN模型
转载自:http://jermmy.xyz/2017/02/16/2017-2-16-learn-tensorflow-build-cnn-model/ 最近跟着 Udacity 上的深度学习课程学了 ...
- 动手学深度学习(tensorflow)---学习笔记整理(一、预备知识篇)
学习视频来源为b站动手学深度学习系列视频:https://space.bilibili.com/209599371/channel/detail?cid=23541 由于上述视频为MXNet/Gluo ...
- 【深度学习】(7) 交叉验证、正则化,自定义网络案例:图片分类,附python完整代码
各位同学好,今天和大家分享一下TensorFlow2.0深度学习中的交叉验证法和正则化方法,最后展示一下自定义网络的小案例. 1. 交叉验证 交叉验证主要防止模型过于复杂而引起的过拟合,找到使模型泛化 ...
- 【深度学习】生成对抗网络(GAN)的tensorflow实现
[深度学习]生成对抗网络(GAN)的tensorflow实现 一.GAN原理 二.GAN的应用 三.GAN的tensorflow实现 参考资料 GAN( Generative Adversarial ...
- 【深度学习】使用tensorflow实现VGG19网络
转载注明出处:http://blog.csdn.net/accepthjp/article/details/70170217 接上一篇AlexNet,本文讲述使用tensorflow实现VGG19网络 ...
- 深度学习-Tensorflow2.2-预训练网络{7}-迁移学习基础针对小数据集-19
使用预训练网络(迁移学习) 预训练网络是一个保存好的之前已在大型数据集(大规模图像分类任务)上训练好的卷积神经网络 如果这个原始数据集足够大且足够通用,那么预训练网络学到的特征的空间层次结构可以作为有 ...
- 深度学习之生成对抗网络(8)WGAN-GP实战
深度学习之生成对抗网络(8)WGAN-GP实战 代码修改 完整代码 WGAN WGAN_train 代码修改 WGAN-GP模型可以在原来GAN代码实现的基础上仅做少量修改.WGAN-GP模型的判别 ...
- 深度学习框架tensorflow学习与应用——代码笔记11(未完成)
11-1 第十周作业-验证码识别(未完成) #!/usr/bin/env python # coding: utf-8# In[1]:import os import tensorflow as tf ...
- [转] 介绍深度学习和长期记忆网络
机器学习,深度学习 101 IBM Power Systems 入门 Beth Hoffman 和 Rupashree Bhattacharya 2017 年 7 月 04 日发布 WeiboGoog ...
最新文章
- StartSSL申请全过程 让网站拥有免费SSL证书
- 【 MATLAB 】filter 函数介绍(一维数字滤波器)
- BZOJ1975 [Sdoi2010]魔法猪学院 k短路
- mysql查询包含字符串的记录_MySQL查询字符串中包含字符的记录
- 用模板类实现shared_ptr和unique_ptr
- CentOS7默认安装PHP不支持mysql的办法
- 2021年最新C语言教程入门,C语言自学教程(最全整理)
- linux之vmlinux、vmlinuz、System.map和/proc/kallsyms简介
- 成功在fedora 13 上安装 了libfetion
- 2022-2028年中国长租公寓行业市场运行格局及发展策略分析报告
- Eclipse Mars2中Augular2开发环境的搭建过程记录
- 京东大图在服务器哪个文件夹,京东图片管理在哪里?怎么使用?
- 截止频率计算公式wc_计算截止频率Wc的快速方法
- 只需三步!使用3DCG软件Blender制作时尚图片
- SHELLEXECUTEINFO控制外部进程
- 图像形状及数量识别(matlab实现)
- 交换机连接控制器_干货丨FIT控制器与eMotion LV1的配置场景介绍
- 【独行秀才】macOS Big Sur 11.5.1 正式版(20G80)原版镜像
- 电子元器件的分类有哪些?
- 思博伦设备修改接口速率的三种方式