欢迎关注”生信修炼手册”!

本文是对tensforflow官方入门教程的学习和翻译,展示了创建一个基础的卷积神经网络模型来解决图像分类问题的过程。具体步骤如下

1. 加载数据集

tensorflow集成了keras这个框架,提供了CIFAR10数据集,该数据集包含了10个类别共6万张彩色图片,加载方式如下

>>> import tensorflow as tf
>>> from tensorflow.keras import datasets,layers, models
>>> import matplotlib.pyplot as plt
>>> (train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170500096/170498071 [==============================] - 33s 0us/step
>>> train_images, test_images = train_images / 255.0, test_images / 255.0

可以通过如下代码来查看部分图片

>>> for i in range(25):
...     plt.subplot(5, 5, i + 1)
...     plt.xticks([])
...     plt.yticks([])
...     plt.grid(False)
...     plt.imshow(train_images[i], cmap = plt.cm.binary)
...     plt.xlabel(class_names[train_labels[i][0]])
...
>>> plt.show()

可视化效果如下

2. 构建卷积神经网络

通过keras的Sequential API来构建卷积神经网络,依次添加卷积层,池化层,全连接层,代码如下

>>> model = models.Sequential()
>>> model.add(layers.Conv2D(32, (3, 3), activation = "relu", input_shape = (32, 32, 3)))
>>> model.add(layers.MaxPooling2D((2, 2)))
>>> model.add(layers.Conv2D(64, (3,3), activation = "relu"))
>>> model.add(layers.MaxPooling2D((2, 2)))
>>> model.add(layers.Conv2D(64, (3, 3), activation = "relu"))
>>> model.add(layers.Flatten())
>>> model.add(layers.Dense(64, activation = "relu"))
>>> model.add(layers.Dense(10))
>>> model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
conv2d_1 (Conv2D)            (None, 30, 30, 32)        896
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 15, 15, 32)        0
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 13, 13, 64)        18496
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 6, 6, 64)          0
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 4, 4, 64)          36928
_________________________________________________________________
flatten (Flatten)            (None, 1024)              0
_________________________________________________________________
dense_1 (Dense)              (None, 64)                65600
_________________________________________________________________
dense_2 (Dense)              (None, 10)                650
=================================================================
Total params: 122,570
Trainable params: 122,570
Non-trainable params: 0
_________________________________________________________________

3. 编译模型

模型在训练之前,必须对其进行编译,主要是确定损失函数,优化器以及评估分类效果好坏的指标,代码如下

>>> model.compile(optimizer = 'adam', loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics = ['accuracy'])

4. 训练模型

使用训练集训练模型,代码如下

>>> history = model.fit(train_images, train_labels, epochs = 10, validation_data = (test_images, test_labels))
2021-06-23 10:59:43.386592: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2)
Epoch 1/10
1563/1563 [==============================] - 412s 203ms/step - loss: 1.5396 - accuracy: 0.4380 - val_loss: 1.2760 - val_accuracy: 0.5413
Epoch 2/10
1563/1563 [==============================] - 94s 60ms/step - loss: 1.1637 - accuracy: 0.5850 - val_loss: 1.1193 - val_accuracy: 0.6084
Epoch 3/10
1563/1563 [==============================] - 95s 61ms/step - loss: 1.0210 - accuracy: 0.6398 - val_loss: 0.9900 - val_accuracy: 0.6556
Epoch 4/10
1563/1563 [==============================] - 88s 56ms/step - loss: 0.9186 - accuracy: 0.6781 - val_loss: 0.9399 - val_accuracy: 0.6687
Epoch 5/10
1563/1563 [==============================] - 95s 61ms/step - loss: 0.8472 - accuracy: 0.7023 - val_loss: 0.8984 - val_accuracy: 0.6868
Epoch 6/10
1563/1563 [==============================] - 85s 55ms/step - loss: 0.7917 - accuracy: 0.7220 - val_loss: 0.8896 - val_accuracy: 0.6888
Epoch 7/10
1563/1563 [==============================] - 88s 56ms/step - loss: 0.7450 - accuracy: 0.7381 - val_loss: 0.8843 - val_accuracy: 0.6974
Epoch 8/10
1563/1563 [==============================] - 87s 55ms/step - loss: 0.7024 - accuracy: 0.7530 - val_loss: 0.8403 - val_accuracy: 0.7089
Epoch 9/10
1563/1563 [==============================] - 92s 59ms/step - loss: 0.6600 - accuracy: 0.7676 - val_loss: 0.8512 - val_accuracy: 0.7095
Epoch 10/10
1563/1563 [==============================] - 91s 58ms/step - loss: 0.6240 - accuracy: 0.7790 - val_loss: 0.8483 - val_accuracy: 0.7119

通过比较训练集和验证集的准确率曲线,可以判断模型训练是否有过拟合等问题,代码如下

>>> plt.plot(history.history['accuracy'], label='accuracy')
[<matplotlib.lines.Line2D object at 0x000001AAC62A7B08>]
>>> plt.plot(history.history['val_accuracy'], label = 'val_accuracy')
[<matplotlib.lines.Line2D object at 0x000001AAC28F8988>]
>>> plt.xlabel('Epoch')
Text(0.5, 0, 'Epoch')
>>> plt.ylabel('Accuracy')
Text(0, 0.5, 'Accuracy')
>>> plt.ylim([0.5, 1])
(0.5, 1.0)
>>> plt.legend(loc='lower right')
<matplotlib.legend.Legend object at 0x000001AAC62A7688>
>>> plt.show()

结果如下

当模型过拟合时,会看到accuracy非常高,而val_accuracy较低,两条线明显偏离。从上图中看到,两个准确率比较接近,没有明显的分离现象,而且值都比较低,模型存在欠拟合的问题。

5. 评估模型

用测试集评估模型效果,结果如下

>>> test_loss, test_acc = model.evaluate(test_images,  test_labels, verbose=2)
313/313 - 7s - loss: 0.8483 - accuracy: 0.7119>>> print(test_acc)
0.711899995803833

准确率达到了70%,对于一个由几行代码快速构建的初步卷积神经网络模型而言,这个效果还可以接受。后续可以考虑数据增强,模型改进,调整学习率等方式,来提高模型的准确率。

·end·

—如果喜欢,快分享给你的朋友们吧—

原创不易,欢迎收藏,点赞,转发!生信知识浩瀚如海,在生信学习的道路上,让我们一起并肩作战!

本公众号深耕耘生信领域多年,具有丰富的数据分析经验,致力于提供真正有价值的数据分析服务,擅长个性化分析,欢迎有需要的老师和同学前来咨询。

更多精彩

  • KEGG数据库,除了pathway你还知道哪些

  • 全网最完整的circos中文教程

  • DNA甲基化数据分析专题

  • 突变检测数据分析专题

  • mRNA数据分析专题

  • lncRNA数据分析专题

  • circRNA数据分析专题

  • miRNA数据分析专题

  • 单细胞转录组数据分析专题

  • chip_seq数据分析专题

  • Hi-C数据分析专题

  • HLA数据分析专题

  • TCGA肿瘤数据分析专题

  • 基因组组装数据分析专题

  • CNV数据分析专题

  • GWAS数据分析专题

  • 机器学习专题

  • 2018年推文合集

  • 2019年推文合集

  • 2020推文合集

写在最后

转发本文至朋友圈,后台私信截图即可加入生信交流群,和小伙伴一起学习交流。

扫描下方二维码,关注我们,解锁更多精彩内容!

一个只分享干货的

生信公众号

使用tensorflow构建一个卷积神经网络相关推荐

  1. 使用tensorflow构建简单卷积神经网络

    一 概要 CIFAR-10分类问题是机器学习领域的一个通用基准,其问题是将32X32像素的RGB图像分类成10种类别:飞机,手机,鸟,猫,鹿,狗,青蛙,马,船和卡车.  更多信息请移步CIFAR-10 ...

  2. DeepDream、反向运行一个卷积神经网络在 DeepDream和卷积神经网络的可视化 中的应用

    日萌社 人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新) 反向运行一个卷积神经网络在 卷积神经网络的可视化 中的应用 D ...

  3. Tensorflow使用CNN卷积神经网络以及RNN(Lstm、Gru)循环神经网络进行中文文本分类

    Tensorflow使用CNN卷积神经网络以及RNN(Lstm.Gru)循环神经网络进行中文文本分类 本案例采用清华大学NLP组提供的THUCNews新闻文本分类数据集的一个子集进行训练和测试http ...

  4. 基于pytorch使用实现CNN 如何使用pytorch构建CNN卷积神经网络

    基于pytorch使用实现CNN 如何使用pytorch构建CNN卷积神经网络 所用工具 文件结构: 数据: 代码: 结果: 改进思路 拓展 本文是一个基于pytorch使用CNN在生物信息学上进行位 ...

  5. 基于tensorflow、keras利用emnist数据集构建CNN卷积神经网络进行手写字母识别

    EMNIST 数据集是一个包含手写字母,数字的数据集,它具有和MNIST相同的数据格式.The EMNIST Dataset | NIST 引用模块介绍: import tensorflow as t ...

  6. Tensorflow之 CNN卷积神经网络的MNIST手写数字识别

    点击"阅读原文"直接打开[北京站 | GPU CUDA 进阶课程]报名链接 作者,周乘,华中科技大学电子与信息工程系在读. 前言 tensorflow中文社区对官方文档进行了完整翻 ...

  7. Python基于TensorFlow深度学习卷积神经网络自动识别网站验证码设计

    开发环境: Pycharm + Python3.7 + Django2.2 + sqlite数据库 + TensorFlow深度学习框架 + selenium自动化测试 "基于深度网络的网站 ...

  8. 【TensorFlow实战】TensorFlow实现经典卷积神经网络之ResNet

    ResNet ResNet(Residual Neural Network)通过使用Residual Unit成功训练152层深的神经网络,在ILSVRC 2015比赛中获得冠军,取得3.57%的to ...

  9. Tensorflow学习之 卷积神经网络 (一)什么是卷积?

    这一节回顾一下卷积神经网络 第一张图是单通道的一张照片,在RGB中只有一个通道,即一个0-255的值来表示其灰度: 第二张图就是一张彩色的图片了,这里具有三个通道,也就是我们常见的RGB三个0-255 ...

最新文章

  1. 2021年Facebook博士生奖研金名单公布!一半获奖者是华人博士生
  2. new是不是c语言运算符优先级表,C语言运算符优先级列表(超详细)
  3. ElasticSearch服务器操作命令
  4. 科技巨头开发人工智能 稀缺品种或成A股香饽饽
  5. rtsp协议_基于libVLC的视频播放器(支持RTSP协议)
  6. php根本自学不了,PHP开发自学还是培训?
  7. 使用C++和LIBSVM实现机器学习+样本分类
  8. Android ListView 代码1
  9. wap(dopra linux )命令,自行更换HG8321R千兆光猫记录
  10. 分享这位的WPF界面设计系列视频
  11. 使用CSS控制表格设计出课程表实验内容:编写一个网页,内容为本学期本班的课程表,并使用CSS设计课程表的显示样式
  12. linux编译lame,lame mp3 源码 分析
  13. 安吉县人力资源和社会保障局数据中心容灾备份项目
  14. [llvm cookbook] 1、LLVM设计与使用
  15. 安装“tampermonkey”脚本下载知网文献PDF格式
  16. '.'和'..'还有'./'和'../'
  17. python如何生成等差数列_python numpy函数中的linspace创建等差数列详解
  18. 中级运维这么学才有意思
  19. 企业级大数据平台应用场景介绍
  20. 基于jQuery的jsp表格动态合并

热门文章

  1. 运用特征脸方法的基于Opencv的猫脸检测实现
  2. AD与AAD区别和联系
  3. Poseidon(海神号)
  4. [jQuery学习系列四 ]4-Jquery学习四-事件操作
  5. 中国电子学会2022年12月份青少年软件编程Python等级考试试卷四级真题(含答案)
  6. 阿里云 ECS 构建集群
  7. 在1705年第一个电灯泡是如何被发明的?
  8. 【Markdown使用技巧总结】-如何在Markdown文档中插入空格?
  9. 软件测试(六)——缺陷以及总结
  10. 用于实时大数据处理的Lambda架构