训练模型并保存

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets,transforms,models
from torch.utils.data import Dataset
import sys
# 数据预处理
transform = transforms.Compose([transforms.RandomResizedCrop(224),# 对图像进行随机裁剪transforms.RandomRotation(20),# 随机旋转角度transforms.RandomHorizontalFlip(p=0.5),# 随机水平翻转transforms.ToTensor()# 变成tensor格式
]) # 数据增强# 读取数据
root = "image"
train_dataset = datasets.ImageFolder(root + "/train",transform)
test_dataset = datasets.ImageFolder(root + "/test",transform)# 导入数据
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=8,shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=8,shuffle=True)
classes = train_dataset.classes
classes_index = train_dataset.class_to_idx
print(classes)
print(classes_index)

model = models.vgg16(pretrained=True)# 载入vgg16预训练模型
print(model)

for param in model.parameters():param.requires_grad = False
# 构建新的全连接层
model.classifier = torch.nn.Sequential(torch.nn.Linear(25088,100),torch.nn.ReLU(),torch.nn.Dropout(p=0.5),torch.nn.Linear(100,2))
LR = 0.0003
# 定义代价函数
entropy_loss = nn.CrossEntropyLoss()
# 定义优化器
optimizer = optim.Adam(model.parameters(),LR)
def train():model.train()for i,data in enumerate(train_loader):# 获得数据和对应的标签inputs,labels = data# 获得模型预测结果(64,10)out = model(inputs)# 交叉熵代价函数out(batch.C),labels(batch)loss = entropy_loss(out,labels)# 梯度清零optimizer.zero_grad()# 计算梯度loss.backward()# 修改权值optimizer.step()def test():model.eval()correct = 0for i,data in enumerate(test_loader):# 获得数据和对应的标签inputs,labels = data# 获得模型预测结果out = model(inputs)# 获得最大值,以及最大值所在的位置_,predicted = torch.max(out,1)# 预测正确的数量correct += (predicted == labels).sum()print("test acc:{0}".format(correct.item()/len(test_dataset)))correct = 0for i,data in enumerate(train_loader):# 获得数据和对应的标签inputs,labels = data# 获得模型预测结果out = model(inputs)# 获得最大值,以及最大值所在的位置_,predicted = torch.max(out,1)# 预测正确的数量correct += (predicted == labels).sum()print("train acc:{0}".format(correct.item()/len(train_dataset)))
for  epoch in range(5):print("epoch:",epoch)train()test()
torch.save(model.state_dict(),"cat_dog.pth") # 保存模型

加载模型进行预测

import torch
import numpy as np
from PIL import Image
from torchvision import transforms,models
model = models.vgg16(pretrained=True)
# 构建新的全连接层
model.classifier = torch.nn.Sequential(torch.nn.Linear(25088,100),torch.nn.ReLU(),torch.nn.Dropout(p=0.5),torch.nn.Linear(100,2))
model.load_state_dict(torch.load("cat_dog.pth")) # 加载模型
model.eval() # 预测模式

label = np.array(["cat","dog"])
# 数据预处理
transform = transforms.Compose([transforms.Resize(224),transforms.ToTensor()
])
# 预测函数
def predict(image_path):# 打开图片img = Image.open(image_path)# 数据处理,增加一个维度img = transform(img).unsqueeze(0)# 预测得到的结果outputs = model(img)# 获得最大值所在位置_,predicted  = torch.max(outputs,1)# 转换为类别名称print(label[predicted.item()])
predict("image/test/cat/cat.1490.jpg")

PyTorch基础-猫狗分类实战-10相关推荐

  1. Pytorch+CNN+猫狗分类实战

    文章目录 0.前言 1.猫狗分类数据集 1.1数据集下载(可选部分) 1.2数据集分析 2.猫狗分类数据集预处理 2.1训练集和测试集划分 2.2训练集和测试集读取 3.剩余代码 4.总结 0.前言 ...

  2. 基于Pytorch的猫狗分类

    无偿分享~ 猫狗二分类文件下载地址 在下一章说        猫狗分类这个真是困扰我好几天,找了好多资料都是以TensorFlow的猫狗分类,但我们要求的是以pytorch的猫狗分类.刚开始我找到了也 ...

  3. 基于Pytorch实现猫狗分类

    基于Pytorch实现猫狗分类 一.环境配置 二.数据集准备 三.猫狗分类的实例 四.实现分类预测测试 五.参考资料 一.环境配置 1.环境使用 Anaconda 2.配置Pytorch pip in ...

  4. 【学习笔记】pytorch迁移学习-猫狗分类实战

    1.迁移学习入门 什么是迁移学习:在深度神经网络算法的引用过程中,如果我们面对的是数据规模较大的问题,那么在搭建好深度神经网络模型后,我们势必要花费大量的算力和时间去训练模型和优化参数,最后耗费了这么 ...

  5. PyTorch深度学习实战 | 猫狗分类

    本文内容使用TensorFlow和Keras建立一个猫狗图片分类器. 图1 猫狗图片 01.安装TensorFlow和Keras库 TensorFlow是一个采用数据流图(data flow grap ...

  6. 【深度学习】ResNet残差网络 ResidualBlock残差块实现(pytorch) | 跟着李沐学AI笔记 | ResNet18进行猫狗分类

    文章目录 前言 一.卷积的相关计算公式(复习) 二.残差块ResidualBlock复现(pytorch) 三.残差网络ResNet18复现(pytorch) 四.直接调用方法 五.具体实践(ResN ...

  7. AlexNet 实现猫狗分类(keras and pytorch)

    AlexNet 实现猫狗分类 前言 在训练网络过程中遇到了很多问题,先在这里抱怨一下,没有硬件条件去使用庞大的ImageNet2012 数据集 .所以在选择合适的数据集上走了些弯路,最后选择有kagg ...

  8. Kaggle深度学习与卷积神经网络项目实战-猫狗分类检测数据集

    Kaggle深度学习与卷积神经网络项目实战-猫狗分类检测数据集 一.相关介绍 二.下载数据集 三.代码示例 1.导入keras库,并显示版本号 2.构建网络 3.数据预处理 4.使用数据增强 四.使用 ...

  9. Java软件研发工程师转行之深度学习(Deep Learning)进阶:手写数字识别+人脸识别+图像中物体分类+视频分类+图像与文字特征+猫狗分类

    本文适合于对机器学习和数据挖掘有所了解,想深入研究深度学习的读者 1.对概率基本概率有所了解 2.具有微积分和线性代数的基本知识 3.有一定的编程基础(Python) Java软件研发工程师转行之深度 ...

最新文章

  1. PostgreSQL · 实现分析 · PostgreSQL 10.0 并行查询和外部表的结合
  2. windows mobile 鼠标等待
  3. PHP中的常见魔术方法功能作用及用法实例
  4. oracle 10G windows启动与关闭另类方法
  5. SIFT讲解(SIFT的特征点选取以及描述是重点)
  6. 机器学习笔记(八)——决策树模型的特征选择
  7. 联邦快递就华为包裹被转运致歉 称有关货件正退回发货方
  8. python 数据结构与算法 day04 快速排序
  9. 拓扑排序---AOV图
  10. 5G简介【华为ICT学堂】笔记
  11. Jenkins build之后清理workspace
  12. mysql建表测试_测试必备mysql技能2:mysql建表
  13. 动态表格的实现(layui动态表格实现)
  14. centos yum清华镜像
  15. python3字符串格式化
  16. 快捷下载中国原创音乐基地音乐(包括金豆和无法下载音乐)
  17. Xv6学习之kinit1
  18. Hdu 5873 2016 ACM/ICPC Asia Regional Dalian Online 1006(兰道定理)
  19. java情剑天涯,求超低内存的手机游戏,越多越好
  20. 支付宝支付-常用支付API详解(查询、退款、提现等)(转)

热门文章

  1. Yii的各种query
  2. visual studio学习python_python3从零学习-开发环境搭建之Visual Studio Code篇
  3. java 语法 冒号_java中生僻的冒号跳转语法
  4. java jlable添加gif,Java动画GIF而不使用JLabel
  5. 文件服务器定时开关机,如何配置作服务器定时开关机.ppt
  6. mysql的代码大全_MySql数目字函数大全
  7. chroot环境怎么重启linux,linux下简易chroot环境的塔建
  8. windows cmd 如果失败了则暂停
  9. vuex中store 的mutation
  10. mysql 隔离级别 快照_「数据库架构」三分钟搞懂事务隔离级别和脏读