先吐槽一下,深度学习发展速度真的很快,深度学习框架也在逐步迭代,真的是苦了俺这搞深度学习程序员。本人从三年前开始学深度学习开始,这些深度学习框架也是一个换过一个,从keras、theano、caffe、darknet、tensorflow,最后到现在要开始使用pytorch。

一、变量、求导torch.autograd模块

默认的variable定义的时候,requires_grad是false,变量是不可导的,如果设置为true,表示变量可导。

#coding=utf-8
#requires_grad默认为false
# 如果调用backward的时候,所有的变量都是不可导的,那么最后会报出没有可到变量的错误
import torch
from torch import  autograd
input=torch.FloatTensor([1,2,3])
input_v=autograd.Variable(input,requires_grad=True)
loss=torch.mean(input_v)print loss.requires_grad
loss.backward()
print input_v
print input_v.grad

二、数据层及其变换

from PIL import Image
import torchvision
import matplotlib.pyplot as plt
import numpy as np
#数据变换
data_transform_train = torchvision.transforms.Compose([torchvision.transforms.RandomRotation(30),torchvision.transforms.RandomCrop((32,32)),torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
data_transform_eval=torchvision.transforms.Compose([torchvision.transforms.CenterCrop((32,32)),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
#对于自定义的数据,要重载下面三个函数getitem、len、init
class mydata(Dataset):def __init__(self,label_file,image_root,is_train=True):self.imagepaths=[]self.labels=[]self.is_train=is_trainif is_train:self.transforms=data_transform_trainelse:self.transforms=data_transform_evalwith open(label_file,'r') as f:for line in f.readlines():#读取label文件self.imagepaths.append(os.path.join(image_root,line.split()[0]))self.labels.append(int(line.split()[1]))def __getitem__(self, item):x=Image.open(self.imagepaths[item]).resize((35,35))y=self.labels[item]if self.is_train:return [self.transforms(x),self.transforms(x)], yelse:return self.transforms(x),ydef __len__(self):return len(self.imagepaths)def make_weights_for_balanced_classes(labels, nclasses):count = {}for item in labels:if count.has_key(item):count[item] += 1else:count[item]=1weight_per_class ={}N = len(labels)for key,value in count.items():weight_per_class[key] = N/float(value)weight = [0] * len(labels)for idx, val in enumerate(labels):weight[idx] = weight_per_class[val]return weighttrain_data=mydata('data/train.txt','./',is_train=True)
weights = make_weights_for_balanced_classes(train_data.labels, 3)
weights = torch.DoubleTensor(weights)
sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))
train_dataloader_student=DataLoader(train_data,batch_size=6,sampler=sampler)
for x,y in train_dataloader_student:for xi in x:print ynpimg = torchvision.utils.make_grid(xi).numpy()#可视化显示plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()

三、网络架构

from moving_avarage_layer import conv2d_moving
import torch
from torch import  autograd,nn
from torch.utils.data import DataLoader, Dataset
from data_layer import mydata,make_weights_for_balanced_classes
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as function
import os
import timeclass MobileNet(nn.Module):def __init__(self):super(MobileNet, self).__init__()def conv_bn(inp, oup, stride):return nn.Sequential(nn.Conv2d(inp, oup, 3, stride, 1, bias=False),nn.BatchNorm2d(oup),nn.ReLU(inplace=True))def conv_dw(inp, oup, stride):return nn.Sequential(nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),nn.BatchNorm2d(inp),nn.ReLU(inplace=True),nn.Conv2d(inp, oup, 1, 1, 0, bias=False),nn.BatchNorm2d(oup),nn.ReLU(inplace=True),)self.model = nn.Sequential(conv_bn(  3,  32, 2),conv_dw( 32,  64, 1),conv_dw( 64, 96, 2),conv_dw(96, 96, 1),conv_dw(96, 128, 2),conv_dw(128, 128, 1),conv_dw(128, 256, 2),conv_dw(256, 256, 1),conv_dw(256, 512, 1),nn.AvgPool2d(2),)self.fc = nn.Linear(512, 4)def forward(self, x):x = self.model(x)#print x.shapex = x.view(-1, 512)x = self.fc(x)return x

四、优化求解

def update_ema_variables(model, ema_model,alpha):for ema_param, param in zip(ema_model.parameters(), model.parameters()):ema_param.data.mul_(alpha).add_(1 - alpha, param.data)
def softmax_mse_loss(input_logits, target_logits):assert input_logits.size() == target_logits.size()input_softmax = function.softmax(input_logits, dim=1)target_softmax = function.softmax(target_logits, dim=1)num_classes = input_logits.size()[1]return function.mse_loss(input_softmax, target_softmax, size_average=False) / num_classestorch.backends.cudnn.enabled = False
torch.manual_seed(7)net_student=MobileNet().cuda()
net_teacher=MobileNet().cuda()
for param in net_teacher.parameters():param.detach_()
if os.path.isfile('teacher.pt'):net_student.load_state_dict(torch.load('teacher.pt'))train_data=mydata('data/race/train.txt','./',is_train=True)
min_batch_size=32
weights = make_weights_for_balanced_classes(train_data.labels, 5)
weights = torch.DoubleTensor(weights)
sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))
train_dataloader=DataLoader(train_data,batch_size=min_batch_size,sampler=sampler,num_workers=8)valid_data=mydata('data/race/val.txt','./',is_train=False)
valid_dataloader=DataLoader(valid_data,batch_size=min_batch_size,shuffle=True,num_workers=8)classify_loss_function = torch.nn.CrossEntropyLoss(size_average=False,ignore_index=-1).cuda()
optimizer = torch.optim.SGD(net_student.parameters(),lr = 0.001, momentum=0.9)globals_step=0
for epoch in range(10000):globals_classify_loss=0globals_consistency_loss = 0net_student.train()start=time.time()end=0for index,(x,y) in enumerate(train_dataloader):optimizer.zero_grad()  #x_student=autograd.Variable(x[0]).cuda()y=autograd.Variable(y).cuda()predict_student=net_student(x_student)classify_loss=classify_loss_function(predict_student,y)/min_batch_sizesum_loss = classify_lossx_teacher= autograd.Variable(x[1],volatile=True).cuda()predict_teacher = net_teacher(x_teacher)ema_logit = autograd.Variable(predict_teacher.detach().data, requires_grad=False)consistency_loss =softmax_mse_loss(predict_student,ema_logit)/min_batch_sizeconsistency_weight=1sum_loss+=consistency_weight*consistency_lossglobals_consistency_loss += consistency_loss.data[0]sum_loss.backward()optimizer.step()alpha = min(1 - 1 / (globals_step + 1), 0.99)update_ema_variables(net_student, net_teacher, alpha)globals_classify_loss +=classify_loss.data[0]globals_step += 1if epoch%5!=0:continuenet_student.eval()correct = 0total = 0for images, labels in valid_dataloader:valid_input=autograd.Variable(images,volatile=True).cuda()outputs = net_student(valid_input)#print outputs.shape_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted.cpu() == labels).sum()print "epoch:%d"%epoch,"time:%d"%(time.time()-start),'accuracy %d' % (100 * correct / total),"consistency loss:%f"%globals_consistency_loss,'classify loss%f:'%globals_classify_losstorch.save(net_student.state_dict(),'teacher.pt')

深度学习(七十三)pytorch学习笔记相关推荐

  1. 深度学习入门之PyTorch学习笔记:卷积神经网络

    深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 3 多层全连接网络 4 卷积神经网络 4.1 主要任务及起源 4.2 卷积神经网络的原理和结构 4.2.1 卷积层 1. ...

  2. 深度学习入门之PyTorch学习笔记:多层全连接网络

    深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 3 多层全连接网络 3.1 PyTorch基础 3.2 线性模型 3.2.1 问题介绍 3.2.2 一维线性回归 3.2 ...

  3. 深度学习入门之PyTorch学习笔记:深度学习框架

    深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 2.1 深度学习框架介绍 2.1.1 TensorFlow 2.1.2 Caffe 2.1.3 Theano 2.1.4 ...

  4. 深度学习入门之PyTorch学习笔记:深度学习介绍

    深度学习入门之PyTorch学习笔记:深度学习介绍 绪论 1 深度学习介绍 1.1 人工智能 1.2 数据挖掘.机器学习.深度学习 1.2.1 数据挖掘 1.2.2 机器学习 1.2.3 深度学习 第 ...

  5. 深度学习入门之PyTorch学习笔记

    深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 3 多层全连接网络 4 卷积神经网络 5 循环神经网络 6 生成对抗网络 7 深度学习实战 参考资料 绪论 深度学习如今 ...

  6. springboot学习(七十三) springboot中使用springdoc替换swagger(springfox)

    文章目录 前言 一.springdoc介绍 二.使用步骤 1.引入库 2. 创建一个spring配置类,添加springdoc的配置 3. 常用的swagger注解和springdoc的对应关系 4. ...

  7. 纽约大学深度学习PyTorch课程笔记(自用)Week3

    纽约大学深度学习PyTorch课程笔记Week3 Week 3 3.1 神经网络参数变换可视化及卷积的基本概念 3.1.1 神经网络的可视化 3.1.2 参数变换 一个简单的参数变换:权重共享 超网络 ...

  8. PyTorch学习笔记(七):PyTorch可视化

    PyTorch可视化 往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本系列目录: PyTorch学习笔记(一) ...

  9. 纽约大学深度学习PyTorch课程笔记(自用)Week6

    纽约大学深度学习PyTorch课程笔记Week6 Week 6 6.1 卷积网络的应用 6.1.1 邮政编码识别器 使用CNN进行识别 6.1.2 人脸检测 一个多尺度人脸检测系统 6.1.3 语义分 ...

最新文章

  1. SQL Server技术问题之索引优缺点
  2. ubuntu系统在vmware中无法联网问题解决
  3. kali 切换root权限_Ubuntu 被曝严重漏洞:切换系统语言 + 输入几行命令,就能获取 root 权限...
  4. 我不是天生的飞鸽传书2011
  5. asa 防火墙基本配置管理
  6. Linux下多线程pthread内存泄露
  7. 【梳理】高等代数(北大) 第一章 线性方程组(docx)
  8. 手写字体研究-matlab
  9. 对报表.FRX文件的全面分析
  10. python灰色关联度分析_基于灰色关联度重庆万州区边坡稳定影响因素分析
  11. c语言列宽作用,c语言|格式化输入输出详解
  12. iDrac6 虚拟控制台 连接失败
  13. elasticsearch从入门到入门系列(二)---快速入门A
  14. matlab 矩阵的n次,用matlab的for循环产生N个矩阵,怎么取第N次的矩阵?
  15. WebDAV将会在公共领域取代FTP
  16. Numpy——np.diag()一文看懂
  17. 根域名服务器都在国外,中国安全吗?安全
  18. 腾讯游戏云以科技连接游戏未来,全力打造行业新生态
  19. Http SSL 即(HTTPS)证书的深入理解及证书管理方法
  20. 机器学习入门级实例——针对葡萄酒质量进行建模

热门文章

  1. sql批量插入数据mysql_MYSQL批量插入数据库实现语句性能分析
  2. 的优缺点_折叠门的优缺点
  3. 【Linux】Linux系统备份与还原
  4. 【debug】json.decoder.JSONDecodeError: Expecting value: line 1 column 1 (char 0)
  5. linux中写如空格参数,Vim中Tab与空格缩进
  6. python 全中文匹配字符_Python教程:进程和线程amp;正则表达式
  7. c++中string插入一个字符_Java内存管理-探索Java中字符串String(十二)
  8. iis如何处理并发请求
  9. redhat7下对用户账户的管理
  10. nodejs nodemailer