PyTorch是Facebook开发的AI框架,其最新代码在GitHub进行更新。自2017年以来,它的使用率稳步一直保持稳定增长。相对于TensorFlow框架入门更为简单,也可以很方便的进行网络的构建以完成网络的训练,从而帮助我们很快的复现论文,是一个非常值得学习的框架。

本文主要介绍PyTorch的入门知识,从构建网络模型开始,到如何创建自定义的数据加载器,然后更新网络权重以完成模型的训练。

构建网络

PyTorch提供了一种构建自己模型的标准方法,整个定义应保留在对象中,该对象是nn.Module类的子类。在该类中,一般包含__init__和forward方法,为了更形象的解释如何使用这些基本概念,我在下面给出了一个神经网络模型构建的示例,该网络包含3个全连接层和2个RELU层。

import torch
import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super(Net, self).__init__()# Defining 3 linear layers but NOT the way they should be connected# Receives an array of length 240 and outputs one with length 120self.fc1 = nn.Linear(240, 120)# Receives an array of length 120 and outputs one with length 60self.fc2 = nn.Linear(120, 60)# Receives an array of length 60 and outputs one with length 10self.fc3 = nn.Linear(60, 10)def forward(self, x):# Defining the way that the layers of the model should be connected# Performs RELU on the output of layer 'self.fc1 = nn.Linear(240, 120)'x = F.relu(self.fc1(x))# Performs RELU on the output of layer 'self.fc2 = nn.Linear(120, 60)'x = F.relu(self.fc2(x))# Passes the array through the last linear layer 'self.fc3 = nn.Linear(60, 10)'x = self.fc3(x)return xnet = Net()

· __init__

与其他Python类一样,__init__方法用于定义类的属性和初始化卷积的一些参数。在PyTorch上下文中,始终调用super()方法来初始化父类。除此之外,还可以定义所有具有可优化参数的网络层,对于网络层的定义不需要按照在网络中使用的顺序,因为此处仅完成对网络层的定义。

· forward

表示网络的前向传播过程,即表示各层如何连接的方法,用来构建网络层的先后运算步骤。从上述示例中可以看到,在其中调用__init__内定义的网络层,然后返回代表网络输出的值。

值得注意的是,在forward方法中应用了一些其他函数,这些函数在__init__方法中未定义,但也可以称作网络层。以F.relu()函数为例,我们没有在__init__方法中定义它,是因为它没有任何可训练的参数。换句话说,如果给F.relu()函数提供相同的输入,它将始终提供相同的输出,网络的训练不会影响其行为。因此,根据经验,可以将没有任何权重更新的网络层放入forward方法中。换而言之,将所有具有权重的网络层放在__init__中。

加载数据:Dataset和DataLoader

Dataset和DataLoader是PyTorch中的两个工具,可以定义如何访问数据,便于读者使用自己的数据完成对模型的训练。下面的代码提供了一个使用简单的Dataset / DataLoader类的示例,说明如何定义自己的Dataset类,并完成数据的加载过程。

import torch
import pandas as pd
from torch.utils.data import Dataset, DataLoaderclass ExampleDataset(Dataset):
"""Example Dataset"""def __init__(self, csv_file):
""" csv_file (string): Path to the csv file containing data."""self.data_frame = pd.read_csv(csv_file)def __len__(self):
return len(self.data_frame)def __getitem__(self, idx):
return self.data_frame[idx]# instantiates the dataset
example_dataset = ExampleDataset('my_data_file.csv')# batch size: number of samples returned per iteration
# shuffle: Flag to shuffle the data before reading so you don't read always in the same order
# num_workers: used to load the data in parallel
example_data_loader = DataLoader(example_dataset, , batch_size=4, shuffle=True, num_workers=4)# Loops over the data 4 samples at a time
for batch_index, batch in enumerate(example_data_loader):print(batch_index, batch)

上述Dataset类中使用了3种方法:

· __init__

在初始化过程中,应该输入数据目录信息和其他允许访问的信息。例如上述示例是从csv文件加载数据,也可以使用加载文件名列表,其中每个文件名代表一个数据。注意:在该过程中还未加载数据。

· __len__

该方法用于返回数据集的大小。例如,如果某些目录中有一些图像,则必须实现一种对构成该数据集文件总数进行计数的方法。上述示例中只是获得数据帧的长度。

· __getitem__

该方法用于接收一个索引idx,并返回数据集中对应的数据和标签,是数据加载的核心方法。

为了更有效地加载数据集,我们可以使用DataLoader类。该类可以并行读取一批数据,同时可以选择是否对数据进行重新排序。所有上述操作技巧都可以帮助我们更好地完成模型的训练。

·  END  ·

RECOMMEND

推荐阅读

1. 效率提升的软件大礼包

2. 那么多可选编程语言,Why Python?

3. 学习Python,你选对书了吗?

4. 90%初学者会混淆的Python概念

深度学习 — — 入门PyTorch(一)相关推荐

  1. PyTorch深度学习入门与实战(案例视频精讲)

    作者:孙玉林,余本国 著 出版社:中国水利水电出版社 品牌:智博尚书 出版时间:2020-07-01 PyTorch深度学习入门与实战(案例视频精讲)

  2. PyTorch深度学习入门

    作者:曾芃壹 出版社:人民邮电出版社 品牌:iTuring 出版时间:2019-09-01 PyTorch深度学习入门

  3. 干货|《深度学习入门之Pytorch》资料下载

    深度学习如今已经成为了科技领域中炙手可热的技术,而很多机器学习框架也成为了研究者和业界开发者的新宠,从早期的学术框架Caffe.Theano到如今的Pytorch.TensorFlow,但是当时间线来 ...

  4. (翻译)60分钟入门深度学习工具-PyTorch

    60分钟入门深度学习工具-PyTorch 作者:Soumith Chintala 原文翻译自: https://pytorch.org/tutorials/beginner/deep_learning ...

  5. 深度学习入门之PyTorch学习笔记:卷积神经网络

    深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 3 多层全连接网络 4 卷积神经网络 4.1 主要任务及起源 4.2 卷积神经网络的原理和结构 4.2.1 卷积层 1. ...

  6. 深度学习入门之PyTorch学习笔记:多层全连接网络

    深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 3 多层全连接网络 3.1 PyTorch基础 3.2 线性模型 3.2.1 问题介绍 3.2.2 一维线性回归 3.2 ...

  7. 一篇文章入门深度学习框架PyTorch

    一篇文章入门深度学习框架PyTorch 1 Tensor(张量) 2 Variable(变量) 3 Dataset(数据集) 4 nn.Module(模组) 5 torch.optim(优化) 一阶优 ...

  8. 深度学习入门之PyTorch学习笔记:深度学习框架

    深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 2.1 深度学习框架介绍 2.1.1 TensorFlow 2.1.2 Caffe 2.1.3 Theano 2.1.4 ...

  9. 深度学习入门之PyTorch学习笔记:深度学习介绍

    深度学习入门之PyTorch学习笔记:深度学习介绍 绪论 1 深度学习介绍 1.1 人工智能 1.2 数据挖掘.机器学习.深度学习 1.2.1 数据挖掘 1.2.2 机器学习 1.2.3 深度学习 第 ...

最新文章

  1. freebsd mysql 安装_Freebsd中mysql安装及使用笔记-阿里云开发者社区
  2. oracle共享时监听,Oracle监听---共享连接参数配置介绍
  3. linux 块编辑,vim中的可视块编辑
  4. Java 平台有哪几个版本?
  5. 细思极恐丨几个有趣的科学实验
  6. 科学计算机clr,科学计算器按键功能汇总
  7. Apache会比路虎的应急效果更好
  8. windows 代理软件_MacOS好用软件推荐(一)
  9. nginx与php处理用户请求,配置 NGINX 处理 PHP 的请求《 LEMP 网站应用运行环境 》
  10. 追求神乎其技的程序设计之道(一)
  11. 1.JsDroid命令行调试命令
  12. Android常用播放器对比,谁更好用?四款Android音乐播放器对比
  13. python斐波那契数列计算_python计算斐波那契数列
  14. 920quiz+922复杂度+927quiz2
  15. c语言预处理命令12个,C语言编译预处理和预处理命令
  16. mysql主从 主机宕机_MySQL主从宕机的解决方法
  17. 亲爱的老狼-清除浮动float的5种方法
  18. 无线网怎么建立虚拟服务器,Win7创建虚拟WiFi热点共享的教程
  19. linux 开发板相关命令
  20. 计算机中cmos设置程序,计算机CMOS设置详解.doc

热门文章

  1. 川大2019计算机硕土论文盲审,关于2019年研究生学位论文答辩和学位申请工作安排的通知...
  2. Python输出田字格
  3. 档案管理html,人员档案管理.html
  4. Linux 网络tc,linux下使用tc和netem模拟复杂网络环境
  5. Java 在线订餐系统
  6. 四川大学(新生赛)羊工八刀(为解决)
  7. NLP冻手之路(1)——中文/英文字典与分词操作(Tokenizer)
  8. js 图片打碎_html5 tweenmax.js打碎玻璃图片轮播切换特效
  9. 用cocos2d 2.1制作一个过河小游戏(1): 总概
  10. 光线传输Review