关键代码:tflearn.DNN(net, checkpoint_path='model_resnet_cifar10',max_checkpoints=10, tensorboard_verbose=0,clip_gradients=0.)
snapshot_epoch=True, # Snapshot (save & evaluate) model every epoch.我的demo:
def get_model(width, height, classes=40):# TODO, modify modelnetwork = input_data(shape=[None, width, height, 3])  # if RGB, 224,224,3# Residual blocks  # 32 layers: n=5, 56 layers: n=9, 110 layers: n=18  n = 2net = tflearn.conv_2d(network, 16, 3, regularizer='L2', weight_decay=0.0001)net = tflearn.residual_block(net, n, 16)net = tflearn.residual_block(net, 1, 32, downsample=True)net = tflearn.residual_block(net, n-1, 32)net = tflearn.residual_block(net, 1, 64, downsample=True)net = tflearn.residual_block(net, n-1, 64)net = tflearn.batch_normalization(net)net = tflearn.activation(net, 'relu')net = tflearn.global_avg_pool(net)# Regression  net = tflearn.fully_connected(net, classes, activation='softmax')#mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=32000, staircase=True)mom = tflearn.Momentum(0.01, lr_decay=0.1, decay_step=2000, staircase=True)net = tflearn.regression(net, optimizer=mom,loss='categorical_crossentropy')# Training  model = tflearn.DNN(net, checkpoint_path='model_resnet_cifar10',max_checkpoints=10, tensorboard_verbose=0,clip_gradients=0.)return modeldef  main():trainX, trainY = image_preloader("data/train", image_shape=(width, height, 3), mode='folder', categorical_labels=True, normalize=True)testX, testY = image_preloader("data/test", image_shape=(width, height, 3), mode='folder', categorical_labels=True, normalize=True)#trainX = trainX.reshape([-1, width, height, 1])#testX = testX.reshape([-1, width, height, 1])print("sample data:")print(trainX[0])print(trainY[0])print(testX[-1])print(testY[-1])model = get_model(width, height, classes=3755)filename = 'tflearn_resnet/model.tflearn'# try to load model and resume trainingtry:#model.load(filename)model.load("model_resnet_cifar10-195804")print("Model loaded OK. Resume training!")except:passearly_stopping_cb = EarlyStoppingCallback(val_acc_thresh=0.94)try:      model.fit(trainX, trainY, validation_set=(testX, testY), n_epoch=500, shuffle=True,snapshot_epoch=True, # Snapshot (save & evaluate) model every epoch.show_metric=True, batch_size=1024, callbacks=early_stopping_cb, run_id='cnn_handwrite')except StopIteration as e:print("OK, stop iterate!Good!")model.save(filename)del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:]filename = 'tflearn_resnet/model-infer.tflearn'model.save(filename)

转载于:https://www.cnblogs.com/bonelee/p/9006243.html

tflearn 在每一个epoch完毕保存模型相关推荐

  1. 在PyTorch训练一个epoch时,模型不能接着训练,Dataloader卡死

    笔者在训练模型的时候,突然偶遇这个问题,即训练一个epoch时,模型不能接着训练,只能通过Ctrl+C强制性暂停,见下图: Ctrl+C之后呈现的信息表明,这个bug是和多线程有关系. 经过笔者实验, ...

  2. python模型保存save_浅谈keras保存模型中的save()和save_weights()区别

    今天做了一个关于keras保存模型的实验,希望有助于大家了解keras保存模型的区别. 我们知道keras的模型一般保存为后缀名为h5的文件,比如final_model.h5.同样是h5文件用save ...

  3. Pytorch学习 - 保存模型和重新加载

    Pytorch学习 - 保存和加载模型 1. 3个函数 2. 模型不同后缀名的区别 3. 保存和重载模型 参考资料: Pytorch官方文档链接 某博客 1. 3个函数 torch.save() : ...

  4. pytorch保存模型pth_Day159:模型的保存与加载

    网络结构和参数可以分开的保存和加载,因此,pytorch保存模型有两种方法: 保存 整个模型 (结构+参数) 只保存模型参数(官方推荐) # 保存整个网络torch.save(model, check ...

  5. Tensorflow保存模型详解(进阶版二):如何保存最近的.ckpt文件 及 如何分开保存.ckpt数据文件和.meta图文件

    在学会了如何有选择的保存变量后,我们来学习如何如何分开保存.ckpt数据文件和.meta图文件 和 如何 保存最近几轮的.ckpt数据文件. 直接上代码: import tensorflow as t ...

  6. Tensorflow详解保存模型(基础版)

    我们都知道tensorflow最后生成的模型文件含: checkpoint xxxxx.meta xxxxx.ckpt.data-xxx xxxxx.index 学习和使用tensorflow的小伙伴 ...

  7. tensorflow保存模型和加载模型的方法(Python和Android)

    tensorflow保存模型和加载模型的方法(Python和Android) 一.tensorflow保存模型的几种方法: (1) tf.train.saver()保存模型 使用 tf.train.s ...

  8. pytorch保存模型pth_Pytorch_trick_04

    科技猛兽:PyTorch 50.Pytorch模型保存与加载,并在加载的模型基础上继续训练​zhuanlan.zhihu.com 1.Pytorch 模型保存与加载,并在加载的模型基础上继续训练 只保 ...

  9. keras如何保存模型

    使用model.save(filepath)将Keras模型和权重保存在一个HDF5文件中,该文件将包含: 模型的结构,以便重构该模型 模型的权重 训练配置(损失函数,优化器等) 优化器的状态,以便于 ...

最新文章

  1. 使用TENSORRT和NVIDIA-DOCKER部署深部神经网络
  2. .NET2.0隐形的翅膀,正则表达式搜魂者【月儿原创】
  3. Python学习笔记(2)-Python执行方式、变量
  4. 【虚拟机】苹果虚拟机mac10.11.6+Xcode8.1
  5. Python基于nginx访问日志并统计IP访问量
  6. 59.排序好的大数据创建索引文件,并实现大文件的二分查找,根据索引百万数据秒读数据...
  7. linux中命令对c文件进行编译,Linux下C语言编译基础及makefile的编写
  8. HTTPBrowserCapabilities---在asp.net中显示浏览器属性
  9. java双机和集群的区别,java 分布式与集群的区别和联系
  10. Word两端对齐问题
  11. ML面试1000题系列(91-100)
  12. 离散数学 习题篇 —— 谓词公式练习
  13. oracle 亿级数据迁移,Oracle12c迁移-某风险报告类系统升级暨迁移至12c-3
  14. python读取udp数据包内容_python – 解析UDP数据包
  15. echarts社区地图、echart地图
  16. eclipse 搭建ARM开发环境
  17. WiFi慢不一定是信号不好,这几招让你上网更顺畅
  18. 联发科MT6893怎么样 联发科MT6893参数配置
  19. 25岁阿里120W年薪架构师推荐学习的750页微服务架构深度解析文档
  20. 【点宽专栏】破解波动性突破实盘系统

热门文章

  1. JAVA多线程中join()方法的详细分析
  2. java需要前台封装对象吗_javaEE之-----------类反射直接封装前台传过来的参数
  3. java word表格_Java 添加Word表格行或列
  4. gta5结局杀老崔我哭了_都已经2020年了,怎么还有人在买GTA5?
  5. 装鸡蛋的鞋子java代码_Java实现 LeetCode 887 鸡蛋掉落(动态规划,谷歌面试题,蓝桥杯真题)...
  6. java callback类_利用java8新特性实现类似javascript callback特性
  7. python阶乘匿名函数_python的高阶函数与匿名函数
  8. html点击旋转180,关于点击三角丝滑旋转180度css3 jq处理方法
  9. mysql服务实例配置_MySQL多实例配置(一)
  10. Android的自定义键盘颜色,android自定义键盘(解决弹出提示的字体颜色问题)