PyTorch基础-猫狗分类实战-10
训练模型并保存
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets,transforms,models
from torch.utils.data import Dataset
import sys
# 数据预处理
transform = transforms.Compose([transforms.RandomResizedCrop(224),# 对图像进行随机裁剪transforms.RandomRotation(20),# 随机旋转角度transforms.RandomHorizontalFlip(p=0.5),# 随机水平翻转transforms.ToTensor()# 变成tensor格式
]) # 数据增强# 读取数据
root = "image"
train_dataset = datasets.ImageFolder(root + "/train",transform)
test_dataset = datasets.ImageFolder(root + "/test",transform)# 导入数据
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=8,shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=8,shuffle=True)
classes = train_dataset.classes
classes_index = train_dataset.class_to_idx
print(classes)
print(classes_index)
model = models.vgg16(pretrained=True)# 载入vgg16预训练模型
print(model)
for param in model.parameters():param.requires_grad = False
# 构建新的全连接层
model.classifier = torch.nn.Sequential(torch.nn.Linear(25088,100),torch.nn.ReLU(),torch.nn.Dropout(p=0.5),torch.nn.Linear(100,2))
LR = 0.0003
# 定义代价函数
entropy_loss = nn.CrossEntropyLoss()
# 定义优化器
optimizer = optim.Adam(model.parameters(),LR)
def train():model.train()for i,data in enumerate(train_loader):# 获得数据和对应的标签inputs,labels = data# 获得模型预测结果(64,10)out = model(inputs)# 交叉熵代价函数out(batch.C),labels(batch)loss = entropy_loss(out,labels)# 梯度清零optimizer.zero_grad()# 计算梯度loss.backward()# 修改权值optimizer.step()def test():model.eval()correct = 0for i,data in enumerate(test_loader):# 获得数据和对应的标签inputs,labels = data# 获得模型预测结果out = model(inputs)# 获得最大值,以及最大值所在的位置_,predicted = torch.max(out,1)# 预测正确的数量correct += (predicted == labels).sum()print("test acc:{0}".format(correct.item()/len(test_dataset)))correct = 0for i,data in enumerate(train_loader):# 获得数据和对应的标签inputs,labels = data# 获得模型预测结果out = model(inputs)# 获得最大值,以及最大值所在的位置_,predicted = torch.max(out,1)# 预测正确的数量correct += (predicted == labels).sum()print("train acc:{0}".format(correct.item()/len(train_dataset)))
for epoch in range(5):print("epoch:",epoch)train()test()
torch.save(model.state_dict(),"cat_dog.pth") # 保存模型
加载模型进行预测
import torch
import numpy as np
from PIL import Image
from torchvision import transforms,models
model = models.vgg16(pretrained=True)
# 构建新的全连接层
model.classifier = torch.nn.Sequential(torch.nn.Linear(25088,100),torch.nn.ReLU(),torch.nn.Dropout(p=0.5),torch.nn.Linear(100,2))
model.load_state_dict(torch.load("cat_dog.pth")) # 加载模型
model.eval() # 预测模式
label = np.array(["cat","dog"])
# 数据预处理
transform = transforms.Compose([transforms.Resize(224),transforms.ToTensor()
])
# 预测函数
def predict(image_path):# 打开图片img = Image.open(image_path)# 数据处理,增加一个维度img = transform(img).unsqueeze(0)# 预测得到的结果outputs = model(img)# 获得最大值所在位置_,predicted = torch.max(outputs,1)# 转换为类别名称print(label[predicted.item()])
predict("image/test/cat/cat.1490.jpg")
PyTorch基础-猫狗分类实战-10相关推荐
- Pytorch+CNN+猫狗分类实战
文章目录 0.前言 1.猫狗分类数据集 1.1数据集下载(可选部分) 1.2数据集分析 2.猫狗分类数据集预处理 2.1训练集和测试集划分 2.2训练集和测试集读取 3.剩余代码 4.总结 0.前言 ...
- 基于Pytorch的猫狗分类
无偿分享~ 猫狗二分类文件下载地址 在下一章说 猫狗分类这个真是困扰我好几天,找了好多资料都是以TensorFlow的猫狗分类,但我们要求的是以pytorch的猫狗分类.刚开始我找到了也 ...
- 基于Pytorch实现猫狗分类
基于Pytorch实现猫狗分类 一.环境配置 二.数据集准备 三.猫狗分类的实例 四.实现分类预测测试 五.参考资料 一.环境配置 1.环境使用 Anaconda 2.配置Pytorch pip in ...
- 【学习笔记】pytorch迁移学习-猫狗分类实战
1.迁移学习入门 什么是迁移学习:在深度神经网络算法的引用过程中,如果我们面对的是数据规模较大的问题,那么在搭建好深度神经网络模型后,我们势必要花费大量的算力和时间去训练模型和优化参数,最后耗费了这么 ...
- PyTorch深度学习实战 | 猫狗分类
本文内容使用TensorFlow和Keras建立一个猫狗图片分类器. 图1 猫狗图片 01.安装TensorFlow和Keras库 TensorFlow是一个采用数据流图(data flow grap ...
- 【深度学习】ResNet残差网络 ResidualBlock残差块实现(pytorch) | 跟着李沐学AI笔记 | ResNet18进行猫狗分类
文章目录 前言 一.卷积的相关计算公式(复习) 二.残差块ResidualBlock复现(pytorch) 三.残差网络ResNet18复现(pytorch) 四.直接调用方法 五.具体实践(ResN ...
- AlexNet 实现猫狗分类(keras and pytorch)
AlexNet 实现猫狗分类 前言 在训练网络过程中遇到了很多问题,先在这里抱怨一下,没有硬件条件去使用庞大的ImageNet2012 数据集 .所以在选择合适的数据集上走了些弯路,最后选择有kagg ...
- Kaggle深度学习与卷积神经网络项目实战-猫狗分类检测数据集
Kaggle深度学习与卷积神经网络项目实战-猫狗分类检测数据集 一.相关介绍 二.下载数据集 三.代码示例 1.导入keras库,并显示版本号 2.构建网络 3.数据预处理 4.使用数据增强 四.使用 ...
- Java软件研发工程师转行之深度学习(Deep Learning)进阶:手写数字识别+人脸识别+图像中物体分类+视频分类+图像与文字特征+猫狗分类
本文适合于对机器学习和数据挖掘有所了解,想深入研究深度学习的读者 1.对概率基本概率有所了解 2.具有微积分和线性代数的基本知识 3.有一定的编程基础(Python) Java软件研发工程师转行之深度 ...
最新文章
- PostgreSQL · 实现分析 · PostgreSQL 10.0 并行查询和外部表的结合
- windows mobile 鼠标等待
- PHP中的常见魔术方法功能作用及用法实例
- oracle 10G windows启动与关闭另类方法
- SIFT讲解(SIFT的特征点选取以及描述是重点)
- 机器学习笔记(八)——决策树模型的特征选择
- 联邦快递就华为包裹被转运致歉 称有关货件正退回发货方
- python 数据结构与算法 day04 快速排序
- 拓扑排序---AOV图
- 5G简介【华为ICT学堂】笔记
- Jenkins build之后清理workspace
- mysql建表测试_测试必备mysql技能2:mysql建表
- 动态表格的实现(layui动态表格实现)
- centos yum清华镜像
- python3字符串格式化
- 快捷下载中国原创音乐基地音乐(包括金豆和无法下载音乐)
- Xv6学习之kinit1
- Hdu 5873 2016 ACM/ICPC Asia Regional Dalian Online 1006(兰道定理)
- java情剑天涯,求超低内存的手机游戏,越多越好
- 支付宝支付-常用支付API详解(查询、退款、提现等)(转)
热门文章
- Yii的各种query
- visual studio学习python_python3从零学习-开发环境搭建之Visual Studio Code篇
- java 语法 冒号_java中生僻的冒号跳转语法
- java jlable添加gif,Java动画GIF而不使用JLabel
- 文件服务器定时开关机,如何配置作服务器定时开关机.ppt
- mysql的代码大全_MySql数目字函数大全
- chroot环境怎么重启linux,linux下简易chroot环境的塔建
- windows cmd 如果失败了则暂停
- vuex中store 的mutation
- mysql 隔离级别 快照_「数据库架构」三分钟搞懂事务隔离级别和脏读