此次是学校老师布置的作业,结果是个大坑,飞桨的数据集是zip的的,数据集要自己搞,而且数据集的加载要飞桨的框架,并且这个数据集的训练集和验证集的加载还不一样。训练集的label在每张图片名上,而验证集的label在专门的csv下。

更正一下:飞桨这个项目可以用

不过要加上这几个

iChange-PM数据集kaggle地址:https://www.kaggle.com/datasets/alexiuschao/palmpm

这里验证集加载参考了Pytorch创建自己的数据集(一)_生活所迫^_^的博客-CSDN博客_pytorch数据集制作x

训练集的加载参考了不同标签和数据类型匹配的数据集在PyTorch的加载(超详细保姆级别教学)_Moon_Boy_Li的博客-CSDN博客_多标签数据集加载

首先定义transform 。后面训练集和验证集的加载都会用到它

import numpy as np
from torchvision import transforms as T
import matplotlib.pyplot as plt
transform = T.Compose([T.Resize(224),  # 缩放图片,保持长宽比不变,最短边为32像素T.CenterCrop(224),  # 从图片中间开始切出224*224的图片T.ToTensor(),  # 将图片(Image)转成Tensor,归一化至[0,1]T.Normalize(mean=[0.492, 0.461, 0.417], std=[0.256, 0.248, 0.251])  # 正则化操作,标准化至[-1,1],规定均值和标准差]
)

再定义一个图像显示函数,方便后边调试使用

# 定义一个显示图像的函数
def imshow(img):img = img / 2 + 0.5 #unnormalizenpimg = img.numpy()plt.imshow(np.transpose(npimg, (1,2,0)))plt.show()

训练集加载

import os
import torch
import torch.nn as nn
from torchvision import transformsfrom torch.utils.data import  DataLoader,Dataset
from torchvision import transforms, utils, datasets
import pandas as pd
import numpy as np
from PIL import Image
from torchvision import transforms, utils, datasets
import torchvision
import matplotlib.pyplot as plt# 通过继承Dataset类来进行数据加载
class Train_Dataset(Dataset): # 继承Datasetdef __init__(self, path_dir, transform=None):  # 初始化一些属性,获取数据集所在路径的数据列表self.path_dir = path_dir  # 文件路径self.transform = transform  # 对象进行数据处理self.images = os.listdir(self.path_dir)  # 把路径下的所有文件放在一个列表里;即在self.images这个张量中存储path_dir路径的所有文件的名称和后缀名def __len__(self): # 返回整个数据集的大小return len(self.images)def __getitem__(self, index):  # 根据索引index返回图像及标签,索引是根据文件夹内的文件顺序进行排列,从0开始递增image_index = self.images[index]  # 根据索引获取图像文件名称img_path = os.path.join(self.path_dir, image_index)  # 获取index在确定数值下图片的路径或者目录#print("img_path:"+img_path+"\n")#../input/palmpm/PALM-Training400/P0072.jpgimg = Image.open(img_path).convert('RGB')  # 读取图像#plt.imshow(img)#plt.show()# 根据目录名称获取图像标签   H高度近视 为0  P病理疾病  为1   N正常为0label = img_path.split('/')[-1].split('.')[0]  # 绝对路径后加\\, '\\'的后一位, '.'的前一位就是标签,如H0001.jpg, 标签就是catlabel_token=labellabel = 1 if 'P' in list(label)[0] else 0#print(label)if self.transform is not None:img = self.transform(img)#print(img.shape)return img, labelpath_dir = "../input/palmpm/PALM-Training400"
images = os.listdir(path_dir)
#print(images)
len(images) # 读取数据集长度
print(len(images))
#实例化对象
Train_dataset = Train_Dataset(path_dir,transform=transform)
#将数据集导入DataLoader,进行shuffle以及选取batch_size
Traindata_loader = DataLoader(Train_dataset,batch_size=4,shuffle=None,num_workers=0)

训练集的随机测试

# 随机获取部分训练数据
Train_dataiter = iter(Traindata_loader)
images, labels = Train_dataiter.next()
# 显示图像
imshow(torchvision.utils.make_grid(images))
# 打印标签
print(''.join('%s' % ["病变" if labels[m].item()==1 else "正常" for m in range(4)])) 

验证集加载

import torch
import torchvision
from torchvision import transforms
from PIL import Image
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from PIL import Image
import pandas as pd
#路径是自己电脑里所对应的路径
datapath = r'../input/palmpm/PALM-Validation400'
txtpath = r'../input/palmpm/PALM-Validation-GT/labels.csv'class Vaild_Dataset(Dataset):def __init__(self,txtpath, transform=None):#创建一个list用来储存图片和标签信息imgs = []#打开第一步创建的txt文件,按行读取,将结果以元组方式保存在imgs里csvfile = open(txtpath,encoding='utf-8')df = pd.read_csv(csvfile,engine='python')#print(len(datainfo))有401行#print(df)#print(df['imgName'])for i in range(len(df)-1):imgs.append((str(df["imgName"][i]),df["Label"][i]))#print(df["Label"][i],df["ID"][i],type(df["Label"][i]))#print(imgs)self.imgs = imgsself.transform = transform#返回数据集大小def __len__(self):return len(self.imgs)#打开index对应图片进行预处理后return回处理后的图片和标签def __getitem__(self, index):pic,label = self.imgs[index]pic = Image.open(datapath+'/'+pic)#pic = transforms.RandomResizedCrop(224)(pic)#pic = transforms.ToTensor()(pic)if self.transform is not None:pic = self.transform(pic)        return pic,label
#实例化对象
Valid_data = Vaild_Dataset(txtpath,transform=transform)
#将数据集导入DataLoader,进行shuffle以及选取batch_size
Valid_data_loader = DataLoader(Valid_data,batch_size=4,shuffle=True,num_workers=0)
#Windows里num_works只能为0,其他值会报错

验证集的测试

# 随机获取部分训练数据
Valid_dataiter = iter(Valid_data_loader)
images, labels = Valid_dataiter.next()
# 显示图像
# 显示图像
imshow(torchvision.utils.make_grid(images))
# 打印标签
print(''.join('%s' % ["病变" if labels[m].item()==1 else "正常" for m in range(4)])) 

现在开始构建Lenet

from keras.models import Sequential
from keras.layers.core import Dense
import tensorflow as tf
from torchvision import transforms
from keras.layers.convolutional_recurrent import ConvLSTM2D
from keras.layers.convolutional import Convolution2D
from keras.layers import LSTM
from keras.layers import Dense, Dropout, Activation
from keras.layers import Convolution1D, MaxPooling1D,MaxPool2D,Flatten,AvgPool2D
np.random.seed(seed=7)
import torch.nn.functional as F
torch.set_default_tensor_type(torch.DoubleTensor)
class LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()#原始图像32*32*3(我用的是224*224)self.conv1 = nn.Conv2d(3, 6, 5)#输出:28x28x6(我的是220*220)self.pool = nn.MaxPool2d(2, 2)#输出:14x14x6(我的是110*110)self.conv2 = nn.Conv2d(6, 16, 5)#输出:10x10x16(我的是106*106)#池化后,输出:5x5x16(我的是53*53)self.fc1 = nn.Linear(16 * 53 * 53, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 2)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, self.num_flat_features(x))x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xdef num_flat_features(self, x):size = x.size()[1:]  # all dimensions except the batch dimensionnum_features = 1for s in size:num_features *= sreturn num_features

from torch import optim
#创建模型,部署gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LeNet().to(device)
#定义优化器
optimizer = optim.Adam(model.parameters(), lr=0.001)

训练函数,这个随便找一个改一改就行

def train_runner(model, device, trainloader, optimizer, epoch):#训练模型, 启用 BatchNormalization 和 Dropout, 将BatchNormalization和Dropout置为Truemodel.train()total = 0correct =0.0#enumerate迭代已加载的数据集,同时获取数据和数据下标for i, data in enumerate(trainloader, 0):inputs, labels = data#把模型部署到device上inputs, labels = inputs.to(device), labels.to(device)#初始化梯度optimizer.zero_grad()#保存训练结果outputs = model(inputs)#计算损失和#多分类情况通常使用cross_entropy(交叉熵损失函数), 而对于二分类问题, 通常使用sigmodloss = F.cross_entropy(outputs, labels)#获取最大概率的预测结果#dim=1表示返回每一行的最大值对应的列下标predict = outputs.argmax(dim=1)total += labels.size(0)correct += (predict == labels).sum().item()#反向传播loss.backward()#更新参数optimizer.step()if i % 100 == 0:#loss.item()表示当前loss的数值print("Train Epoch{} \t Loss: {:.6f}, accuracy: {:.6f}%".format(epoch, loss.item(), 100*(correct/total)))Loss.append(loss.item())Accuracy.append(correct/total)return loss.item(), correct/total
def test_runner(model, device, testloader):#模型验证, 必须要写, 否则只要有输入数据, 即使不训练, 它也会改变权值#因为调用eval()将不启用 BatchNormalization 和 Dropout, BatchNormalization和Dropout置为Falsemodel.eval()#统计模型正确率, 设置初始值correct = 0.0test_loss = 0.0total = 0#torch.no_grad将不会计算梯度, 也不会进行反向传播with torch.no_grad():for data, label in testloader:data, label = data.to(device), label.to(device)output = model(data)test_loss += F.cross_entropy(output, label.long()).item() #此处label续变为label.long() 否则会报错predict = output.argmax(dim=1)#计算正确数量total += label.size(0)correct += (predict == label).sum().item()#计算损失值print("test_avarage_loss: {:.6f}, accuracy: {:.6f}%".format(test_loss/total, 100*(correct/total)))

调用执行

#调用
epoch = 5
Loss = []
Accuracy = []
for epoch in range(1, epoch+1):loss, acc = train_runner(model, device, Traindata_loader, optimizer, epoch)Loss.append(loss)Accuracy.append(acc)test_runner(model, device, Valid_data_loader)print('Finished Training')
plt.subplot(2,1,1)
plt.plot(Loss)
plt.title('Loss')
plt.show()
plt.subplot(2,1,2)
plt.plot(Accuracy)
plt.title('Accuracy')
plt.show()

运行结果

kaggle代码直通车:https://www.kaggle.com/alexiuschao/lenet-ichallenge-pm/edit

pytorch实战:采用Lenet运行iChallenge-PM数据集相关推荐

  1. (!详解 Pytorch实战:①)kaggle猫狗数据集二分类:加载(集成/自定义)数据集

    这系列的文章是我对Pytorch入门之后的一个总结,特别是对数据集生成加载这一块加强学习 另外,这里有一些比较常用的数据集,大家可以进行下载: 需要注意的是,本篇文章使用的PyTorch的版本是v0. ...

  2. 【深度学习】李宏毅2021/2022春深度学习课程笔记 - Auto Encoder 自编码器 + PyTorch实战

    文章目录 一.Basic Idea of Auto Encoder 1.1 Auto Encoder 结构 1.2 Auto Encoder 降维 1.3 Why Auto Encoder 1.4 D ...

  3. Pytorch实战1:LeNet手写数字识别 (MNIST数据集)

    版权说明:此文章为本人原创内容,转载请注明出处,谢谢合作! Pytorch实战1:LeNet手写数字识别 (MNIST数据集) 实验环境: Pytorch 0.4.0 torchvision 0.2. ...

  4. PyTorch:采用sklearn 工具生成这样的合成数据集+利用PyTorch实现简单合成数据集上的线性回归进行数据分析

    PyTorch:采用sklearn 工具生成这样的合成数据集+利用PyTorch实现简单合成数据集上的线性回归进行数据分析 目录 输出结果 核心代码 输出结果 核心代码 #PyTorch:采用skle ...

  5. Pytorch采用AlexNet实现猫狗数据集分类(训练与预测)

    Pytorch采用AlexNet实现猫狗数据集分类(训练与预测) 介绍 AlexNet网络模型 猫狗数据集 AlexNet网络训练 训练全代码 预测 预测图片 介绍 AlexNet模型是CNN网络中经 ...

  6. 【Pytorch实战4】基于CIFAR10数据集训练一个分类器

    参考资料: <深度学习之pytorch实战计算机视觉> Pytorch官方教程 Pytorch中文文档 先是数据的导入与预览. import torch import torchvisio ...

  7. 我用 PyTorch 复现了 LeNet-5 神经网络(自定义数据集篇)!

    大家好,我是红色石头! 在上三篇文章: 这可能是神经网络 LeNet-5 最详细的解释了! 我用 PyTorch 复现了 LeNet-5 神经网络(MNIST 手写数据集篇)! 我用 PyTorch ...

  8. 【深度学习】我用 PyTorch 复现了 LeNet-5 神经网络(自定义数据集篇)!

    在上三篇文章: 这可能是神经网络 LeNet-5 最详细的解释了! 我用 PyTorch 复现了 LeNet-5 神经网络(MNIST 手写数据集篇)! 我用 PyTorch 复现了 LeNet-5 ...

  9. 【Pytorch实战6】一个完整的分类案例:迁移学习分类蚂蚁和蜜蜂(Res18,VGG16)

    参考资料: <深度学习之pytorch实战计算机视觉> Pytorch官方教程 Pytorch官方文档 本文是采用pytorch进行迁移学习的实战演练,实战目的是为了进一步学习和熟悉pyt ...

  10. Pytorch实战第一步--用经典神经网络实现猫狗大战

    文章目录 前言 一.猫狗大战数据集 二.pytorch实战 1.程序整体结构 2.读入数据 3.网络结构 4.网络结构 5测试 总结 总结 前言 随着人工智能的不断发展,机器学习这门技术也越来越重要, ...

最新文章

  1. Oracle truncate、 delete、 drop区别
  2. IMail Server 8.22安装、注册
  3. cas跨域单点登录原理_CAS实现SSO单点登录原理
  4. 按值传递时 php必须复制值,PHP笔试题汇总
  5. JAVA 操作系统已经来到第五个版本了 现陆续放出三个版本 这是第二个版本
  6. mysql获取一个表的数据作为值插入_请问如何在mysql中得到一个即将插入数据表中的那条数据的id值(id自增长)?...
  7. CardLayout布局练习(小的图片浏览器)
  8. python二分法查找程序_查找Python程序的输出| 套装2(基础)
  9. idea显示左边project栏和隐藏project栏的快捷键
  10. 深入浅出python中文版pdf-深入浅出Python 巴里著 中文 PDF版 [37M]
  11. 服务器启动显示fr 01,X3850X5服务器无法开机故障处理-微码升级
  12. NC6 UAP流程平台适配 类 nc.itf.scmpub.reference.uap.pf.PfServiceScmUtil
  13. 生成专题2 | 图像生成评价指标FID
  14. 计算机辅助地理教学的利和弊,浅谈多媒体在高效地理课堂中的利和弊
  15. 【Maya开发基础】全局缩放补偿
  16. centos kvm镜像
  17. linux内核原子操作的实现
  18. Mongodb和ElasticSearch(ES)---未完待续
  19. 香港理工大学智能计算实验室招收进化计算/机器学习/类脑计算方向全奖博士生/研究助理/博士后...
  20. 计算机科学导论实验(一)

热门文章

  1. react源码分析:babel如何解析jsx
  2. linux下对IP地址的转发和端口的伪装----利用iptables部署
  3. 为大众而写的程序员小说——从 简单易懂的现代魔法 说开去
  4. adb通过局域网连接手机
  5. FoxBarcode(一维码生成库)使用教程
  6. 视觉检测系统设计过程中遇到的问题
  7. sony相机二次开发sdK C语言,sdk与开放API协议支持二次开发的摄像头
  8. 360Lib整体介绍
  9. login.defs文件基础
  10. ABTest之最常见的八个错误