MAML-RL Pytorch 代码解读 (6) – maml_rl/envs/bandit.py

文章目录

  • MAML-RL Pytorch 代码解读 (6) -- maml_rl/envs/bandit.py
    • 基本介绍
      • 源码链接
      • 文件路径
    • `import` 包
    • `BernoulliBanditEnv()` 类
    • `GaussianBanditEnv()` 类

基本介绍

在网上看到的元学习 MAML 的代码大多是跟图像相关的,强化学习这边的代码比较少。

因为自己的思路跟 MAML-RL 相关,所以打算读一些源码。

MAML 的原始代码是基于 tensorflow 的,在 Github 上找到了基于 Pytorch 源码包,学习这个包。

源码链接

https://github.com/dragen1860/MAML-Pytorch-RL

文件路径

./maml_rl/envs/bandit.py

import

import numpy as npimport gym
from gym import spaces
from gym.utils import seeding

BernoulliBanditEnv()

阅读这部分代码不太明白 self._tasktask 具体指带的是什么信息。

class BernoulliBanditEnv(gym.Env):#### 这个类的环境是伯努力分布下的多臂赌博机问题。智能体推动其中 “K” 个手臂,说了 “i” 信息,收到从伯努力分布中采样得到的具有 “pi” 信息的奖励,从[0,1]均匀分布中采样得到的pi被保证。"""Multi-armed bandit problems with Bernoulli observations, as describedin [1].At each time step, the agent pulls one of the `k` possible arms (actions), say `i`, and receives a reward sampled from a Bernoulli distribution with parameter `p_i`. The multi-armed bandit tasks are generated by sampling the parameters `p_i` from the uniform distribution on [0, 1].[1] Yan Duan, John Schulman, Xi Chen, Peter L. Bartlett, Ilya Sutskever,Pieter Abbeel, "RL2: Fast Reinforcement Learning via Slow ReinforcementLearning", 2016 (https://arxiv.org/abs/1611.02779)"""#### 初始类信息。当前环境继承gym.Env()类,k表示维度;那么action_space就是离散的k维度动作空间;observation_space中的状态信息就是一个数据,但是不明白这里为何high=0而不是1;_task接受task字典;_means应该就是将task字典里面的'mean'用元素全部是0.5的K维度向量赋予而成;最后设置随机数。def __init__(self, k, task={}):super(BernoulliBanditEnv, self).__init__()self.k = kself.action_space = spaces.Discrete(self.k)self.observation_space = spaces.Box(low=0, high=0,shape=(1,), dtype=np.float32)self._task = taskself._means = task.get('mean', np.full((k,), 0.5, dtype=np.float32))self.seed()#### 跳转到源码上是numpy的,意思就是获得一组随机数。第一个self.np_random是一个与随机数相关的实例,seed是随机数种子。def seed(self, seed=None):self.np_random, seed = seeding.np_random(seed)return [seed]#### 因为元学习需要采用很多的任务,因此means得到的是num_tasks(任务数)行和self.k列的随机数矩阵作为均值,tasks对这个矩阵按行拆分成一个个字典,字典的键是"mean"值是每一行作为一个array数组。相当于初始化每个任务的均值是随机数array。def sample_tasks(self, num_tasks):means = self.np_random.rand(num_tasks, self.k)tasks = [{'mean': mean} for mean in means]return tasks#### 因为元学习有很多的任务,需要重置任务实现切换。用_task形参task,意味着在程序运行中task的初始化可能更原本的初始化不一样。self._means用task字典的”mean“来替代。def reset_task(self, task):self._task = taskself._means = task['mean']#### 这个reset实现每个任务的重置,重置成初始状态。def reset(self):return np.zeros(1, dtype=np.float32)#### assert关键字的作用是检查输出的动作信息在不在动作空间内。mean的意思应该是从action中获得这个赌博机的期望?这里没搞懂。reward的意思是从概率是mean的一重二项分布中抽取奖励信息。观测信息就是一个元素是0的array。def step(self, action):assert self.action_space.contains(action)mean = self._means[action]reward = self.np_random.binomial(1, mean)observation = np.zeros(1, dtype=np.float32)return observation, reward, True, self._task

GaussianBanditEnv()

class GaussianBanditEnv(gym.Env):#### 这个类的环境是正态分布下的多臂赌博机问题。智能体推动其中 “K” 个手臂,说了 “i” 信息,收到从正态分布中采样得到的均值是 “pi” 且标准差是 ”std“ 且在任务中固定的信息的奖励,从[0,1]均匀分布中采样得到的pi被保证。"""Multi-armed problems with Gaussian observations.At each time step, the agent pulls one of the `k` possible arms (actions),say `i`, and receives a reward sampled from a Normal distribution withmean `p_i` and standard deviation `std` (fixed across all tasks). Themulti-armed bandit tasks are generated by sampling the parameters `p_i`from the uniform distribution on [0, 1]."""#### 初始类信息。当前环境继承gym.Env()类,k表示维度;std表示标准差;那么action_space就是离散的k维度动作空间;observation_space中的状态信息就是一个数据,但是不明白这里为何high=0而不是1;_task接受task字典;_means应该就是将task字典里面的'mean'用元素全部是0.5的K维度向量赋予而成;最后设置随机数。def __init__(self, k, std=1.0, task={}):super(GaussianBanditEnv, self).__init__()self.k = kself.std = stdself.action_space = spaces.Discrete(self.k)self.observation_space = spaces.Box(low=0, high=0,shape=(1,), dtype=np.float32)self._task = taskself._means = task.get('mean', np.full((k,), 0.5, dtype=np.float32))self.seed()#### 跳转到源码上是numpy的,意思就是获得一组随机数。第一个self.np_random是一个与随机数相关的实例,seed是随机数种子。def seed(self, seed=None):self.np_random, seed = seeding.np_random(seed)return [seed]#### 因为元学习需要采用很多的任务,因此means得到的是num_tasks(任务数)行和self.k列的随机数矩阵作为均值,tasks对这个矩阵按行拆分成一个个字典,字典的键是"mean"值是每一行作为一个array数组。相当于初始化每个任务的均值是随机数array。def sample_tasks(self, num_tasks):means = self.np_random.rand(num_tasks, self.k)tasks = [{'mean': mean} for mean in means]return tasks#### 因为元学习有很多的任务,需要重置任务实现切换。用_task形参task,意味着在程序运行中task的初始化可能更原本的初始化不一样。self._means用task字典的”mean“来替代。def reset_task(self, task):self._task = taskself._means = task['mean']#### 这个reset实现每个任务的重置,重置成初始状态。def reset(self):return np.zeros(1, dtype=np.float32)#### assert关键字的作用是检查输出的动作信息在不在动作空间内。mean的意思应该是从action中获得这个赌博机的期望?这里没搞懂。reward的意思是从概率是mean的一重二项分布中抽取奖励信息。观测信息就是一个元素是0的array。def step(self, action):assert self.action_space.contains(action)mean = self._means[action]reward = self.np_random.normal(mean, self.std)observation = np.zeros(1, dtype=np.float32)return observation, reward, True, self._task

MAML-RL Pytorch 代码解读 (6) -- maml_rl/envs/bandit.py相关推荐

  1. Faceboxes pytorch代码解读(一) box_utils.py(上篇)

    Faceboxes pytorch代码解读(一) box_utils.py(上篇) 有幸读到Shifeng Zhang老师团队的人脸检测论文,感觉对自己的人脸学习论文十分有帮助.通过看别人的paper ...

  2. TSN算法的PyTorch代码解读(训练部分)

    这篇博客来读一读TSN算法的PyTorch代码,总体而言代码风格还是不错的,多读读优秀的代码对自身的提升还是有帮助的,另外因为代码内容较多,所以分训练和测试两篇介绍,这篇介绍训练代码,介绍顺序为代码运 ...

  3. 对抗自编码器AAE——pytorch代码解读试验

    AAE网络结构基本框架如论文中所示: 闲话不多说,直接来学习一下加了注释和微调的基本AAE的代码(初始代码链接github): aae_pytorch_basic.py #!/usr/bin/env ...

  4. Resnet的pytorch官方实现代码解读

    Resnet的pytorch官方实现代码解读 目录 Resnet的pytorch官方实现代码解读 前言 概述 34层网络结构的"平原"网络与"残差"网络的结构图 ...

  5. mapbox 修改初始位置_一行代码教你如何随心所欲初始化Bert参数(附Pytorch代码详细解读)...

    微信公众号:NLP从入门到放弃 微信文章在这里(排版更漂亮,但是内置链接不太行,看大家喜欢哪个点哪个看吧): 一行代码带你随心所欲重新初始化bert的参数(附Pytorch代码详细解读)​mp.wei ...

  6. ResNet及其变种的结构梳理、有效性分析与代码解读(PyTorch)

    点击我爱计算机视觉标星,更快获取CVML新技术 本文来自知乎,作者费敬敬,现为同济大学计算机科学与技术硕士. https://zhuanlan.zhihu.com/p/54289848 温故而知新,理 ...

  7. 说话人识别损失函数的PyTorch实现与代码解读

    概述 说话人识别中的损失函数分为基于多类别分类的损失函数,和端到端的损失函数(也叫基于度量学习的损失函数),关于这些损失函数的理论部分,可参考说话人识别中的损失函数 本文主要关注这些损失函数的实现,此 ...

  8. 实验并解读github上三个DeepDream的Pytorch代码

    实验并解读github上三个DeepDream的Pytorch代码 今天在学习DeepDream的有关内容,关于论文的翻译已经在启发主义--深入神经网络(Inceptionism: Going Dee ...

  9. python实现胶囊网络_Capsule Network胶囊网络解读与pytorch代码实现

    本文是论文<Dynamic Routing between Capsules>的论文解读与pytorch代码实现. 如需转载本文或代码请联系作者 @Riroaki 并声明. 众所周知,卷积 ...

最新文章

  1. 液晶模块 LM6063A接口转接
  2. php中文乱码问号,如何解决PHP中文乱码问题?
  3. 12、索引在什么情况下不会被使用?
  4. 系统仿真基础与计算机实现,计算机综合仿真实验系统的研究与开发
  5. Aspx 页面生命周期
  6. oracle 还原dmp时_报错的值太大,基于oracle数据库的CLOUD备份恢复测试
  7. mysql外部排序_深入浅出MySQL优先队列(你一定会踩到的order by limit 问题)
  8. 复制的python代码格式错误_新手常见6种的python报错及解决方法
  9. NYOJ题目1170-最大的数
  10. 各省生活资料PPI数据(2009-2018年)
  11. 什么是elastic-job(持续更新)
  12. ps批量修改名片文字_pS如何在图中添加和修改文字
  13. 教育技术学就业方向_教育技术学专业就业方向与就业前景
  14. speedoffice文档如何在方框内打钩
  15. 美团、飞猪基础架构组实习经历分享
  16. jsp四大作用域和九大内置对象
  17. AttributeError: ‘FigureCanvasTkAgg‘ object has no attribute ‘set_window_title‘
  18. MySQL错误:ERROR 1819 (HY000): Your password does not satisfy the current policy requirements
  19. 圆形头像、图片显示效果
  20. 愚人节看到的两则IT界的玩笑

热门文章

  1. java使用jacob将word,excel,ppt转成html
  2. 模型小常识,C4D扫描的使用
  3. 学习Python后能找什么工作
  4. trove mysql 镜像_centos7下手动制作trove镜像
  5. eclipse 创建maven项目 出现Could not calculate build plan错误解决
  6. CorelDRAWX4的VBA插件开发(三十一)使用C++制作动态连接库DLL辅助VBA构键强大功能-(5)在VBA中动态调用DLL文件
  7. 解决Apache/2.4.39 (Win64) PHP/7.2.18 Server at localhost Port 80问题
  8. Marvin java图像处理
  9. 新手小白也能会的从淘宝口令到下载完淘宝直播回放视频的步骤详情
  10. centos7 shell脚本开机自启动(亲测可用)