之前一段时间里,学习过tensorflow和Pytorch也写了点心得,目前是因为项目原因用了一段时间Keras,觉得很不错啊,至少从入门来说对新手极度友好,由于keras是基于tensoflow的基础,相当于tensorflow的高级API吧!(如果理解有错可以在下方评论纠正博主哈!)

一、安装

安装很简单,也是基于tensorflow的前提下(如果有疑问,请参考博主之前的博客),

pip install keras

二、应用示例

# -*- coding: utf-8 -*-
# =============================================================================
# #在线下载
# from keras.datasets import mnist
# (X_train,Y_train),(X_test,Y_test) = mnist.load_data()
#
# print("train:%d imgs"%len(X_train))
# print("test:%d imgs"%len(X_test))
# =============================================================================
online_or_not =False
#本地读取
from tensorflow.examples.tutorials.mnist import input_data
import numpy as npimport matplotlib.pyplot as plt
from PIL import Imagedef show_mnist(train_image,train_labels):n = 3m = 3for i in range(n):for j in range(m):plt.subplot(n,m,i*n+j+1)#plt.subplots_adjust(wspace=0.2, hspace=0.8)index = i * n + j #当前图片的标号img_array = train_image[index]*255img = Image.fromarray(img_array)plt.title(train_labels[index])plt.imshow(img,cmap='Greys')plt.show()#show_mnist(x_train, y_train)# coding: utf-8
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense, Activation, Convolution2D, MaxPooling2D, Flatten
from keras.optimizers import Adam
np.random.seed(1337)"""mnist数据集的label本身进行了one-hot标签化处理"""
if online_or_not:# download the mnist(X_train, Y_train), (X_test, Y_test) = mnist.load_data()# data pre-processingX_train = X_train.reshape(-1, 1, 28, 28)/255X_test = X_test.reshape(-1, 1, 28, 28)/255Y_train = np_utils.to_categorical(Y_train, num_classes=10)Y_test = np_utils.to_categorical(Y_test, num_classes=10)
else:X_train = []X_test = []print("train:%d imgs"%len(X_train))
print("test:%d imgs"%len(X_test))# build CNN
model = Sequential()# conv layer 1 output shape(32, 28, 28)
model.add(Convolution2D(filters=32,kernel_size=5,strides=1,padding='same',batch_input_shape=(None, 1, 28, 28),data_format='channels_first'))
model.add(Activation('relu'))# pooling layer1 (max pooling) output shape(32, 14, 14)
model.add(MaxPooling2D(pool_size=2, strides=2, padding='same', data_format='channels_first'))# conv layer 2 output shape (64, 14, 14)
model.add(Convolution2D(64, 5, strides=1, padding='same', data_format='channels_first'))
model.add(Activation('relu'))# pooling layer 2 (max pooling) output shape (64, 7, 7)
model.add(MaxPooling2D(2, 2, 'same', data_format='channels_first'))# full connected layer 1 input shape (64*7*7=3136), output shape (1024)
model.add(Flatten())
model.add(Dense(1024))
model.add(Activation('relu'))# full connected layer 2 to shape (10) for 10 classes
model.add(Dense(10))
model.add(Activation('softmax'))model.summary()
# define optimizer
adam = Adam(lr=1e-4)
model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])# training
print ('Training')
model.fit(X_train, Y_train, epochs=1, batch_size=128)# testing
print ('Testing')
loss, accuracy = model.evaluate(X_test, Y_test)
print ('loss, accuracy: ', (loss, accuracy))

三、模型可视化

from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense, Activation, Convolution2D, MaxPooling2D, Flatten
from keras.optimizers import Adam
#模型可视化引入
import keras.callbacks
import tensorflow as tf
np.random.seed(1337)(X_train, Y_train), (X_test, Y_test) = mnist.load_data()# data pre-processing
X_train = X_train.reshape(-1, 1, 28, 28)/255
X_test = X_test.reshape(-1, 1, 28, 28)/255
Y_train = np_utils.to_categorical(Y_train, num_classes=10)
Y_test = np_utils.to_categorical(Y_test, num_classes=10)# build CNN
def build_model():model = Sequential()# conv layer 1 output shape(32, 28, 28)model.add(Convolution2D(filters=32,kernel_size=5,strides=1,padding='same',batch_input_shape=(None, 1, 28, 28),data_format='channels_first'))model.add(Activation('relu'))# pooling layer1 (max pooling) output shape(32, 14, 14)model.add(MaxPooling2D(pool_size=2, strides=2, padding='same', data_format='channels_first'))# conv layer 2 output shape (64, 14, 14)model.add(Convolution2D(64, 5, strides=1, padding='same', data_format='channels_first'))model.add(Activation('relu'))# pooling layer 2 (max pooling) output shape (64, 7, 7)model.add(MaxPooling2D(2, 2, 'same', data_format='channels_first'))# full connected layer 1 input shape (64*7*7=3136), output shape (1024)model.add(Flatten())model.add(Dense(1024))model.add(Activation('relu'))# full connected layer 2 to shape (10) for 10 classesmodel.add(Dense(10))model.add(Activation('softmax'))model.summary()return model
model = build_model()
# define optimizer
adam = Adam(lr=1e-4)
model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])#模型可视化
tb_config = keras.callbacks.TensorBoard(log_dir= 'E:/python/kerascode/mnist_cnn/logs', write_images= 1, histogram_freq= 0)
cbks = [tb_config]# training
print ('Training')
#模型可视化需要加入“callbacks= cbks”
model.fit(X_train, Y_train, epochs=1,  callbacks= cbks, batch_size=512)# testing
print ('Testing')
loss, accuracy = model.evaluate(X_test, Y_test)
print ('loss, accuracy: ', (loss, accuracy))

会在 log_dir= 'E:/python/kerascode/mnist_cnn/logs'该文件夹处生成events.out.tfevents.1566199816.ZD 训练日志,打开Anaconda Prompt,激活对应的环境,定位到logs文件的上一级目录,如下图所示,输入

tensorboard --logdir=AB

(此处AB即上文中的logs文件夹名称)。

将网址复制到谷歌浏览器中即可,结果如下图所示:

四、模型可视化踩坑

from keras.utils.vis_utils import plot_model
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation
from keras.layers.embeddings import Embedding
from keras.layers.recurrent import LSTMmodel = Sequential()
model.add(Embedding(input_dim=1024, output_dim=256, input_length=50))
model.add(LSTM(128))  # try using a GRU instead, for fun
model.add(Dropout(0.5))
model.add(Dense(1))
model.add(Activation('sigmoid'))plot_model(model, to_file='model1.png', show_shapes=True)

这个过程会报错,

raise ImportError('Failed to import `pydot`. ''Please install `pydot`. ''For example with `pip install pydot`.')

作为一个菜鸡,只能查到是导入pydot发生错误,很尴尬!还好有个博客大佬,教会如何改错。

打开vis_utils.py文件,将相应地方注释并加入下列代码

# `pydot` is an optional dependency,
# see `extras_require` in `setup.py`.
# =============================================================================
# try:
#     import pydot
# except ImportError:
#     pydot = None
# =============================================================================
try:import pydot_ng as pydot
except ImportError:try:import pydotplus as pydotexcept ImportError:try:import pydotexcept ImportError:pydot=None

参考文献:

https://blog.csdn.net/xu_haim/article/details/84981284

深度学习之keras (一) 初探相关推荐

  1. 【深度学习】Keras加载权重更新模型训练的教程(MobileNet)

    [深度学习]Keras加载权重更新模型训练的教程(MobileNet) 文章目录 1 重新训练 2 keras常用模块的简单介绍 3 使用预训练模型提取特征(口罩检测) 4 总结 1 重新训练 重新建 ...

  2. 【深度学习】Keras和Tensorflow框架使用区别辨析

    [深度学习]Keras和Tensorflow框架使用区别辨析 文章目录 1 概述 2 Keras简介 3 Tensorflow简介 4 使用tensorflow的几个小例子 5 Keras搭建CNN ...

  3. 【深度学习】Keras实现回归和二分类问题讲解

    [深度学习]Keras实现回归和二分类问题讲解 文章目录 [深度学习]Keras实现回归和二分类问题讲解 1 回归问题1.1 波士顿房价预测数据集1.2 构建基准模型1.3 数据预处理1.4 超参数 ...

  4. DL框架之Keras:深度学习框架Keras框架的简介、安装(Python库)、相关概念、Keras模型使用、使用方法之详细攻略

    DL框架之Keras:深度学习框架Keras框架的简介.安装(Python库).相关概念.Keras模型使用.使用方法之详细攻略 目录 Keras的简介 1.Keras的特点 2.Keras四大特性 ...

  5. 【深度学习】Keras vs PyTorch vs Caffe:CNN实现对比

    作者 | PRUDHVI VARMA 编译 | VK 来源 | Analytics Indiamag 在当今世界,人工智能已被大多数商业运作所应用,而且由于先进的深度学习框架,它非常容易部署.这些深度 ...

  6. 简易的深度学习框架Keras代码解析与应用

    北京 | 深度学习与人工智能研修12月23-24日 再设经典课程 重温深度学习阅读全文> 正文约12690个字,22张图,预计阅读时间:32分钟. 总体来讲keras这个深度学习框架真的很&qu ...

  7. Machine Learning Mastery 博客文章翻译:深度学习与 Keras

    目录 Keras 中神经网络模型的 5 步生命周期 在 Python 迷你课程中应用深度学习 Keras 深度学习库的二元分类教程 如何用 Keras 构建多层感知器神经网络模型 如何在 Keras ...

  8. 深度学习框架Keras的安装

    原文链接:https://blog.csdn.net/qingzhuochenfu/article/details/51187603 本人已经将最新博客更新转移至个人网站了,欢迎来访~~ SCP-17 ...

  9. 深度学习:Keras基础--序贯模型(sequential)

    深度学习:Keras入门(一)之基础篇 1.Keras搭建神经网络: Keras有两种类型的模型,序贯模型(Sequential)和函数式模型(Model),函数式模型应用更为广泛,序贯模型是函数式模 ...

最新文章

  1. HBASE_API的应用
  2. GIS叠加分析功能学习
  3. android 之Fragment(轻量级的Activity)详解
  4. java泛型(二)、泛型的内部原理:类型擦除以及类型擦除带来的问题
  5. 如何使用ABAP异步RFC调用提升应用性能
  6. 图像特征提取与描述_角点特征02:SIFT算法+SURF算法
  7. python画正方形的代码_Python编程练习:使用 turtle 库完成正方形的绘制
  8. wifi怎么设置找不到服务器,无线网 登入ip找不到服务器
  9. 程序员不是神,心态决定一切
  10. bootstrap tab 组合表头
  11. JAVA RMI远程方法调用简单实例
  12. oracle根据中文获取拼音全拼函数
  13. java pdf 转tif_使用java对pdf转成tiff文件
  14. openwrt安装aliddns使用阿里云ddns
  15. 【GPU精粹与Shader编程】(八) 《GPU Pro 1》全书核心内容提炼总结
  16. 技术管理岗岗位职责总结
  17. 如何找到最新的RFC文档
  18. 携程2016校园招聘笔试题分析
  19. 2021-08-15 minikube在阿里云centos系统上的安装实践
  20. UVA 10242 Fourth Point

热门文章

  1. php四种基础算法:冒泡,选择,插入和快速排序法
  2. 考研数学:【以错补错】 降低做题出错率
  3. 【转载】在C#中运用SQLDMO备份和恢复Microsoft SQL Server数据库
  4. 从零开始实现ASP.NET Core MVC的插件式开发(五) - 插件的删除和升级
  5. [CF1082E] Increasing Frequency
  6. 水木告白工作室:Java从零入门之模仿头条资讯(一)
  7. Windows vs Linux:\r\n 与 \r
  8. 云计算(Cloud Computing) 培训总结
  9. 为人示弱,做事留余 | 摸鱼系列
  10. SpringBoot的配置项