tensorflow实现宝可梦数据集迁移学习
目录
一、迁移学习简介
二、构建预训练模型
1、调用内置模型
2、修改模型
3、构建模型
三、导入数据和预处理
1、设置batch size
2、读取训练数据
3、读取验证数据
4、读取测试数据
5、预处理
四、模型训练
1、设置early_stopping
2、模型编译
3、模型设置
4、模型评估
5、保存训练权重
五、模型预测
1、构建预测模型
2、导入权重
3、预测
4、对比分析
一、迁移学习简介
迁移学习就是把预先定义好的模型,以及该模型在对应数据集上训练得到的参数迁移到新的模型,用来帮助新模型训练。通过迁移学习我们可以将模型已经学到的参数,分享给新模型从而加快并优化模型的学习效率,从而不用像大多数网络那样从零开始学习。对于小样本学习的也可以减少过拟合或者欠拟合问题。
迁移学习的几种实现方式:
Transfer Learning:冻结预训练模型的全部卷积层,只训练自己定制的全连接层。
Extract Feature Vector:先计算出预训练模型的卷积层对所有训练和测试数据的特征向量,然后抛开预训练模型,只训练自己定制的简配版全连接网络。
Fine-tuning:冻结预训练模型的部分卷积层(通常是靠近输入的多数卷积层,因为这些层保留了大量底层信息)甚至不冻结任何网络层,训练剩下的卷积层(通常是靠近输出的部分卷积层)和全连接层。
二、构建预训练模型
1、调用内置模型
调用tensorflow内置VGG19模型,下载该模型在"imagenet"数据集上预训练权重
net = keras.applications.VGG19(weights='imagenet', include_top=False,
pooling='max')
2、修改模型
冻结卷积层,将全连接层修改为自定义数据集对应分类数。
net.trainable = False
newnet = keras.Sequential([
net,
layers.Dense(5)
])
3、构建模型
newnet.build(input_shape=(4,224,224,3))
newnet.summary()
三、导入数据和预处理
1、设置batch size
根据模型参数量和硬件环境设定batch size大小
batchsz = 128
2、读取训练数据
images, labels, table = load_pokemon('pokemon',mode='train')
db_train = tf.data.Dataset.from_tensor_slices((images, labels))
db_train = db_train.shuffle(1000).map(preprocess).batch(batchsz)
3、读取验证数据
images2, labels2, table = load_pokemon('pokemon',mode='val')
db_val = tf.data.Dataset.from_tensor_slices((images2, labels2))
db_val = db_val.map(preprocess).batch(batchsz)
4、读取测试数据
images3, labels3, table = load_pokemon('pokemon',mode='test')
db_test = tf.data.Dataset.from_tensor_slices((images3, labels3))
db_test = db_test.map(preprocess).batch(batchsz)
5、预处理
def preprocess(x,y):
# x: 图片的路径,y:图片的数字编码
x = tf.io.read_file(x)
x = tf.image.decode_jpeg(x, channels=3)
x = tf.image.resize(x, [244, 244])
x = tf.image.random_flip_up_down(x)
x = tf.image.random_crop(x, [224,224,3])
x = tf.cast(x, dtype=tf.float32) / 255.
x = normalize(x)
y = tf.convert_to_tensor(y)
y = tf.one_hot(y, depth=5)
return x, y
四、模型训练
1、设置early_stopping
为防止过拟合,这里使用early_stopping,当模型在验证集上精度变化在min_delta以内,并且持续次数达到patience以后,模型训练即停止。
early_stopping = EarlyStopping(
monitor='val_accuracy',
min_delta=0.001,
patience=5
)
2、模型编译
设置优化器,损失函数和精度衡量标准
newnet.compile(optimizer=optimizers.Adam(lr=1e-3),
loss=losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
3、模型设置
设置训练集,验证集,验证频率,迭代次数以及回调函数
newnet.fit(db_train, validation_data=db_val, validation_freq=1, epochs=20,
callbacks=[early_stopping])
4、模型评估
训练结束后,使用evaluate函数进行模型评估,了解模型最终精度情况。
newnet.evaluate(db_test)
5、保存训练权重
newnet.save_weights('weights.ckpt')
五、模型预测
1、构建预测模型
net = keras.applications.VGG19(weights='imagenet', include_top=False,
pooling='max')
net.trainable = False
model= keras.Sequential([
net,
layers.Dense(5)
])
model.build(input_shape=(4,224,224,3))
2、导入权重
model.load_weights('weights.ckpt')
3、预测
logits = newnet.predict(x)
prob = tf.nn.softmax(logits, axis=1)
print(prob)
max_prob_index = np.argmax(prob, axis=-1)[0]
prob = prob.numpy()
max_prob = prob[0][max_prob_index]
print(max_prob)
max_index = np.argmax(logits, axis=-1)[0]
name = ['妙蛙种子', '小火龙', '超梦', '皮卡丘', '杰尼龟']
print(name[max_index])
测试图像:
预测结果:
tf.Tensor([[0.78470963 0.09179451 0.03650109 0.01834733 0.06864741]], shape=(1, 5), dtype=float32)
0.78470963
妙蛙种子
4、对比分析
使用同样测试图像在没有进行迁移学习训练的模型上进行测试,输出结果:
tf.Tensor([[0.46965462 0.0470721 0.20003504 0.11915307 0.16408516]], shape=(1, 5), dtype=float32)
0.46965462
妙蛙种子
从结果上看,两个模型都能准确预测,但输出的分类概率(迁移学习0.7847,非迁移学习0.4696),两者存在明显差别,可以看出使用迁移学习能够达到更好的拟合效果。
tensorflow实现宝可梦数据集迁移学习相关推荐
- 深度学习tensorflow实现宝可梦图像分类
目录 一.数据集简介 二.数据预处理 三.构建卷积神经网络 四.模型训练 五.预测 六.分析与优化 一.数据集简介 宝可梦数据集(共1168张图像):bulbasaur(妙蛙种子,234).charm ...
- 宝可梦数据集分析及预测
前言 以下内容为本人学习过程中记录,仅用于学习,如有错误或者纰漏,请留言指正,谢谢. 数据集和代码下载 – 百度云链接:https://pan.baidu.com/s/1RFUEVcD85J2AQ3_ ...
- 精灵宝可梦数据集与动漫头像数据集
精灵宝可梦数据集 链接:https://pan.baidu.com/s/1O-YBLBeqDpui_FhspnwY3g 提取码:r1ze 动漫头像数据集 链接:https://pan.baidu.c ...
- PyTorch实现基于ResNet18迁移学习的宝可梦数据集分类
一.实现过程 1.数据集描述 数据集分为5类,分别如下: 皮卡丘:234 超梦:239 杰尼龟:223 小火龙:238 妙蛙种子:234 自取链接:https://pan.baidu.com/s/1b ...
- 阿里云天池大数据:【入门】精灵宝可梦数据集分析
目的 学习,实践,不同机器学习算法 使用的包及安装 pip install numpy pip install Pandas 数据获取 阿里云天池大数据竞赛官网获取 莫某 引入包 import pan ...
- python3下tensorflow练习(八)之迁移学习
这周帮同学做了一个CNN的分类任务,因为赶时间所以直接用InceptionV3的参数进行迁移学习,只替换最后一层全连接层,然后对自己的数据集进行4分类的训练.在最后这一层全连接层之前的网络层称之为瓶颈 ...
- Python入门(10)——宝可梦数据集探索
数据时代的到来刷新了人们探索未知的方式,本文就通过使用数据分析的方式来帮助我更好的了解宝可梦这种神奇的生物,然后再选择最经济实惠,简单好抓的宝可梦来挑战联盟.通过使用搜索引擎,找到了一份包含着从第一代 ...
- 【入门】精灵宝可梦数据集分析
数据集下载 !wget -O pokemon_data.csv https://pai-public-data.oss-cn-beijing.aliyuncs.com/pokemon/pokemon. ...
- 【TensorFlow】官方教程—如何快速迁移学习训练自己的模型。How to Retrain an Image Classifier for New Categories
如何训练图像的新的类别分类 How to Retrain an Image Classifier for New Categories [https://www.tensorflow.org/hub/ ...
最新文章
- getServletPath与getRequestURI
- [转]NLog学习笔记二:深入学习
- python 重写__repr__与__str__函数
- listener.ora--sqlnet.ora--tnsnames.ora的关系以及手工配置举例(转载:http://blog.chinaunix.net/uid-83572-id-5510.ht)
- ASP.NET页面之间传值Application(5)
- gdb x命令_gdb基本命令
- 17年数据分析经验告诉你大数据行业的门道
- jenkins X 和k8s CI/CD
- 要把人工智能提速50倍的ARM,却依然坚持做“通用的计算架构”
- ubuntu14.04 LTS Visual Studio Code 编辑器推荐
- PHP——下载图片到本地代码
- 【LeetCode】026. Remove Duplicates from Sorted Array
- javascript 自建立对象
- opmanager邮件告警配置
- uno牌的玩法图解_UNO基本玩法和技巧
- 使用网上成熟的【MySqlBackup】组件,通过WEB网页操作,备份远程计算机中的数据库到C:\inetpub\wwwroot文件夹下,系统汇报错误(访问被拒绝),该如何解决呢?
- 非常规手段免疫U盘病毒(Autorun.inf)
- 【软件群英会】QQ群 12月1日晚上聊天记录
- OBB包围盒及其碰撞检测算法(一)
- ReactiveUI 入门
热门文章
- 【解题报告】Leecode911. 在线选举——Leecode每日一题系列
- 最全的时间类解析。 SimpleDateFormat + Date() 和 DateTimeFormatter + LocalDate()的区别与使用场景
- 【已解决】Exception in thread “Thread-0“ redis.clients.jedis.exceptions.JedisConnectionException: java.n
- HashMap 1.7 死循环过程
- js函数提示 vscode_工欲善其事,必先利其器,VSCode高效插件
- 如何设定vs2012用linux文件格式,Visual Studio 2012发布网站详细步骤
- shell脚本编译规范(编写第一个脚本,脚本变量的作用,类型 ,了解read命令,let命令,环境变量和预定义变量)
- Linux系统弱口令检测和网络端口扫描方法(JR、NMAP)
- 详解DNS正向解析实验(有图有实验)
- list对oracle结果集排序了_详解SQL窗口函数和分组排序函数