Keras实现mode.fit和model.fit_generator比较
模型部分
模型部分都一样,比如我这里使用AlexNet网络来做。我做的是一个二分类任务,所以结尾部分网络有改动。输入图片尺寸是256*256的,所以输出图片尺寸有一点改动。
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D, ZeroPadding2D, BatchNormalization
import os# AlexNet
model = Sequential()
#第一段
model.add(Conv2D(filters=96, kernel_size=(11,11),strides=(4,4), padding='valid',input_shape=(256, 256, 3),activation='relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(3,3), strides=(2,2), padding='valid'))
#第二段
model.add(Conv2D(filters=256, kernel_size=(5,5), strides=(1,1), padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(3,3), strides=(2,2), padding='valid'))
#第三段
model.add(Conv2D(filters=384, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu'))
model.add(Conv2D(filters=384, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu'))
model.add(Conv2D(filters=256, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu'))
model.add(MaxPooling2D(pool_size=(3,3), strides=(2,2), padding='valid'))
#第四段
model.add(Flatten())
model.add(Dense(4096, activation='relu'))
model.add(Dropout(0.5))model.add(Dense(4096, activation='relu'))
model.add(Dropout(0.5))model.add(Dense(1000, activation='relu'))
model.add(Dropout(0.5))# Output Layer
model.add(Dense(1))
model.add(Activation('sigmoid'))
使用model.fit_generator
使用这个方法,因为是生成器,所以会比较节省内存。但有一个问题是没有办法定义callback,有的地方不太方便。
train_dir = os.path.abspath(r"../train/")train_dir_good = os.path.abspath(r"../train/good/")
train_dir_wrong = os.path.abspath(r"../train/wrong/")validation_dir = os.path.abspath(r"../validation/")
validation_dir_good = os.path.abspath(r"../validation/good/")
validation_dir_wrong = os.path.abspath(r"../validation/wrong/")
from keras.preprocessing.image import ImageDataGeneratortrain_datagen = ImageDataGenerator(rescale = 1./255)
test_datagen = ImageDataGenerator(rescale = 1./255)train_generator = train_datagen.flow_from_directory(train_dir, batch_size = 20, class_mode= 'binary')validation_generator = test_datagen.flow_from_directory(validation_dir, batch_size = 20, class_mode = 'binary')
from keras import optimizersmodel.compile(loss= 'binary_crossentropy', optimizer= optimizers.RMSprop(lr= 1e-4, decay= 0.01/20), metrics= ['acc'])history = model.fit_generator(train_generator, steps_per_epoch = 100, epochs= 20,validation_data = validation_generator, validation_steps= 50)
model.save('AlexNet.h5')
使用model.fit
下面使用model.fit来做:
https://keras.io/zh/models/model/ 这个链接说明了model.fit怎么用,需要输入的numpy数组
from keras.callbacks import ModelCheckpointcheckpoint = ModelCheckpoint(save_best_only= True)
callbacks = [checkpoint]
model.fit(trainX, trainY, validation_data= (testX, testY), batch_size = 20, epoch = 15, callbacks = callbacks)
Keras实现mode.fit和model.fit_generator比较相关推荐
- model.fit以及model.fit_generator区别及用法
model.fit以及model.fit_generator区别及用法_猫爱吃鱼the的博客-CSDN博客
- 2020-12-11 keras通过model.fit_generator训练模型(节省内存)
keras通过model.fit_generator训练模型(节省内存) 前言 前段时间在训练模型的时候,发现当训练集的数量过大,并且输入的图片维度过大时,很容易就超内存了,举个简单例子,如果我们有2 ...
- [深度学习] Keras 如何使用fit和fit_generator
介绍 在本教程中,您将了解Keras .fit和.fit_generator函数的工作原理,包括它们之间的差异.为了帮助您获得实践经验,我已经提供了一个完整的示例,向您展示如何从头开始实现Keras数 ...
- keras中的fit函数参数_keras的fit_generator与callback函数
fit_generator函数 fit_generator函数 callback类 每一个epoch结束(on_epoch_end)时,都要调用callback函数,callback函数(类)都要集成 ...
- Keras之model.fit_generator()的使用
Keras之model.fit_generator()的使用 model.fit_generator()是利用生成器,分批次向模型送入数据的方式,可以有效节省单次内存的消耗 一.使用方式 1.引入库 ...
- 使用生成器和多线程为Keras训练模型的fit函数提供数据
#导入线程包 from multiprocessing.dummy import Pool as ThreadPool ############################## #定义需要放到线程 ...
- 【解决两个警告】Model.fit_generator` is deprecated and will be removed in a future version. Please use `Mode
在训练 经典卷积神经网络VGG时,因为版本问题,报了警告,下面来解决警告. 其实警告,大多来自前后版本的问题,可能你使用的这个版本里面对于一个方法是这个要求,下一个版本或者更新的版本,对于这个方法就是 ...
- 2020-12-08 tensorflow model.fit_generator()函数参数
model.fit_generator()函数参数 fit_generator(self, generator, steps_per_epoch, epochs=1, verbose=1, callb ...
- 你在数据预处理上花费的时间,是否比机器学习还要多?
你在数据预处理上花费的时间,是否比机器学习还要多? 本文作者:三川 2017-05-31 19:05 导语:IBM 模型架构专家,向大家介绍一个新出世的 Python 数据预处理神器--nuts-ml ...
最新文章
- 在博客以及jupyter notebook 中编写数学公式
- 如何用OpenCV制作一个低成本的立体相机
- easyUI的combobox是否可用
- 图像处理中的跨度(stride)
- Java的Arrays.sort()良心总结
- 深度学习在图像超分辨率重建中的应用
- 第二冲刺阶段个人博客7
- 免费报名丨网易、腾讯、唯品会等100位名企超资深营销增长官,约你闭门“搞事情”...
- 卡尔曼滤波器学习笔记(一)
- PHP脚本占用内存太多,解决方案
- phoenix hbase java_java jdbc访问hbase phoenix
- 聚焦LS-MIMO的四大层面,浅谈5G关键技术
- 安装ipython(一分钟读懂)
- 管理造成的问题:京东商城后台语言改用java
- redis集群搭建管理入门
- python二手房使用教程_python爬取安居客二手房网站数据方法分享
- axure文本框添加水印_Axure如何给元件添加注释?
- Apache网页与安全优化
- JUL、JCL、Log4j、Slf4j各种日志框架的使用
- R语言用标准最小二乘OLS,广义相加模型GAM ,样条函数进行逻辑回归LOGISTIC分类...