还没看懂,晚上接着看

https://blog.csdn.net/fuwenyan/article/details/79902002

最近两天刚开始用mxnet,训练时发现log只有accuracy,没有loss,训练半天到跑验证的时候才发现loss为NAN了。

这样不能随时看到loss的变化而及时做出调整,比较浪费时间精力。

在python mxnet安装路径下有相关接口和文件。

我用的Anaconda2,路径为Anaconda2\Lib\site-packages\mxnet\metric.py

这是一个在线评价模块Online evaluation metric module。里面给出了一些精度、loss等的评价示例类,并且给出了详细直观的示例。这些评价类的基类都是class EvalMetric(object)。

如分类精度评价:

class Accuracy(EvalMetric):"""Computes accuracy classification score.The accuracy score is defined as.. math::\\text{accuracy}(y, \\hat{y}) = \\frac{1}{n} \\sum_{i=0}^{n-1}\\text{1}(\\hat{y_i} == y_i)Parameters----------axis : int, default=1The axis that represents classesname : strName of this metric instance for display.output_names : list of str, or NoneName of predictions that should be used when updating with update_dict.By default include all predictions.label_names : list of str, or NoneName of labels that should be used when updating with update_dict.By default include all labels.Examples-------->>> predicts = [mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])]>>> labels   = [mx.nd.array([0, 1, 1])]>>> acc = mx.metric.Accuracy()>>> acc.update(preds = predicts, labels = labels)>>> print acc.get()('accuracy', 0.6666666666666666)"""def __init__(self, axis=1, name='accuracy',output_names=None, label_names=None):super(Accuracy, self).__init__(name, axis=axis,output_names=output_names, label_names=label_names)self.axis = axisdef update(self, labels, preds):"""Updates the internal evaluation result.Parameters----------labels : list of `NDArray`The labels of the data.preds : list of `NDArray`Predicted values."""check_label_shapes(labels, preds)for label, pred_label in zip(labels, preds):if pred_label.shape != label.shape:pred_label = ndarray.argmax(pred_label, axis=self.axis)pred_label = pred_label.asnumpy().astype('int32')label = label.asnumpy().astype('int32')check_label_shapes(label, pred_label)self.sum_metric += (pred_label.flat == label.flat).sum()self.num_inst += len(pred_label.flat)

上述代码中,label是标注分类标签,preds是预测标签

如cross entropy loss:

@register
@alias('ce')
class CrossEntropy(EvalMetric):"""Computes Cross Entropy loss.The cross entropy over a batch of sample size :math:`N` is given by.. math::-\\sum_{n=1}^{N}\\sum_{k=1}^{K}t_{nk}\\log (y_{nk}),where :math:`t_{nk}=1` if and only if sample :math:`n` belongs to class :math:`k`.:math:`y_{nk}` denotes the probability of sample :math:`n` belonging toclass :math:`k`.Parameters----------eps : floatCross Entropy loss is undefined for predicted value is 0 or 1,so predicted values are added with the small constant.name : strName of this metric instance for display.output_names : list of str, or NoneName of predictions that should be used when updating with update_dict.By default include all predictions.label_names : list of str, or NoneName of labels that should be used when updating with update_dict.By default include all labels.Examples-------->>> predicts = [mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])]>>> labels   = [mx.nd.array([0, 1, 1])]>>> ce = mx.metric.CrossEntropy()>>> ce.update(labels, predicts)>>> print ce.get()('cross-entropy', 0.57159948348999023)"""def __init__(self, eps=1e-12, name='cross-entropy',output_names=None, label_names=None):super(CrossEntropy, self).__init__(name, eps=eps,output_names=output_names, label_names=label_names)self.eps = epsdef update(self, labels, preds):"""Updates the internal evaluation result.Parameters----------labels : list of `NDArray`The labels of the data.preds : list of `NDArray`Predicted values."""check_label_shapes(labels, preds)for label, pred in zip(labels, preds):label = label.asnumpy()pred = pred.asnumpy()label = label.ravel()assert label.shape[0] == pred.shape[0]prob = pred[numpy.arange(label.shape[0]), numpy.int64(label)]self.sum_metric += (-numpy.log(prob + self.eps)).sum()self.num_inst += label.shape[0]

你还可以根据自己定义的评价标准按照上面的示例去写自己的评价接口。

使用的时候

    import mxnet as mx#-----------...-----------------eval_metrics = mx.metric.CompositeEvalMetric()metric1 = mx.metric.Accuracy()metric2 = mx.metric.CrossEntropy()metric3 = mx.metric.MSE()for child_metric in [metric1, metric2, metric3]:eval_metrics.add(child_metric)
然后再mod.fit中使用eval_metrics,如下:mod.fit(train_data, eval_metric=eval_metrics, epoch_end_callback=epoch_end_callback,batch_end_callback=batch_end_callback,optimizer='sgd', optimizer_params=optimizer_params,arg_params=args, aux_params=auxs, begin_epoch=begin_epoch, num_epoch=end_epoch)

insightface中使用:

class AccMetric(mx.metric.EvalMetric):def __init__(self):self.axis = 1super(AccMetric, self).__init__('acc', axis=self.axis,output_names=None, label_names=None)self.losses = []self.count = 0def update(self, labels, preds):self.count+=1if args.loss_type>=2 and args.loss_type<=4 and args.margin_verbose>0:if self.count%args.ctx_num==0:mbatch = self.count//args.ctx_numif mbatch==1 or mbatch%args.margin_verbose==0:a = 0.0b = 0.0if len(preds)>=4:a = preds[-2].asnumpy()[0]b = preds[-1].asnumpy()[0]elif len(preds)==3:a = preds[-1].asnumpy()[0]b = aprint('[%d][MARGIN]%f,%f'%(mbatch,a,b))#loss = preds[2].asnumpy()[0]#if len(self.losses)==20:#  print('ce loss', sum(self.losses)/len(self.losses))#  self.losses = []#self.losses.append(loss)preds = [preds[1]] #use softmax outputfor label, pred_label in zip(labels, preds):#print(pred_label)#print(label.shape, pred_label.shape)if pred_label.shape != label.shape:pred_label = mx.ndarray.argmax(pred_label, axis=self.axis)pred_label = pred_label.asnumpy().astype('int32').flatten()label = label.asnumpy().astype('int32').flatten()#print(label)#print('label',label)#print('pred_label', pred_label)assert label.shape==pred_label.shapeself.sum_metric += (pred_label.flat == label.flat).sum()self.num_inst += len(pred_label.flat)
class LossValueMetric(mx.metric.EvalMetric):def __init__(self):self.axis = 1super(LossValueMetric, self).__init__('lossvalue', axis=self.axis,output_names=None, label_names=None)self.losses = []def update(self, labels, preds):loss = preds[-1].asnumpy()[0]self.sum_metric += lossself.num_inst += 1.0gt_label = preds[-2].asnumpy()#print(gt_label)

在上面LossValueMetric中,

声明输出标签 name:lossvalue

输出值:self.sum_metric

同时输出acc 和loss:

    if args.loss_type<10:_metric = AccMetric()else:_metric = LossValueMetric()
eval_metrics = [mx.metric.create(_metric),mx.metric.create(LossValueMetric())]

更多详细代码解释可以参考

https://blog.csdn.net/u014380165/article/details/78311231

修改后可以定制输出loss,如下

————————————————
版权声明:本文为CSDN博主「非文艺小燕儿_Vivien」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/fuwenyan/article/details/79902002

mxnet输出训练loss相关推荐

  1. 神经网络训练loss不下降原因集合

    train loss与test loss结果分析 train loss 不断下降,test loss不断下降,说明网络仍在学习; train loss 不断下降,test loss趋于不变,说明网络过 ...

  2. 机器学习中的训练与损失 test and loss (训练loss不下降原因集合)

    train loss 不断下降,test loss不断下降,说明网络仍在学习; train loss 不断下降,test loss趋于不变,说明网络过拟合; train loss 趋于不变,test ...

  3. 在服务器上远程使用tensorboard查看训练loss和准确率

    本人使用的是vscode 很简单 from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter('./logs')w ...

  4. mxnet自定义训练日志

    batch训练回调函数: def _batch_callback(param):#global global_stepglobal_step[0]+=1mbatch = global_step[0]f ...

  5. pytorch-分类任务训练loss不变

    分类任务遇到loss不变情况,使用交叉熵损失 先说结论,原因是pytorch的交叉熵里自带softmax,如果在网络输出上手动加softmax就会出错. 现象 在分类数为3663的分类任务中出现los ...

  6. MXNet的训练基础脚本:base_module.py

    写在前面:在MXNet中有一个很重要的脚本:base_module.py,这个脚本中的BaseModule类定义了和模型实现相关的框架.另外还有一个脚本module.py会在另外一篇博客中讲,这个类继 ...

  7. 验证和训练loss和acc多种情况分析

    首先要明确一点刚开始的时候loss一定是个不为0且合理的值,接近为0那么就意味着你这个网络已经梯度消失,可以换网络或者是改网络了,要么就是数据集有问题,要么就是loss函数有问题. 1,假设你的数据集 ...

  8. 训练loss不下降原因集合

    11年it研发经验,从一个会计转行为算法工程师,学过C#,c++,java,android,php,go,js,python,CNN神经网络,四千多篇博文,三千多篇原创,只为与你分享,共同成长,一起进 ...

  9. 【模型训练-loss】模型训练过程中train, test loss的关系及原因

    网上查找了一些资料,避免忘记了,做个笔记供以后参考 train loss 不断下降,test loss不断下降,说明网络仍在学习; train loss 不断下降,test loss趋于不变,说明网络 ...

最新文章

  1. Tensorflow—Droupout
  2. Python一行代码实现快速排序
  3. YOLOv3在OpenCV4.0.0/OpenCV3.4.2上的C++ demo实现
  4. rabbitmq 取消消息_SpringBoot整合RabbitMQ实现延迟消息
  5. Noise噪音halcon算子,持续更新
  6. JVM(六)为什么新生代有两个Survivor分区? 1
  7. Java使用自定义包
  8. 【2016年第4期】研究(国家自然科学基金项目成果集萃)导读
  9. MySQL 常用命令大全
  10. 6.4信号灯(Semaphores)
  11. 【中秋福利】Linux系统从入门到精通推荐的书籍——中秋限时送书活动
  12. Fisher discrimination criterion (费舍尔判别准则)
  13. 推荐系统(四)——因果效应uplift model系列模型S-Learner,T-Learner,X-Learner
  14. win10安装apache环境
  15. Wifi热点java_用笔记本电脑开启热点Java小程序
  16. matlab bsxfun memory,[转]matlab函数 bsxfunarrayfun
  17. 如何查询计算机com口使用
  18. 台式电脑主板插线步骤图_电脑主板跳线插法 装机接线详细图解教程
  19. linux下go语言入门,Go语言入门之旅(二):环境搭建-Linux篇
  20. 注意力机制QKV理解

热门文章

  1. shell获取时间戳
  2. 用google代替CSDN的博客搜索功能
  3. 内核模式下的注册表操作
  4. WinDbg 脚本实例,可以显示 SSDT
  5. Android 通过http协议数据交互
  6. Linux概念架构的理解
  7. 计算机网络本地连接被禁用说明什么,win10系统网络被禁用重新启用本地连接的设置技巧...
  8. 深信服5月26日笔试
  9. 扩展 日历_2021少林日历 | 以最美的方式记录时光
  10. java控制语句练习题_[Java初探实例篇02]__流程控制语句知识相关的实例练习