tflearn 在每一个epoch完毕保存模型
关键代码:tflearn.DNN(net, checkpoint_path='model_resnet_cifar10',max_checkpoints=10, tensorboard_verbose=0,clip_gradients=0.)
snapshot_epoch=True, # Snapshot (save & evaluate) model every epoch.我的demo:
def get_model(width, height, classes=40):# TODO, modify modelnetwork = input_data(shape=[None, width, height, 3]) # if RGB, 224,224,3# Residual blocks # 32 layers: n=5, 56 layers: n=9, 110 layers: n=18 n = 2net = tflearn.conv_2d(network, 16, 3, regularizer='L2', weight_decay=0.0001)net = tflearn.residual_block(net, n, 16)net = tflearn.residual_block(net, 1, 32, downsample=True)net = tflearn.residual_block(net, n-1, 32)net = tflearn.residual_block(net, 1, 64, downsample=True)net = tflearn.residual_block(net, n-1, 64)net = tflearn.batch_normalization(net)net = tflearn.activation(net, 'relu')net = tflearn.global_avg_pool(net)# Regression net = tflearn.fully_connected(net, classes, activation='softmax')#mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=32000, staircase=True)mom = tflearn.Momentum(0.01, lr_decay=0.1, decay_step=2000, staircase=True)net = tflearn.regression(net, optimizer=mom,loss='categorical_crossentropy')# Training model = tflearn.DNN(net, checkpoint_path='model_resnet_cifar10',max_checkpoints=10, tensorboard_verbose=0,clip_gradients=0.)return modeldef main():trainX, trainY = image_preloader("data/train", image_shape=(width, height, 3), mode='folder', categorical_labels=True, normalize=True)testX, testY = image_preloader("data/test", image_shape=(width, height, 3), mode='folder', categorical_labels=True, normalize=True)#trainX = trainX.reshape([-1, width, height, 1])#testX = testX.reshape([-1, width, height, 1])print("sample data:")print(trainX[0])print(trainY[0])print(testX[-1])print(testY[-1])model = get_model(width, height, classes=3755)filename = 'tflearn_resnet/model.tflearn'# try to load model and resume trainingtry:#model.load(filename)model.load("model_resnet_cifar10-195804")print("Model loaded OK. Resume training!")except:passearly_stopping_cb = EarlyStoppingCallback(val_acc_thresh=0.94)try: model.fit(trainX, trainY, validation_set=(testX, testY), n_epoch=500, shuffle=True,snapshot_epoch=True, # Snapshot (save & evaluate) model every epoch.show_metric=True, batch_size=1024, callbacks=early_stopping_cb, run_id='cnn_handwrite')except StopIteration as e:print("OK, stop iterate!Good!")model.save(filename)del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:]filename = 'tflearn_resnet/model-infer.tflearn'model.save(filename)
转载于:https://www.cnblogs.com/bonelee/p/9006243.html
tflearn 在每一个epoch完毕保存模型相关推荐
- 在PyTorch训练一个epoch时,模型不能接着训练,Dataloader卡死
笔者在训练模型的时候,突然偶遇这个问题,即训练一个epoch时,模型不能接着训练,只能通过Ctrl+C强制性暂停,见下图: Ctrl+C之后呈现的信息表明,这个bug是和多线程有关系. 经过笔者实验, ...
- python模型保存save_浅谈keras保存模型中的save()和save_weights()区别
今天做了一个关于keras保存模型的实验,希望有助于大家了解keras保存模型的区别. 我们知道keras的模型一般保存为后缀名为h5的文件,比如final_model.h5.同样是h5文件用save ...
- Pytorch学习 - 保存模型和重新加载
Pytorch学习 - 保存和加载模型 1. 3个函数 2. 模型不同后缀名的区别 3. 保存和重载模型 参考资料: Pytorch官方文档链接 某博客 1. 3个函数 torch.save() : ...
- pytorch保存模型pth_Day159:模型的保存与加载
网络结构和参数可以分开的保存和加载,因此,pytorch保存模型有两种方法: 保存 整个模型 (结构+参数) 只保存模型参数(官方推荐) # 保存整个网络torch.save(model, check ...
- Tensorflow保存模型详解(进阶版二):如何保存最近的.ckpt文件 及 如何分开保存.ckpt数据文件和.meta图文件
在学会了如何有选择的保存变量后,我们来学习如何如何分开保存.ckpt数据文件和.meta图文件 和 如何 保存最近几轮的.ckpt数据文件. 直接上代码: import tensorflow as t ...
- Tensorflow详解保存模型(基础版)
我们都知道tensorflow最后生成的模型文件含: checkpoint xxxxx.meta xxxxx.ckpt.data-xxx xxxxx.index 学习和使用tensorflow的小伙伴 ...
- tensorflow保存模型和加载模型的方法(Python和Android)
tensorflow保存模型和加载模型的方法(Python和Android) 一.tensorflow保存模型的几种方法: (1) tf.train.saver()保存模型 使用 tf.train.s ...
- pytorch保存模型pth_Pytorch_trick_04
科技猛兽:PyTorch 50.Pytorch模型保存与加载,并在加载的模型基础上继续训练zhuanlan.zhihu.com 1.Pytorch 模型保存与加载,并在加载的模型基础上继续训练 只保 ...
- keras如何保存模型
使用model.save(filepath)将Keras模型和权重保存在一个HDF5文件中,该文件将包含: 模型的结构,以便重构该模型 模型的权重 训练配置(损失函数,优化器等) 优化器的状态,以便于 ...
最新文章
- 使用TENSORRT和NVIDIA-DOCKER部署深部神经网络
- .NET2.0隐形的翅膀,正则表达式搜魂者【月儿原创】
- Python学习笔记(2)-Python执行方式、变量
- 【虚拟机】苹果虚拟机mac10.11.6+Xcode8.1
- Python基于nginx访问日志并统计IP访问量
- 59.排序好的大数据创建索引文件,并实现大文件的二分查找,根据索引百万数据秒读数据...
- linux中命令对c文件进行编译,Linux下C语言编译基础及makefile的编写
- HTTPBrowserCapabilities---在asp.net中显示浏览器属性
- java双机和集群的区别,java 分布式与集群的区别和联系
- Word两端对齐问题
- ML面试1000题系列(91-100)
- 离散数学 习题篇 —— 谓词公式练习
- oracle 亿级数据迁移,Oracle12c迁移-某风险报告类系统升级暨迁移至12c-3
- python读取udp数据包内容_python – 解析UDP数据包
- echarts社区地图、echart地图
- eclipse 搭建ARM开发环境
- WiFi慢不一定是信号不好,这几招让你上网更顺畅
- 联发科MT6893怎么样 联发科MT6893参数配置
- 25岁阿里120W年薪架构师推荐学习的750页微服务架构深度解析文档
- 【点宽专栏】破解波动性突破实盘系统
热门文章
- JAVA多线程中join()方法的详细分析
- java需要前台封装对象吗_javaEE之-----------类反射直接封装前台传过来的参数
- java word表格_Java 添加Word表格行或列
- gta5结局杀老崔我哭了_都已经2020年了,怎么还有人在买GTA5?
- 装鸡蛋的鞋子java代码_Java实现 LeetCode 887 鸡蛋掉落(动态规划,谷歌面试题,蓝桥杯真题)...
- java callback类_利用java8新特性实现类似javascript callback特性
- python阶乘匿名函数_python的高阶函数与匿名函数
- html点击旋转180,关于点击三角丝滑旋转180度css3 jq处理方法
- mysql服务实例配置_MySQL多实例配置(一)
- Android的自定义键盘颜色,android自定义键盘(解决弹出提示的字体颜色问题)