tensorflow笔记系列文章均参考自中国大学Mooc上北京大学软件与微电子学院曹建老师的《Tensorflow笔记2》课程。曹建老师讲的非常棒,受益良多,强烈建议tensorflow初学者学习。

使用keras搭建神经网络通常有以下六步:

  • import 相关模块
  • train, test 说明训练集和测试集
  • model = tf.keras.models.Sequential 搭建网络结构,逐层描述每层网络
  • model.compile 配置训练方法,例如优化器、损失函数、评测指标
  • model.fit 执行训练过程
  • model.summary 打印出网络的结构和参数统计

Sequential()用法

Sequential()可以认为是一个容器,里面封装了一个神经网络的结构,在Sequential()要描述从输入层到输出层每一层的网络结构,语法如下:

model = tf.keras.models.Sequential([网络结构])

网络结构举例:

拉直层tf.keras.layers.Flatten(),这一层不含计算,只是将输入数据拉直变成一维数组

全连接层tf.kears.layers.Dense(神经元个数,activation="激活函数",kernel_regularizer=哪种正则化),其中激活函数可选:relu、softmax、sigmoid、tanh,正则化可选:tf.kears.regularizers.l1()、tf.keras.regularizers.l2()

卷积层tf.keras.layers.Conv2D(filters=卷积核个数,kernel_size=卷积核尺寸,strides=卷积步长,padding="valid"或者"same")

LSTM层:tf.keras.layers.LSTM()

compile()用法

compile指定网络的优化器、损失函数以及评测指标,语法如下:

model.compile(optimizer=优化器,loss=损失函数, metrics=["准确率"])

Optimizer可选

  • ‘sgd’ or tf.keras.optimizer.SGD(lr=学习率,momentum=动量参数)
  • ‘adagrad’ or tf.keras.optimizers.Adagrad(lr=学习率)
  • ‘adadelta’ or tf.kears.optimizers.Adadelta(lr=学习率)
  • ‘adam’ or tf.keras.optimizers.Adam(lr=学习率,beta_1=0.9,beta_2=0.999)

loss可选

  • ‘mse’ or tf.keras.losses.MeanSquaredError()
  • ‘sparse_categorical_crossentropy’ or tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),需要注意的是from_logits这个参数代表输出时是否符合概率分布,如果不符合概率分布,那么为true,否则为false

Metrics可选

  • ‘accuracy’ : y_和y都是数值,如y__=[1] y=[1]
  • ‘categorical_accuracy’ : y_和y都是独热码形式
  • ‘sparse_categorical_accuracy’ : y_是数值,y是独热码

fit()用法

fit()函数描述训练的的过程,语法如下:

model.fit(训练集的输入特征,训练集的标签,batch_size= , epochs= , validation_data=(测试集的输入特征,测试集的标签),validation_split=从训练集划分多少比例给测试集,validation_freq=多少次epoch测试一次)

validation_data和`validation_split二者选择其一即可

keras搭建神经网络实现鸢尾花数据集例子

# import相关模块
import tensorflow as tf
from sklearn import datasets
import numpy as np# 加载训练数据
x_train = datasets.load_iris().data
y_train = datasets.load_iris().targetnp.random.seed(116)
np.random.shuffle(x_train)
np.random.seed(116)
np.random.shuffle(y_train)
np.random.seed(116)# 搭载网络结构
model = tf.keras.models.Sequential([tf.keras.layers.Dense(3, activation="softmax", kernel_regularizer=tf.keras.regularizers.l2())
])# 配置训练方法
model.compile(optimizer='sgd',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])# 使用fit进行训练
model.fit(x_train, y_train, batch_size=32,epochs=500, validation_split=0.2,validation_freq=20)# 使用Summary函数打印网络的结构和参数统计
model.summary()

Tensorflow搭建神经网络八股及实现鸢尾花数据集分类相关推荐

  1. 一层神经网络实现鸢尾花数据集分类

    一层神经网络实现鸢尾花数据集分类 1.数据集介绍 2.程序实现 2.1 数据集导入 2.2 数据集乱序 2.3 数据集划分成永不相见的训练集和测试集 3.4 配成[输入特征,标签]对,每次喂入一小撮( ...

  2. (决策树,朴素贝叶斯,人工神经网络)实现鸢尾花数据集分类

    from sklearn.datasets import load_iris # 导入方法类iris = load_iris() #导入数据集iris iris_feature = iris.data ...

  3. 利用神经网络对鸢尾花数据集分类

    利用神经网络对鸢尾花数据集分类 详细实现代码请见:https://download.csdn.net/download/weixin_43521269/12578696 一.简介 一个人工神经元网络是 ...

  4. Python实现鸢尾花数据集分类问题——基于skearn的SVM(有详细注释的)

    Python实现鸢尾花数据集分类问题--基于skearn的SVM 代码如下: 1 # !/usr/bin/env python2 # encoding: utf-83 __author__ = 'Xi ...

  5. Python实现鸢尾花数据集分类问题——基于skearn的LogisticRegression

    Python实现鸢尾花数据集分类问题--基于skearn的LogisticRegression 一. 逻辑回归 逻辑回归(Logistic Regression)是用于处理因变量为分类变量的回归问题, ...

  6. 用逻辑回归实现鸢尾花数据集分类(1)

    鸢尾花数据集的分类问题指导 -- 对数几率回归(逻辑回归)问题研究 (1) 这一篇Notebook是应用对数几率回归(Logit Regression)对鸢尾花数据集进行品种分类的.首先会带大家探索一 ...

  7. 基于Adaboost实现鸢尾花数据集分类

    写在之前 提交内容分为两大部分: 一为Adaboost算法实现,代码在文件夹<算法实现>中,<提升方法笔记>为个人学习笔记. 二为基于Adaboost模型实现鸢尾花数据集分类, ...

  8. 实验一:鸢尾花数据集分类

    实验一:鸢尾花数据集分类 一.问题描述 利用机器学习算法构建模型,根据鸢尾花的花萼和花瓣大小,区分鸢尾花的品种.实现一个基础的三分类问题. 二.数据集分析 Iris 鸢尾花数据集内包含 3 种类别,分 ...

  9. orange实现逻辑回归_分别用逻辑回归和决策树实现鸢尾花数据集分类

    学习了决策树和逻辑回归的理论知识,决定亲自上手尝试一下.最终导出决策树的决策过程的图片和pdf.逻辑回归部分参考的是用逻辑回归实现鸢尾花数据集分类,感谢原作者xiaoyangerr 注意:要导出为pd ...

最新文章

  1. antv g2字体阴影_antv g2的理解总结
  2. mysql中Bname表示什么_《MY SQL实用教程》期末考试题
  3. Qt 获取文件夹下所有文件
  4. 顺序查找的基本原理及实现
  5. 共享SQL语句减少硬解析
  6. Proof-of-Stake (POS) outperforms Bitcoin’s Proof-of-Work (POW)
  7. JavaScript中为何要使用prototype
  8. php zhegnze_php 正则表达式
  9. java 边界_Java泛型中的上下边界的理解
  10. MATLAB基本介绍(1)
  11. 《动手学深度学习》(PyTorch版)
  12. Sqlite可视化工具sqliteman安装
  13. c语言输入10个数从小,C语言中,从键盘输入10个数,从小到大排列输出,怎
  14. SuperMap GIS 9D(2019)产品白皮书_V2018Q4R1
  15. 2020赚钱机会总结,拾元富另附10个副业赚钱必备的工具与平台,看看你到底错过了多少钱!
  16. ipad分屏功能怎么开启_英雄联盟手游设置怎么调最合适?英雄联盟手游设置方法与新手开启功能解析...
  17. 计算机网络必须包括,计算机网络硬件包括( )等几个方面。
  18. 置换群的基本概念与题目
  19. PPT怎么添加到公众号文章
  20. js显示格式化代码并高亮(vue中实现代码高亮)

热门文章

  1. 【云原生】初识云原生
  2. matlab 工频干扰去除,单片机应用系统中去除工频干扰的快速实现
  3. RISCV-ISA软件开发记录
  4. 大悲咒的发音[java]
  5. 开源ERP系统Odoo中国发展史
  6. 爱托才会赢,车托常用的技术揭秘
  7. Linux OpenGL 实践篇-11-shadow
  8. Java游戏心得 - 一口咬不到馅儿
  9. 21天混入数据科学家队伍(下)
  10. 在Word中如何直接计算加减乘除?