华为宣传说mindspore比pytorch快,说是加了自动微风,确实在mindspore中训练不需要自己写优化的过程,不过空说无凭,试验了一下,真的快一些

这里拿mnist分类的例子做实验

epoch选取了10和50

mindspore:

# -*- coding: utf-8 -*-
import os
import time
import mindspore.nn as nn
from mindspore.common.initializer import Normal
from mindspore import Tensor, Model,export,load_checkpoint
from mindspore.nn.metrics import Accuracy
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
import mindspore.dataset.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from mindspore.dataset.vision import Inter
from mindspore import dtype as mstype
import numpy as np
import mindspore.dataset as ds
from mindspore.train.callback import Callback# https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/datasets/MNIST_Data.zipa
train_data_path = "./datasets/MNIST_Data/train"
test_data_path = "./datasets/MNIST_Data/test"
mnist_path = "./datasets/MNIST_Data"
model_path = "./models/ckpt/"#定义数据集
def create_dataset(data_path, batch_size=128, repeat_size=1,num_parallel_workers=1):""" create dataset for train or testArgs:data_path (str): Data pathbatch_size (int): The number of data records in each grouprepeat_size (int): The number of replicated data recordsnum_parallel_workers (int): The number of parallel workers"""# define datasetmnist_ds = ds.MnistDataset(data_path)# define some parameters needed for data enhancement and rough justificationresize_height, resize_width = 32, 32rescale = 1.0 / 255.0shift = 0.0rescale_nml = 1 / 0.3081shift_nml = -1 * 0.1307 / 0.3081# according to the parameters, generate the corresponding data enhancement methodresize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)rescale_op = CV.Rescale(rescale, shift)hwc2chw_op = CV.HWC2CHW()type_cast_op = C.TypeCast(mstype.int32)# using map to apply operations to a datasetmnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)# process the generated datasetbuffer_size = 10000mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)mnist_ds = mnist_ds.repeat(repeat_size)return mnist_dsstart=time.time()#定义网络
class  mnist(nn.Cell):def __init__(self, num_class=10):super(mnist, self).__init__()self.conv1 = nn.Conv2d(1, 8, 5, pad_mode='valid')self.conv2 = nn.Conv2d(8, 12, 5, pad_mode='valid')self.fc1 = nn.Dense(300 , 120, weight_init=Normal(0.02))self.fc2 = nn.Dense(120, 60, weight_init=Normal(0.02))self.fc3 = nn.Dense(60, num_class, weight_init=Normal(0.02))self.relu = nn.ReLU()self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)self.flatten = nn.Flatten()def construct(self, x):x = self.max_pool2d(self.relu(self.conv1(x)))x = self.max_pool2d(self.relu(self.conv2(x)))        x = self.flatten(x)x = self.relu(self.fc1(x))x = self.relu(self.fc2(x))x = self.fc3(x) return x    network = mnist()
net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')#定义模型
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()} )repeat_size = 1
ds_train = create_dataset(os.path.join(mnist_path, "train"), 128, repeat_size)
model.train(50, ds_train, dataset_sink_mode=False)
print(time.time()-start)

pytorch:

# -*- coding: utf-8 -*-
import time
import torch.nn as nn
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from torch.optim import lr_scheduler
# transformstransform = transforms.Compose([transforms.Resize(32),transforms.ToTensor(),                                transforms.Normalize((0.5,), (0.5,))])# datasets
trainset = torchvision.datasets.MNIST('data',download=True,train=True, transform=transform)
testset = torchvision.datasets.MNIST('data',download=True,train=False,transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,shuffle=True, num_workers=0)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,shuffle=False, num_workers=0)start=time.time()class mnist(nn.Module):def __init__(self,num_classes=10):super(mnist, self).__init__()self.conv1=nn.Conv2d(1,8,5,padding_mode='reflect')self.conv2=nn.Conv2d(8,12,5,padding_mode='reflect')self.fc1=nn.Linear(300,120)self.fc2=nn.Linear(120,60)self.fc3=nn.Linear(60,num_classes)self.relu=nn.ReLU()self.max_pool2d=nn.MaxPool2d(2,2)self.flatten = nn.Flatten()def forward(self,x):x = self.max_pool2d(self.relu(self.conv1(x)))x = self.max_pool2d(self.relu(self.conv2(x)))        x = self.flatten(x)x = self.relu(self.fc1(x))x = self.relu(self.fc2(x))x = self.fc3(x) return x model = mnist().to(device='cpu')
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)for epoch in range(50):for index,data in enumerate(trainloader):inputs,labels=datainputs = inputs.to('cpu')labels = labels.to('cpu')optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()if index%100==0:print(index,loss)print(time.time()-start)
#torch.save(model, 'FashionMNIST.pth')

运行结果:

epoch mindspore pytorch
10 147.7769854068756 199.70614314079285
50 747.4645166397095 985.8706977367401

可以看到确实pytorch速度是快不少, mindspore在微分方面的优化效果显著。

mindspore比pytorch快?是的相关推荐

  1. 【深度学习】基于MindSpore和pytorch的Softmax回归及前馈神经网络

    1 实验内容简介 1.1 实验目的 (1)熟练掌握tensor相关各种操作: (2)掌握广义线性回归模型(logistic模型.sofmax模型).前馈神经网络模型的原理: (3)熟练掌握基于mind ...

  2. 清华「计图」现在支持国产芯片了!动态图推理比PyTorch快了270倍

    明敏 发自 凹非寺 量子位 报道 | 公众号 QbitAI 清华自研的深度学习框架计图(Jittor)在动态图推理速度上又一次完胜PyTorch. 最近,计图团队完成了在寒武纪芯片MLU270上的移植 ...

  3. 英伟达DALI加速技巧:让数据预处理速度比原生PyTorch快4倍

    点击我爱计算机视觉标星,更快获取CVML新技术 本文转载自机器之心. 选自towardsdatascience 作者:Pieterluitjens 机器之心编译 参与:一鸣.嘉明.思 你的数据处理影响 ...

  4. AI求解偏微分方程新基准登NeurIPS,发现JAX计算速度比PyTorch快6倍,LeCun转发:这领域确实很火...

    萧箫 发自 凹非寺 量子位 | 公众号 QbitAI 用AI求解偏微分方程,这段时间确实有点火. 但究竟什么样的AI求解效果最好,却始终没有一个统一的定论. 现在,终于有人为这个领域制作了一个名叫PD ...

  5. AI求解偏微分方程新基准登NeurIPS,发现JAX计算速度比PyTorch快6倍,LeCun转发:这领域确实很火

    萧箫 发自 凹非寺 量子位 | 公众号 QbitAI 用AI求解偏微分方程,这段时间确实有点火. 但究竟什么样的AI求解效果最好,却始终没有一个统一的定论. 现在,终于有人为这个领域制作了一个名叫PD ...

  6. 【mindspore】mindspore实现手写数字识别

    mindspore实现手写数字识别 具体流程参考教程:MindSpore快速入门 MindSpore 接口文档 注:本文章记录的是我在开发过程中的学习笔记,仅供参考学习,欢迎讨论,但不作为开发教程使用 ...

  7. MindSpore实现手写数字识别

    具体流程参考教程:MindSpore快速入门 MindSpore 接口文档 注:本文章记录的是我在开发过程中的学习笔记,仅供参考学习,欢迎讨论,但不作为开发教程使用. 数据的流水线处理 defdata ...

  8. 技术干货|昇思MindSpore NLP模型迁移之Roberta ——情感分析任务

    熟悉 BERT 模型的小伙伴对于 Roberta 模型肯定不陌生了.Roberta 模型在 BERT 模型的基础上进行了一定的改进,主要改进点有以下几个部分: 1. 训练语料:BERT只使用 16 G ...

  9. 对标TensorFlow、PyTorch,中国自主的AI框架砸向开源生态

    本文授权转载自电子发烧友网,作者黄晶晶 目前比较主流的AI深度学习框架主要由国际巨头领衔,比如谷歌的TensorFlow和Facebook的PyTorch等.2017年中国的一家初创团队悄悄成立,他们 ...

  10. jittor 和pytorch gpu 使用效率对比(惊人jittor的算力利用率是pytorch 4-5倍)

    之前使用的是cpu对比 pytorch 好像更胜一筹(本人觉得是当时可能环境不对这次配置好了完美环境使用lsgan代码进行对比果然如jittor官网所说比pytorch快,但是本人还是有一个惊奇的发现 ...

最新文章

  1. 图像检索:几类基于内容的图像分类技术
  2. 添加nginx为系统服务(service nginx start/stop/restart)
  3. 【C语言】数字在排序数组中出现的次数(改动)
  4. Python基础教程:获取list中指定元素的索引
  5. 2.7万字还原行业面貌,《2019 AI金融风控行业研究报告》正式上线!...
  6. 【C语言进阶深度学习记录】十六 静态库与动态库的创建与使用
  7. 面向对象的静态、抽象和加载
  8. The pricess diaries
  9. MapStruct使用指南
  10. 支持USB Video Class的摄像头
  11. 产品经理笔试题分析(一)
  12. 【王者荣耀】入门战斗经验
  13. vmware workstation 16 安装centos7 全记录(文字版)
  14. 计算机msvcp110.dll丢失,msvcp110.dll丢失怎样修复
  15. Problem : 救公主续
  16. 卷积神经网络超详细介绍
  17. 小红书最新的内容趋势是什么?
  18. mysql在手游中的作用_手游业务MySQL数据库虚拟化漫谈 | By 肖力
  19. mahout之lda(cvb)运用
  20. QSettings遇到神坑

热门文章

  1. docker stop all containers
  2. Halcon 3D create_pose
  3. [句型] 二十四、特殊疑问句 [ where ] [ what ] [ why ]
  4. 建议收藏|一文带你读懂 Prisma 的使用
  5. 每日哲学与编程练习3——无重复数字(Python实现)
  6. 抖音开展大规模打击刷粉、刷量,账号广告导流行动
  7. Minimum supported Gradle version is 6.1.1. Current version is 5.4.1.
  8. 安卓手机怎么运行java?如何在Android手机上运行jAVA程序?
  9. 香港马市、田忌赛马?这款游戏 IP 碉堡了
  10. linux查看网络端口状态命令行,Linux下用netstat查看网络状态、端口状态