我们将在Google Colab中实现执行,因为它提供免费的云TPU(张量处理单元)。

在继续下一步之前,在Colab笔记本中,转到“编辑”,然后选择“设置”,从下面屏幕截图中的列表中选择“TPU”作为“硬件加速器”。

验证TPU下面的代码是否正常运行。

import os
assert os.environ['COLAB_TPU_ADDR']

如果启用了TPU,它将成功执行,否则它将返回‘KeyError: ‘COLAB_TPU_ADDR’’。你也可以通过打印TPU地址来检查TPU

TPU_Path = 'grpc://'+os.environ['COLAB_TPU_ADDR']
print('TPU Address:', TPU_Path)

启用TPU后,我们将安装兼容的控制盘和依赖项,以使用以下代码设置XLA环境。

VERSION = "20200516"
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version $VERSION

一旦安装成功,我们将继续定义加载数据集、初始化CNN模型、训练和测试的方法。首先,我们将导入所需的库。

import numpy as np
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.utils.utils as xu
from torchvision import datasets, transforms

之后,我们将进一步定义需要的超参数。

# 定义参数
FLAGS = {}
FLAGS['datadir'] = "/tmp/mnist"
FLAGS['batch_size'] = 128
FLAGS['num_workers'] = 4
FLAGS['learning_rate'] = 0.01
FLAGS['momentum'] = 0.5
FLAGS['num_epochs'] = 50
FLAGS['num_cores'] = 8
FLAGS['log_steps'] = 20
FLAGS['metrics_debug'] = False

下面的代码片段将把CNN模型定义为PyTorch实例,以及用于加载数据、训练模型和测试模型的函数。

SERIAL_EXEC = xmp.MpSerialExecutor()class FashionMNIST(nn.Module):def __init__(self):super(FashionMNIST, self).__init__()self.conv1 = nn.Conv2d(1, 10, kernel_size=5)self.bn1 = nn.BatchNorm2d(10)self.conv2 = nn.Conv2d(10, 20, kernel_size=5)self.bn2 = nn.BatchNorm2d(20)self.fc1 = nn.Linear(320, 50)self.fc2 = nn.Linear(50, 10)def forward(self, x):x = F.relu(F.max_pool2d(self.conv1(x), 2))x = self.bn1(x)x = F.relu(F.max_pool2d(self.conv2(x), 2))x = self.bn2(x)x = torch.flatten(x, 1)x = F.relu(self.fc1(x))x = self.fc2(x)return F.log_softmax(x, dim=1)# 只在内存中实例化一次模型权重。
WRAPPED_MODEL = xmp.MpModelWrapper(FashionMNIST())def train_mnist():torch.manual_seed(1)def get_dataset():norm = transforms.Normalize((0.1307,), (0.3081,))train_dataset = datasets.FashionMNIST(FLAGS['datadir'],train=True,download=True,transform=transforms.Compose([transforms.ToTensor(), norm]))test_dataset = datasets.FashionMNIST(FLAGS['datadir'],train=False,download=True,transform=transforms.Compose([transforms.ToTensor(), norm]))return train_dataset, test_dataset#使用串行执行器可以避免多个进程下载相同的数据train_dataset, test_dataset = SERIAL_EXEC.run(get_dataset)train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset,num_replicas=xm.xrt_world_size(),rank=xm.get_ordinal(),shuffle=True)train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=FLAGS['batch_size'],sampler=train_sampler,num_workers=FLAGS['num_workers'],drop_last=True)test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=FLAGS['batch_size'],shuffle=False,num_workers=FLAGS['num_workers'],drop_last=True)# 调整学习率lr = FLAGS['learning_rate'] * xm.xrt_world_size()# 获取损失函数、优化器和模型device = xm.xla_device()model = WRAPPED_MODEL.to(device)optimizer = optim.SGD(model.parameters(), lr=lr, momentum=FLAGS['momentum'])loss_fn = nn.NLLLoss()def train_fun(loader):tracker = xm.RateTracker()model.train()for x, (data, target) in enumerate(loader):optimizer.zero_grad()output = model(data)loss = loss_fn(output, target)loss.backward()xm.optimizer_step(optimizer)tracker.add(FLAGS['batch_size'])if x % FLAGS['log_steps'] == 0:print('[xla:{}]({}) Loss={:.5f}'.format(xm.get_ordinal(), x, loss.item(), time.asctime()), flush=True)def test_fun(loader):total_samples = 0correct = 0model.eval()data, pred, target = None, None, Nonefor data, target in loader:output = model(data)pred = output.max(1, keepdim=True)[1]correct += pred.eq(target.view_as(pred)).sum().item()total_samples += data.size()[0]accuracy = 100.0 * correct / total_samplesprint('[xla:{}] Accuracy={:.2f}%'.format(xm.get_ordinal(), accuracy), flush=True)return accuracy, data, pred, target# 训练和评估循环accuracy = 0.0data, pred, target = None, None, Nonefor epoch in range(1, FLAGS['num_epochs'] + 1):para_loader = pl.ParallelLoader(train_loader, [device])train_fun(para_loader.per_device_loader(device))xm.master_print("Finished training epoch {}".format(epoch))para_loader = pl.ParallelLoader(test_loader, [device])accuracy, data, pred, target  = test_fun(para_loader.per_device_loader(device))if FLAGS['metrics_debug']:xm.master_print(met.metrics_report(), flush=True)return accuracy, data, pred, target

现在,要将结果绘制为测试图像的预测标签和实际标签,将使用以下功能模块。

# 结果可视化
import math
from matplotlib import pyplot as pltM, N = 5, 5
RESULT_IMG_PATH = '/tmp/test_result.png'def plot_results(images, labels, preds):images, labels, preds = images[:M*N], labels[:M*N], preds[:M*N]inv_norm = transforms.Normalize((-0.1307/0.3081,), (1/0.3081,))num_images = images.shape[0]fig, axes = plt.subplots(M, N, figsize=(12, 12))fig.suptitle('Predicted Lables')for i, ax in enumerate(fig.axes):ax.axis('off')if i >= num_images:continueimg, label, prediction = images[i], labels[i], preds[i]img = inv_norm(img)img = img.squeeze() # [1,Y,X] -> [Y,X]label, prediction = label.item(), prediction.item()if label == prediction:ax.set_title(u'Actual {}/ Predicted {}'.format(label, prediction), color='blue')else:ax.set_title('Actual {}/ Predicted {}'.format(label, prediction), color='red')ax.imshow(img)plt.savefig(RESULT_IMG_PATH, transparent=True)

现在,我们都准备好在MNIST数据集上训练模型。训练开始前,我们将记录开始时间,训练结束后,我们将记录结束时间并打印50个epoch的总训练时间。

# 启动训练流程
def train_cnn(rank, flags):global FLAGSFLAGS = flagstorch.set_default_tensor_type('torch.FloatTensor')accuracy, data, pred, target = train_mnist()if rank == 0:# 检索TPU核心0上的张量并绘制。plot_results(data.cpu(), pred.cpu(), target.cpu())xmp.spawn(train_cnn, args=(FLAGS,), nprocs=FLAGS['num_cores'],start_method='fork')

一旦训练成功结束,我们将打印训练所用的总时间。

# 如果想查看时间也可以打印下,不过前面代码并未添加strat_time(),记得添加一下~
# end_time = time.time()
# print('Total Training time = ',end_time-start_time )

利用TPU这种方法花费了大约4.5分钟,这意味着50个epoch训练PyTorch模型不到5分钟。最后,我们将通过训练的模型来可视化预测。

from google.colab.patches import cv2_imshow
import cv2
img = cv2.imread(RESULT_IMG_PATH, cv2.IMREAD_UNCHANGED)
cv2_imshow(img)

因此,我们可以得出这样的结论:使用TPU实现深度学习模型可以实现快速的训练,正如我们前面所看到的那样。

在不到5分钟的时间内,对50个epoch的40000张训练图像进行了CNN模型的训练。我们在训练中也获得了89%以上的准确率。

因此,在TPU上训练深度学习模型在时间和准确性方面总是有好处的。

PyTorch实现TPU版本CNN模型相关推荐

  1. 第16课:项目实战——利用 PyTorch 构建 CNN 模型

    上一篇,我们主要介绍了 CNN 的基本概念和模型结构.本文将带领大家使用 PyTorch 一步步搭建 CNN 模型,进行数字图片识别.本案例中,我们选用的是 MNIST 数据集. 总的来说,我们构建分 ...

  2. 【深度学习】Keras vs PyTorch vs Caffe:CNN实现对比

    作者 | PRUDHVI VARMA 编译 | VK 来源 | Analytics Indiamag 在当今世界,人工智能已被大多数商业运作所应用,而且由于先进的深度学习框架,它非常容易部署.这些深度 ...

  3. Keras vs PyTorch vs Caffe:CNN实现对比

    作者|PRUDHVI VARMA 编译|VK 来源|Analytics Indiamag 在当今世界,人工智能已被大多数商业运作所应用,而且由于先进的深度学习框架,它非常容易部署.这些深度学习框架提供 ...

  4. 《Pytorch - CNN模型》

    2020年10月5号,依然在家学习. 今天是我写的第四个 Pytorch程序, 这一次我想把之前基于PyTorch实现的简易的传统的BP全连接神经网络改写成CNN网络,想看看对比和效果差异. 这一次我 ...

  5. 从LeNet-5 CNN模型入门PyTorch

    从LeNet-5 CNN模型入门PyTorch 1. PyTorch 准备 1.1 PyTorch特点 1.2 PyTorch安装测试 2. 完整代码 2.1 LeNet模型 2.2 训练 2.2 测 ...

  6. Pytorch版本YOLOv3模型转Darknet weights模型然后转caffemodel再转wk模型在nnie上面推理

    Pytorch版本YOLOv3模型转darknet weights模型然后转caffemodel再转wk模型在nnie上面推理 文章目录 Pytorch版本YOLOv3模型转darknet weigh ...

  7. AI 图片截取、ffmpeg使用及安装, anaconda环境,图片标注(labelme),模型训练(yolov5),CUDA+Pytorch安装及版本相关问题

    AI 图片截取(ffmpeg), anaconda环境,图片标注(labelme),模型训练(yolov5),CUDA+Pytorch安装及版本相关问题 一.截取有效图片 录制RTSP视频脚本 #!/ ...

  8. Pytorch实现CNN模型的迁移学习——蜜蜂和蚂蚁图片分类项目

    很多时候当训练一个新的图像分类任务时,一般不会完全从一个随机的模型开始训练,而是利用预训练的模型来加速训练的过程.经常使用在ImageNet上的预训练模型. 这是一种transfer learning ...

  9. Pytorch训练Bilinear CNN模型笔记

    Pytorch训练Bilinear CNN模型笔记 注:一个项目需要用到机器学习,而本人又是一个python小白,根据老师的推荐,然后在网上查找了一些资料,终于实现了目的. 参考文献: Caltech ...

  10. PyTorch入门(五)使用CNN模型进行中文文本分类

      本文将会介绍如何在PyTorch中使用CNN模型进行中文文本分类.   使用CNN实现中文文本分类的基本思路: 文本预处理 将字(或token)进行汇总,形成字典文件,可保留前n个字 文字转数字, ...

最新文章

  1. Forbidden Attack:7万台web服务器陷入被攻击的险境
  2. xend: No such file or directory. Is xend running? 问题
  3. python语言需要英语非常好吗-Python用不好英语水平不够?这里有官方中文文档你看不看...
  4. 非洲的风能和太阳能真是企业家无与伦比的商机?
  5. Linux 命令之 chown -- 用来变更文件或目录的拥有者或所属群组
  6. 来了!苹果二代AirPods 3月发布 全黑配色加入
  7. 漫画:什么是红黑树?(下篇)
  8. gpg加密命令 linux_Ubuntu下加密命令GPG和KEY
  9. 单结晶体管的导电特性_【硬见小百科】二极管基础知识分类,应用,特性,原理,参数(二)...
  10. 人人都是产品经理总结 第一章
  11. Verge3D 2.12 for 3ds Max发布
  12. B站视频下载方法之--手机下载后再转移至电脑
  13. Spring控制反转(IOC)之注解配置
  14. Hadoop安装snappy(编译源码)
  15. Python量化交易平台:JQData | API使用文档(转)
  16. 学会聆听别人,聆听也是一门艺术。
  17. 二值神经网络(Binary Neural Network,BNN)
  18. 【仿真建模】第一课:AnyLogic入门基础教程 - 行人库入门讲解
  19. 韩国NF数字功放芯片在家庭影院领域中的应用
  20. 面试笔记(51信用卡-Java开发实习)

热门文章

  1. PM42L-048 步进电机
  2. CAD​自定义快捷键命令
  3. 什么是计算机科学中的“本体论”
  4. 业务流程图(TFD)实例
  5. JavaScript实现 网页倒计时
  6. php实现秒数倒计时,jQuery网页倒计时代码 显示天、小时、分钟与秒数
  7. 打印机无法访问计算机,WIN7无法访问共享打印机及文件的解决办法
  8. 什么是IT行业? IT行业都有哪些职位?
  9. 计算机等级考试数据库三级模拟题7
  10. 不同tric 改进的理由