人民币二分类模型

数据-模型-损失函数-优化器-迭代训练

  • 数据收集 img label
  • 数据划分 train valid test
  • 数据读取 Dataloader [sampler-生成索引 dataset-img,label]
  • 数据预处理 transforms

DataLoader

import torch

torch.utils.data.DataLoader()

  • 功能:构建可以可迭代的数据装载器
  • 参数:

    dataset Dataset类,决定数据从哪里读取和如何读取

    batchsize 批大小

    num_works 是否多进程读取数据

    shuffle每个epoch是否乱序

    drop_last 当样本数不能被batchsize整除时,是否舍弃最后一批数据

torch.utils.data.Dataloader(

dataset,

batch_size=1,

shuffle=False,

sampler=None,

batch_sampler=None,

num_workers=0,

drop_last=False

)

Epoch:训练样本都输入到模型中,称为一个Epoch

Iteration:一批样本输入到模型中,称之为一个Iteration

Batchsize:批大小,决定一个Epoch有多少个Iteration

例子:

样本总80,batchsize 8 ,1Epoch = 10 Iteration

1 Epoch = 10 Iteration ? drop_last = True

1 Epoch = 11 Iteration ? drop_last = False

# 功能Dataset抽象类,所有自定义的Dataset需要基础它,并且复写
# __getitem__()
# getitem:接收一个索引,返回一个样本class Dataset(object):def __getitem__(self, index):raise NotImplementedErrordef __add__(self, other):return ConcatDataset([self,other])
import os
dir_test = os.path.join('..','..','data')
print(dir_test)
..\..\data

Transforms

常见的处理方法有:

  • 数据中心化
  • 数据标准化
  • 缩放
  • 剪裁
  • 旋转
  • 翻转
  • 填充
  • 噪声添加
  • 灰度变换
  • 线性变换
  • 仿射变换
  • 亮度、饱和度和对比度变换

transforms.Normalize(mean,std,inplace=False)

数据标准化,能加速模型收敛

数据增强方法

import os
import numpy as np
import torch
import random
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot  as plt

1.tranforms–CenterCrop

transforms.CenterCrop()中心裁剪

参数 size:所需要的尺寸

2.tranforms–RandomCrop

transforms.RandomCrop()随机裁剪

参数

size:所需要的尺寸

padding:填充大小

pad_if_need:如图像小于设定的size,则填充

padding_mode:填充模式,constant像素值fill设定,edge像素值由边缘图像决定,reflect镜像填充,最有个像素不镜像,symmertric镜像填充,最有一个像素镜像

fill:constant设置填充的像素值,如图像小于设定的size,则填充

3.tranforms–RandomResizedCrop

transforms.RandomResizedCrop()中心裁剪

参数 size:所需要的尺寸

参数 scale:随机裁剪面积比例(0.08,1)

参数 ratio:随机长宽比(3/4,4/3)

参数 interpolation:插值方法
* PIL.Image.NEAREST
* PIL.Image.BILINEAR
* PIL.Image.BICUBIC

4.tranforms–FiveCrop

transforms.FiveCrop()在图像的上下左右以及中心裁剪出尺寸为size的10张图片

参数 size:所需要的尺寸

5.tranforms–TenCrop

transforms.TenCrop()在图像的上下左右以及中心裁剪出尺寸为size的10张图片

参数 size:所需要的尺寸

参数 vertical_flip:是否翻转

6.tranforms–RandomHorizontalFlip

transforms.RandomHorizontalFlip()依概率水平翻转【左右】

参数 p:翻转概率

7.tranforms–RandomVerticalFlip

transforms.RandomVerticalFlip()依概率水平垂直【上下】

参数 p:翻转概率

8.tranforms–RandomRotation

transforms.RandomRotation()依概率旋转

参数 degresss:旋转角度

参数 resample:重采样

参数 expand:扩大图片,保持原图信息

参数 center:旋转中心

9.tranforms–Pad

transforms.Pad()对图片边缘填充

参数 padding:设置填充大小

参数 padding_mode:填充模式,分别是constant、edge、reflect、symmetric

参数 fill:为constant时填充像素值

10.tranforms–ColorJitter

transforms.ColorJitter()调节亮度、对比度、饱和度和色相

参数 brightness:调节亮度因子

参数 constrast:调节对比度参数

参数 saturation:调节饱和度

参数 hue:调节色相参数

11.tranforms–Grayscale

transforms.Grayscale()依概率图片转换为灰度

参数 num_output_channels:输出通道数智能设1或3

参数 p:转化为灰度的概率

11.tranforms–RandomGrayscale

transforms.RandomGrayscale()依概率图片转换为灰度

参数 num_output_channels:输出通道数智能设1或3

参数 p:转化为灰度的概率

12.tranforms–RandomAffine

transforms.RandomAffine()对图像进行仿射变换,仿射变换是二维的线性变换,有五种基本原子变换构成,分别是旋转、平移、缩放、错切、翻转

参数 degrees:旋转角度设置

参数 translate:平移区间设置 a设置宽width,b设置高height

参数 scale:缩放比例

参数 fill_color:填充颜色设置

参数 shear:错切角度设置,有水平错切和垂直错切。(a=X轴角度,b=Y轴角度)

参数 resample:重采样

13.tranforms–RandomErasing

transforms.RandomErasing()对图像进行随机遮挡

参数 p:执行遮挡的概率

参数 scale:遮挡区域面积

参数 p:遮挡区域长宽比

参数 p:设置遮挡区域的像素值

14.tranforms–Lambda

transforms.Lambda()用户自定义Lambda方法

表达式:

lambda[arg1[,arg2,…,argn]]:expression

transforms.Tencrop(200,vertical_filp=True)
transforms.Lambda(lambda crops:torch.stack([transforms.Totensor()(crop) for crop in crops]))

15.tranforms–RandomChoice

transforms.RandomChoice()随机选择一个transforms方法

transforms.RandomChoice([方法1,方法2,方法3])

16.tranforms–RandomApply

transforms.RandomApply()依概率执行一组transforms方法

transforms.RandomChoice([方法1,方法2,方法3],p=0.5)

17.tranforms–RandomOrder

transforms.RandomOrder()对一组transforms操作打乱顺序

transforms.RandomChoice([方法1,方法2,方法3])

18 自定义transforms方法

  • 1.仅接收一个参数,返回一个参数
  • 2.注意上下游的输出与输入
class Compose(object):def __call__(self,img):for t in self.transforms:img = t(img)return img

通过类实现多参数传入

class YourTransforms(object):def __init__(self,...):...def __call__(self,img):...return img
  • 椒盐噪声:又叫做脉冲噪声,是一种随机出现的白点或者黑点,白点叫盐噪声,黑点叫椒噪声
  • 信噪比(Signal-Noise Rate,SNR)是衡量噪声的比例,图像中为图像像素的占比

一个例子

Class AddPepperNoise(object):"""Args:snr(float):signal noise ratep(float):概率值,依概率执行操作"""def __init__(self,snr,p=0.9):assert isinstance(snr,float) or (isinstance(p,float))self.snr = snrself.p = pdef __call__(self,img):if random.uniform(0,1) < self.p:img_ = np.array(img).copy()h ,w ,c = img_.shapesignal_pct = self.snrnoise_pct = (1 - self.snr)mask = np.random.choice((0,1,2),size = (h,w,1),p=[signal_pct,noise_pct/2.,noise_pct/2])mask = np.repeat(mask,c,axis=2)img_[mask == 1] = 255 # 盐噪声img_[mask == 1] = 0 # 椒噪声return Image.fromarray(img_.astype('utf-8')).convert('RGB')else:return img

数据读取机制Dataloader和Dataset和Transforms相关推荐

  1. 系统学习Pytorch笔记三:Pytorch数据读取机制(DataLoader)与图像预处理模块(transforms)

    Pytorch官方英文文档:https://pytorch.org/docs/stable/torch.html? Pytorch中文文档:https://pytorch-cn.readthedocs ...

  2. TensorFlow数据读取机制:文件队列 tf.train.slice_input_producer和 tf.data.Dataset机制

    TensorFlow数据读取机制:文件队列 tf.train.slice_input_producer和tf.data.Dataset机制 之前写了一篇博客,关于<Tensorflow生成自己的 ...

  3. PyTorch框架学习八——PyTorch数据读取机制(简述)

    PyTorch框架学习八--PyTorch数据读取机制(简述) 一.数据 二.DataLoader与Dataset 1.torch.utils.data.DataLoader 2.torch.util ...

  4. linux 读取大量图片 内存,10 张图帮你搞定 TensorFlow 数据读取机制

    导读 在学习tensorflow的过程中,有很多小伙伴反映读取数据这一块很难理解.确实这一块官方的教程比较简略,网上也找不到什么合适的学习材料.今天这篇文章就以图片的形式,用最简单的语言,为大家详细解 ...

  5. tensorflow 1.0 学习:十图详解tensorflow数据读取机制

    本文转自:https://zhuanlan.zhihu.com/p/27238630 在学习tensorflow的过程中,有很多小伙伴反映读取数据这一块很难理解.确实这一块官方的教程比较简略,网上也找 ...

  6. 十图详解TensorFlow数据读取机制(附代码)

    在学习TensorFlow的过程中,有很多小伙伴反映读取数据这一块很难理解.确实这一块官方的教程比较简略,网上也找不到什么合适的学习材料.今天这篇文章就以图片的形式,用最简单的语言,为大家详细解释一下 ...

  7. PyTorch 入坑六 数据处理模块Dataloader、Dataset、Transforms

    深度学习中的数据处理概述 深度学习三要素:数据.算力和算法 在工程实践中,数据的重要性越来越引起人们的关注.在数据科学界流传着一种说法,"数据决定了模型的上限,算法决定了模型的下限" ...

  8. tensorflow数据读取机制

    原博客地址:https://zhuanlan.zhihu.com/p/27238630 代码地址:https://github.com/hzy46/Deep-Learning-21-Examples/ ...

  9. Cassandra数据读取机制

    数据读取流程 Cassandra会根据需要读取的ColumnFamily查询该ColumnFamily下的Memtable以及所有的SSTable,合并查询结果,将最新的结果返回给客户端.Cassan ...

最新文章

  1. 说说.net事件和委托。
  2. Cassandra-Java(增删查改)
  3. mysql 5.7 存储引擎_mysql5.7——innodb存储引擎总结
  4. D3DX 9.9 LEARNERNOTO
  5. 4.5 搭建深层神经网络块-深度学习-Stanford吴恩达教授
  6. Mina集成Spring --- 在配置文件中配置sessionconfig
  7. Redhat ssh服务登录慢
  8. bootstrap简单使用
  9. 实现进程守护 脚本命令
  10. Linux文件系统管理命令(第二版)
  11. android各层调用关系,架构流程
  12. HttpClient4.x之Post请求示例
  13. 苹果vs剪辑下载_适合mac的视频剪辑软件
  14. 最简单的直播礼物连刷特效制作(带源码)
  15. MySQL在线DDL gh-ost使用说明
  16. 设置合适的密码策略chage命令
  17. 想成为一名数据科学家?你得先读读这篇文章
  18. 银行中台与互联网中台有什么不同?该怎么建?
  19. Vue前端实现微信扫码登录
  20. 51学工坊整理|甲骨文Oracle数据库 21c来了,来看看有哪些创新技术

热门文章

  1. Magic Leap开发指南(1)--开发前准备
  2. 耗时一个月上架了一款微信小程序,赚了2022年的第一笔副收入
  3. 【运行报错--Hadoop】修改后的新分发脚本不生效
  4. jquery语法三ajax+echarts插件的使用
  5. mysql和oracle有什么区别
  6. 什么是 CSRF 攻击?
  7. 【存储技术发展趋势】
  8. 雅虎地图与谷歌地图坐标_打开Yahoo! 将与Google玩得很好,不竞争
  9. PR视频剪辑教程_02_导入素材与新建序列
  10. 利用Pycharm断点调试Python程序