tags: Python DL

写在前面

前几天改了一份代码, 是关于深度学习中卷积神经网络的Python代码, 用于解决分类问题. 代码是用TensorFlow的Keras接口写的, 需求是转换成pytorch代码, 鉴于两者的api相近, 盖起来也不会太难, 就是一些细节需要注意, 在这里记录一下, 方便大家参考.

关于库函数导入

首先来看看在库函数的导入方面这两个流行的深度学习框架有什么区别, 这就需要简单了解一下二者的主要结构了. 为方便叙述, 下面提到的TF都是指TensorFlow2.X with Keras, Torch都是指PyTorch.

模型构建

首先来看模型的构建, 对于TF, 模型的构建可以方便地通过sequential方法得到, 这就需要引入该方法:

from tensorflow.keras.models import Sequential

在Torch中, 当然也可以通过sequential进行模型的构建, (不过官方还是更推荐采用面向对象的方式)

这里需要引入:

from torch.nn import Sequential

说到模型构建, 就不得不提在卷积神经网络里面十分常用的几个层: conv层, maxpool层和全连接层(softmax), 这些在两个框架中都有现成的, 下面来看看如何调用这些方法:

在TF中:

from tensorflow.keras.layers import Conv2D, MaxPooling2D
from tensorflow.keras.layers import Activation, Dropout, Flatten, Dense

而在Torch中:

from torch.nn import Conv2d, MaxPool2d
from torch.nn import Flatten, Linear, CrossEntropyLoss
from torch.optim import SGD

可见二者只有些微的不同, TF中将一些激活函数的调用放在了参数里面, 而Torch都是以库函数的形式给出的.

数据读入

最后来看看数据的导入部分, 在TF中可以很方便地使用下面的方法进行数据(图片)的处理和读取:

from tensorflow.keras import backend
from tensorflow.keras.preprocessing.image import ImageDataGenerator

在Torch中, 需要类似导入:

from torchvision import transforms, datasets
from torch.utils.data import DataLoader

数据读取/处理部分的api差异

在数据读取部分, 我感觉还是Keras比较方便一些1, Torch主要还是使用的模块化的导入方式, 需要先实例化一个类, 然后用该对象进行对图像的处理.

下面先来看看TF的读取图片数据的代码:

# 导入数据
if backend.image_data_format() == 'channels_first':input_shape = (3, img_width, img_height)
else:input_shape = (img_width, img_height, 3)# 训练集图像增强
train_datagen = ImageDataGenerator(rescale=1. / 255,shear_range=0.2,zoom_range=0.2,horizontal_flip=True)# 测试集图像增强(only rescaling)
test_datagen = ImageDataGenerator(rescale=1. / 255)train_generator = train_datagen.flow_from_directory(train_data_dir,target_size=(img_width, img_height),batch_size=batch_size,class_mode='categorical')  # 多分类validation_generator = test_datagen.flow_from_directory(validation_data_dir,target_size=(img_width, img_height),batch_size=batch_size,class_mode='categorical')  # 多分类

接下来是Torch的代码:

# 导入数据
input_shape = (img_width, img_height, 3)# 训练集图像增强
train_datagen = transforms.Compose([transforms.ToTensor(),transforms.RandomHorizontalFlip(),transforms.Resize((img_width, img_height))
])# 测试集图像增强(only rescaling)
test_datagen = transforms.Compose([  # 对读取的图片进行以下指定操作transforms.ToTensor(), # 这步相当于Keras的rescale为1/255transforms.Resize((img_width, img_height))
])train_generator = datasets.ImageFolder(train_data_dir, transform=train_datagen)validation_generator = datasets.ImageFolder(validation_data_dir,transform=test_datagen)train_loader = torch.utils.data.DataLoader(train_generator, batch_size=batch_size,shuffle=True)test_loader = torch.utils.data.DataLoader(validation_generator,batch_size=batch_size,shuffle=False)

模型构建部分的api差异

下面谈谈最重要的, 模型的构建部分的api调用的区别, 在TF中直接进行model.add的调用, 就可以方便地创建一个CNN识别模型了, 这里需要注意数据流维度的对应, 下面是代码. 简练直观.

# 创建模型
model = Sequential()
model.add(Conv2D(filters=6, kernel_size=(5, 5), padding='valid', input_shape=input_shape, activation='tanh'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(filters=16, kernel_size=(5, 5), padding='valid', activation='tanh'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(120, activation='tanh'))
model.add(Dense(84, activation='tanh'))
model.add(Dense(4, activation='softmax'))#编译模型
model.compile(loss='categorical_crossentropy',optimizer='sgd',metrics=['accuracy'])

在Torch中, 也有类似的方法, 不过不需要进行模型的编译, 代码如下:

# 创建模型
model = Sequential(Conv2d(in_channels=3, out_channels=6, kernel_size=(5, 5), padding='valid'),MaxPool2d(kernel_size=(2, 2)),Conv2d(in_channels=6, out_channels=16, kernel_size=(5, 5), padding='valid'),MaxPool2d(kernel_size=(2, 2)),Flatten(),Linear(400, 120),Linear(120, 84), Linear(84, 4)
)# 这里设置了损失函数为交叉熵函数
criterion = CrossEntropyLoss()
# 设置优化器为随机梯度下降算法
optimizer = SGD(model.parameters(), lr=0.001)

这里在api方面还是有一些区别, 例如全连接层的写法以及参数, 还有卷积层的一些区别. 一样地, 还是要非常注意数据维度.

模型训练部分的api差异

在TF中, 由于引入了Keras这个强有力而且语法简洁的api, 训练起来模型也十分简单, 代码如下:

#训练模型
history=model.fit_generator(train_generator,steps_per_epoch=nb_train_samples // batch_size,epochs=epochs,validation_data=validation_generator,validation_steps=nb_validation_samples // batch_size)

但是在Torch中, 还需要自己一步步进行搭建, 略显繁琐

n_total_steps = len(train_loader)
for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):# Forward passoutputs = model(images)loss = criterion(outputs, labels)# Backward and optimizeoptimizer.zero_grad()loss.backward()optimizer.step()if (i+1) % 5 == 0:print(f'''Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.4f}''')torch.save(model.state_dict(), './ckpt')

小结

善用搜索引擎, 官方文档都有这两个框架的详细api使用方法.

主要参考


  1. 图像预处理 - Keras 中文文档; ↩︎

CNN图像分类Keras代码转换pytorch思路与实现相关推荐

  1. 【深度学习】Keras vs PyTorch vs Caffe:CNN实现对比

    作者 | PRUDHVI VARMA 编译 | VK 来源 | Analytics Indiamag 在当今世界,人工智能已被大多数商业运作所应用,而且由于先进的深度学习框架,它非常容易部署.这些深度 ...

  2. Keras vs PyTorch vs Caffe:CNN实现对比

    作者|PRUDHVI VARMA 编译|VK 来源|Analytics Indiamag 在当今世界,人工智能已被大多数商业运作所应用,而且由于先进的深度学习框架,它非常容易部署.这些深度学习框架提供 ...

  3. Pytorch和CNN图像分类

    Pytorch和CNN图像分类 PyTorch是一个基于Torch的Python开源机器学习库,用于自然语言处理等应用程序.它主要由Facebookd的人工智能小组开发,不仅能够实现强大的GPU加速, ...

  4. 2_初学者快速掌握主流深度学习框架Tensorflow、Keras、Pytorch学习代码(20181211)

    初学者快速掌握主流深度学习框架Tensorflow.Keras.Pytorch学习代码 一.TensorFlow 1.资源地址: 2.资源介绍: 3.配置环境: 4.资源目录: 二.Keras 1.资 ...

  5. Keras vs PyTorch,哪一个更适合做深度学习?

    选自Medium 作者:Karan Jakhar 机器之心编译 参与:小舟.魔王 如何选择工具对深度学习初学者是个难题.本文作者以 Keras 和 Pytorch 库为例,提供了解决该问题的思路. 当 ...

  6. 【图像分类】 基于Pytorch的细粒度图像分类实战

    欢迎大家来到<图像分类>专栏,今天讲述基于pytorch的细粒度图像分类实战! 作者&编辑 | 郭冰洋 1 简介 针对传统的多类别图像分类任务,经典的CNN网络已经取得了非常优异的 ...

  7. 【干货】Keras vs PyTorch,哪一个更适合做深度学习?

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 如何选择工具对深度学习初学者是个难题.本文作者以 Keras 和 ...

  8. 初学者指南:使用 Numpy、Keras 和 PyTorch 实现最简单的机器学习模型线性回归

    来源:DeepHub IMBA 本文约5100字,建议阅读10分钟 本文将使用 Python 中最著名的三个模块来实现一个简单的线性回归模型. 机器学习是人工智能的一门子科学,其中计算机和机器通常学会 ...

  9. Keras vs PyTorch:谁是第一深度学习框架?

    「第一个深度学习框架该怎么选」对于初学者而言一直是个头疼的问题.本文中,来自 deepsense.ai 的研究员给出了他们在高级框架上的答案.在 Keras 与 PyTorch 的对比中,作者还给出了 ...

最新文章

  1. 数据访问层之数据库访问设计(转)
  2. Access violation at address 0x77f96c94
  3. java设计模式之装饰器模式
  4. 自定义图片字段调用的问题出现{dede:img ..}
  5. 关于Python课程
  6. ALV OO的栏位属性
  7. 【MOSS】Sharepoint大附件上传
  8. 视频 | 在小程序竞争激烈的今天,淘票票如何脱颖而出?
  9. Kafka分区分配策略(1)——RangeAssignor
  10. 小说中场景的功能_《流浪地球》:从小说到电影
  11. c++11或c++14或c++17参数包的使用
  12. 在设计四人抢答器中灯全亮_数字电子技术课程设计报告(四人抢答器).doc
  13. Android中的PopUpWindow
  14. 使用kubeadm搭建的k8s集群修改node节点主机名
  15. [19保研]四川大学网络空间安全学院 关于举办2018年优秀大学生暑期夏令营的通知...
  16. Android实例精讲——通过ListView构造微信聊天界面视图
  17. pybind11学习 | 面向对象编程
  18. 零基础可以学python么
  19. 九大数据分析方法:结构分析法
  20. 蓝桥杯 青蛙跳杯子【第八届】【省赛】【C组】 BFS 广搜

热门文章

  1. 支付宝APP支付(基于Java实现支付宝APP支付)
  2. 【转载】前后端分离的思考与实践(五)
  3. Java 中 modifer #39;public#39; is reduntant for interface methods
  4. UVa 1620 懒惰的苏珊(逆序数)
  5. 管理和维护RHCS集群
  6. 我TM快疯了,在博客园开博短短2个月,经历博客园数次故障。。。
  7. C++ STL set集合的使用
  8. PAT乙级(1002 写出这个数 )
  9. 今晚直播预告丨Oracle 19c避雷经验分享
  10. openGauss北京Meetup成功举办,“产学研用”合力共建主流根社区(附:视频回放PPT)...