作者 | Jordi TORRES.AI

deephub翻译组 | Alexander Zhao

来源 | DeepHub IMBA(ID:deephub-imba)

编程实现神经网络的最佳框架是什么?TensorFlow还是PyTorch?我的回答是:别担心,你从哪一个入门,你选择哪一个并不重要,重要的是自己动手实践!下面我们开始吧!

这两种框架都提供了编程神经网络常用的机器学习步骤:

  • 导入所需的库

  • 加载并预处理数据

  • 定义模型

  • 定义优化器和损失函数

  • 训练模型

  • 评估模型

这些步骤可以在任何一个框架中找到非常类似的实现(即使是像MindSpore这样的框架)。为此,在本文中,我们将构建一个神经网络模型,分别在PyTorch API与TensorFlow Keras API下进行手写数字分类任务的实现。

神经网络编程步骤

a)导入必要的库

在这两个框架中,我们需要首先导入一些Python库并定义一些我们将需要训练的超参数:

 import numpy as np  import matplotlib.pyplot as plt  epochs = 10  batch_size=64

对于TensorFlow,您仅需要额外导入以下库:

 import tensorflow as tf
而对于PyTorch,您还需要导入这两个库:
 import torch  import torchvision

b)导入并预处理数据

使用TensorFlow加载和准备数据可以使用以下两行代码:

 (x_trainTF_, y_trainTF_), _ = tf.keras.datasets.mnist.load_data()  x_trainTF = x_trainTF_.reshape(60000, 784).astype('float32')/255 y_trainTF = tf.keras.utils.to_categorical(y_trainTF_,             num_classes=10)

而在PyTorch则是这两行代码:

 xy_trainPT = torchvision.datasets.MNIST(root='./data', train=True,download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()]))xy_trainPT_loader = torch.utils.data.DataLoader(xy_trainPT, batch_size=batch_size)

我们可以通过matplotlib.pyplot库验证这两个代码是否加载了相同的数据:


print("TensorFlow:") fig = plt.figure(figsize=(25, 4)) for idx in np.arange(20):    ax = fig.add_subplot(2, 20/2, idx+1, xticks=[], yticks=[])    ax.imshow(x_trainTF_[idx], cmap=plt.cm.binary)    ax.set_title(str(y_trainTF_[idx]))
print("PyTorch:") fig = plt.figure(figsize=(25, 4)) for idx in np.arange(20):    ax = fig.add_subplot(2, 20/2, idx+1, xticks=[], yticks=[])    ax.imshow(torch.squeeze(image, dim = 0).numpy(),              cmap=plt.cm.binary)    image, label = xy_trainPT [idx]    ax.set_title(str(label))

c)定义模型

在定义模型的时候,这两种框架都使用相当相似的语法来完成。对于TensorFlow,可以使用以下代码来完成:

 modelTF = tf.keras.Sequential([tf.keras.layers.Dense(10,activation='sigmoid',input_shape=(784,)),tf.keras.layers.Dense(10,activation='softmax') ])
在PyTorch下则这么完成:
 modelPT= torch.nn.Sequential( torch.nn.Linear(784,10),torch.nn.Sigmoid(),torch.nn.Linear(10,10), torch.nn.LogSoftmax(dim=1) )

d)定义优化器与损失函数

同样,指定优化器和loss函数的方法在两个框架下也是很相似的。在TensorFlow下,我们可以这样做:

 modelTF.compile(                loss="categorical_crossentropy",                    optimizer=tf.optimizers.SGD(lr=0.01),                metrics = ['accuracy']                )

在PyTorch下则是这样的:


criterion = torch.nn.NLLLoss() optimizer = torch.optim.SGD(modelPT.parameters(), lr=0.01)

e)训练模型

最大的不同在于训练。对于TensorFlow,我们只需要这一行代码:

_ = modelTF.fit(x_trainTF, y_trainTF, epochs=epochs,                 batch_size=batch_size, verbose = 0)

而在PyTorch下则更长,像这样:

for e in range(epochs):     for images, labels in xy_trainPT_loader:         images = images.view(images.shape[0], -1)         loss = criterion(modelPT(images), labels)         loss.backward()         optimizer.step()         optimizer.zero_grad()

PyTorch没有内置像在Keras或Scikit-learn中非常常见的fit()等训练方法,因此训练循环必须由程序员手动指定。嗯,这其实是在简单性和实用性之间进行一定的折衷,以便能够做更多自定义的事情。

f)评估模型

评估模型也是如此,在TensorFlow中,您只需对测试数据调用evaluate()方法:

 _, (x_testTF, y_testTF)= tf.keras.datasets.mnist.load_data() x_testTF = x_testTF.reshape(10000, 784).astype('float32')/255 y_testTF = tf.keras.utils.to_categorical(y_testTF, num_classes=10)  _ , test_accTF = modelTF.evaluate(x_testTF, y_testTF) print('\nAccuracy del model amb TensorFlow =', test_accTF)  TensorFlow model Accuracy = 0.8658999800682068

在PyTorch中,再次需要程序员手动定义评估循环:

 xy_testPT = torchvision.datasets.MNIST(root='./data', train=False, download=True,             transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()]))  xy_test_loaderPT = torch.utils.data.DataLoader(xy_testPT)  correct_count, all_count = 0, 0 for images,labels in xy_test_loaderPT:   for i in range(len(labels)):     img = images[i].view(1, 784)      logps = modelPT(img)     ps = torch.exp(logps)     probab = list(ps.detach().numpy()[0])     pred_label = probab.index(max(probab))     true_label = labels.numpy()[i]     if(true_label == pred_label):       correct_count += 1     all_count += 1  print("\nAccuracy del model amb PyTorch =", (correct_count/all_count))  TensorFlow model Accuracy = 0.8657

更重要的是,它们在相互融合!

好了,如这个简单的示例所示,在TensorFlow和PyTorch中创建神经网络的方式并没有真正的区别,只是在一些细节方面,程序员必须实现训练和评估循环的方式,以及一些超参数,像epoch或batch_size是在不同的步骤中指定的。

实际上,在过去两年中,这两个框架一直在不断融合,相互学习并采用它们的长处。例如,在几周前发布的新版本TensorFlow 2.2中,训练步骤可以像PyTorch一样,现在程序员可以通过实现train_step()来指定循环主体的详细内容。因此,不必担心选择“错误的”框架,它们正在相互融合!最重要的是要学习背后的深度学习概念,您在其中一个框架中获得的所有知识在另一个框架下照样有用。

工业应用还是学术研究?

但是,很明显,神经网络的工业应用与学术研究是截然不同的。在这种情况下,决定选择哪一个很重要。

TensorFlow是一个非常强大且成熟的Python库,具有强大的可视化功能以及用于高性能模型开发的各种选项。它具有准备用于生产的部署共轭能,并自动支持Web和移动平台。

另一方面,PyTorch仍然是一个年轻的框架,但是拥有一个非常活跃的社区,尤其是在研究领域。门户网站The Gradient在附图中显示了主要的深度学习会议(CVPR,ICRL,ICML,NIPS,ACL,ICCV等)发表的研究论文中PyTorch的使用量,可以看到PyTorch在研究界的兴起和广泛采用。

从2018年的数据可以看出,Pythorch框架的使用还是少数,而相比之下,2019年的使用量对比TensorFlow是压倒性的。因此,如果你想创造与人工智能相关的产品,TensorFlow是一个不错的选择。如果你想做研究,我推荐PyTorch。

新手请选择Keras

如果你还是个萌新,对这一切都还很不了解,请从TensorFlow的Keras API开始。PyTorch的API具有更大的灵活性和控制力,但显然TensorFlow的Keras API可以更容易上手。而且,如果您正在阅读这篇文章,我假定您是深度学习领域的入门者。

顺便说一句,Keras计划在2020年推出几种新特性,它们都是为了“让事情变得更容易”。以下是最近添加的或即将发布的一些新功能的列表:

  • 预处理层API

到目前为止,我们已经使用NumPy和PIL(Python Imaging Library)编写的辅助工具完成了预处理。这种外部预处理使模型的可移植性降低,因为每次有人重用已经训练好的模型时,他们都必须重新实现整个预处理流程。因此,通过“预处理层”,预处理现在可以成为模型的一部分。这包括诸如文本标准化,标记化,向量化,图像标准化,数据增强等方面。也就是说,这将允许模型接受原始文本或原始图像作为输入。我个人认为这将非常有趣。

  • Keras Tuner

这是一个可让您在Keras中找到模型的最佳超参数的框架。当你开始进行一些深度学习工作时,您会发现超参数的调整将是整个工作中最为繁重的部分,这个框架旨在解决这一问题。

  • AutoKeras

该项目旨在用几行代码建立一个很好的机器学习模型,根据可能的模型空间自动搜索最佳模型,并使用Keras Tuner查找进行超参数调整。对于高级用户,AutoKeras还允许对搜索空间和过程的配置进行更高级别的控制。

  • Cloud Keras

我们的愿景是让程序员更容易地将本地代码(我们的笔记本电脑或Google Colab本地工作)移动到云端,使其能够在云端以最佳和分布式的方式执行此代码,而不必担心集群或Docker参数。

  • 与TensorFlow集成

与TFX(TensorFlow Extended,用于管理机器学习生产应用程序的平台)进行更多集成的工作正在进行中,并为将模型导出到TF Lite(用于移动和嵌入式设备的机器学习执行引擎)提供更好的支持。毫无疑问,改善对模型生产的支持对于Keras程序员的忠诚度至关重要。

小结

打个比方,你认为哪种语言是入门编程的最佳语言,C++还是Java?好吧…这取决于我们想用它做什么,最重要的是取决于我们能学到什么样的工具。我们可能无法达成一致,因为我们有一个先入为主的观点,我们很难改变对这个问题的回答(同样的情况也发生在PyTorch和TensorFlow的“粉丝”身上???? )。但我们都同意的一点是,最重要的是知道如何编程。事实上,无论我们从一种语言的编程中学到什么,当我们使用另一种语言时,它都会为我们服务,对吧?对于框架来也是如此,重要的是要了解深入学习,而不是框架的语法细节,然后我们将这些知识用于正在流行的框架或者我们想用的其他框架。

本文代码:

https://github.com/jorditorresBCN/PyTorch-vs-TensorFlow/blob/master/MNIST-with-PyTorch-and-TensorFlow.ipynb

colab google notebook:

https://colab.research.google.com/github/jorditorresBCN/PyTorch-vs-TensorFlow/blob/master/MNIST-with-PyTorch-and-TensorFlow.ipynb

新勋章,新奖品,高流量,还有更多福利等你来拿~

更多精彩推荐

360金融新任首席科学家:别指望AI Lab做成中台

☞搞懂微服务,从捕捉一头野猪说起

☞AI 图像智能修复老照片,效果惊艳到我了!| 附代码

调查了 10,975 位 Go 语言开发者,我们有了这些发现!

☞架构师前辈告诉你:代码该如何才能自己写得容易,别人看得也不痛苦

你点的每个“在看”,我都认真当成了喜欢

TensorFlow 还是 PyTorch?哪一个才更适合编写深度神经网络?相关推荐

  1. pytorch自带网络_使用PyTorch Lightning自动训练你的深度神经网络

    作者:Erfandi Maula Yusnu, Lalu 编译:ronghuaiyang 原文链接 使用PyTorch Lightning自动训练你的深度神经网络​mp.weixin.qq.com 导 ...

  2. Microsoft R 和 Open Source R,哪一个才最适合你?

    由于微信不允许外部链接,你需要点击文章尾部左下角的 "阅读原文",才能访问文中链接. R 是一个开源统计软件,在分析领域普及的非常快. 在过去几年中,无论业务规模如何,很多公司都采 ...

  3. 华硕笔记本学计算机,买平板电脑学习办公?也许平板、笔记本二合一的产品才更适合你...

    作为一个数码发烧友,给亲朋好友推荐数码产品已经是家常便饭了.这也让我注意到了一个很有趣的现象,那就是许多新生代对「电脑」已经不像我们这代人那么痴迷了.例如亲戚家一位表妹就表示想买iPad而不是笔记本电 ...

  4. Keras vs PyTorch,哪一个更适合做深度学习?

    选自Medium 作者:Karan Jakhar 机器之心编译 参与:小舟.魔王 如何选择工具对深度学习初学者是个难题.本文作者以 Keras 和 Pytorch 库为例,提供了解决该问题的思路. 当 ...

  5. 【干货】Keras vs PyTorch,哪一个更适合做深度学习?

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 如何选择工具对深度学习初学者是个难题.本文作者以 Keras 和 ...

  6. 比word更适合编写方案文档的工具Latex

    用了一小段时间的latex,现在来小结一下. Latex是啥 Latex是一种排版系统,也是一种文档语法.通过Latex编辑工具(比如TexStudio)可以将latex版的原文转换成pdf等格式的文 ...

  7. 使用PyTorch Lightning自动训练你的深度神经网络

    点击上方"AI公园",关注公众号,选择加"星标"或"置顶" 作者:Erfandi Maula Yusnu, Lalu 编译:ronghuai ...

  8. 深度学习调用TensorFlow、PyTorch等框架

    深度学习调用TensorFlow.PyTorch等框架 一.开发目标目标 提供统一接口的库,它可以从C++和Python中的多个框架中运行深度学习模型.欧米诺使研究人员能够在自己选择的框架内轻松建立模 ...

  9. 哪种ARM Cortex内核更适合我的应用:A系列、R系列、还是M系列?

    ARM Cortex内核系列提供非常广泛的具有可扩展性的性能选项,设计人员有机会在多种选项中选择最适合自身应用的内核,而非千篇一律的采用同一方案.Cortex系列组合大体上分为三种类别: ● Cort ...

最新文章

  1. vue2.x的小问题
  2. js新建list_制作一个基于node的todo-list工具,并发布至npm
  3. 获取cookie里面的值
  4. AI 的会议总结(by南大周志华)
  5. PPT幻灯片转换成word的软件
  6. 31、JAVA_WEB开发基础之servlet(2)
  7. “约见”面试官系列之常见面试题之第五十三篇之网站的资源优化(建议收藏)
  8. 技术交底软件_【干货分享】软件类产品如何进行专利挖掘与技术交底书撰写?...
  9. Asp.Net 之 枚举类型的下拉列表绑定
  10. SQL注入风险高,手写 SQL 须谨慎
  11. Oracle执行计划分析及实际生产案例解析
  12. ModifyStyle()
  13. 关系抽取论文阅读笔记
  14. centos 挂载 cos云存储服务器,centos文件目录挂载docker实现共享操作步骤
  15. STM32CUBEMX生成KEIL工程后使用AC6(V6)编译加快速度
  16. PVE 天龙八部TLBB服务端搭建(一)--linux环境搭建
  17. 计算机表示图形的几种方法。
  18. 纯JS+HTML+CSS实现表格增删改查翻页--模板文件管理
  19. 关于云开发数据库的使用经验和建议
  20. matlab怎么添加条纹噪声,基于频域的图像条纹噪声消除方法

热门文章

  1. 清北学堂模拟赛d6t2 刀塔
  2. httpClient创建对象、设置超时
  3. MPMoviePlayerController属性,方法,通知整理
  4. 【Linux开发】Ubuntu下几个软件的配置记录backup
  5. 链表(创建,插入,删除和打印输出
  6. re: Asp.net常用的51个代码(非常实用)(转)
  7. 贪心---区间覆盖问题(水题)
  8. 剑指Offer字符串加法问题
  9. java rdd hashmap_利用Spark Rdd生成Hfile直接导入到Hbase详解
  10. 安川机器人焊枪切换设定方法_【分享】焊接机器人的性能要求与系统构成