mindspore比pytorch快?是的
华为宣传说mindspore比pytorch快,说是加了自动微风,确实在mindspore中训练不需要自己写优化的过程,不过空说无凭,试验了一下,真的快一些
这里拿mnist分类的例子做实验
epoch选取了10和50
mindspore:
# -*- coding: utf-8 -*-
import os
import time
import mindspore.nn as nn
from mindspore.common.initializer import Normal
from mindspore import Tensor, Model,export,load_checkpoint
from mindspore.nn.metrics import Accuracy
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
import mindspore.dataset.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from mindspore.dataset.vision import Inter
from mindspore import dtype as mstype
import numpy as np
import mindspore.dataset as ds
from mindspore.train.callback import Callback# https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/datasets/MNIST_Data.zipa
train_data_path = "./datasets/MNIST_Data/train"
test_data_path = "./datasets/MNIST_Data/test"
mnist_path = "./datasets/MNIST_Data"
model_path = "./models/ckpt/"#定义数据集
def create_dataset(data_path, batch_size=128, repeat_size=1,num_parallel_workers=1):""" create dataset for train or testArgs:data_path (str): Data pathbatch_size (int): The number of data records in each grouprepeat_size (int): The number of replicated data recordsnum_parallel_workers (int): The number of parallel workers"""# define datasetmnist_ds = ds.MnistDataset(data_path)# define some parameters needed for data enhancement and rough justificationresize_height, resize_width = 32, 32rescale = 1.0 / 255.0shift = 0.0rescale_nml = 1 / 0.3081shift_nml = -1 * 0.1307 / 0.3081# according to the parameters, generate the corresponding data enhancement methodresize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)rescale_op = CV.Rescale(rescale, shift)hwc2chw_op = CV.HWC2CHW()type_cast_op = C.TypeCast(mstype.int32)# using map to apply operations to a datasetmnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)# process the generated datasetbuffer_size = 10000mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)mnist_ds = mnist_ds.repeat(repeat_size)return mnist_dsstart=time.time()#定义网络
class mnist(nn.Cell):def __init__(self, num_class=10):super(mnist, self).__init__()self.conv1 = nn.Conv2d(1, 8, 5, pad_mode='valid')self.conv2 = nn.Conv2d(8, 12, 5, pad_mode='valid')self.fc1 = nn.Dense(300 , 120, weight_init=Normal(0.02))self.fc2 = nn.Dense(120, 60, weight_init=Normal(0.02))self.fc3 = nn.Dense(60, num_class, weight_init=Normal(0.02))self.relu = nn.ReLU()self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)self.flatten = nn.Flatten()def construct(self, x):x = self.max_pool2d(self.relu(self.conv1(x)))x = self.max_pool2d(self.relu(self.conv2(x))) x = self.flatten(x)x = self.relu(self.fc1(x))x = self.relu(self.fc2(x))x = self.fc3(x) return x network = mnist()
net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')#定义模型
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()} )repeat_size = 1
ds_train = create_dataset(os.path.join(mnist_path, "train"), 128, repeat_size)
model.train(50, ds_train, dataset_sink_mode=False)
print(time.time()-start)
pytorch:
# -*- coding: utf-8 -*-
import time
import torch.nn as nn
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from torch.optim import lr_scheduler
# transformstransform = transforms.Compose([transforms.Resize(32),transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])# datasets
trainset = torchvision.datasets.MNIST('data',download=True,train=True, transform=transform)
testset = torchvision.datasets.MNIST('data',download=True,train=False,transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,shuffle=True, num_workers=0)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,shuffle=False, num_workers=0)start=time.time()class mnist(nn.Module):def __init__(self,num_classes=10):super(mnist, self).__init__()self.conv1=nn.Conv2d(1,8,5,padding_mode='reflect')self.conv2=nn.Conv2d(8,12,5,padding_mode='reflect')self.fc1=nn.Linear(300,120)self.fc2=nn.Linear(120,60)self.fc3=nn.Linear(60,num_classes)self.relu=nn.ReLU()self.max_pool2d=nn.MaxPool2d(2,2)self.flatten = nn.Flatten()def forward(self,x):x = self.max_pool2d(self.relu(self.conv1(x)))x = self.max_pool2d(self.relu(self.conv2(x))) x = self.flatten(x)x = self.relu(self.fc1(x))x = self.relu(self.fc2(x))x = self.fc3(x) return x model = mnist().to(device='cpu')
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)for epoch in range(50):for index,data in enumerate(trainloader):inputs,labels=datainputs = inputs.to('cpu')labels = labels.to('cpu')optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()if index%100==0:print(index,loss)print(time.time()-start)
#torch.save(model, 'FashionMNIST.pth')
运行结果:
epoch | mindspore | pytorch |
10 | 147.7769854068756 | 199.70614314079285 |
50 | 747.4645166397095 | 985.8706977367401 |
可以看到确实pytorch速度是快不少, mindspore在微分方面的优化效果显著。
mindspore比pytorch快?是的相关推荐
- 【深度学习】基于MindSpore和pytorch的Softmax回归及前馈神经网络
1 实验内容简介 1.1 实验目的 (1)熟练掌握tensor相关各种操作: (2)掌握广义线性回归模型(logistic模型.sofmax模型).前馈神经网络模型的原理: (3)熟练掌握基于mind ...
- 清华「计图」现在支持国产芯片了!动态图推理比PyTorch快了270倍
明敏 发自 凹非寺 量子位 报道 | 公众号 QbitAI 清华自研的深度学习框架计图(Jittor)在动态图推理速度上又一次完胜PyTorch. 最近,计图团队完成了在寒武纪芯片MLU270上的移植 ...
- 英伟达DALI加速技巧:让数据预处理速度比原生PyTorch快4倍
点击我爱计算机视觉标星,更快获取CVML新技术 本文转载自机器之心. 选自towardsdatascience 作者:Pieterluitjens 机器之心编译 参与:一鸣.嘉明.思 你的数据处理影响 ...
- AI求解偏微分方程新基准登NeurIPS,发现JAX计算速度比PyTorch快6倍,LeCun转发:这领域确实很火...
萧箫 发自 凹非寺 量子位 | 公众号 QbitAI 用AI求解偏微分方程,这段时间确实有点火. 但究竟什么样的AI求解效果最好,却始终没有一个统一的定论. 现在,终于有人为这个领域制作了一个名叫PD ...
- AI求解偏微分方程新基准登NeurIPS,发现JAX计算速度比PyTorch快6倍,LeCun转发:这领域确实很火
萧箫 发自 凹非寺 量子位 | 公众号 QbitAI 用AI求解偏微分方程,这段时间确实有点火. 但究竟什么样的AI求解效果最好,却始终没有一个统一的定论. 现在,终于有人为这个领域制作了一个名叫PD ...
- 【mindspore】mindspore实现手写数字识别
mindspore实现手写数字识别 具体流程参考教程:MindSpore快速入门 MindSpore 接口文档 注:本文章记录的是我在开发过程中的学习笔记,仅供参考学习,欢迎讨论,但不作为开发教程使用 ...
- MindSpore实现手写数字识别
具体流程参考教程:MindSpore快速入门 MindSpore 接口文档 注:本文章记录的是我在开发过程中的学习笔记,仅供参考学习,欢迎讨论,但不作为开发教程使用. 数据的流水线处理 defdata ...
- 技术干货|昇思MindSpore NLP模型迁移之Roberta ——情感分析任务
熟悉 BERT 模型的小伙伴对于 Roberta 模型肯定不陌生了.Roberta 模型在 BERT 模型的基础上进行了一定的改进,主要改进点有以下几个部分: 1. 训练语料:BERT只使用 16 G ...
- 对标TensorFlow、PyTorch,中国自主的AI框架砸向开源生态
本文授权转载自电子发烧友网,作者黄晶晶 目前比较主流的AI深度学习框架主要由国际巨头领衔,比如谷歌的TensorFlow和Facebook的PyTorch等.2017年中国的一家初创团队悄悄成立,他们 ...
- jittor 和pytorch gpu 使用效率对比(惊人jittor的算力利用率是pytorch 4-5倍)
之前使用的是cpu对比 pytorch 好像更胜一筹(本人觉得是当时可能环境不对这次配置好了完美环境使用lsgan代码进行对比果然如jittor官网所说比pytorch快,但是本人还是有一个惊奇的发现 ...
最新文章
- 图像检索:几类基于内容的图像分类技术
- 添加nginx为系统服务(service nginx start/stop/restart)
- 【C语言】数字在排序数组中出现的次数(改动)
- Python基础教程:获取list中指定元素的索引
- 2.7万字还原行业面貌,《2019 AI金融风控行业研究报告》正式上线!...
- 【C语言进阶深度学习记录】十六 静态库与动态库的创建与使用
- 面向对象的静态、抽象和加载
- The pricess diaries
- MapStruct使用指南
- 支持USB Video Class的摄像头
- 产品经理笔试题分析(一)
- 【王者荣耀】入门战斗经验
- vmware workstation 16 安装centos7 全记录(文字版)
- 计算机msvcp110.dll丢失,msvcp110.dll丢失怎样修复
- Problem : 救公主续
- 卷积神经网络超详细介绍
- 小红书最新的内容趋势是什么?
- mysql在手游中的作用_手游业务MySQL数据库虚拟化漫谈 | By 肖力
- mahout之lda(cvb)运用
- QSettings遇到神坑
热门文章
- docker stop all containers
- Halcon 3D create_pose
- [句型] 二十四、特殊疑问句 [ where ] [ what ] [ why ]
- 建议收藏|一文带你读懂 Prisma 的使用
- 每日哲学与编程练习3——无重复数字(Python实现)
- 抖音开展大规模打击刷粉、刷量,账号广告导流行动
- Minimum supported Gradle version is 6.1.1. Current version is 5.4.1.
- 安卓手机怎么运行java?如何在Android手机上运行jAVA程序?
- 香港马市、田忌赛马?这款游戏 IP 碉堡了
- linux查看网络端口状态命令行,Linux下用netstat查看网络状态、端口状态