欢迎关注WX公众号:【程序员管小亮】

专栏——TensorFlow学习笔记

文章目录

一、神经网络的基本单位:神经元

如果把神经网络的基本单位:神经元和真实的神经细胞(神经元)进行比较的话,会发现在结构上是有一些类似的。

神经网络的神经元示意图如下:

神经细胞模式图如下:

二、卷积神经网络(CNN)

关于理论方面的介绍,可以看一下这个 高赞博客——大话卷积神经网络CNN(干货满满)。

CNN 主要包含:一个或多个卷积层、池化层和全连接层。大部分 CNN 主要是进行不同层的排列组合,构成一个网络结构,来解决实际问题,比如经典的 LeNet-5 就是解决手写数字识别问题的。

三、基于 tf2.0 实现 LeNet

其实 CNN 的实例实现和 TensorFlow2.0 学习笔记(二):多层感知机(MLP) 的多层感知机在代码结构上很类似,不同之处在于新加入了一些层,所以这里的 CNN 网络结构并不是唯一的,可以通过增加、删除卷积层和池化层还有全连接层,或者调整学习率、训练轮数、训练数据集大小以及其他超参数,以期达到更佳的效果和更好的性能。

如下便是刚才所说的 LeNet-5,网络结构如下:


代码如下:

class CNN(tf.keras.Model):def __init__(self):super().__init__()self.conv1 = tf.keras.layers.Conv2D(filters=6,             # 卷积层神经元(卷积核)数目kernel_size=[5, 5],     # 感受野大小padding='valid',         # padding策略(vaild 或 same)strides=(1, 1),activation=tf.nn.relu   # 激活函数)self.pool1 = tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2)self.conv2 = tf.keras.layers.Conv2D(filters=16,kernel_size=[5, 5],padding='valid',strides=(1, 1),activation=tf.nn.relu)self.pool2 = tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2)self.flatten = tf.keras.layers.Flatten()# 等价于self.flatten = tf.keras.layers.Reshape(target_shape=(4 * 4 * 16,))self.dense1 = tf.keras.layers.Dense(units=120, activation=tf.nn.relu)self.dense2 = tf.keras.layers.Dense(units=84, activation=tf.nn.relu)self.dense3 = tf.keras.layers.Dense(units=10)def call(self, inputs):x = self.conv1(inputs)                  # [batch_size, 24, 24, 6]x = self.pool1(x)                       # [batch_size, 12, 12, 6]x = self.conv2(x)                       # [batch_size, 8, 8, 16]x = self.pool2(x)                       # [batch_size, 4, 4, 16]x = self.flatten(x)                     # [batch_size, 5 * 5 * 16]x = self.dense1(x)                      # [batch_size, 120]x = self.dense2(x)                      # [batch_size, 84]x = self.dense3(x)                      # [batch_size, 10]output = tf.nn.softmax(x)return output

输出结果:

# 测试了五次
test accuracy: 0.980700
test accuracy: 0.987200
test accuracy: 0.988100
test accuracy: 0.989000
test accuracy: 0.987100


将 TensorFlow2.0 学习笔记(二):多层感知机(MLP) 的 model = MLP() 更换成 model = CNN() ,可以注意到,基于 LeNet 可以达到 98% 左右的准确率,比之前的多层感知机要高出 1%!这是一个非常显著的提高!事实上,通过改变模型的网络结构(比如加入 Dropout 层防止过拟合),准确率还有进一步提升的空间。

另外要注意,这是没有调参的输出结果,所以不是最佳性能。完整代码如下:

import tensorflow as tf
import numpy as npclass MNISTLoader():def __init__(self):mnist = tf.keras.datasets.mnist(self.train_data, self.train_label), (self.test_data,self.test_label) = mnist.load_data()# MNIST中的图像默认为uint8(0-255的数字)# 以下代码将其归一化到0-1之间的浮点数,并在最后增加一维作为颜色通道self.train_data = np.expand_dims(self.train_data.astype(np.float32) / 255.0,axis=-1)                                             # [60000, 28, 28, 1]self.test_data = np.expand_dims(self.test_data.astype(np.float32) / 255.0,axis=-1)                                            # [10000, 28, 28, 1]self.train_label = self.train_label.astype(np.int32)    # [60000]self.test_label = self.test_label.astype(np.int32)      # [10000]self.num_train_data, self.num_test_data = self.train_data.shape[0], self.test_data.shape[0]def get_batch(self, batch_size):# 从数据集中随机取出batch_size个元素并返回index = np.random.randint(0, np.shape(self.train_data)[0], batch_size)return self.train_data[index, :], self.train_label[index]class CNN(tf.keras.Model):def __init__(self):super().__init__()self.conv1 = tf.keras.layers.Conv2D(filters=6,             # 卷积层神经元(卷积核)数目kernel_size=[5, 5],     # 感受野大小padding='valid',         # padding策略(vaild 或 same)strides=(1, 1),activation=tf.nn.relu   # 激活函数)self.pool1 = tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2)self.conv2 = tf.keras.layers.Conv2D(filters=16,kernel_size=[5, 5],padding='valid',strides=(1, 1),activation=tf.nn.relu)self.pool2 = tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2)self.flatten = tf.keras.layers.Flatten()# 等价于self.flatten = tf.keras.layers.Reshape(target_shape=(4 * 4 * 16,))self.dense1 = tf.keras.layers.Dense(units=120, activation=tf.nn.relu)self.dense2 = tf.keras.layers.Dense(units=84, activation=tf.nn.relu)self.dense3 = tf.keras.layers.Dense(units=10)def call(self, inputs):x = self.conv1(inputs)                  # [batch_size, 24, 24, 6]x = self.pool1(x)                       # [batch_size, 12, 12, 6]x = self.conv2(x)                       # [batch_size, 8, 8, 16]x = self.pool2(x)                       # [batch_size, 4, 4, 16]x = self.flatten(x)                     # [batch_size, 5 * 5 * 16]x = self.dense1(x)                      # [batch_size, 120]x = self.dense2(x)                      # [batch_size, 84]x = self.dense3(x)                      # [batch_size, 10]output = tf.nn.softmax(x)return outputnum_epochs = 5          # 训练轮数
batch_size = 50            # 批大小
learning_rate = 0.001  # 学习率model = CNN()                                                     # 实例化模型
data_loader = MNISTLoader()                                            # 数据载入
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)  # 实例化优化器num_batches = int(data_loader.num_train_data // batch_size * num_epochs)
for batch_index in range(num_batches):# 随机取一批训练数据X, y = data_loader.get_batch(batch_size)with tf.GradientTape() as tape:# 计算模型预测值y_pred = model(X)# 计算损失函数loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)loss = tf.reduce_mean(loss)print("batch %d: loss %f" % (batch_index, loss.numpy()))# 计算模型变量的导数grads = tape.gradient(loss, model.variables)# 优化器的使用optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))# 评估器
sparse_categorical_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
# 迭代轮数
num_batches = int(data_loader.num_test_data // batch_size)
for batch_index in range(num_batches):start_index, end_index = batch_index * \batch_size, (batch_index + 1) * batch_size# 模型预测的结果y_pred = model.predict(data_loader.test_data[start_index: end_index])sparse_categorical_accuracy.update_state(y_true=data_loader.test_label[start_index: end_index], y_pred=y_pred)
print("test accuracy: %f" % sparse_categorical_accuracy.result())

推荐阅读

  • TensorFlow2.0 学习笔记(一):TensorFlow 2.0 的安装和环境配置以及上手初体验
  • TensorFlow2.0 学习笔记(二):多层感知机(MLP)
  • TensorFlow2.0 学习笔记(三):卷积神经网络(CNN)
  • TensorFlow2.0 学习笔记(四):迁移学习(MobileNetV2)
  • TensorFlow2.0 学习笔记(五):循环神经网络(RNN)

参考文章

  • TensorFlow 官方文档
  • 简单粗暴 TensorFlow 2.0

TensorFlow2.0 学习笔记(三):卷积神经网络(CNN)相关推荐

  1. 深度学习(DL)与卷积神经网络(CNN)学习笔记随笔-04-基于Python的LeNet之MLP

    原文地址可以查看更多信息 本文主要参考于:Multilayer Perceptron  python源代码(github下载 CSDN免费下载) 本文主要介绍含有单隐层的MLP的建模及实现.建议在阅读 ...

  2. 深度学习(DL)与卷积神经网络(CNN)学习笔记随笔-03-基于Python的LeNet之LR

    原地址可以查看更多信息 本文主要参考于:Classifying MNIST digits using Logistic Regression  python源代码(GitHub下载 CSDN免费下载) ...

  3. 【深度学习系列】卷积神经网络CNN原理详解(一)——基本原理(1)

    上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...

  4. 深度学习笔记:卷积神经网络的可视化--卷积核本征模式

    目录 1. 前言 2. 代码实验 2.1 加载模型 2.2 构造返回中间层激活输出的模型 2.3 目标函数 2.4 通过随机梯度上升最大化损失 2.5 生成滤波器模式可视化图像 2.6 将多维数组变换 ...

  5. Tensorflow2.0学习笔记(一)北大曹健老师教学视频1-4讲

    Tensorflow2.0学习笔记(一)北大曹健老师教学视频1-4讲 返回目录 这个笔记现在是主要根据北京大学曹健老师的视频写的,这个视频超级棒,非常推荐. 第一讲 常用函数的使用(包含了很多琐碎的函 ...

  6. 机器学习笔记三—卷积神经网络与循环神经网络

    系列文章目录 机器学习笔记一-机器学习基本知识 机器学习笔记二-梯度下降和反向传播 机器学习笔记三-卷积神经网络与循环神经网络 机器学习笔记四-机器学习可解释性 机器学习笔记五-机器学习攻击与防御 机 ...

  7. [人工智能-深度学习-24]:卷积神经网络CNN - CS231n解读 - 卷积神经网络基本层级

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:[人工智能-深度学习-23]:卷积神经网络CNN - CS231n解读 - 卷积神经网络基本层级_ ...

  8. [人工智能-深度学习-33]:卷积神经网络CNN - 常见分类网络- LeNet网络结构分析与详解

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/detai ...

  9. Tensorflow2.0学习笔记(二)

    Tensorflow2.0学习笔记(二)--Keras练习 文章目录 Tensorflow2.0学习笔记(二)--Keras练习 前言 二.使用步骤 1.实现步骤及代码 2.下载 Fashion MN ...

最新文章

  1. 鼠标点击触发事件python_如何在鼠标按下的情况下触发tkinter的“Enter”事件?...
  2. asp.net 拦截html,关于c#:如何在-ASPNET-Core-中实现全局异常拦截
  3. 使用jQuery Treeview插件实现树状结构效果
  4. 关于 someone could be eavesdropping on you right now (man-in-the-middle attack) ssh的解决办法
  5. Spring Data JPA和分页
  6. Vue3 --- 安装和使用echarts
  7. Bootstrap3 滚动监听插件的调用方式
  8. SMP、NUMA、MPP(Teradata)体系结构介绍
  9. oracle里的定时器,oracle中创建定时器
  10. Python实现快乐的数字
  11. nullnullvc中加花
  12. java垃圾回收算法
  13. Java爬虫入门教程 开篇
  14. Java + OpenCV 实现图片人脸检测
  15. 淘宝为什么放弃SpringCloud、Dubbo,选择了这个牛逼的神仙框架!贼爽
  16. 面试官问:你的缺点是什么,这么回答漂亮!(真实案例)
  17. java后端应届生面试题,附答案解析
  18. .NET 结构体 Struck、类
  19. linux 设备树 usb控制器,linux 设备树中 dwc3 节点的phys参数含义
  20. 大数据分析师的沟通技巧

热门文章

  1. 发票 ocr java_OCR识别技术—增值税发票识别
  2. Qt插件:QPluginLoader
  3. idea的java类图标C不见,取而代之是J标识,且写代码无提示。
  4. 赏红叶,是金秋心旷神怡之事
  5. 蒸发器,冷凝器面积过大
  6. 计算机毕业设计springboot小组学习系统
  7. ZigBee 3.0理论教程-通用-1-04:协议与架构-媒体访问控制层(MAC)
  8. Facebook Surround360 学习笔记--(4)色彩/视差不一致问题
  9. 病毒币骗局:号称募资生产口罩,研究疫苗
  10. 二维码这把利刃,产品应该用到极致