第一次这么认真分享,是因为,我找了好久也没找到和自己目标一致的,只好参考别人的,自行修改了一下下。我的目标是对单通道、只包含一个分割目标的数据集进行归一化。如果想要了解多个分割目标的归一化,可以参考下面的链接:
https://hulin.blog.csdn.net/article/details/116600119?spm=1001.2014.3001.5506
1、准备数据集(dataset)

import os
import torch
import torchvision
from PIL import Image
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
class MyData(Dataset):def __init__(self,root_dir,label_dir):self.root_dir = root_dirself.label_dir = label_dirself.img_path = os.listdir(root_dir)#获取训练数据列表self.label_path = os.listdir(label_dir)self.transformer = torchvision.transforms.Compose([transforms.Resize((256,256)),transforms.CenterCrop(256),transforms.RandomRotation(180),transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomVerticalFlip(p=0.5),transforms.ToTensor(),transforms.Normalize(mean=0.4573,std=0.3618)#输入是tensor类型,0.4573和0.3618是计算出来的,后面说如何计算,目前准备阶段可以先忽略这一行语句])def __getitem__(self, index):#read img:image_name = self.img_path[index]#获取每一个训练数据名称image_path = os.path.join(self.root_dir, image_name)#获取每一个训练数据的路径!!!很重要,不然会打不开文件!!!!img_pil = Image.open(image_path)#read label:label_name = self.label_path[index]label_path = os.path.join(self.label_dir,label_name)label_pil = Image.open(label_path)#data enhance:img_tran = self.transformer(img_pil)label_tran = self.transformer(label_pil)label_tran = torch.squeeze(label_tran)#CEloss要求label:[b,h,w]image = img_tran.float()#指定img为floattensor型labels = label_tran.long()#指定label为longtensor型return image,labelsdef __len__(self):return len(self.img_path)
#训练和验证数据的路径:
train_image_dir = 'E:\\pycharm\\UNet\\train_image'
train_label_dir = 'E:\\pycharm\\UNet\\train_label'
valid_image_dir = 'E:\\pycharm\\UNet\\valid_image'
valid_label_dir = 'E:\\pycharm\\UNet\\valid_label'
#调用MyData类,创建dataset
train_dataset = MyData(train_image_dir,train_label_dir)
valid_dataset = MyData(valid_image_dir,valid_label_dir)

以上就准备好了要归一化的数据集,使用pytorch中的transforms.Normalize(mean,std)进行归一化,要求input是tensor类型,所以前面我将该语句放在了transforms.ToTensor()后面。
接下来的问题就是,如何知道我们的这个数据集的mean(均值)和std(标准差)。
2、计算mean和std:
该方法适用于单通道、只包含一个分割目标的数据集,如果是多个目标,则需要对每个目标(一个目标对应一个通道)进行计算mean和std,参照开头链接。

def getstat(dataset):print(len(dataset))loader = torch.utils.data.DataLoader(dataset,batch_size=1,shuffle=False,num_workers=0,pin_memory = True)mean = torch.zeros(1)#因为我的数据集是单通道的,只包含目标(1)和背景(0),所以我只需要计算一个通道的mean和stdstd = torch.zeros(1)for x,_ in loader:#计算loader中所有数据的mean和atd的累积mean += x.mean()std += x.std()mean = torch.div(mean,len(dataset))#得到整体数据集mean的平均std = torch.div(std,len(dataset))return list(mean.numpy()),list(std.numpy())#返回mean和std的listmean,std = getstat(train_dataset)#调用getstat
mean_,std_ = getstat(valid_dataset)
print(mean,std)
print(mean_,std_)

结果如下:

如有问题,还望指教。(小声说:作者也是刚开始学习)

pytorch实现:数据集归一化处理相关推荐

  1. Pytorch自定义数据集

    简述 Pytorch自定义数据集方法,应该是用pytorch做算法的最基本的东西. 往往网络上给的demo都是基于torch自带的MNIST的相关类.所以,为了解决使用其他的数据集,在查阅了torch ...

  2. 从零开始的图像语义分割:FCN快速复现教程(Pytorch+CityScapes数据集)

    从零开始的图像语义分割:FCN复现教程(Pytorch+CityScapes数据集) 前言 一.图像分割开山之作FCN 二.代码及数据集获取 1.源项目代码 2.CityScapes数据集 三.代码复 ...

  3. pytorch 读取数据集(LiTS-肝肿瘤分割挑战数据集)

    pytorch 读取数据集 我的数据集长这样: xx.png和xx_mask.png是对应的待分割图像和ground truth 读取数据集 数据集对象被抽象为Dataset类,实现自定义的数据集需要 ...

  4. pytorch自定义数据集DataLoder

    pytorch官方例程: DATA LOADING AND PROCESSING TUTORIAL torch.utils.data.Dataset 是dataset的抽象类,我们可以同过继承Data ...

  5. 数据集制作_轻松学Pytorch自定义数据集制作与使用

    点击上方蓝字关注我们 微信公众号:OpenCV学堂 关注获取更多计算机视觉与深度学习知识 大家好,这是轻松学Pytorch系列的第六篇分享,本篇你将学会如何从头开始制作自己的数据集,并通过DataLo ...

  6. 【小白学习PyTorch教程】十七、 PyTorch 中 数据集torchvision和torchtext

    @Author:Runsen 对于PyTorch加载和处理不同类型数据,官方提供了torchvision和torchtext. 之前使用 torchDataLoader类直接加载图像并将其转换为张量. ...

  7. 【问题记录】pytorch自定义数据集 No such file or directory, invalid index of a 0-dim

    保存模型: : 保存整个神经网络的结构和模型参数 torch.save(mymodel, 'mymodel.pkl') 只保存神经网络的模型参数 torch.save(mymodel.state_di ...

  8. pytorch: 自定义数据集加载

    很多网络在数据加载方式 pytorch 的输入流水线的操作顺序是这样的: 创建一个 Dataset 对象     创建一个 DataLoader 对象     不停的 循环 这个 DataLoader ...

  9. pytorch对数据集进行重新采样

    背景: 当不同类型数据的数量差别巨大的时候,比如猫有200张训练图片,而狗有2000张,很容易出现模型只能学到狗的特征,导致准确率无法提升的情况. 这时候,一种可行的方法就是对原始数据集进行采样,从而 ...

最新文章

  1. [LeetCode] 130. Surrounded Regions Java
  2. 2018-08-12 长大
  3. 游戏必备组件有哪些_面试必备:2019Vue经典面试题总结(含答案)
  4. 线程队列-queue
  5. android 微信分享gif图,android后台动态创建图片并实现微信分享
  6. java 开发微信中回调验证一直提示 解密失败处理(Java)
  7. 事业编和公务员哪个好?
  8. java生成json字符串,真香
  9. Linux开机自动启动Tomcat
  10. 自然辩证法 题目2
  11. NO.5 计算数组中三个数的最大乘积
  12. 关于绝对路径与相对路径(详细)
  13. 计算机键盘锁不了怎么办,键盘锁住了怎么解锁?键盘锁死了怎么办?
  14. HTTP协议详解+经典面试题
  15. python程序设计入门书籍推荐_python刚刚入门,接下来这几本python的书会让你成为别人眼里的大神!...
  16. python输出希腊字母
  17. 厦门大学校区计算机考试,厦门大学计算机等级考试报名
  18. 无线网络连接后总是提示可能需要其他登陆信息
  19. 信息系统项目管理师论文范例5:成本管理
  20. 清晰度18级,最新能到2022年的历史影像来瞅一眼(高分卫星、资源卫星)

热门文章

  1. 用Python采集《雪中悍刀行》弹幕做成词云实例
  2. 【推荐系统】搜狐个性化视频推荐架构设计和实践
  3. VRay官方2019建筑表现集锦:世界顶级工作室作品展
  4. [附源码]java毕业设计高校学院主页系统
  5. win10突然蓝牙鼠标连不上,右下角不显示蓝牙图标
  6. 微信3.1.0.58逆向-微信3.1.0.58HOOK接口(WeChatHelper3.1.0.58.dll)使用说明-获取群成员
  7. 机器学习中的正则化——L1范数和L2范数
  8. Android debuggerd 源码分析
  9. 定时采用ajax方式获得数据库,ajax定时刷新数据库
  10. 鸿蒙系统配在华为什么手机上,鸿蒙系统什么时候能用 鸿蒙系统哪些手机可以用...