PyTorch代码学习-ImageNET训练

文章说明:本人学习pytorch/examples/ImageNET/main()理解(待续)

# -*- coding: utf-8 -*-
import argparse  # 命令行解释器相关程序,命令行解释器
import os        # 操作系统文件相关
import shutil    # 文件高级操作
import time      # 调用时间模块import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn        # gpu 使用
import torch.distributed as dist            # 分布式(pytorch 0.2)
import torch.optim                          # 优化器
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models# name中若为小写且不以‘——’开头,则对其进行升序排列
model_names = sorted(name for name in models.__dict__if name.islower() and not name.startswith("__")and callable(models.__dict__[name]))                # callable功能为判断返回对象是否可调用(即某种功能)。# 创建argparse.ArgumentParser对象
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
# 添加命令行元素
parser.add_argument('data', metavar='DIR',help='path to dataset')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',choices=model_names,help='model architecture: ' +' | '.join(model_names) +' (default: resnet18)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int,metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--print-freq', '-p', default=10, type=int,metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',help='use pre-trained model')
parser.add_argument('--world-size', default=1, type=int,help='number of distributed processes')
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,help='url used to set up distributed training')
parser.add_argument('--dist-backend', default='gloo', type=str,help='distributed backend')# 定义参数
best_prec1 = 0# 定义主函数main()
def main():global args, best_prec1# 使用函数parse_args()进行参数解析,输入默认是sys.argv[1:],# 返回值是一个包含命令参数的Namespace,所有参数以属性的形式存在,比如args.myoption。args = parser.parse_args()########## 使用多播地址进行初始化args.distributed = args.world_size > 1if args.distributed:dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,world_size=args.world_size)##### step1: create model and set GPU # 导入pretrained model 或者创建modelif args.pretrained:# format 格式化表达字符串,上述默认arch为resnet18print("=> using pre-trained model '{}'".format(args.arch))      model = models.__dict__[args.arch](pretrained=True)else:print("=> creating model '{}'".format(args.arch))model = models.__dict__[args.arch]()# 分布式运行,可实现在多块GPU上运行if not args.distributed:if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):# 批处理,多GPU默认用dataparallel使用在多块gpu上model.features = torch.nn.DataParallel(model.features)           model.cuda()else:model = torch.nn.DataParallel(model).cuda()else:# Wrap model in DistributedDataParallel (CUDA only for the moment)model.cuda()model = torch.nn.parallel.DistributedDataParallel(model)##### step2: define loss function (criterion) and optimizer# 使用交叉熵损失函数criterion = nn.CrossEntropyLoss().cuda()                            # optimizer 使用 SGD + momentum# 动量,默认设置为0.9optimizer = torch.optim.SGD(model.parameters(), args.lr,momentum=args.momentum,# 权值衰减,默认为1e-4                 weight_decay=args.weight_decay)         # 恢复模型(详见模型存取与恢复)
####step3:optionally resume from a checkpointif args.resume:if os.path.isfile(args.resume):                                 # 判断返回的是不是文件print("=> loading checkpoint '{}'".format(args.resume))checkpoint = torch.load(args.resume)                        # load 一个save的对象args.start_epoch = checkpoint['epoch']                      # default = 90best_prec1 = checkpoint['best_prec1']                       # best_prec1 = 0model.load_state_dict(checkpoint['state_dict'])optimizer.load_state_dict(checkpoint['optimizer'])          # load_state_dict:恢复模型print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))else:print("=> no checkpoint found at '{}'".format(args.resume))cudnn.benchmark = True##### step4: Data loading code base of dataset(have downloaded) and normalize# 从 train、val文件中导入数据traindir = os.path.join(args.data, 'train')valdir = os.path.join(args.data, 'val')# 数据预处理:normalize: - mean / stdnormalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],       std=[0.229, 0.224, 0.225])# ImageFolder 一个通用的数据加载器train_dataset = datasets.ImageFolder(traindir,# 对数据进行预处理transforms.Compose([                      # 将几个transforms 组合在一起transforms.RandomSizedCrop(224),      # 随机切再resize成给定的size大小transforms.RandomHorizontalFlip(),    # 概率为0.5,随机水平翻转。transforms.ToTensor(),                # 把一个取值范围是[0,255]或者shape为(H,W,C)的numpy.ndarray,# 转换成形状为[C,H,W],取值范围是[0,1.0]的torch.FloadTensornormalize,]))#######if args.distributed:# Use a DistributedSampler to restrict each process to a distinct subset of the dataset.train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)else:train_sampler = None
####### train 数据下载及预处理train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),num_workers=args.workers, pin_memory=True, sampler=train_sampler)val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(valdir, transforms.Compose([ # 重新改变大小为`size`,若:height>width`,则:(size*height/width, size)transforms.Scale(256),# 将给定的数据进行中心切割,得到给定的size。transforms.CenterCrop(224),transforms.ToTensor(),normalize,])),batch_size=args.batch_size, shuffle=False,num_workers=args.workers, pin_memory=True)         # default workers = 4##### step5: 验证函数if args.evaluate:validate(val_loader, model, criterion)             # 自定义的validate函数,见下return##### step6:开始训练模型for epoch in range(args.start_epoch, args.epochs):# Use .set_epoch() method to reshuffle the dataset partition at every iterationif args.distributed:train_sampler.set_epoch(epoch)adjust_learning_rate(optimizer, epoch)      # adjust_learning_rate 自定义的函数,见下# train for one epochtrain(train_loader, model, criterion, optimizer, epoch)# evaluate on validation setprec1 = validate(val_loader, model, criterion)# remember best prec@1 and save checkpointis_best = prec1 > best_prec1best_prec1 = max(prec1, best_prec1)save_checkpoint({'epoch': epoch + 1,'arch': args.arch,'state_dict': model.state_dict(),'best_prec1': best_prec1,'optimizer' : optimizer.state_dict(),}, is_best)# 定义相关函数
# def train 函数
def train(train_loader, model, criterion, optimizer, epoch):batch_time = AverageMeter()data_time = AverageMeter()losses = AverageMeter()top1 = AverageMeter()top5 = AverageMeter()# switch to train modemodel.train()end = time.time()for i, (input, target) in enumerate(train_loader):# measure data loading timedata_time.update(time.time() - end)target = target.cuda(async=True)input_var = torch.autograd.Variable(input)target_var = torch.autograd.Variable(target)# compute outputoutput = model(input_var)# criterion 为定义过的损失函数loss = criterion(output, target_var)        # measure accuracy and record lossprec1, prec5 = accuracy(output.data, target, topk=(1, 5))losses.update(loss.data[0], input.size(0))top1.update(prec1[0], input.size(0))top5.update(prec5[0], input.size(0))# compute gradient and do SGD stepoptimizer.zero_grad()loss.backward()optimizer.step()# measure elapsed timebatch_time.update(time.time() - end)end = time.time()# 每十步输出一次if i % args.print_freq == 0:     # default=10print('Epoch: [{0}][{1}/{2}]\t''Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t''Data {data_time.val:.3f} ({data_time.avg:.3f})\t''Loss {loss.val:.4f} ({loss.avg:.4f})\t''Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t''Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(epoch, i, len(train_loader), batch_time=batch_time,data_time=data_time, loss=losses, top1=top1, top5=top5))def validate(val_loader, model, criterion):batch_time = AverageMeter()losses = AverageMeter()top1 = AverageMeter()top5 = AverageMeter()# switch to evaluate modemodel.eval()end = time.time()for i, (input, target) in enumerate(val_loader):target = target.cuda(async=True)# 这是一种用来包裹张量并记录应用的操作"""Attributes:data: 任意类型的封装好的张量。grad: 保存与data类型和位置相匹配的梯度,此属性难以分配并且不能重新分配。requires_grad: 标记变量是否已经由一个需要调用到此变量的子图创建的bool值。只能在叶子变量上进行修改。volatile: 标记变量是否能在推理模式下应用(如不保存历史记录)的bool值。只能在叶变量上更改。is_leaf: 标记变量是否是图叶子(如由用户创建的变量)的bool值.grad_fn: Gradient function graph trace.Parameters:data (any tensor class): 要包装的张量.requires_grad (bool): bool型的标记值. **Keyword only.**volatile (bool): bool型的标记值. **Keyword only.**"""input_var = torch.autograd.Variable(input, volatile=True)target_var = torch.autograd.Variable(target, volatile=True)# compute outputoutput = model(input_var)loss = criterion(output, target_var)# measure accuracy and record lossprec1, prec5 = accuracy(output.data, target, topk=(1, 5))losses.update(loss.data[0], input.size(0))top1.update(prec1[0], input.size(0))top5.update(prec5[0], input.size(0))# measure elapsed timebatch_time.update(time.time() - end)end = time.time()if i % args.print_freq == 0:print('Test: [{0}/{1}]\t''Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t''Loss {loss.val:.4f} ({loss.avg:.4f})\t''Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t''Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(i, len(val_loader), batch_time=batch_time, loss=losses,top1=top1, top5=top5))print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'.format(top1=top1, top5=top5))return top1.avg# 保存当前节点
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):torch.save(state, filename)if is_best:shutil.copyfile(filename, 'model_best.pth.tar')# 计算并存储参数当前值或平均值
class AverageMeter(object):# Computes and stores the average and current value"""batch_time = AverageMeter()即 self = batch_time则 batch_time 具有__init__,reset,update三个属性,直接使用batch_time.update()调用功能为:batch_time.update(time.time() - end)仅一个参数,则直接保存参数值对应定义:def update(self, val, n=1)losses.update(loss.data[0], input.size(0))top1.update(prec1[0], input.size(0))top5.update(prec5[0], input.size(0))这些有两个参数则求参数val的均值,保存在avg中##不确定##"""def __init__(self):self.reset()       # __init__():reset parametersdef reset(self):self.val = 0self.avg = 0self.sum = 0self.count = 0def update(self, val, n=1):self.val = valself.sum += val * nself.count += nself.avg = self.sum / self.count# 更新 learning_rate :每30步,学习率降至前的10分之1
def adjust_learning_rate(optimizer, epoch):"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""lr = args.lr * (0.1 ** (epoch // 30))            # args.lr = 0.1 , 即每30步,lr = lr /10for param_group in optimizer.param_groups:       # 将更新的lr 送入优化器 optimizer 中,进行下一次优化param_group['lr'] = lr# 计算准确度
def accuracy(output, target, topk=(1,)):"""Computes the precision@k for the specified values of kprec1, prec5 = accuracy(output.data, target, topk=(1, 5))"""maxk = max(topk)# size函数:总元素的个数batch_size = target.size(0)# topk函数选取output前k大个数_, pred = output.topk(maxk, 1, True, True)##########不了解t()pred = pred.t()correct = pred.eq(target.view(1, -1).expand_as(pred))res = []for k in topk:correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)res.append(correct_k.mul_(100.0 / batch_size))return resif __name__ == '__main__':main()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293
  • 294
  • 295
  • 296
  • 297
  • 298
  • 299
  • 300
  • 301
  • 302
  • 303
  • 304
  • 305
  • 306
  • 307
  • 308
  • 309
  • 310
  • 311
  • 312
  • 313
  • 314
  • 315
  • 316
  • 317
  • 318
  • 319
  • 320
  • 321
  • 322
  • 323
  • 324
  • 325
  • 326
  • 327
  • 328
  • 329
  • 330
  • 331
  • 332
  • 333
  • 334
  • 335
  • 336
  • 337
  • 338
  • 339
  • 340
  • 341
  • 342
  • 343
  • 344
  • 345
  • 346
  • 347
  • 348
  • 349
  • 350
  • 351
  • 352
  • 353
  • 354
  • 355
  • 356
  • 357
  • 358
  • 359
  • 360
  • 361
  • 362
  • 363
  • 364
  • 365
  • 366
  • 367
  • 368
  • 369
  • 370
  • 371
  • 372
  • 373
  • 374
  • 375
  • 376
  • 377
  • 378
  • 379
  • 380
  • 381
  • 382
  • 383
  • 384
  • 385
  • 386
  • 387
  • 388
  • 389

文章目录 [隐藏]

  • 1 方法一(推荐):

    • 1.1 保存
    • 1.2 恢复
  • 2 方法二:
    • 2.1 保存
    • 2.2 恢复
  • 3 一个相对完整的例子
  • 4 获取模型中某些层的参数

在模型完成训练后,我们需要将训练好的模型保存为一个文件供测试使用,或者因为一些原因我们需要继续之前的状态训练之前保存的模型,那么如何在PyTorch中保存和恢复模型呢?

参考PyTorch官方的这份repo,我们知道有两种方法可以实现我们想要的效果。

方法一(推荐):

第一种方法也是官方推荐的方法,只保存和恢复模型中的参数。

保存

torch.save(the_model.state_dict(), PATH)
1
torch.save(the_model.state_dict(), PATH)

恢复

the_model = TheModelClass(*args, **kwargs) the_model.load_state_dict(torch.load(PATH))
1
2

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

使用这种方法,我们需要自己导入模型的结构信息。

方法二:

使用这种方法,将会保存模型的参数和结构信息。

保存

torch.save(the_model, PATH)
1
torch.save(the_model, PATH)

恢复

the_model = torch.load(PATH)
1
the_model = torch.load(PATH)

一个相对完整的例子

saving

torch.save({ ‘epoch’: epoch + 1, ‘arch’: args.arch, ‘state_dict’: model.state_dict(), ‘best_prec1’: best_prec1, }, ‘checkpoint.tar’ )
1
2
3
4
5
6

torch.save({
            ‘epoch’: epoch + 1,
            ‘arch’: args.arch,
            ‘state_dict’: model.state_dict(),
            ‘best_prec1’: best_prec1,
        }, ‘checkpoint.tar’ )

loading

if args.resume: if os.path.isfile(args.resume): print(“=> loading checkpoint ‘{}’”.format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint[‘epoch’] best_prec1 = checkpoint[‘best_prec1’] model.load_state_dict(checkpoint[‘state_dict’]) print(“=> loaded checkpoint ‘{}’ (epoch {})” .format(args.evaluate, checkpoint[‘epoch’]))
1
2
3
4
5
6
7
8
9

if args.resume:
        if os.path.isfile(args.resume):
            print(“=> loading checkpoint ‘{}’”.format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint[‘epoch’]
            best_prec1 = checkpoint[‘best_prec1’]
            model.load_state_dict(checkpoint[‘state_dict’])
            print(“=> loaded checkpoint ‘{}’ (epoch {})”
                  .format(args.evaluate, checkpoint[‘epoch’]))

获取模型中某些层的参数

对于恢复的模型,如果我们想查看某些层的参数,可以:

# 定义一个网络 from collections import OrderedDict model = nn.Sequential(OrderedDict([ (‘conv1’, nn.Conv2d(1,20,5)), (‘relu1’, nn.ReLU()), (‘conv2’, nn.Conv2d(20,64,5)), (‘relu2’, nn.ReLU()) ]))

http://www.taodudu.cc/news/show-2420930.html

相关文章:

  • 完美收官!Fortinet Accelerate 2022中国站在北京落幕
  • python全栈测试开发工程师_Python测试开发全栈核心课程 互联网测试工程师必修课...
  • 数据交易,距离生产要素市场化还有多远? | 2022全球数字价值峰会
  • 中国智能经济觉醒,云智一体打造产业智能化加速器
  • “碳壁垒”悄然而起,碳足迹如何算清楚、减明白?|双碳科普
  • 为什么说服务逻辑,才是SaaS的底层逻辑
  • 赛道和资本的玩儿法已经过气,SaaS公司活下去还能靠什么?
  • 挖掘数字资产,生意增长是本质,但数据创新仍有难题待解 | 2022全球数字价值峰会...
  • 数字时代的保险创新与升级 | 创新场景50
  • 消费品牌数字营销“终局九问” | 2022全球数字价值峰会
  • 基础软件皇冠上的明珠,数据库创新 | 创新场景50
  • 千行百业如何正确上BI?不仅要数据,更要生态 | 创新场景50
  • 傅一平:业务流程的数字化到底是什么?
  • 清华大学林常乐:数据要素定价的思考与实践 | 数字思考者50人
  • 真正的数字化,是CEO决策的底层逻辑要变了
  • 【产业互联网周报】罗永浩AR创业公司获美团领投;英特尔自动驾驶子公司Mobileye敲定IPO条款;星环科技登陆科创板...
  • 宝洁中国CIO沈锋:全球日化巨头是如何做数字化的|数字思考者50人
  • 【北交所周报】IPO上会5过5;四成个股实现上涨,硅烷科技涨56%,成单周涨幅最大个股;...
  • 算力和想象力
  • poppin_xpower_ 常城
  • Exp3 免杀原理与实践 20154328 常城
  • Exp5 CAL_MSF基础运用 20154328 常城
  • EXP6 信息搜集与漏洞扫描 20154328 常城
  • Exp2 后门原理与实践 20154328 常城
  • Exp9 Web安全实践基础 20154328 常城
  • Exp4 恶意代码分析 20154328 常城
  • Exp8 Web基础 20154328 常城
  • ITextSharp获取pdf文件指定关键字的坐标信息,用于签名。
  • 如何在PDF文件中快速查找关键字?
  • 使用iText对pdf中查找关键字坐标进行填充

PyTorch代码学习-ImageNET训练相关推荐

  1. 收藏 | PyTorch深度学习模型训练加速指南2021

    点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:作者:LORENZ KUHN 编译:ronghuaiyang ...

  2. TSN算法的PyTorch代码解读(训练部分)

    这篇博客来读一读TSN算法的PyTorch代码,总体而言代码风格还是不错的,多读读优秀的代码对自身的提升还是有帮助的,另外因为代码内容较多,所以分训练和测试两篇介绍,这篇介绍训练代码,介绍顺序为代码运 ...

  3. 《PHASEN:A Phase and Harmonics-Aware Speech Enhancement Network》Pytorch代码学习Ⅱ

    数据预处理 本文的实验采用的是Voice Bank的数据集,其中训练集大约包含11000条语音.上一篇文章中提到模型的输入是语音数据的短时傅里叶变换(幅值.相位),包含四个维度,分别是[batch, ...

  4. 《PHASEN:A Phase and Harmonics-Aware Speech Enhancement Network》Pytorch代码学习

    PHASEN结构 源码地址:https://github.com/huyanxin/phasen PHASEN是一个双流网络,其中幅值流和相位流分别专门用于幅值和相位预测.幅值流主要由卷积操作,频域变 ...

  5. 【深度学习】翻译:60分钟入门PyTorch(四)——训练一个分类器

    前言 原文翻译自:Deep Learning with PyTorch: A 60 Minute Blitz 翻译:林不清(https://www.zhihu.com/people/lu-guo-92 ...

  6. 深度学习100+经典模型TensorFlow与Pytorch代码实现大合集

    关注上方"深度学习技术前沿",选择"星标公众号", 资源干货,第一时间送达! [导读]深度学习在过去十年获得了极大进展,出现很多新的模型,并且伴随TensorF ...

  7. 【深度学习】PyTorch深度学习训练可视化工具visdom

    PyTorch Author:louwill Machine Learning Lab 在进行深度学习实验时,能够可视化地对训练过程和结果进行展示是非常有必要的.除了Torch版本的TensorBoa ...

  8. MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(三)...

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 在前两篇文章MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网 ...

  9. 实战例子_Pytorch官方力荐新书《Pytorch深度学习实战指南》pdf及代码分享

    PyTorch是目前非常流行的机器学习.深度学习算法运算框架.它可以充分利用GPU进行加速,可以快速的处理复杂的深度学习模型,并且具有很好的扩展性,可以轻松扩展到分布式系统.PyTorch与Pytho ...

  10. CornerNet代码学习之pytorch多线程

    Cornernet代码之pytorch多线程学习 源码剖析 main() train() 页锁定内存 守护线程 init_parallel_jobs().pin_memory() 信号量 附录-源码内 ...

最新文章

  1. 美团高德并不是解决快车问题的灵药,烧完钱之后只会产生新的滴滴
  2. 游戏里的角色都什么格式图片_二十年前是怎样开发游戏的?
  3. cocos2d-x游戏实例(18)-纵版射击游戏(5)
  4. java中的反射(二)
  5. 谷歌浏览器如何正确离线网页
  6. 【小程序项目分享】多功能抽签分组系统
  7. 关于嵌入式软件系统测试策略和方案设计详解
  8. c语言程序算一元二次方程,以实例跟我学C语言:如何求解一元二次方程的根
  9. CTO(技术总监)平时都在做些什么?
  10. canvas--放大镜效果
  11. 11.2 申请API KEY
  12. Gvim计数器模板经典练习
  13. matlab 离散阶跃函数,matlab阶跃函数
  14. [Error]cannot convert 'float'tot float for argument 1to floa
  15. Lambda expression are not supported at language level '5'
  16. 山寨王被山寨 腾讯九城恶性竞争害产业
  17. 读书忘却时间——灵魂的沉淀
  18. 电子商务行业物流现状分析
  19. 华为机试java_华为java机试面试题目大全
  20. java计算机毕业设计绿叶有限公司工资管理信息系统源码+系统+mysql数据库+lw文档

热门文章

  1. 手机工商银行怎么转账_工行手机银行可以转账吗
  2. 工商银行近20年实时大数据平台建设历程
  3. 密歇根州立大学被黑 个人信息和社保号码被盗
  4. flutter数据解析出现type ‘String‘ is not a subtype of type ‘int‘错误
  5. 卷毛机器人抢大龙_LOL:机器人史诗级加强,如果他还没退役,SKT都不敢放机器人...
  6. 对封装的ajax的应用-查询商铺
  7. react-native系列(11)组件篇:Image图片加载和ImageEditor图片剪切
  8. 三向振动台的计算机辅助测试实验分析原因,振动试验原理及试验考虑的条件
  9. java barchart类,JavaFX BarChart条形图颜色
  10. 计算机网络什么属于广域网,以什么将网络划分为广域网和局域网