作者 | Manpreet Singh Minhas

编译 | VK
来源 | Towards Data Science

深度学习/机器学习工作流程通常不同于人们对正常软件开发过程的期望。但这并不意味着人们不应该从这些年来不断发展的软件开发中汲取灵感并进行实践。

在本文中,我将讨论单元测试以及为什么以及如何在代码中包含这些测试。我们将首先简要介绍单元测试,然后是一个深度学习中的单元测试示例,以及如何通过命令行和VS代码测试资源管理器运行这些测试。

介绍

单元测试是软件开发人员熟悉的概念。这是一种非常有用的技术,可以帮助你防止代码中出现明显的错误和bug。它包括测试源代码的各个单元,如函数、方法和类,以确定它们是否满足要求并具有预期的行为。

单元测试通常很小,执行起来不需要太多时间。测试的输入范围很广,通常包括边界和边缘情况。这些输入的输出通常由开发人员手动计算,以测试被测试单元的输出。

例如,对于加法器函数,我们将有如下测试用例。(稍后我们将看到一个深度学习的示例。)

你可以用正输入、零输入、负输入、正输入和负输入测试用例。

如果我们正在测试的函数/方法的输出与单元测试中为所有输入案例定义的输出相等,那么你的单元将通过测试,否则它将失败。你将确切地知道哪个测试用例失败。可以进一步调查,找出问题所在。

如果有多个开发人员正在处理一个大型项目。假设有人基于某些假设和数据大小编写了一段代码,而新的开发人员更改了代码库中不再满足这些假设的内容。那么代码肯定会失败。单元测试允许避免这种情况。

下面是单元测试的一些好处。

  • 强制你编写具有明确定义的输入和输出的模块化和可重用代码。因此,你的代码将更易于集成。

  • 提高了更改/维护代码的信心。它有助于识别代码更改引入的bug。

  • 提高了对单元本身的信心,因为如果它通过了单元测试,我们可以确定逻辑没有明显的错误,并且单元按预期运行。

  • 调试变得更容易,因为你可以知道哪个单元失败了,以及哪些特定的测试用例失败了。

Python中的单元测试

每种语言都有自己的工具和包可用于进行单元测试。Python还提供了一些单元测试框架。unittest包是标准Python库的一部分。

我将讨论如何通过命令行/bash和VS Code UI界面来使用这个框架。它的灵感来自JUnit,与其他语言中的主要单元测试框架有相似的风格。它支持测试自动化、共享测试的设置和关闭代码、将测试聚合到集合中以及独立于测试的报告框架[4]。

在这个框架中,单元测试的基本构建块是测试用例——必须设置并检查其正确性的场景。在unittest中,测试用例是unittest.TestCase。要生成测试用例,必须编写TestCase的子类。

TestCase实例的测试用例应该是自包含的,这样它可以单独运行,也可以与任何数量的其他测试用例任意组合运行。TestCase子类的测试方法应该在名称中有test前缀,并执行特定的测试代码。

为了执行测试,TestCase基类有几个assert方法,允许你对照被测试单元的输出检查测试用例的输出。如果测试失败,将引发异常并给出解释性消息,unittest将测试用例标识为失败。任何其他异常都将被视为错误。

有两种类型的setup方法可用于为测试设置类。

  1. setUp -这将在类中的每个测试方法之前调用。

  2. setUpClass-整个类只运行一次。这是你应该用来做深度学习测试的方法。在此方法中加载模型,以避免在执行每个测试方法之前重新加载模型。这将节省模型重新加载时间。

请注意,各种测试的运行顺序是通过根据字符串的内置顺序对测试方法名称进行排序来确定的。

现在让我们看看我为一个项目的PyTorch数据加载器而创建的单元测试。代码如下所示。

import unittest
from pathlib import Pathimport torch
from PIL import Image
from segdataset import SegmentationDataset
from torch.utils.data import DataLoader
from torchvision import transformsclass Test_TestSegmentationDataset(unittest.TestCase):@classmethoddef setUpClass(cls) -> None:seg_dataset = SegmentationDataset("CrackForest","Images","Masks",transforms=transforms.Compose([transforms.ToTensor()]))seg_dataloader = DataLoader(seg_dataset,batch_size=4,shuffle=False,num_workers=8)cls.samples = next(iter(seg_dataloader))def test_image_tensor_dimensions(self):image_tensor_shape = Test_TestSegmentationDataset.samples['image'].shapeself.assertEqual(image_tensor_shape[0], 4)self.assertEqual(image_tensor_shape[1], 3)self.assertEqual(image_tensor_shape[2], 320)self.assertEqual(image_tensor_shape[3], 480)def test_mask_tensor_dimensions(self):mask_tensor_shape = Test_TestSegmentationDataset.samples['mask'].shapeself.assertEqual(mask_tensor_shape[0], 4)self.assertEqual(mask_tensor_shape[1], 1)self.assertEqual(mask_tensor_shape[2], 320)self.assertEqual(mask_tensor_shape[3], 480)def test_mask_img_pair(self):ref_image_tensor = transforms.ToTensor()(Image.open(Path("CrackForest/Images/001.jpg")))ref_mask_tensor = transforms.ToTensor()(Image.open(Path("CrackForest/Masks/001_label.PNG")))datagen_image_tensor = Test_TestSegmentationDataset.samples['image'][0]datagen_mask_tensor = Test_TestSegmentationDataset.samples['mask'][0]self.assertTrue(torch.equal(ref_image_tensor, datagen_image_tensor))self.assertTrue(torch.equal(ref_mask_tensor, datagen_mask_tensor))
© 2021 GitHub, Inc.

被测试的分割数据集需要批量加载相应的图像和mask对。将正确的图像映射到正确的mask是至关重要的。

为此,通常,图像和mask的名称中都有相同的数字。如果你正在通过一些增强来调整图像的大小,那么你的结果大小应该与预期的一样。对于PyTorch,数据加载器返回的张量应该是BxCxHxW形式,其中B是批大小,C是通道数,H是高度,W是宽度。

现在,我来解释代码中发生了什么。我创建了一个从unittest.TestCase测试用例基类。如前所述,我创建了一个setUpClass方法,它是一个类方法,用于确保初始化只执行一次。

class Test_TestSegmentationDataset(unittest.TestCase):@classmethoddef setUpClass(cls) -> None:seg_dataset = SegmentationDataset("CrackForest","Images","Masks",transforms=transforms.Compose([transforms.ToTensor()]))seg_dataloader = DataLoader(seg_dataset,batch_size=4,shuffle=False,num_workers=8)cls.samples = next(iter(seg_dataloader))

这里需要注意的一点是,为了测试,我在dataloader中禁用了shuffle。因为我希望名称中带有001的映像和mask出现在dataloader创建的第一批的索引0中。

从不同的批次中检查不同的样本索引将是一个更好的测试,因为你将确保不同批次的顺序是一致的。我把第一批储存在cls作为类属性。

现在初始化完成了,我们来看看各个测试。

在第一个测试中,我检查dataloader返回的图像张量维度。因为我没有调整大小的图像,我希望大小为320x480和这些图像正在读取为RGB,所以应该有3个通道。在setUpClass方法中,我将批大小指定为4,因此张量的第一个维度应该是4。如果尺寸有问题,这个测试就会失败。

    def test_image_tensor_dimensions(self):image_tensor_shape = Test_TestSegmentationDataset.samples['image'].shapeself.assertEqual(image_tensor_shape[0], 4)self.assertEqual(image_tensor_shape[1], 3)self.assertEqual(image_tensor_shape[2], 320)self.assertEqual(image_tensor_shape[3], 480)

下一个测试是完全相同的,除了它是为mask张量。在这个特定的数据集中,mask只有一个通道。所以我希望通道数是1。批量大小应为4。mask形状应为320x480。

    def test_mask_tensor_dimensions(self):mask_tensor_shape = Test_TestSegmentationDataset.samples['mask'].shapeself.assertEqual(mask_tensor_shape[0], 4)self.assertEqual(mask_tensor_shape[1], 1)self.assertEqual(mask_tensor_shape[2], 320)self.assertEqual(mask_tensor_shape[3], 480)

最后一个测试检查两件事。首先是通过手动应用dataloader中指定的变换获得的张量是否产生与dataloader相同的结果。其次是图像和mask对是正确的。

要直接应用torchvision变换,需要实例化transform并将图像作为输入传递给该实例。如果transform需要一个PIL图像或numpy数组(对于ToTensor就是这种情况),任何其他格式都会导致错误。

    def test_mask_img_pair(self):ref_image_tensor = transforms.ToTensor()(Image.open(Path("CrackForest/Images/001.jpg")))ref_mask_tensor = transforms.ToTensor()(Image.open(Path("CrackForest/Masks/001_label.PNG")))datagen_image_tensor = Test_TestSegmentationDataset.samples['image'][0]datagen_mask_tensor = Test_TestSegmentationDataset.samples['mask'][0]self.assertTrue(torch.equal(ref_image_tensor, datagen_image_tensor))self.assertTrue(torch.equal(ref_mask_tensor, datagen_mask_tensor))

现在我们已经准备好了unittest,让我们先看看如何通过命令行运行这个测试。

可以使用以下命令:

python -m unittest discover -s Tests -p "test_*"

一旦指定了搜索目录和搜索模式,Unittest就可以发现测试。

-s或--start directory directory:它指定开始发现目录。在我们的例子中,由于测试位于tests文件夹中,所以我们将该文件夹指定为该标志的值。-p或--pattern:它指定匹配模式。我指定了一个自定义模式,只是为了向你展示这个功能是可用的。因为默认模式是test*.py,所以它在默认情况下适用于我们的测试脚本。-v或--verbose:如果你指定这个值,你将获得测试类中每个测试方法的输出。

非详细输出和详细输出如下所示。如果所有的测试方法都通过了,那么最后会收到一条OK消息。

但是,如果任何一个测试方法失败,你将得到一条失败消息,其中指定了失败的测试。你会知道哪个断言失败了。如前所述,这对调试和查找破坏代码的原因非常有帮助。在本例中,我更改了正在读取的图像,但没有更改正在比较的张量,这导致了错误。

你可以将此测试执行行包含在任何自动批处理或bash文件中,这些文件可用于自动部署。例如,我们在GitHub操作中使用类似的测试,在更新版本自动推送到包存储库之前自动验证代码是否工作。

接下来,我将向你展示如何使用VS代码测试资源管理器通过UI运行这些测试。

在VS Code[3]中运行Python单元测试

在VS代码中,Python中的测试在默认情况下是禁用的。

要启用测试,请在命令Pallete上使用Python:configuretests命令。此命令提示你选择测试框架、包含测试的文件夹以及用于标识测试文件的模式。

最后两个输入与我们用于通过命令行运行单元测试的输入完全相同。Unittest框架不需要进一步安装。但是,如果你选择的框架包没有安装在你的环境中,VS代码会提示你安装它。

一旦发现被正确设置,我们将在VS代码活动栏中看到带有图标的测试资源管理器。测试资源管理器帮助你可视化、导航和运行测试。

你还可以在测试脚本中看到直接可用的运行测试和调试测试选项。你可以从该视图运行所有或单个测试,还可以导航到不同类中的单个测试方法。

如果测试失败,我会出现一个红色的十字而不是绿色的勾号。如果你想节省时间,你可以选择只运行失败的测试,而不是再次运行所有测试。

结论

本文结束了关于深度学习单元测试的文章。我们简要地了解了什么是单元测试以及它们的好处。

接下来,我们介绍了一个使用unittest包框架用PyTorch编写的数据加载器单元的实际示例。我们学习了如何通过命令行和Python测试资源管理器从VS代码运行这些测试。

我希望你开始为代码编写单元测试并从中获益!谢谢你阅读这篇文章。代码位于:https://github.com/msminhas93/deeplabv3finetunning

参考引用

[1]https://softwaretestingfundamentals.com/unit-testing/

[2]https://www.tutorialspoint.com/unittest_framework/unittest_framework_overview.htm

[3]https://code.visualstudio.com/docs/python/testing

[4]https://docs.python.org/3/library/unittest.html

[5]https://stackoverflow.com/questions/23667610/what-is-the-difference-between-setup-and-setupclass-in-python-unittest/23670844

往期精彩回顾适合初学者入门人工智能的路线及资料下载机器学习及深度学习笔记等资料打印机器学习在线手册深度学习笔记专辑《统计学习方法》的代码复现专辑
AI基础下载机器学习的数学基础专辑温州大学《机器学习课程》视频
本站qq群851320808,加入微信群请扫码:

【深度学习】深度学习中的单元测试相关推荐

  1. 深度学习在计算机视觉中的应用长篇综述

    深度学习在计算机视觉中的应用长篇综述 前言 2012年ImageNet比赛,使深度学习在计算机视觉领域在全世界名声大震,由此人工智能的全球大爆发.第一个研究CNN的专家使Yann LeCun,现就职于 ...

  2. 微信高级研究员解析深度学习在NLP中的发展和应用 | 公开课笔记

    作者 | 张金超(微信模式识别中心的高级研究员) 整理 | Just 出品 | 人工智能头条(公众号ID:AI_Thinker) 近年来,深度学习方法极大的推动了自然语言处理领域的发展.几乎在所有的 ...

  3. 今晚8点开播 | 微信高级研究员解析深度学习在NLP中的发展和应用

    近年来,深度学习方法极大的推动了自然语言处理领域的发展.几乎在所有的 NLP 任务上我们都能看到深度学习技术的应用,并且在很多的任务上,深度学习方法的表现大大的超过了传统方法.可以说,深度学习方法给 ...

  4. 公开课 | 微信高级研究员解析深度学习在NLP中的发展和应用

    近年来,深度学习方法极大的推动了自然语言处理领域的发展.几乎在所有的 NLP 任务上我们都能看到深度学习技术的应用,并且在很多的任务上,深度学习方法的表现大大的超过了传统方法.可以说,深度学习方法给 ...

  5. 读“深度学习在图像处理领域中的应用综述”有感

    摘 要 随着大数据时代的到来,一系列深度学习网络结构已在图像处理领域展现出巨大的优势,为了能够及时跟踪深度学习在图像领域的最新发展,本文章针对深度学习在图像处理领域的相关研究进行综述. 关键词: 深度 ...

  6. Nat. Mach. Intell. | 集成深度学习在生物信息学中的发展与展望

    本期给大家介绍悉尼大学Jean Yang教授课题组发表在Nature machine intelligence的文章"Ensemble deep learning in bioinforma ...

  7. [王晓刚]深度学习在图像识别中的研究进展与展望(转发)

    [王晓刚]深度学习在图像识别中的研究进展与展望(转发) (2015-06-04 08:27:56) 转载▼     深度学习是近十年来人工智能领域取得的最重要的突破之一.它在语音识别.自然语言处理.计 ...

  8. Lesson 12.1 深度学习建模实验中数据集生成函数的创建与使用

    Lesson 12.1 深度学习建模实验中数据集生成函数的创建与使用   为了方便后续练习的展开,我们尝试自己创建一个数据生成器,用于自主生成一些符合某些条件.具备某些特性的数据集.相比于传统的机器学 ...

  9. 微信研究员解析深度学习在NLP中的发展和应用

    微信研究员解析深度学习在NLP中的发展和应用 深度学习在自然语言(NLP)中的发展和应用视频教程,深度学习方法的表现大大的超过了传统方法.可以说,深度学习方法给NLP带来了一场重要的变革.在本课程中, ...

最新文章

  1. 重构-改善既有代码的设计:重新组织数据的16种方法(六)
  2. 说了这么多次 I/O,可你知道其中的原理么
  3. iOS判断UIWebView加载完成的方法
  4. C++模板类注意事项
  5. PyTorch 《动手学深度学习》学习笔记(Dive-into-DL-Pytorch)
  6. 浅析ios开发中Block块语法的妙用
  7. 操作系统笔记(六)调度
  8. python允许无止境的循环_ParisGabriel:Python无止境 day03
  9. Android8.0 学习(15)---适配Android 8.0
  10. 营销系统优惠券模板设计
  11. csrss.exe和winlogon.exe引起cpu居高不下的解决办法
  12. Javaweb常见面试题
  13. 网红品牌终将祛魅,而伊利、康师傅这些老司机们却仍然历久弥新
  14. 计算机如何安装cpu风扇,从零开始学装机 教你如何安装CPU风扇
  15. 用python制作相册影集_影集制作APP哪个好?就用这些APP把照片做成相册!
  16. 光端机连接示意图详细连接方式图解
  17. 人工智能:图像数字化相关的知识介绍
  18. 微信按住说话HTML实现
  19. 机动战士敢达ol服务器链接中断,机动战士敢达OL延迟掉线画面卡解决办法
  20. 商标45类分类表明细表_45类(2017)-商标类别明细

热门文章

  1. SqlServerException:拒绝对表对象的select,insert权限解决(新建账号导致的问题)
  2. (转)Javascript 面向对象编程(一):封装
  3. WebBrowser1.Navigate重复载入同一页面时载入的是旧页面
  4. 纯ActionScript3.0打造的工作流程编辑器(WorkFlowEdit V1.0Bata1.0)
  5. [Git] 001 初识 Git 与 GitHub 之新建仓库
  6. python flask 学习与实战
  7. C#与C++的几个不同之处知识点
  8. mysql登录报错 ERROR 1045 (28000)
  9. Dev-GridView-对于gridview的列值的合计
  10. BZOJ3163 [Heoi2013]Eden的新背包问题