计算数据集的均值和标准差

import os
import cv2
import numpy as np
from torch.utils.data import Dataset
from PIL import Imagedef compute_mean_and_std(dataset):# 输入PyTorch的dataset,输出均值和标准差mean_r = 0mean_g = 0mean_b = 0for img, _ in dataset:img = np.asarray(img) # change PIL Image to numpy arraymean_b += np.mean(img[:, :, 0])mean_g += np.mean(img[:, :, 1])mean_r += np.mean(img[:, :, 2])mean_b /= len(dataset)mean_g /= len(dataset)mean_r /= len(dataset)diff_r = 0diff_g = 0diff_b = 0N = 0for img, _ in dataset:img = np.asarray(img)diff_b += np.sum(np.power(img[:, :, 0] - mean_b, 2))diff_g += np.sum(np.power(img[:, :, 1] - mean_g, 2))diff_r += np.sum(np.power(img[:, :, 2] - mean_r, 2))N += np.prod(img[:, :, 0].shape)std_b = np.sqrt(diff_b / N)std_g = np.sqrt(diff_g / N)std_r = np.sqrt(diff_r / N)mean = (mean_b.item() / 255.0, mean_g.item() / 255.0, mean_r.item() / 255.0)std = (std_b.item() / 255.0, std_g.item() / 255.0, std_r.item() / 255.0)return mean, std

视频数据基本信息

import cv2
video = cv2.VideoCapture(mp4_path)
height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
fps = int(video.get(cv2.CAP_PROP_FPS))
video.release()

读取并预处理CIFAR10

import torchvision
import torchvision.transforms as transforms# torchvision数据集的输出是在[0, 1]范围内的PILImage图片。
# 我们此处使用归一化的方法将其转化为Tensor,数据范围为[-1, 1]transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
'''注:这一部分需要下载部分数据集 因此速度可能会有一些慢 同时你会看到这样的输出Downloading http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz
Extracting tar file
Done!
Files already downloaded and verified
'''

常用训练和验证数据预处理

#其中 ToTensor 操作会将 PIL.Image 或形状为 H×W×D,数值范围为 [0, 255] 的 np.ndarray 转换为形状#为 D×H×W,数值范围为 [0.0, 1.0] 的 torch.Tensor。
train_transform = torchvision.transforms.Compose([torchvision.transforms.RandomResizedCrop(size=224,scale=(0.08, 1.0)),torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225)),])val_transform = torchvision.transforms.Compose([torchvision.transforms.Resize(256),torchvision.transforms.CenterCrop(224),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225)),
])

分类模型训练代码

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):for i ,(images, labels) in enumerate(train_loader):images = images.to(device)labels = labels.to(device)# Forward passoutputs = model(images)loss = criterion(outputs, labels)# Backward and optimizeroptimizer.zero_grad()loss.backward()optimizer.step()if (i+1) % 100 == 0:print('Epoch: [{}/{}], Step: [{}/{}], Loss: {}'.format(epoch+1, num_epochs, i+1, total_step, loss.item()))

分类模型测试代码

# Test the model
model.eval()  # eval mode(batch norm uses moving mean/variance #instead of mini-batch mean/variance)
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Test accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))

自定义loss

继承torch.nn.Module类写自己的loss。

class MyLoss(torch.nn.Moudle):def __init__(self):super(MyLoss, self).__init__()def forward(self, x, y):loss = torch.mean((x - y) ** 2)return loss

pytorch实战从入门到精通第三部分——数据处理相关推荐

  1. pytorch实战从入门到精通第二部分——卷积神经网络

    两层卷积网络的示例 # convolutional neural network (2 convolutional layers) class ConvNet(nn.Module):def __ini ...

  2. Pytorch实战从入门到精通第一部分——手写字符识别全流程

    下面是用MNIST手写字符数据从数据loader到全连接网络设计.模型训练.模型测试.模型存储的全过程完整代码,仔细品味可供学习使用. import torch import torch.nn as ...

  3. CUDA从入门到精通(三):必备资料

    CUDA从入门到精通(三):必备资料 2013-07-23 09:20 3676人阅读 评论(0) 收藏 举报  分类: GPU(29)  版权声明:本文为卜居原创文章,未经博主允许不得转载.卜居博客 ...

  4. 【Python】Python实战从入门到精通之四 -- 教你使用Python中字典

    本文是Python实战–从入门到精通系列的第四篇文章: Python实战从入门到精通第一讲–Python中的变量和数据类型 Python实战从入门到精通第二讲–Python中列表操作详解 Python ...

  5. 【Python】Python实战从入门到精通之三 -- 教你使用Python中条件语句

    本文是Python实战–从入门到精通系列的第三篇文章: Python实战从入门到精通第1讲–Python中的变量和数据类型 Python实战从入门到精通第2讲–Python中列表操作详解 Python ...

  6. 黑客零基础入门教程:「黑客攻防实战从入门到精通(第二版)」堪称黑客入门天花板

    前言 您知道在每天上网时,有多少黑客正在浏览您计算机中的重要数据吗﹖黑客工具的肆意传播,使得即使是稍有点计算机基础的人,就可以使用简单的工具对网络中一些疏于防范的主机进行攻击,在入侵成功之后,对其中的 ...

  7. unity应用开发实战案例_Unity3D游戏引擎开发实战从入门到精通

    Unity3D游戏引擎开发实战从入门到精通(坦克大战项目实战.NGUI开发.GameObject) 一.Unity3D游戏引擎开发实战从入门到精通是怎么样的一门课程(介绍) 1.1.Unity3D游戏 ...

  8. .net MVC5+EF6+bootstrap搭建框架,从入门到精通(三)——之(Bootstrap Fileinput)多图片上传

    .net MVC5+EF6+bootstrap搭建框架,从入门到精通(三)--之(Bootstrap Fileinput)多图片上传 前言废话 .net mvc 实战多图片上传 前言废话 人生最大的b ...

  9. php flock 都是true_PHP从入门到精通(三)PHP语言基础

    PHP从入门到精通(三)PHP语言基础 一.PHP标记风格 PHP支持4种标记风格 1.XML风格.(推荐使用) <?phpecho "这是XML分割的标记"; ?> ...

最新文章

  1. 刚刚!美国官宣100000名 IT 人失业,感觉很慌 !
  2. 机器学习基础-数据降维
  3. 【编程】二叉树的先序、中序、后序遍历
  4. mysql delimiter 作用
  5. android中的broadcastReceiver
  6. 轻松八句话 教会你完全搞定MySQL数据库(基础)
  7. 关于win7启动看不到桌面的解决方法
  8. python3的encode()和decode()
  9. [Node.js]001.安装与环境配置
  10. 日照科技中等专业学校 远程预付费系统的设计与应用
  11. springMVC实现json 返回到页面
  12. Python基础七 元组、字典、集合
  13. 微信分享自定义图标大小限制_微信分享时安卓的自定义参数无效的解决办法
  14. 数据挖掘项目——Airbnb 新用户的民宿预定结果预测
  15. 动态规划从入门到放弃【1】
  16. 第四课:点亮LED灯
  17. 《OSPF和IS-IS详解》一1.7 独立且平等
  18. Ehcache缓存时间设置
  19. 罕见水星凌日直播,QQ物联携手腾讯云带你连接宇宙
  20. 开源软件及国内发展趋势

热门文章

  1. Umbrella Network与Linear Finance合作,将专业金融数据带入DeFi
  2. 二代征信在小额线上贷款风控领域应用探索
  3. 检测到目标服务器启用了TRACE方法
  4. 金币(NOIP2015 普及组第一题)
  5. 读书笔记 - 《软件业的成功奥秘》
  6. POJ3250(单调栈)
  7. Unity教程之再谈Unity中的优化技术
  8. objective-c 类别
  9. 安卓渗透测试工具——Drozer(安装和使用)
  10. 单调栈与单调队列简单例题