cats vs dogs——resnet18

  • 数据
  • Net
  • train+test

数据

这是一个在kaggle上的竞赛,原数据提供了25000张图片,本文所使用的数据集来自其中train的8000张,包括4000张猫和4000张狗

#将数据导入
data_dir = 'E:\\code\\Python\\catanddog\\train'
test_dir = 'E:\\code\\Python\\catanddog\\test'
class Data(data.Dataset):def __init__(self, path, transform = None, train = True, test = False):self.test = testself.train = trainself.transform = transform#imgs = [os.path.join(path, img) for img in os.listdir(path)]imgs = [os.path.join(data_dir, img) for img in path]#imgs存的是每张图片的总路径if self.test:#test模式self.imgs = imgselse:#train模式random.shuffle(imgs)#数据打乱self.imgs = imgsdef __getitem__(self, index):img = self.imgs[index]if self.test:label = 2else:label = 0 if 'cat' in img.split('\\')[-1] else 1#cat = 0, dog = 1image = Image.open(img)image = self.transform(image)return image, labeldef __len__(self):return len(self.imgs)
#train和val的transform处理
#对train进行了随机裁剪翻转的操作
transform_train = transforms.Compose([transforms.Resize((256, 256)), transforms.RandomCrop((224, 224)), transforms.RandomHorizontalFlip(), transforms.ToTensor(),transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.2225))
])transform_val = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_dirs, valid_dirs = train_test_split(os.listdir(data_dir), test_size = 0.2, random_state = 2021)
trainset = Data(train_dirs, transform = transform_train)
valset = Data(valid_dirs, transform = transform_val)
trainloader = torch.utils.data.DataLoader(trainset, batch_size = 20, shuffle = True, num_workers = 0)
valloader = torch.utils.data.DataLoader(valset, batch_size = 20, shuffle = False, num_workers = 0)

Net

model = resnet18(pretrained=True)
model.fc = nn.Linear(512, 2,bias=True)
model = model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)  # 设置训练细节
scheduler = StepLR(optimizer, step_size=3)
criterion = nn.CrossEntropyLoss()

train+test

#训练并保存模型
for epoch in range(1):train(epoch)val(epoch)
torch.save(model, 'catvsdog_model.pth')  # 保存模型
#一次epoch在验证集上准确度有0.988

#对图片进行预测
classes = ['cat', 'dog']
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torch.load('catvsdog_model.pth')  # 加载模型
model = model.to(device)
model.eval()  # 把模型转为test模式img = cv2.imread("E:\\code\\Python\\catanddog\\test\\dog.3886.jpg")  # 读取要预测的图片
cv2.imshow("img", img)
img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
trans = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),])img = trans(img)
img = img.to(device)
img = img.unsqueeze(0)  # 图片扩展多一维,因为输入到保存的模型中是4维的[batch_size,通道,长,宽],而普通图片只有三维,[通道,长,宽]
output = model(img)
prob = F.softmax(output, dim=1)  # prob是2个分类的概率
value, predicted = torch.max(output.data, 1)
pred_class = classes[predicted.item()]
print(pred_class)

猫狗大战——pytorch+resnet18相关推荐

  1. python猫狗大战pytorch_深度学习实战---猫狗大战(pytorch实现)

    数据准备 猫狗大战数据集下载链接 微软的数据集已经分好类,直接使用就行, 数据划分 我们将猫和狗的图片分别移动到训练集和验证集中,其中90%的数据作为训练集,10%的图片作为验证集,使用shutil. ...

  2. 深度学习实战---猫狗大战(pytorch实现)

    数据准备 猫狗大战数据集下载链接 微软的数据集已经分好类,直接使用就行, 数据划分 我们将猫和狗的图片分别移动到训练集和验证集中,其中90%的数据作为训练集,10%的图片作为验证集,使用shutil. ...

  3. 图片分类 猫狗大战 pytorch VGG

    使用pytorch实现猫狗大战 一.简介 二.理论 三.实现 1️⃣.实现准备 2️⃣.创建VGG16模型 3️⃣.训练模型 4️⃣.在验证集上测试训练的模型 5️⃣.在测试集上运行 四.总结 五.我 ...

  4. 深度学习--猫狗大战pytorch实战

    文章目录 数据准备&处理 模型构建 训练 kaggle上的一个经典项目,拿来做做算是当CNN入门了,做的比较粗糙简单 我把整个项目分成了四块 config用来配置一些参数,Dataset用来构 ...

  5. 猫狗大战pytorch实现

    目录 评估函数,计算 图片多分类的准确率 topK 保存准确率信息 完整代码 评估函数,计算 图片多分类的准确率 topK ## topk的准确率计算 def accuracy(output, lab ...

  6. CNN入门+猫狗大战(Dogs vs. Cats)+PyTorch入门

    一些修改(修改后的代码) 修改原网络的输出方式.原网络采用的交叉熵torch.nn.CrossEntropyLoss()进行Loss计算,而这个函数内部是已经进行了softmax处理的(参考),所以网 ...

  7. Python分类实例之猫狗大战

    目 录 作者介绍 编程实战指南 比赛数据集介绍(Dogs vs cats) 环境配置 模型定义 数据加载 训练和测试 结果展示 参考 作者介绍 周新龙,男,西安工程大学电子信息学院,2019级研究生, ...

  8. cleverhans与foolbox的对比使用(pytorch+python3)

    cleverhans与foolbox的对比使用(pytorch+python3) 一.最新版cleverhans Although CleverHans is likely to work on ma ...

  9. PC端 Rockchip RKNN-Toolkit 连接 Rockchip NPU 设备

    PC端 Rockchip RKNN-Toolkit 连接 Rockchip NPU 设备 flyfish 安装Windows版的Rockchip RKNN-Toolkit 可以使用anaconda简化 ...

最新文章

  1. 《 硬件创业:从产品创意到成熟企业的成功路线图》——导读
  2. Javascript及Jquery获取元素节点以及添加和删除操作
  3. 太阳能板如何串联_还在犹豫用不用太阳能灯?这些分析让你少花钱,更省钱。...
  4. windows 任务管理器,查看进程id,进程标识符pid
  5. 【C/C++语法外功】C/C++头文件一览[轉]
  6. Java继承个人的理解_我对java继承的理解
  7. docker中创建RabbitMQ并在管理端界面打开
  8. poi 顺序解析word_JavaPOI解析word提取数据到excel
  9. 【LOJ】#2887. 「APIO2015」雅加达的摩天楼 Jakarta Skyscrapers
  10. 手把手教你用Scrapy爬取知乎大V粉丝列表
  11. 哪些手机支持双wifi?
  12. Unity中的Time
  13. WireGuard 教程:使用 DNS-SD 进行 NAT-to-NAT 穿透
  14. Android 在一个APP里打开另一个APP
  15. 2020计算机保研实录
  16. 万能的小镇市场能否成为悟空问答的救命良药?
  17. Leetcode_24_Swap Nodes in Pairs
  18. Pytorch学习之神经网络参数管理
  19. led灯亮度渐变实现
  20. 梆梆安全加固企业版分析

热门文章

  1. 质量意识:质量管理发展三阶段
  2. 使用Python进行ADSL宽带拨号连接等操作
  3. CTFHub信息泄露
  4. 反转字符串、反转字符串中的元音字母、两个数组的交集,springboot工作原理面试
  5. Parallels (干货在最底部)
  6. 电商项目——商品服务-API-属性分组——第十一章——上篇
  7. Django 序列化和反序列化(九)
  8. 阿里云【7天实践训练营】进阶路线——Day2:阿里云云计算助理工程师认证(ACA)课程1 ~ 2章
  9. 身份证,银行卡丢失后
  10. vue 使用百度地图以及地图样式、绘制 点、点聚合和多边形区域、自定义点样式、聚合样式等