11年it研发经验,从一个会计转行为算法工程师,学过C#,c++,java,android,php,go,js,python,CNN神经网络,四千多篇博文,三千多篇原创,只为与你分享,共同成长,一起进步,关注我,给你分享更多干货知识!

https://github.com/bleakie/MaskInsightface/blob/master/src/train_softmax.py

跟这个差不多,不一定完全一样:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_functionimport os
import sys
import math
import random
import logging
import pickle
import numpy as npimport mfn_55
from common import verification, spherenet
from common.noise_sgd import NoiseSGD
from image_iter import FaceImageIter,FaceImageIterList
import mxnet as mx
from mxnet import ndarray as nd
import argparse
import mxnet.optimizer as optimizerfrom mx_2d_jz import get_networksys.path.append(os.path.join(os.path.dirname(__file__), 'common'))
import face_image
# from noise_sgd import NoiseSGD
sys.path.append(os.path.join(os.path.dirname(__file__), 'eval'))
sys.path.append(os.path.join(os.path.dirname(__file__), 'symbols'))
# import fresnetimport sklearn
sys.path.append(os.path.join(os.path.dirname(__file__), 'losses'))logger = logging.getLogger()
logger.setLevel(logging.INFO)class AccMetric(mx.metric.EvalMetric):def __init__(self):self.axis = 1super(AccMetric, self).__init__('acc', axis=self.axis,output_names=None, label_names=None)self.losses = []self.count = 0def update(self, labels, preds):self.count+=1if args.loss_type>=2 and args.loss_type<=7 and args.margin_verbose>0:if self.count%args.ctx_num==0:mbatch = self.count//args.ctx_num_verbose = args.margin_verboseif mbatch==1 or mbatch%_verbose==0:a = 0.0b = 0.0if len(preds)>=4:a = preds[-2].asnumpy()[0]b = preds[-1].asnumpy()[0]elif len(preds)==3:a = preds[-1].asnumpy()[0]b = aprint('[%d][MARGIN]%f,%f'%(mbatch,a,b))if args.logits_verbose>0:if self.count%args.ctx_num==0:mbatch = self.count//args.ctx_num_verbose = args.logits_verboseif mbatch==1 or mbatch%_verbose==0:a = 0.0b = 0.0if len(preds)>=3:v = preds[-1].asnumpy()v = np.sort(v)num = len(v)//10a = np.mean(v[0:num])b = np.mean(v[-1*num:])print('[LOGITS] %d,%f,%f'%(mbatch,a,b))#loss = preds[2].asnumpy()[0]#if len(self.losses)==20:#  print('ce loss', sum(self.losses)/len(self.losses))#  self.losses = []#self.losses.append(loss)preds = [preds[1]] #use softmax outputfor label, pred_label in zip(labels, preds):#print(pred_label)#print(label.shape, pred_label.shape)if pred_label.shape != label.shape:pred_label = mx.ndarray.argmax(pred_label, axis=self.axis)pred_label = pred_label.asnumpy().astype('int32').flatten()label = label.asnumpy()if label.ndim==2:label = label[:,0]label = label.astype('int32').flatten()#print(label)#print('label',label)#print('pred_label', pred_label)assert label.shape==pred_label.shapeself.sum_metric += (pred_label.flat == label.flat).sum()self.num_inst += len(pred_label.flat)class LossValueMetric(mx.metric.EvalMetric):def __init__(self):self.axis = 1super(LossValueMetric, self).__init__('lossvalue', axis=self.axis,output_names=None, label_names=None)self.losses = []def update(self, labels, preds):loss = preds[-1].asnumpy()[0]self.sum_metric += lossself.num_inst += 1.0gt_label = preds[-2].asnumpy()#print(gt_label)def parse_args():parser = argparse.ArgumentParser(description='Train face network')# generalparser.add_argument('--data_dir', default='/data3/data/lbg/ms1m-retinaface-t1', help='training set directory')parser.add_argument('--prefix', default='../model/model', help='directory to save model.')# parser.add_argument('--pretrained', default='', help='pretrained model to load')parser.add_argument('--pretrained', default='../model/model,31', help='pretrained model to load')parser.add_argument('--ckpt', type=int, default=1, help='checkpoint saving option. 0: discard saving. 1: save when necessary. 2: always save')parser.add_argument('--network', default='mfnv2', help='specify network')parser.add_argument('--version-se', type=int, default=0, help='whether to use se in network')parser.add_argument('--version-input', type=int, default=1, help='network input config')parser.add_argument('--version-output', type=str, default='E', help='network embedding output config')parser.add_argument('--version-unit', type=int, default=3, help='resnet unit config')parser.add_argument('--version-act', type=str, default='prelu', help='network activation config')parser.add_argument('--end-epoch', type=int, default=100000, help='training epoch size.')parser.add_argument('--noise-sgd', type=float, default=0.0, help='')parser.add_argument('--lr', type=float, default=0.001, help='start learning rate')parser.add_argument('--wd', type=float, default=0.0005, help='weight decay')parser.add_argument('--mom', type=float, default=0.9, help='momentum')parser.add_argument('--emb-size', type=int, default=512, help='embedding length')parser.add_argument('--per-batch_size', type=int, default=240, help='batch size in each context')parser.add_argument('--margin-m', type=float, default=0.5, help='')parser.add_argument('--margin-s', type=float, default=64.0, help='')parser.add_argument('--margin-a', type=float, default=0.0, help='')parser.add_argument('--margin-b', type=float, default=0.0, help='')parser.add_argument('--easy-margin', type=int, default=0, help='')parser.add_argument('--margin-verbose', type=int, default=0, help='')parser.add_argument('--logits-verbose', type=int, default=0, help='')parser.add_argument('--c2c-threshold', type=float, default=0.0, help='')parser.add_argument('--c2c-mode', type=int, default=-10, help='')parser.add_argument('--output-c2c', type=int, default=0, help='')parser.add_argument('--train-limit', type=int, default=0, help='')parser.add_argument('--margin', type=int, default=4, help='')parser.add_argument('--beta', type=float, default=1000., help='')parser.add_argument('--beta-min', type=float, default=5., help='')parser.add_argument('--beta-freeze', type=int, default=0, help='')parser.add_argument('--gamma', type=float, default=0.12, help='')parser.add_argument('--power', type=float, default=1.0, help='')parser.add_argument('--scale', type=float, default=0.9993, help='')parser.add_argument('--center-alpha', type=float, default=0.5, help='')parser.add_argument('--center-scale', type=float, default=0.003, help='')parser.add_argument('--images-per-identity', type=int, default=0, help='')parser.add_argument('--triplet-bag-size', type=int, default=3600, help='')parser.add_argument('--triplet-alpha', type=float, default=0.3, help='')parser.add_argument('--triplet-max-ap', type=float, default=0.0, help='')parser.add_argument('--verbose', type=int, default=2000, help='')parser.add_argument('--loss_type', type=int, default=4, help='')parser.add_argument('--incay', type=float, default=0.0, help='feature incay')parser.add_argument('--use-deformable', type=int, default=0, help='')parser.add_argument('--rand-mirror', type=int, default=1, help='')parser.add_argument('--cutoff', type=int, default=0, help='')parser.add_argument('--patch', type=str, default='0_0_96_112_0',help='')parser.add_argument('--lr-steps', type=str, default='', help='')parser.add_argument('--max-steps', type=int, default=0, help='')parser.add_argument('--target', type=str, default='lfw,agedb_30,cfp_fp', help='')args = parser.parse_args()return argsdef get_symbol(args, arg_params, aux_params):data_shape = (args.image_channel,args.image_h,args.image_w)image_shape = ",".join([str(x) for x in data_shape])margin_symbols = []fixed_param_names=[]if args.network[0]=='m':print('init mobilenet', args.num_layers)if args.num_layers==1:embedding = fmobilenet.get_symbol(args.emb_size,version_se=args.version_se, version_input=args.version_input,version_output=args.version_output, version_unit=args.version_unit)else:# embedding = mfn_55.get_symbol(512)embedding=get_network(512,is_train=True)for node in embedding.get_internals():if "_qf_fixed" in node.name:fixed_param_names.append(node.name)# embedding = fmobilenetv2.get_symbol(args.emb_size)elif args.network[0]=='i':print('init inception-resnet-v2', args.num_layers)embedding = finception_resnet_v2.get_symbol(args.emb_size,version_se=args.version_se, version_input=args.version_input,version_output=args.version_output, version_unit=args.version_unit)elif args.network[0]=='x':print('init xception', args.num_layers)embedding = fxception.get_symbol(args.emb_size,version_se=args.version_se, version_input=args.version_input,version_output=args.version_output, version_unit=args.version_unit)elif args.network[0]=='n':print('init nasnet', args.num_layers)embedding = fnasnet.get_symbol(args.emb_size)elif args.network[0]=='s':print('init spherenet', args.num_layers)embedding = spherenet.get_symbol(args.emb_size, args.num_layers)else:print('init resnet', args.num_layers)embedding = fresnet.get_symbol(args.emb_size, args.num_layers,version_se=args.version_se, version_input=args.version_input,version_output=args.version_output, version_unit=args.version_unit,version_act=args.version_act)all_label = mx.symbol.Variable('softmax_label')if not args.output_c2c:gt_label = all_labelelse:gt_label = mx.symbol.slice_axis(all_label, axis=1, begin=0, end=1)gt_label = mx.symbol.reshape(gt_label, (args.per_batch_size,))c2c_label = mx.symbol.slice_axis(all_label, axis=1, begin=1, end=2)c2c_label = mx.symbol.reshape(c2c_label, (args.per_batch_size,))assert args.loss_type>=0extra_loss = Noneif args.loss_type==0: #softmax_weight = mx.symbol.Variable('fc7_weight')_bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7')elif args.loss_type==1: #sphere_weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)_weight = mx.symbol.L2Normalization(_weight, mode='instance')fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes,weight = _weight,beta=args.beta, margin=args.margin, scale=args.scale,beta_min=args.beta_min, verbose=1000, name='fc7')elif args.loss_type==8: #centerloss, TODO_weight = mx.symbol.Variable('fc7_weight')_bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7')print('center-loss', args.center_alpha, args.center_scale)extra_loss = mx.symbol.Custom(data=embedding, label=gt_label, name='center_loss', op_type='centerloss', \num_class=args.num_classes, alpha=args.center_alpha, scale=args.center_scale, batchsize=args.per_batch_size)elif args.loss_type==2:s = args.margin_sm = args.margin_m_weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)_weight = mx.symbol.L2Normalization(_weight, mode='instance')if s>0.0:nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*sfc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')if m>0.0:if args.margin_verbose>0:zy = mx.sym.pick(fc7, gt_label, axis=1)cos_t = zy/smargin_symbols.append(mx.symbol.mean(cos_t))s_m = s*mgt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0)fc7 = fc7-gt_one_hotif args.margin_verbose>0:new_zy = mx.sym.pick(fc7, gt_label, axis=1)new_cos_t = new_zy/smargin_symbols.append(mx.symbol.mean(new_cos_t))else:fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')if m>0.0:body = embedding*embeddingbody = mx.sym.sum_axis(body, axis=1, keepdims=True)body = mx.sym.sqrt(body)body = body*mgt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)body = mx.sym.broadcast_mul(gt_one_hot, body)fc7 = fc7-bodyelif args.loss_type==3:s = args.margin_sm = args.margin_massert args.margin==2 or args.margin==4_weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)_weight = mx.symbol.L2Normalization(_weight, mode='instance')nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*sfc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')zy = mx.sym.pick(fc7, gt_label, axis=1)cos_t = zy/sif args.margin_verbose>0:margin_symbols.append(mx.symbol.mean(cos_t))if m>1.0:t = mx.sym.arccos(cos_t)t = t*mbody = mx.sym.cos(t)new_zy = body*sif args.margin_verbose>0:new_cos_t = new_zy/smargin_symbols.append(mx.symbol.mean(new_cos_t))diff = new_zy - zydiff = mx.sym.expand_dims(diff, 1)gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)body = mx.sym.broadcast_mul(gt_one_hot, diff)fc7 = fc7+bodyelif args.loss_type==4:s = args.margin_sm = args.margin_massert s>0.0assert m>=0.0assert m<(math.pi/2)_weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)_weight = mx.symbol.L2Normalization(_weight, mode='instance')nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*sfc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')zy = mx.sym.pick(fc7, gt_label, axis=1)cos_t = zy/sif args.margin_verbose>0:margin_symbols.append(mx.symbol.mean(cos_t))if args.output_c2c==0:cos_m = math.cos(m)sin_m = math.sin(m)mm = math.sin(math.pi-m)*m#threshold = 0.0threshold = math.cos(math.pi-m)if args.easy_margin:cond = mx.symbol.Activation(data=cos_t, act_type='relu')else:cond_v = cos_t - thresholdcond = mx.symbol.Activation(data=cond_v, act_type='relu')body = cos_t*cos_tbody = 1.0-bodysin_t = mx.sym.sqrt(body)new_zy = cos_t*cos_mb = sin_t*sin_mnew_zy = new_zy - bnew_zy = new_zy*sif args.easy_margin:zy_keep = zyelse:zy_keep = zy - s*mmnew_zy = mx.sym.where(cond, new_zy, zy_keep)else:#set c2c as cosm^2 in data.pycos_m = mx.sym.sqrt(c2c_label)sin_m = 1.0-c2c_labelsin_m = mx.sym.sqrt(sin_m)body = cos_t*cos_tbody = 1.0-bodysin_t = mx.sym.sqrt(body)new_zy = cos_t*cos_mb = sin_t*sin_mnew_zy = new_zy - bnew_zy = new_zy*sif args.margin_verbose>0:new_cos_t = new_zy/smargin_symbols.append(mx.symbol.mean(new_cos_t))diff = new_zy - zydiff = mx.sym.expand_dims(diff, 1)gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0, name='one_hot0')body = mx.sym.broadcast_mul(gt_one_hot, diff)fc7 = fc7+bodyelif args.loss_type==5:s = args.margin_sm = args.margin_massert s>0.0#assert m>=0.0#assert m<(math.pi/2)_weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)_weight = mx.symbol.L2Normalization(_weight, mode='instance')nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*sfc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')zy = mx.sym.pick(fc7, gt_label, axis=1)cos_t = zy/st = mx.sym.arccos(cos_t)if args.margin_verbose>0:margin_symbols.append(mx.symbol.mean(t))if args.margin_a>0.0:t = t*args.margin_aif args.margin_m>0.0:t = t+args.margin_mbody = mx.sym.cos(t)if args.margin_b>0.0:body = body - args.margin_bnew_zy = body*sif args.margin_verbose>0:margin_symbols.append(mx.symbol.mean(t))diff = new_zy - zydiff = mx.sym.expand_dims(diff, 1)gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)body = mx.sym.broadcast_mul(gt_one_hot, diff)fc7 = fc7+bodyelif args.loss_type==6:s = args.margin_sm = args.margin_massert s>0.0assert m>=0.0assert m<(math.pi/2)_weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)_weight = mx.symbol.L2Normalization(_weight, mode='instance')nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*sfc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')zy = mx.sym.pick(fc7, gt_label, axis=1)cos_t = zy/st = mx.sym.arccos(cos_t)if args.margin_verbose>0:margin_symbols.append(mx.symbol.mean(t))t_min = mx.sym.min(t)ta = mx.sym.broadcast_div(t_min, t)a1 = args.margin_ar1 = ta-a1r1 = mx.symbol.Activation(data=r1, act_type='relu')r1 = r1+a1r2 = mx.symbol.zeros(shape=(args.per_batch_size,))cond = t-1.0cond = mx.symbol.Activation(data=cond, act_type='relu')r = mx.sym.where(cond, r2, r1)var_m = r*mt = t+var_mbody = mx.sym.cos(t)new_zy = body*sif args.margin_verbose>0:#new_cos_t = new_zy/s#margin_symbols.append(mx.symbol.mean(new_cos_t))margin_symbols.append(mx.symbol.mean(t))diff = new_zy - zydiff = mx.sym.expand_dims(diff, 1)gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)body = mx.sym.broadcast_mul(gt_one_hot, diff)fc7 = fc7+bodyelif args.loss_type==7:s = args.margin_sm = args.margin_massert s>0.0assert m>=0.0assert m<(math.pi/2)_weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)_weight = mx.symbol.L2Normalization(_weight, mode='instance')nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*sfc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')zy = mx.sym.pick(fc7, gt_label, axis=1)cos_t = zy/st = mx.sym.arccos(cos_t)if args.margin_verbose>0:margin_symbols.append(mx.symbol.mean(t))var_m = mx.sym.random.uniform(low=args.margin_a, high=args.margin_m, shape=(1,))t = mx.sym.broadcast_add(t,var_m)body = mx.sym.cos(t)new_zy = body*sif args.margin_verbose>0:#new_cos_t = new_zy/s#margin_symbols.append(mx.symbol.mean(new_cos_t))margin_symbols.append(mx.symbol.mean(t))diff = new_zy - zydiff = mx.sym.expand_dims(diff, 1)gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)body = mx.sym.broadcast_mul(gt_one_hot, diff)fc7 = fc7+bodyelif args.loss_type==10: #marginal lossnembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')params = [1.2, 0.3, 1.0]n1 = mx.sym.expand_dims(nembedding, axis=1) #N,1,Cn2 = mx.sym.expand_dims(nembedding, axis=0) #1,N,Cbody = mx.sym.broadcast_sub(n1, n2) #N,N,Cbody = body * bodybody = mx.sym.sum(body, axis=2) # N,N#body = mx.sym.sqrt(body)body = body - params[0]mask = mx.sym.Variable('extra')body = body*maskbody = body+params[1]#body = mx.sym.maximum(body, 0.0)body = mx.symbol.Activation(data=body, act_type='relu')body = mx.sym.sum(body)body = body/(args.per_batch_size*args.per_batch_size-args.per_batch_size)extra_loss = mx.symbol.MakeLoss(body, grad_scale=params[2])elif args.loss_type==11: #npair lossparams = [0.9, 0.2]nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')nembedding = mx.sym.transpose(nembedding)nembedding = mx.symbol.reshape(nembedding, (args.emb_size, args.per_identities, args.images_per_identity))nembedding = mx.sym.transpose(nembedding, axes=(2,1,0)) #2*id*512#nembedding = mx.symbol.reshape(nembedding, (args.emb_size, args.images_per_identity, args.per_identities))#nembedding = mx.sym.transpose(nembedding, axes=(1,2,0)) #2*id*512n1 = mx.symbol.slice_axis(nembedding, axis=0, begin=0, end=1)n2 = mx.symbol.slice_axis(nembedding, axis=0, begin=1, end=2)n1 = mx.symbol.reshape(n1, (args.per_identities, args.emb_size))n2 = mx.symbol.reshape(n2, (args.per_identities, args.emb_size))cosine_matrix = mx.symbol.dot(lhs=n1, rhs=n2, transpose_b = True) #id*id, id=N of N-pairdata_extra = mx.sym.Variable('extra')data_extra = mx.sym.slice_axis(data_extra, axis=0, begin=0, end=args.per_identities)mask = cosine_matrix * data_extra#body = mx.sym.mean(mask)fii = mx.sym.sum_axis(mask, axis=1)fij_fii = mx.sym.broadcast_sub(cosine_matrix, fii)fij_fii = mx.sym.exp(fij_fii)row = mx.sym.sum_axis(fij_fii, axis=1)row = mx.sym.log(row)body = mx.sym.mean(row)extra_loss = mx.sym.MakeLoss(body)elif args.loss_type==12: #triplet lossnembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')anchor = mx.symbol.slice_axis(nembedding, axis=0, begin=0, end=args.per_batch_size//3)positive = mx.symbol.slice_axis(nembedding, axis=0, begin=args.per_batch_size//3, end=2*args.per_batch_size//3)negative = mx.symbol.slice_axis(nembedding, axis=0, begin=2*args.per_batch_size//3, end=args.per_batch_size)ap = anchor - positivean = anchor - negativeap = ap*apan = an*anap = mx.symbol.sum(ap, axis=1, keepdims=1) #(T,1)an = mx.symbol.sum(an, axis=1, keepdims=1) #(T,1)triplet_loss = mx.symbol.Activation(data = (ap-an+args.triplet_alpha), act_type='relu')triplet_loss = mx.symbol.mean(triplet_loss)#triplet_loss = mx.symbol.sum(triplet_loss)/(args.per_batch_size//3)extra_loss = mx.symbol.MakeLoss(triplet_loss)elif args.loss_type==13: #triplet loss with angular marginm = args.margin_msin_m = math.sin(m)cos_m = math.cos(m)nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')anchor = mx.symbol.slice_axis(nembedding, axis=0, begin=0, end=args.per_batch_size//3)positive = mx.symbol.slice_axis(nembedding, axis=0, begin=args.per_batch_size//3, end=2*args.per_batch_size//3)negative = mx.symbol.slice_axis(nembedding, axis=0, begin=2*args.per_batch_size//3, end=args.per_batch_size)ap = anchor * positivean = anchor * negativeap = mx.symbol.sum(ap, axis=1, keepdims=1) #(T,1)an = mx.symbol.sum(an, axis=1, keepdims=1) #(T,1)ap = mx.symbol.arccos(ap)an = mx.symbol.arccos(an)triplet_loss = mx.symbol.Activation(data = (ap-an+args.margin_m), act_type='relu')#body = ap*ap#body = 1.0-body#body = mx.symbol.sqrt(body)#body = body*sin_m#ap = ap*cos_m#ap = ap-body#triplet_loss = mx.symbol.Activation(data = (an-ap), act_type='relu')triplet_loss = mx.symbol.mean(triplet_loss)extra_loss = mx.symbol.MakeLoss(triplet_loss)elif args.loss_type==9: #coco losscentroids = []for i in range(args.per_identities):xs = mx.symbol.slice_axis(embedding, axis=0, begin=i*args.images_per_identity, end=(i+1)*args.images_per_identity)mean = mx.symbol.mean(xs, axis=0, keepdims=True)mean = mx.symbol.L2Normalization(mean, mode='instance')centroids.append(mean)centroids = mx.symbol.concat(*centroids, dim=0)nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*args.coco_scalefc7 = mx.symbol.dot(nembedding, centroids, transpose_b = True) #(batchsize, per_identities)#extra_loss = mx.symbol.softmax_cross_entropy(fc7, gt_label, name='softmax_ce')/args.per_batch_size#extra_loss = mx.symbol.BlockGrad(extra_loss)else:#embedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*float(args.loss_type)embedding = embedding * 5_weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)_weight = mx.symbol.L2Normalization(_weight, mode='instance') * 2fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes,weight = _weight,beta=args.beta, margin=args.margin, scale=args.scale,beta_min=args.beta_min, verbose=100, name='fc7')#fc7 = mx.sym.Custom(data=embedding, label=gt_label, weight=_weight, num_hidden=args.num_classes,#                       beta=args.beta, margin=args.margin, scale=args.scale,#                       op_type='ASoftmax', name='fc7')if args.loss_type<=1 and args.incay>0.0:params = [1.e-10]sel = mx.symbol.argmax(data = fc7, axis=1)sel = (sel==gt_label)norm = embedding*embeddingnorm = mx.symbol.sum(norm, axis=1)norm = norm+params[0]feature_incay = sel/normfeature_incay = mx.symbol.mean(feature_incay) * args.incayextra_loss = mx.symbol.MakeLoss(feature_incay)#out = softmax#l2_embedding = mx.symbol.L2Normalization(embedding)#ce = mx.symbol.softmax_cross_entropy(fc7, gt_label, name='softmax_ce')/args.per_batch_size#out = mx.symbol.Group([mx.symbol.BlockGrad(embedding), softmax, mx.symbol.BlockGrad(ce)])out_list = [mx.symbol.BlockGrad(embedding)]softmax = Noneif args.loss_type<10:softmax = mx.symbol.SoftmaxOutput(data=fc7, label = gt_label, name='softmax', normalization='valid')out_list.append(softmax)if args.logits_verbose>0:logits = mx.symbol.softmax(data = fc7)logits = mx.sym.pick(logits, gt_label, axis=1)margin_symbols.append(logits)#logit_max = mx.sym.max(logits)#logit_min = mx.sym.min(logits)#margin_symbols.append(logit_max)#margin_symbols.append(logit_min)if softmax is None:out_list.append(mx.sym.BlockGrad(gt_label))if extra_loss is not None:out_list.append(extra_loss)for _sym in margin_symbols:_sym = mx.sym.BlockGrad(_sym)out_list.append(_sym)out = mx.symbol.Group(out_list)return (out, arg_params, aux_params,fixed_param_names)def train_net():ctx = []cvd ="0,1"# os.environ['CUDA_VISIBLE_DEVICES'].strip()if len(cvd)>0:for i in cvd.split(','):ctx.append(mx.gpu(int(i)))if len(ctx)==0:ctx = [mx.cpu()]print('use cpu')else:print('gpu num:', len(ctx))prefix = args.prefixprefix_dir = os.path.dirname(prefix)if not os.path.exists(prefix_dir):os.makedirs(prefix_dir)end_epoch = args.end_epochargs.ctx_num = len(ctx)args.num_layers =26# int(args.network[1:])print('num_layers', args.num_layers)if args.per_batch_size==0:args.per_batch_size = 128if args.loss_type==10:args.per_batch_size = 256args.batch_size = args.per_batch_size*args.ctx_numargs.rescale_threshold = 0args.image_channel = 3ppatch = [int(x) for x in args.patch.split('_')]assert len(ppatch)==5os.environ['BETA'] = str(args.beta)data_dir_list = args.data_dir.split(',')if args.loss_type!=12 and args.loss_type!=13:assert len(data_dir_list)==1data_dir = data_dir_list[0]args.use_val = Falsepath_imgrec = Nonepath_imglist = Noneval_rec = Noneprop = face_image.load_property(data_dir)args.num_classes = prop.num_classesimage_size = prop.image_sizeargs.image_h = image_size[0]args.image_w = image_size[1]print('image_size', image_size)assert(args.num_classes>0)print('num_classes', args.num_classes)args.coco_scale = 0.5*math.log(float(args.num_classes-1))+3#path_imglist = "/raid5data/dplearn/MS-Celeb-Aligned/lst2"path_imgrec = os.path.join(data_dir, "train.rec")val_rec = os.path.join(data_dir, "val.rec")if os.path.exists(val_rec) and args.loss_type<10:args.use_val = Trueelse:val_rec = None#args.use_val = Falseif args.loss_type==1 and args.num_classes>20000:args.beta_freeze = 5000args.gamma = 0.06if args.loss_type<9:assert args.images_per_identity==0else:if args.images_per_identity==0:if args.loss_type==11:args.images_per_identity = 2elif args.loss_type==10 or args.loss_type==9:args.images_per_identity = 16elif args.loss_type==12 or args.loss_type==13:args.images_per_identity = 5assert args.per_batch_size%3==0assert args.images_per_identity>=2args.per_identities = int(args.per_batch_size/args.images_per_identity)print('Called with argument:', args)data_shape = (args.image_channel,image_size[0],image_size[1])mean = Nonebegin_epoch = 0base_lr = args.lrbase_wd = args.wdbase_mom = args.momif len(args.pretrained)==0:arg_params = Noneaux_params = Nonesym, arg_params, aux_params,fixed_param_names = get_symbol(args, arg_params, aux_params)else:vec = args.pretrained.split(',')print('loading', vec)_, arg_params, aux_params = mx.model.load_checkpoint(vec[0], int(vec[1]))# _, arg_params, aux_params = mx.model.load_checkpoint(r"/data/sharedata/lbg/reid/MobileFaceNet/eval/imdb_feature/mv2/model", 0)sym, arg_params, aux_params,fixed_param_names = get_symbol(args, arg_params, aux_params)if args.network[0]=='s':data_shape_dict = {'data' : (args.per_batch_size,)+data_shape}spherenet.init_weights(sym, data_shape_dict, args.num_layers)data_extra = Nonehard_mining = Falsetriplet_params = Nonecoco_mode = Falseif args.loss_type==10:hard_mining = True_shape = (args.batch_size, args.per_batch_size)data_extra = np.full(_shape, -1.0, dtype=np.float32)c = 0while c<args.batch_size:a = 0while a<args.per_batch_size:b = a+args.images_per_identitydata_extra[(c+a):(c+b),a:b] = 1.0#print(c+a, c+b, a, b)a = bc += args.per_batch_sizeelif args.loss_type==11:data_extra = np.zeros( (args.batch_size, args.per_identities), dtype=np.float32)c = 0while c<args.batch_size:for i in range(args.per_identities):data_extra[c+i][i] = 1.0c+=args.per_batch_sizeelif args.loss_type==12 or args.loss_type==13:triplet_params = [args.triplet_bag_size, args.triplet_alpha, args.triplet_max_ap]elif args.loss_type==9:coco_mode = Truelabel_name = 'softmax_label'label_shape = (args.batch_size,)if args.output_c2c:label_shape = (args.batch_size,2)if data_extra is None:print("fixed_param_names",len(fixed_param_names))model = mx.mod.Module(context       = ctx,symbol        = sym,fixed_param_names=fixed_param_names)else:data_names = ('data', 'extra')#label_name = ''model = mx.mod.Module(context       = ctx,symbol        = sym,data_names    = data_names,label_names   = (label_name,),)if args.use_val:val_dataiter = FaceImageIter(batch_size           = args.batch_size,data_shape           = data_shape,path_imgrec          = val_rec,#path_imglist         = val_path,shuffle              = False,rand_mirror          = False,mean                 = mean,ctx_num              = args.ctx_num,data_extra           = data_extra,)else:val_dataiter = Noneif len(data_dir_list)==1 and args.loss_type!=12 and args.loss_type!=13:train_dataiter = FaceImageIter(batch_size           = args.batch_size,data_shape           = data_shape,path_imgrec          = path_imgrec,shuffle              = True,rand_mirror          = args.rand_mirror,mean                 = mean,cutoff               = args.cutoff,c2c_threshold        = args.c2c_threshold,output_c2c           = args.output_c2c,c2c_mode             = args.c2c_mode,limit                = args.train_limit,ctx_num              = args.ctx_num,images_per_identity  = args.images_per_identity,data_extra           = data_extra,hard_mining          = hard_mining,triplet_params       = triplet_params,coco_mode            = coco_mode,mx_model             = model,label_name           = label_name,)else:iter_list = []for _data_dir in data_dir_list:_path_imgrec = os.path.join(_data_dir, "train.rec")_dataiter = FaceImageIter(batch_size           = args.batch_size,data_shape           = data_shape,path_imgrec          = _path_imgrec,shuffle              = True,rand_mirror          = args.rand_mirror,mean                 = mean,cutoff               = args.cutoff,c2c_threshold        = args.c2c_threshold,output_c2c           = args.output_c2c,c2c_mode             = args.c2c_mode,limit                = args.train_limit,ctx_num              = args.ctx_num,images_per_identity  = args.images_per_identity,data_extra           = data_extra,hard_mining          = hard_mining,triplet_params       = triplet_params,coco_mode            = coco_mode,mx_model             = model,label_name           = label_name,)iter_list.append(_dataiter)iter_list.append(_dataiter)train_dataiter = FaceImageIterList(iter_list)if args.loss_type<10:_metric = AccMetric()else:_metric = LossValueMetric()eval_metrics = [mx.metric.create(_metric),mx.metric.create(LossValueMetric())]if args.network[0]=='r':initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet styleelif args.network[0]=='i' or args.network[0]=='x':initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2) #inceptionelse:initializer = mx.init.Xavier(rnd_type='uniform', factor_type="in", magnitude=2)_rescale = 1.0/args.ctx_numif args.noise_sgd>0.0:print('use noise sgd')opt = NoiseSGD(scale = args.noise_sgd, learning_rate=base_lr, momentum=base_mom, wd=base_wd, rescale_grad=_rescale)else:opt = optimizer.SGD(learning_rate=base_lr, momentum=base_mom, wd=base_wd, rescale_grad=_rescale)som = 20if args.loss_type==12 or args.loss_type==13:som = 2_cb = mx.callback.Speedometer(args.batch_size, som)ver_list = []ver_name_list = []for name in args.target.split(','):path = os.path.join(data_dir,name+".bin")if os.path.exists(path):data_set = verification.load_bin(path, image_size)ver_list.append(data_set)ver_name_list.append(name)print('ver', name)def ver_test(nbatch):results = []for i in range(len(ver_list)):acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(ver_list[i], model, args.batch_size, 10, data_extra, label_shape)print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))#print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc2, std2))results.append(acc2)return resultsdef val_test():acc = AccMetric()val_metric = mx.metric.create(acc)val_metric.reset()val_dataiter.reset()for i, eval_batch in enumerate(val_dataiter):model.forward(eval_batch, is_train=False)model.update_metric(val_metric, eval_batch.label)acc_value = val_metric.get_name_value()[0][1]print('VACC: %f'%(acc_value))highest_acc = [0.0, 0.0]  #lfw and target#for i in xrange(len(ver_list)):#  highest_acc.append(0.0)global_step = [0]save_step = [0]if len(args.lr_steps)==0:lr_steps = [40000, 60000, 80000]if args.loss_type>=1 and args.loss_type<=7:lr_steps = [100000, 140000, 160000]p = 512.0/args.batch_sizefor l in range(len(lr_steps)):lr_steps[l] = int(lr_steps[l]*p)else:lr_steps = [int(x) for x in args.lr_steps.split(',')]print('lr_steps', lr_steps)def _batch_callback(param):#global global_stepglobal_step[0]+=1mbatch = global_step[0]for _lr in lr_steps:if mbatch==args.beta_freeze+_lr:opt.lr *= 0.1print('lr change to', opt.lr)break_cb(param)if mbatch%1000==0:print('lr-batch-epoch:',opt.lr,param.nbatch,param.epoch)if mbatch>=0 and mbatch%args.verbose==0:acc_list = ver_test(mbatch)save_step[0]+=1msave = save_step[0]do_save = Falseif len(acc_list)>0:lfw_score = acc_list[0]if lfw_score>highest_acc[0]:highest_acc[0] = lfw_scoreif lfw_score>=0.9:do_save = Trueif acc_list[-1]>=highest_acc[-1]:highest_acc[-1] = acc_list[-1]if lfw_score>=0.9:do_save = Trueif args.ckpt==0:do_save = Falseelif args.ckpt>1:do_save = True#for i in xrange(len(acc_list)):#  acc = acc_list[i]#  if acc>=highest_acc[i]:#    highest_acc[i] = acc#    if lfw_score>=0.99:#      do_save = True#if args.loss_type==1 and mbatch>lr_steps[-1] and mbatch%10000==0:#  do_save = Trueif do_save:print('saving', msave)if val_dataiter is not None:val_test()arg, aux = model.get_params()mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)#if acc>=highest_acc[0]:#  lfw_npy = "%s-lfw-%04d" % (prefix, msave)#  X = np.concatenate(embeddings_list, axis=0)#  print('saving lfw npy', X.shape)#  np.save(lfw_npy, X)print('[%d]Accuracy-Highest: %1.5f'%(mbatch, highest_acc[-1]))if mbatch<=args.beta_freeze:_beta = args.betaelse:move = max(0, mbatch-args.beta_freeze)_beta = max(args.beta_min, args.beta*math.pow(1+args.gamma*move, -1.0*args.power))#print('beta', _beta)os.environ['BETA'] = str(_beta)if args.max_steps>0 and mbatch>args.max_steps:sys.exit(0)#epoch_cb = mx.callback.do_checkpoint(prefix, 1)epoch_cb = None#def _epoch_callback(epoch, sym, arg_params, aux_params):#  print('epoch-end', epoch)# print('initializer', initializer)print('_batch_callback', _batch_callback)# print('aux_params', aux_params)model.fit(train_dataiter,begin_epoch        = begin_epoch,num_epoch          = end_epoch,eval_data          = val_dataiter,eval_metric        = eval_metrics,kvstore            = 'device',optimizer          = opt,#optimizer_params   = optimizer_params,initializer        = initializer,arg_params         = arg_params,aux_params         = aux_params,allow_missing      = True,batch_end_callback = _batch_callback,epoch_end_callback = epoch_cb )if __name__ == '__main__':args = parse_args()train_net()

insightface mxnet训练 旧版相关推荐

  1. insightface mxnet训练horovod版

    依赖项: horovod Horovod是Uber开源的又一个深度学习工具,它的发展吸取了Facebook "Training ImageNet In 1 Hour" 与百度 &q ...

  2. insightface mxnet训练 out of Memory

    11年it研发经验,从一个会计转行为算法工程师,学过C#,c++,java,android,php,go,js,python,CNN神经网络,四千多篇博文,三千多篇原创,只为与你分享,共同成长,一起进 ...

  3. 我的第一个开源项目:Java爬虫爬取旧版正方教务系统课程表、成绩表

    Java爬虫爬取旧版正方教务系统课程表.成绩表 一.项目展示 1.正方教务系统 首页 2.爬虫系统 首页: 成绩查询: 课表查询: 二.项目实现 1.爬取思路描述 无论是成绩查询或课表查询亦或者其它的 ...

  4. 的app抓包 ssl_抓包旧版App

    准备软件 今天早些时候有一些开发者发现,苹果悄然发布了iOS 14 Developer Beta开发者预览升级的描述文件,不过最后时刻官方撤回了这个文件,防止新系统提前泄露. 正当开发者想要第一时间升 ...

  5. Java学习从入门到精通-旧版

    为什么80%的码农都做不了架构师?>>>    Java学习从入门到精通-旧版 http://tech.ccidnet.com/art/3737/20051017/465333_1. ...

  6. 删除win10自带的旧版edge浏览器(亲测有效)

    卸载新版的edge之后,下载edge dev版本,之后可恶的旧版edge就跑出来了,而且无法下载 powershell方法已失效 有效的方法请往下看. 打开powershell 输入 get-appx ...

  7. LoRa 之一 旧版驱动(sx12xxDrivers-V2.1.0)移植及驱动架构详解

      在之前的项目中,一直使用 LoRa 通信.很早之前就想写写文章记录一下学习过程.怎奈一直是一知半解的状态,想写不敢写!LoRa 这个东西在国内用的貌似不是太多.   对于无线通信,各个国家或者地区 ...

  8. 图标出问题_同是Office365,为什么你的软件图标还是旧版的?

    为什么你的office365套件最新版的图标还是旧版? 是新版图标还没向正式版用户推送吗? 我的office365的账号有问题吗? 难道是我打开的方式不对吗? 旧版图标 新版图标 打开产品信息一看你的 ...

  9. python新旧特性过渡_网站改版时的一种新旧版过渡方案

    网站改版时,需要考虑一个周全的过渡方案,其中不容忽视的一点就是对旧版的处理问题.即使借助完美的数据迁移方案可以使新版从内容上完全取代旧版,但我们仍然不应该立即彻底废除掉旧版,因为: 1.网民有可能通过 ...

最新文章

  1. python学到什么程度可以写爬虫-月薪2万的爬虫工程师,Python需要学到什么程度?...
  2. javascript中的this
  3. AttributeError: module 'select' has no attribute 'error'解决方法
  4. SQL Server Reporting Services最佳做法
  5. MFC中的CRect(区域)
  6. ICPC程序设计题解书籍系列之八:(美)斯基纳等:《挑战编程-程序设计竞赛训练手册》
  7. 中间人攻击的实践与原理(ARP毒化、DNS欺骗)
  8. 计算机信息系统用户管理规定,计算机信息系统保密管理暂行办法 | 中华全国商业信息中心...
  9. linux下 查看 光模块信息,HPE品牌SFP光模块信息检查办法
  10. Cadence orcad 批量设置原理图标题栏
  11. [RK3399][Android7.1] Audio中的Ducking模式
  12. act考试是什么意思?
  13. JavaMailSenderImpl 使用465端口配置
  14. maven报错解决办法之一
  15. 马云斯坦福大学演讲:想法与技术可以改变世界
  16. m3u8源地址,抓片....下片....看片都透露给你
  17. 多线程基础(四)之死锁
  18. 世界上最远的距离——泰戈尔 (MP3下载)
  19. 向Kubernetes集群添加/删除Node
  20. 计算机的大管家教学反思,第二课 计算机的“大管家”.doc

热门文章

  1. php中$_REQUEST、$_POST、$_GET的区别
  2. Ubuntu命令行下安装,卸载软件包的过程
  3. VC++如何判断当前操作系统是32位还是64位?
  4. php+实现群发微信模板消息_php实现发送微信模板消息的方法,php信模板消息_PHP教程...
  5. Problem 58 怎样判断当前程序链接的是多线程版的Glibc还是单线程版的Glibc?
  6. fullcalendar 显示的时间间隔只有四十五分钟_手腕上的机械闹钟百达翡丽 Ref.5520P旅行时间闹钟腕表...
  7. 泉州服务器维修,泉州云服务器
  8. centos安装后两个启动项、_centos8的启动项配置
  9. jdeveloper_在JDeveloper 12.1.3中为WebSocket使用Java API
  10. 计算机二级能学到知识吗,2017年关于计算机二级msoffice学习知识点