该项目中把tf的数据存储和读取抽取出两个函数,方便开发,思想和代码值得借迁

一.存储

def save_variables(save_path, variables=None, sess=None):import joblibsess = sess or get_session()variables = variables or tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)ps = sess.run(variables)save_dict = {v.name: value for v, value in zip(variables, ps)}dirname = os.path.dirname(save_path)if any(dirname):os.makedirs(dirname, exist_ok=True)joblib.dump(save_dict, save_path)
  • 第一次见 or 这样写,意思就是前一个不是None或者0,就取前一个,否则取后一个。
  • tf里,一个session就保存了各种训练的数据和计算图,所依直接把sess传过来,从tf自带的tf.GraphKeys.GLOBAL_VARIABLES取出其中的全局变量名。然后run()一下就能得到参数值,再放入一个字典容器
  • 根据路径存入joblib里面
  • 其中joblib是sklearn中的一个专门用于保存训练的模型的
    不知道的点这里

二.加载

def load_variables(load_path, variables=None, sess=None):import joblibsess = sess or get_session()variables = variables or tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)loaded_params = joblib.load(os.path.expanduser(load_path))restores = []if isinstance(loaded_params, list):assert len(loaded_params) == len(variables), 'number of variables loaded mismatches len(variables)'for d, v in zip(loaded_params, variables):restores.append(v.assign(d))else:for v in variables:restores.append(v.assign(loaded_params[v.name]))sess.run(restores)

跟上面相反,看代码就能明白

  • assign()是tf里的赋值函数,注意tf里的操作写完都要run()才能生效,不然它仅仅是图上的一个结点
  • isinstance()是python里比较两个对象是否相同,它具有继承关系,也就是说如果他的父类相同,也算同一类

强化学习 ---baseline项目之 TensorFlow的训练参数的存储和加载相关推荐

  1. tensorflow 保存训练loss_tensorflow2.0保存和加载模型 (tensorflow2.0官方教程翻译)

    最新版本:https://www.mashangxue123.com/tensorflow/tf2-tutorials-keras-save_and_restore_models.html 英文版本: ...

  2. 【医疗人工智能论文】使用深度强化学习的腹腔镜机器人辅助训练

    Article 作者:Xiaoyu Tan , Chin-Boon Chng, Ye Su, Kah-Bin Lim, and Chee-Kong Chui 文献题目:Robot-Assisted T ...

  3. ROS开发笔记(10)——ROS 深度强化学习dqn应用之tensorflow版本(double dqn/dueling dqn/prioritized replay dqn)

    ROS开发笔记(10)--ROS 深度强化学习dqn应用之tensorflow版本(double dqn/dueling dqn/prioritized replay dqn) 在ROS开发笔记(9) ...

  4. RL之SARSA:利用强化学习之SARSA实现走迷宫—训练智能体走到迷宫(复杂陷阱迷宫)的宝藏位置

    RL之SARSA:利用强化学习之SARSA实现走迷宫-训练智能体走到迷宫(复杂陷阱迷宫)的宝藏位置 目录 输出结果 设计思路 实现代码 测试记录全过程 输出结果 设计思路 实现代码 后期更新-- 测试 ...

  5. 深度强化学习落地方法论(7)——训练篇

    目录 训练开始前 环境可视化 数据预处理 训练进行中 拥抱不确定性 DRL通用超参数 折扣因子 作用原理 选取方法 Frame Skipping 网络结构 网络类型 网络深度 DRL特色超参数 DQN ...

  6. 深度强化学习入门:用TensorFlow构建你的第一个游戏AI

    本文通过一种简单的 Catch 游戏介绍了深度强化学习的基本原理,并给出了完整的以 Keras 为前端的 TensorFlow 代码实现,是入门深度强化学习的不错选择. GitHub 链接:https ...

  7. MATLAB强化学习实战(七) 在Simulink中训练DDPG控制倒立摆系统

    在Simulink中训练DDPG控制倒立摆系统 倒立摆的Simscape模型 创建环境接口 创建DDPG智能体 训练智能体 DDPG智能体仿真 此示例显示了如何训练深度确定性策略梯度(DDPG)智能体 ...

  8. 无人机+强化学习开源项目、工具包汇总(二)

    1.IEEE无人机竞赛2022 https://github.com/engcang/ieee_uav_2022 相关论文: E. Lee.D. Lee.H. Lim.S. Song 和 H. Myu ...

  9. 强化学习入门项目 Spinning up OpenAI (1) installation

    Spinning up是openAI的一个入门RL学习项目,涵盖了从基础概念到各个baseline算法. 在此记录一下学习过程. Spining Up 需要python3, OpenAI Gym,和O ...

最新文章

  1. 干货|了解机器学习常用数据预处理
  2. linux shell mysql备份_linux shell 备份mysql 数据库
  3. JMeter基础之—录制脚本
  4. context:component-scan/和mvc:annotation-driven/的区别
  5. 20.if条件语句.rs
  6. 如何查看一个现有的keil工程之前由什么版本的keil IDE编译
  7. 【转】2:C#TPL探秘
  8. Windows监听进程是否退出C++
  9. css行内元素和块级元素
  10. android drawable资源调用使用心得
  11. 解决ORA-00054资源正忙的问题
  12. 编程语言的动态性(Dart和OC对比)
  13. rslinx连接linux教程,RSLinx Classic软件通讯配置教程
  14. 打开cmd 的方式和常用的cmd快捷键
  15. 【历史上的今天】10 月 17 日:微软发布 Windows 8.1;IMDb 成立;海盗湾创始人诞生
  16. 三步必杀(P4231)
  17. Navicat for mysql 在WIN10下导入SQL不成功解决办法
  18. 桌面计算器The C++ Programming Language程序解析
  19. 外贸老手告诉你:外贸实用工具
  20. 基于Javaweb的小项目(类似于qqzone)1——设计数据库

热门文章

  1. 连接SQL Server数据库
  2. python小项目之头像右上角加数字
  3. jquery操作checkbox最佳方法
  4. VS 2008 和 .NET 3.5 Beta 2 发布了
  5. 高并发编程知识体系阅读总结
  6. linux服务sendmail邮件服务
  7. Visual Studio 2015 初体验
  8. 行内元素垂直方向位置调整的一些感悟和困惑
  9. mybatis之xml中日期时间段查询的sql语句
  10. 正常使用 flex profiler 解决 Socket timeout