模型部分

模型部分都一样,比如我这里使用AlexNet网络来做。我做的是一个二分类任务,所以结尾部分网络有改动。输入图片尺寸是256*256的,所以输出图片尺寸有一点改动。

from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D, ZeroPadding2D, BatchNormalization
import os# AlexNet
model = Sequential()
#第一段
model.add(Conv2D(filters=96, kernel_size=(11,11),strides=(4,4), padding='valid',input_shape=(256, 256, 3),activation='relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(3,3), strides=(2,2), padding='valid'))
#第二段
model.add(Conv2D(filters=256, kernel_size=(5,5), strides=(1,1), padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(3,3), strides=(2,2), padding='valid'))
#第三段
model.add(Conv2D(filters=384, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu'))
model.add(Conv2D(filters=384, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu'))
model.add(Conv2D(filters=256, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu'))
model.add(MaxPooling2D(pool_size=(3,3), strides=(2,2), padding='valid'))
#第四段
model.add(Flatten())
model.add(Dense(4096, activation='relu'))
model.add(Dropout(0.5))model.add(Dense(4096, activation='relu'))
model.add(Dropout(0.5))model.add(Dense(1000, activation='relu'))
model.add(Dropout(0.5))# Output Layer
model.add(Dense(1))
model.add(Activation('sigmoid'))

使用model.fit_generator

使用这个方法,因为是生成器,所以会比较节省内存。但有一个问题是没有办法定义callback,有的地方不太方便。

train_dir = os.path.abspath(r"../train/")train_dir_good = os.path.abspath(r"../train/good/")
train_dir_wrong = os.path.abspath(r"../train/wrong/")validation_dir = os.path.abspath(r"../validation/")
validation_dir_good = os.path.abspath(r"../validation/good/")
validation_dir_wrong = os.path.abspath(r"../validation/wrong/")
from keras.preprocessing.image import ImageDataGeneratortrain_datagen = ImageDataGenerator(rescale = 1./255)
test_datagen = ImageDataGenerator(rescale = 1./255)train_generator = train_datagen.flow_from_directory(train_dir, batch_size = 20, class_mode= 'binary')validation_generator = test_datagen.flow_from_directory(validation_dir, batch_size = 20, class_mode = 'binary')
from keras import optimizersmodel.compile(loss= 'binary_crossentropy', optimizer= optimizers.RMSprop(lr= 1e-4, decay= 0.01/20), metrics= ['acc'])history = model.fit_generator(train_generator, steps_per_epoch = 100, epochs= 20,validation_data = validation_generator, validation_steps= 50)
model.save('AlexNet.h5')

使用model.fit

下面使用model.fit来做:
https://keras.io/zh/models/model/ 这个链接说明了model.fit怎么用,需要输入的numpy数组

from keras.callbacks import ModelCheckpointcheckpoint = ModelCheckpoint(save_best_only= True)
callbacks = [checkpoint]
model.fit(trainX, trainY, validation_data= (testX, testY), batch_size = 20, epoch = 15, callbacks = callbacks)

Keras实现mode.fit和model.fit_generator比较相关推荐

  1. model.fit以及model.fit_generator区别及用法

    model.fit以及model.fit_generator区别及用法_猫爱吃鱼the的博客-CSDN博客

  2. 2020-12-11 keras通过model.fit_generator训练模型(节省内存)

    keras通过model.fit_generator训练模型(节省内存) 前言 前段时间在训练模型的时候,发现当训练集的数量过大,并且输入的图片维度过大时,很容易就超内存了,举个简单例子,如果我们有2 ...

  3. [深度学习] Keras 如何使用fit和fit_generator

    介绍 在本教程中,您将了解Keras .fit和.fit_generator函数的工作原理,包括它们之间的差异.为了帮助您获得实践经验,我已经提供了一个完整的示例,向您展示如何从头开始实现Keras数 ...

  4. keras中的fit函数参数_keras的fit_generator与callback函数

    fit_generator函数 fit_generator函数 callback类 每一个epoch结束(on_epoch_end)时,都要调用callback函数,callback函数(类)都要集成 ...

  5. Keras之model.fit_generator()的使用

    Keras之model.fit_generator()的使用 model.fit_generator()是利用生成器,分批次向模型送入数据的方式,可以有效节省单次内存的消耗 一.使用方式 1.引入库 ...

  6. 使用生成器和多线程为Keras训练模型的fit函数提供数据

    #导入线程包 from multiprocessing.dummy import Pool as ThreadPool ############################## #定义需要放到线程 ...

  7. 【解决两个警告】Model.fit_generator` is deprecated and will be removed in a future version. Please use `Mode

    在训练 经典卷积神经网络VGG时,因为版本问题,报了警告,下面来解决警告. 其实警告,大多来自前后版本的问题,可能你使用的这个版本里面对于一个方法是这个要求,下一个版本或者更新的版本,对于这个方法就是 ...

  8. 2020-12-08 tensorflow model.fit_generator()函数参数

    model.fit_generator()函数参数 fit_generator(self, generator, steps_per_epoch, epochs=1, verbose=1, callb ...

  9. 你在数据预处理上花费的时间,是否比机器学习还要多?

    你在数据预处理上花费的时间,是否比机器学习还要多? 本文作者:三川 2017-05-31 19:05 导语:IBM 模型架构专家,向大家介绍一个新出世的 Python 数据预处理神器--nuts-ml ...

最新文章

  1. 在博客以及jupyter notebook 中编写数学公式
  2. 如何用OpenCV制作一个低成本的立体相机
  3. easyUI的combobox是否可用
  4. 图像处理中的跨度(stride)
  5. Java的Arrays.sort()良心总结
  6. 深度学习在图像超分辨率重建中的应用
  7. 第二冲刺阶段个人博客7
  8. 免费报名丨网易、腾讯、唯品会等100位名企超资深营销增长官,约你闭门“搞事情”...
  9. 卡尔曼滤波器学习笔记(一)
  10. PHP脚本占用内存太多,解决方案
  11. phoenix hbase java_java jdbc访问hbase phoenix
  12. 聚焦LS-MIMO的四大层面,浅谈5G关键技术
  13. 安装ipython(一分钟读懂)
  14. 管理造成的问题:京东商城后台语言改用java
  15. redis集群搭建管理入门
  16. python二手房使用教程_python爬取安居客二手房网站数据方法分享
  17. axure文本框添加水印_Axure如何给元件添加注释?
  18. Apache网页与安全优化
  19. JUL、JCL、Log4j、Slf4j各种日志框架的使用
  20. R语言用标准最小二乘OLS,广义相加模型GAM ,样条函数进行逻辑回归LOGISTIC分类...

热门文章

  1. Pat乙级 1038 统计同成绩学生
  2. 苹果xsmax怎么开机_苹果XSMAX进水不开机维修
  3. python模拟地面网管接收数据
  4. crontab添加定时任务
  5. JAVA架构师面试题and如何成为架构师
  6. SEO配置信息操作文档
  7. Eclipse安装 Activiti Designer插件
  8. Linux jdk配置
  9. 让MySQL支持Emoji表情 mysql 5.6
  10. 新增记录行(ecshop)