torchnet package (2)

torchnet
torch7

Dataset Iterators

尽管是用for loop语句很容易处理Dataset,但有时希望以on-the-fly manner或者在线程中读取数据,这时候Dataset Iterator就是个好的选择
注意,iterators是用于特殊情况的,一般情况下还是使用Dataset比较好
Iteartor 的两个主要方法:
* run() 返回一个Lua 迭代器,也可以使用()操作符,因为iterator源码中定义了__call事件
* exec(funcname,...) 在指定的dataset上执行funcname方法,funcname是dataset自己的方法,比如size

  • tnt.DatasetIterator(self,dataset[,perm][,filter][,transform])
    The default dataset iterator
    perm(idx), 实现shuffle功能,即对idx进行变换,更复杂的变换可以使用ShuffleDataset
    filter(sample), 闭包函数,筛选样本是否用于迭代,返回bool值
    transform(sample),闭包函数,实现对样本的变换,更复杂的变换可以结合TransformDataset和transform.compose等实现

  1. ldata = tnt.ListData{list=torch.range(1,10):long(),load = function(x) return {x,x+1} end}


  2. dIter = tnt.DatasetIterator{dataset = ldata,filter = function(x) if x[1]<2 then return false else return true end end} 

  3. for v in dIter:run() 

  4. print(v) 

  5. end 

  • tnt.ParallelDatasetIterator(self[,init],closure,nthread[,perm][,filter][,transform][,ordered])
    这个才是迭代器的重点,用于以多线程方式迭代数据。

The purpose of this class is to have a zero pre-processing cose. when reading datasets on the fly from disk(not loading thenm fully in memory), or performing complex pre-processing this canbe of interest.

nthreads 指定了线程的个数
init(threadid) 闭包函数,指定了线程threadid的初始化工作,如果啥都不做可以省略
closure(threadid) 每个线程的job,返回的必须时tnt.Dataset的一个实例
perm(idx) 用于shuffle
filter(sample) 闭包函数,指定哪些样本不用于迭代
transform(sample) 对样本进行变换,在filter之前执行
order 线程之间数据的处理是否有序,主要是为了程序的可重现性,当order=true时,多次执行程序,顺序是相同的

  1. tnt=require'torchnet'


  2. local list=torch.Tensor{{2,2},{2,2},{2,2},{2,2}}:long() 

  3. ldata = tnt.ListDataset{list=list,load=function(x) return torch.Tensor(x[1],x[2]) end} 

  4. local bdata = tnt.BatchDataset{batchsize=2,dataset = tnt.TransformDataset{dataset = ldata,transform=function(x) return 2*x end}} 

  5. Padata = tnt.ParallelDatasetIterator{ 

  6. nthread = 4, 

  7. init = function(tid) 

  8. print ('init thread id: '.. tid) 

  9. tnt=require'torchnet' 

  10. end, 

  11. closure = function(tid) 

  12. print('closure of threadid: '.. tid) 

  13. return bdata 

  14. end 

  15. }  

尤其需要注意的是,closure中的所有upvalues都必须是可序列化的,最好是避免使用upvalues,并保证closure中使用的package都在init中require

tnt.Engine

在网络训练的过程中,都是计算前向误差,误差反传,更新权重这些过程,只是模型,数据和评价函数不同而已,所以Engine给训练过程提供了一个模板,该模板建立了model,DatasetIterator,Criterion和Meter之间的联系

engine=tnt.Engine()包含两个主要方法
* engine:train() 在数据集上训练数据
* engine:test() 评估模型,可选
Engine不仅实现了训练和评估的一般模板,还提供了许多接口,用于控制训练过程

  • tnt.SGDEngine
    SGDEngine 模块在train过程中使用Stochastic Gradient Descent方法训练,模块包含数据采样,前向传递,反向传递,参数更新等,还有一些钩子函数
    hooks = {
    ['onStart'] = function() end, --用于训练开始前的设置和初始化
    ['onStartEpoch'] = function() end, -- 每一个epoch前的操作
    ['onSample'] = function() end, -- 每次采样一个样本之后的操作
    ['onForward'] = function() end, -- 在model:forward()之后的操作
    ['onForwardCriterion'] = function() end, -- 前向计算损失函数之后的操作
    ['onBackwardCriterion'] = function() end, -- 反向计算损失误差之后的操作
    ['onBackward'] = function() end, -- 反向传递误差之后的操作
    ['onUpdate'] = function() end, -- 权重参数更新之后的操作
    ['onEndEpoch'] = function() end, -- 每一个epoch结束时的操作
    ['onEnd'] = function() end, -- 整个训练过程结束后的收拾现场
    }
    可以发现Engine给的hook函数还是很全面的,几乎训练过程的每一个节点都允许用户制定操作,使用hook函数

  1. local engine = SGDEngine()


  2. local meter = tnt.AverageValueMeter() 

  3. engine.hooks.onStartEpoch = function(state) meter:reset() end 

一般而言,训练过程最少应该知道训练模型,损失函数,数据和学习率,这里学习方法已经知道了SGD,Engine用到的数据是tnt.DatasetIterator类型的。 评估过程只需要数据和模型就可以了

外部可以通过state变量与Engine训练过程交互
state = {
['network'] = network, --设置了model
['criterion'] = criterion, -- 设置损失函数
['iterator'] = iterator, -- 数据迭代器
['lr'] = lr, -- 学习率
['lrcriterion'] = lrcriterion, --
['maxepoch'] = maxepoch, --最大epoch数
['sample'] = {}, -- 当前采集的样本,可以在onSample中通过该阈值查看采样样本
['epoch'] = 0 , -- 当前的epoch
['t'] = 0, -- 已经训练样本的个数
['training'] = true -- 训练过程
}

评估时需要指定:
state = {
['netwrok'] = network
['iterator'] = iterator
['criterion'] = criterion
}

  • tnt.OptimEngine
    这个方法和SGDEngine的最大的区别在于封装了optim中的多种优化方法。在训练开始的时候,engine会通过getParameters获取model的参数
    train需要附加两个量:

    • optimMethod 优化方法,比如optim.sgd

    • config 优化方法对应的参数
      Example:

  1. local engine = tnt.OptimEngine{


  2. network = network, 

  3. criterion=criterion, 

  4. iterator = iterator, 

  5. optimMethod = optim.sgd, 

  6. config = { 

  7. learningRate = 0.1, 

  8. momentum = 0.9, 

  9. }, 



tnt.Meter

和Engine配合使用,用于measure the model.
几乎所有的meters都会有3个方法:
* add() 给待统计的meter添加一个观测值,其输入参数一般形式为(output,value),output为model的输出,target为真实值
* value() 获得待统计的meter的当前值
* reset() 重新计数
Meter的使用示例:

  1. local meter = tnt.<Measure>Meter() -- <Measure> 可以选择具体的度量


  2. for state,event in tnt.<Optimization>Engine:train{ --定义Engine 

  3. network = network, 

  4. criterion=criterion, 

  5. iterator=iterator, 

  6. } do 

  7. if state == 'start-epoch' then  

  8. meter:reset() -- reset meter 

  9. elseif state == 'forward-criterion' then 

  10. meter:add(state.network.output,sample.target) 

  11. elseif state == 'end-epoch' then 

  12. print('value of meter:) .. meter:value()) 

  13. end 

  14. end 

  • tnt.APMeter(self)
    评估每一类的平均正确率
    APMeter的操作对象是一个的Tensor,表示N个样本对应在K类中的值,另外可选的一个的 Tensor表示每个样本的权重

  1. target = torch.Tensor{


  2. {0,0,0,1},{0,0,1,0},{0,1,0,0},{1,0,0,0},{1,0,0,0}} 

  3. apm = tnt.APMeter() 

  4. for i=1,5 do 

  5. apm:add{output=torch.rand(1,4),target=target[i]:size(1,4)} -- 注意N*K的Tensor 

  6. end 

  7. print(apm:value()) 

  • tnt.AverageValueMeter(self)
    用于统计任意添加的变量的方差和均值,可以用来测量平均损失等
    add()的输入必须时number类型,另外在add的时候可以有一个可选的参数n,表示对应值的权重

  1. avm = tnt.AverageValueMeter()


  2. for i=1,10 do  

  3. avm:add(i,10-i) 

  4. end 

  5. print(avm:value()) -- 输出 4 2.4720... 

  • tnt.AUCMeter(self)
    对于二分类问题计算Area Under Curve (AUC).
    AUCMeter操作的变量是1D的tensor

  • tnt.ConfusionMeter(self,k[,nirmalized])
    多类之间的混淆矩阵,注意不是多类多标签问题,多标签是指一个类的实例可能分配多个标签,这类问题参见tnt.MultiLabelConfusionMeter
    初始化的时候,需要指定类别数k,normalized指定是否将confuse matrix 归一化,归一化之后输出的是百分比,否则是数值
    add(output,target) 输入都是的tensor,这里为什么每次都是N个样本一起输入呢?这是因为往往训练模型都是Batch模式处理的,target可以是N的tensor,每个值表示对应类别标号,也可以时NK的tensor表示类别的one-hot vector
    value()返回K
    K的混淆矩阵行表示groundtruth,列表示predicted targets

  • tnt.mAPMeter(self)
    统计所有类别之间的平均正确率,和APMeter参数完全一致,不同的时value()返回的是多个类别总的正确率

  • tnt.MovingAverageValueMeter(self,windowsize)
    该meter和AverageValueMeter非常类似,输入的也是number,不同在于他统计的不是所有的number的均值和方差,而是往前windowsize时间窗内的numbers的均值和方差,windowsize在初始化时需要指定

  • tnt.MultiLabelConfusionMeter(self,k[,normalized])
    多类多标签混淆矩阵,这个没接触过,不知道理解对不对,先放这吧,需要的时候再看

The tnt.MultiLabelConfusionMeter constructs a confusion matrix for multi- label, multi-class classification problems. In constructing the confusion matrix, the number of positive predictions is assumed to be equal to the number of positive labels in the ground-truth. Correct predictions (that is, labels in the prediction set that are also in the ground-truth set) are added to the diagonal of the confusion matrix. Incorrect predictions (that is, labels in the prediction set that are not in the ground-truth set) are equally divided over all non-predicted labels in the ground-truth set.

At initialization time, the k parameter that indicates the number of classes in the classification problem under consideration must be specified. Additionally, an optional parameter normalized (default = false) may be specified that determines whether or not the confusion matrix is normalized (that is, it contains percentages) or not (that is, it contains counts).

The add(output, target) method takes as input an NxK tensor output that contains the output scores obtained from the model for N examples and K classes, and a corresponding NxK-tensor target that provides the targets for the N examples using one-hot vectors (that is, vectors that contain only zeros and a single one at the location of the target value to be encoded).

  • tnt.ClassErrorMeter(self[,topk][,accuracy])
    参数: topk = table
    accuracy = boolean
    该meter用于统计分类误差,topk是一个table指定分别统计前k类预测误差,如ImageNet Competition中的Top5类误差,accuracy表示返回的是正确了还是错误率,accuracy=true,返回的就是1-error
    add(output,target),output是一个的tensor,target可以使一个N的tensor也可以是一个的tensor,参考之前的AUCMeter
    value()返回的时topk误差,value(k)返回的是第topk类误差

  • tnt.TimeMeter(self[,unit])
    这个Meter用于统计events之间的时间,也可以用来统计batch数据的平均处理数据。她很特别!
    unit在初始的时候给定,是一个布尔值,默认false,当设置为true时,返回值将会被incUnit()值平均,计算平均时间消耗。
    tnt.TimeMeter提供的方法有:

    • reset() 重置timer,unit counter

    • stop() stop the timer

    • resume() 唤醒timer

    • incUnit() uint+1

    • value() 返回从reset()到现在的时间消耗

  • tnt.PrecisionAtKMeter(self[,topk][,dim][,online])

待补充
  • tnt.RecallMeter(self[,threshold][,preclass])
    统计threshold下的召回率,threshold是一个table类型,每个元素是一个阈值,默认值为0.5. perclass是一个布尔值,表示是单独统计每一类的召回率还是统计整个召回率,默认值是false
    add(output,target) output是N*K的概率矩阵,行和为1;target是NK的二值矩阵,不一定行和为1,如{0,1,0,1}
    value()返回的是table值,对应的是threshold table中指定阈值下的召回率,如果perclass = true,那么table的每个元素就是一个table

  • tnt.PrecisionMeter(self[,threshold][,perclass])
    参考RecallMeter,这里计算的是正确率

  • tnt.NDCGMeter(self[,K])
    计算normalized discounted cumulative gain,没使用过。。。。

tnt.Log

Log是一个由sting key索引的table,这些keys必须在构造函数中指定,有一个特殊的键 __status__可以在log:status()函数中设置用于记录一些基本的messages

Log中提供的一些closures以及对应attached events
* onSet(log,key,value) 对应着给键赋值 log:set{}
* onGet(log,key) 对应着读取key对应的值 log:get()
* onFlush(log) 对应着清空log log:flush()
* onClose(log) 对应log:close() 关闭log

示例:

  1. tnt = require'torchnet'


  2. logtext = require 'torchnet.log.view.text' 

  3. logstatus = require 'torchnet.log.view.status' 

  4. log = tnt.log{ 

  5. keys = {'loss','accuracy'} 

  6. onFlush = { 

  7. -- write out all keys in "log" file 

  8. logtext{filename='log.txt', keys={"loss", "accuracy"}, format={"%10.5f", "%3.2f"}}, 

  9. -- write out loss in a standalone file 

  10. logtext{filename='loss.txt', keys={"loss"}}, 

  11. -- print on screen too 

  12. logtext{keys={"loss", "accuracy"}}, 

  13. }, 

  14. onSet = { 

  15. -- add status to log 

  16. logstatus{filename='log.txt'}, 

  17. -- print status to screen 

  18. logstatus{}, 






  19. -- set values 

  20. log:set{ 

  21. loss = 0.1, 

  22. accuracy = 97 




  23. -- write some info 

  24. log:status("hello world") 


  25. -- flush out log 

  26. log:flush() 



后面我们来看一个具体的例子,以VGG16为例实现一个Siamese CNN网络计算patch之间的相似度


转载于:https://www.cnblogs.com/YiXiaoZhou/p/6774806.html

torchnet package (2)相关推荐

  1. go build 编译报错 missing go.sum entry for module providing package

    go build 编译报错 missing go.sum entry for module providing package 解决方法 // 移除未使用的依赖 go mod tidy 再次编译,就可 ...

  2. cannot find package “github.com/json-iterator/go“cannot find package “github.com/modern-go/reflect2“

    1. 问题现象 ../github.com/coreos/etcd/client/json.go:18:2: cannot find package "github.com/json-ite ...

  3. 关于python导入模块和package的一些深度思考

    背景 在python中有导入模块和导入package一说,这篇文章主要介绍导入模块和package的一些思考. 首先什么是模块?什么是package? 模块:用来从逻辑上组织python代码(变量,函 ...

  4. 关于Activity class {package/class} does not exist

    当你写好程序,或者修改程序运行的时候出Activity class {package/class,这里面是你的路劲} does not exist 这个问题就是你的R文件冲突了,也就是几个不同的工程使 ...

  5. error: Error: No resource found for attribute ‘layout_scrollFlags’ in package‘包名’

    遇到error: Error: No resource found for attribute 'layout_scrollFlags' in package'包名' 这个问题时候刚开始自己也是感觉到 ...

  6. cannot find package “github.com/coreos/go-systemd/journal”

    1. 问题现象 使用 golang etcd 导入包 github.com/coreos/etcd/clientv3 库时有如下错误: ../github.com/coreos/etcd/pkg/lo ...

  7. PyCharm中Directory与Python package的区别

    对于Python而言,有一点是要认识明确的,python作为一个相对而言轻量级的,易用的脚本语言(当然其功能并不仅限于此,在此只是讨论该特点),随着程序的增长,可能想要把它分成几个文件,以便逻辑更加清 ...

  8. Sublime Text 3 及Package Control 安装(附上一个3103可用的Key)

    一.Sublime Text 3 下载. 官方下载地址:http://www.sublimetext.com/ 二.Sublime Text 3  安装. 打开安装包,进行傻瓜式安装. 三.注册. 点 ...

  9. Udacity机器人软件工程师课程笔记(十)-ROS-Catkin-包(package)和gazebo

    包和gazebo仿真 1.添加包 (1)克隆simple_arm包 克隆现有的包并将其添加到我们新创建的工作区. 首先导航到src目录,然后从其github仓库克隆本课程 simple_arm 的包. ...

最新文章

  1. 3年工作必备 装饰器模式
  2. C# WinForm编程之System.Windows.Forms.DataGridViewRow.DataBoundItem Property
  3. tuple object is not callable解决方案
  4. 论文笔记:Git Loss
  5. mqtt客户端_初次接触MQTT
  6. java oscache 使用_OScache的使用(Java对象)
  7. string.h包含哪些函数_Excel进行数据分析常用方法及函数汇总—【杏花开生物医药统计】...
  8. Python开发基础-day1
  9. POJ2942 Knights of the Round Table 点双连通分量,逆图,奇圈
  10. 【BZOJ3050】Seating,线段树
  11. 在unity向量空间内绘制几何(4): 利用平面几何知识画像素直线
  12. 我们应该改变Linux的二十四件事
  13. 无人机图像的目标检测的学习
  14. python的迭代器_python迭代器详解
  15. PyTorch绘制训练过程的accuracy和loss曲线
  16. oa系统怎么安装服务器配置,OA系统安装配置及维护手册-金蝶在线服务中心.DOC
  17. 一文图解自定义修改el-table样式
  18. 标准粒子群算法(PSO)
  19. dota2显示连接不上服务器没有响应,Win10登录不上dota2提示“无法与任何服务器建立连接”怎么办?...
  20. phpnow运行本地php文件,使用PHPnow搭建本地wordpress

热门文章

  1. 用JavaScript玩转计算机图形学(二)基本光源
  2. CUDA系列学习(一)An Introduction to GPU and CUDA
  3. 程序员面试题精选100题(19)-反转链表[数据结构]
  4. 从CVPR 2013看计算机视觉的研究领域和趋势 [CVPR 2013] Three Trending Computer Vision Research Areas
  5. 编程之美-字符串移位问题方法整理
  6. 【OpenCV3】cv::divide()使用详解
  7. (转)OpenNLP进行中文命名实体识别(下:载入模型识别实体)
  8. android studio修改包名
  9. 解决ARC下performselector-may-cause-a-leak-because-its-selector-is-unknown 警告
  10. LVS 配置Iptables防火墙及故障解决