文章目录

  • 前言
  • 一、任务目标
  • 二、数据处理
    • 1.数据分析
    • 2.模型训练
      • 2.1模型准备
      • 2.2参数设置
      • 2.3读取数据并转换为tensor类型
      • 2.4开始训练
      • 2.5测试模型准确率
  • 三、未完成的问题

前言

我们今年是一个实践就是 糖尿病视网膜病变检测
这个题目是kaggle的一个竞赛原题,Diabetic Retinopathy Detection。
只不过kaggle上是五分类,而我们是四分类。


一、任务目标

这次任务的数据集是1000的糖网的4个等级的眼底图像,我们需要利用深度学习框架pytorch 来根据眼底图像预测其分类。

二、数据处理

1.数据分析

通过对数据统计可以得到(已经划分的训练集):

图片种类的分布是有一点不均匀的,同时图片数量也有一点少,所以我们先简单的对图片数据进行一下扩充,这里我们使用最简单的图片反转作为数据增强的方式。我们对一类的图片进行的左右翻转和上下翻转,扩充为原来的三倍。对二类和三类的图片我们做了上下翻转,扩充到原来的二倍。对零类图片不做任何处理。

这是处理后的训练集分布。
我是7:3分割的训练集和验证集

2.模型训练


2.1模型准备

使用的模型是torchvision.model里的经典模型和预训练好的参数。

from torchvision import models as models
# inception_v3,ResNet50
model = models.resnet50(pretrained=True)
#将pretrained置为true,意思是使用已经预训练好的参数。
model.fc#打印模型全连接层的输入和输出参数
#Linear(in_features=2048, out_features=1000, bias=True)

因为我们是四分类所以调整模型输出为:

model.fc = torch.nn.Linear(in_features=2048, out_features=4, bias=True)
model.aux_logits = False #这个设置是InceptionV3这个模型需要设置的,
#不知道什么意思,但不设置会报错。

2.2参数设置

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device=device)
learning_rate = 1e-4
num_epochs = 10
batch_size = 32
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
loss_criterion = torch.nn.CrossEntropyLoss()

2.3读取数据并转换为tensor类型

这一部分我是定义了一个类mydataset继承父类Dataset来获取可迭代的数据对象,同时对图片的处理和transform转换也在这里面实现。这里就不多说,对dataset不懂得可以看我之前写的dataset类。直接贴代码。

my_transform = transforms.Compose([transforms.Resize((299,299)),transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
#这里进行transform是因为inception_v3模型的输入是(299*299)
#resnet就不需要是(299*299)了
class retinaDataset(Dataset):def __init__(self, imagepath=r"D:\course\junior_2\deep_learning\third\train", csv_path=" ",transform=my_transform):self.df = pd.read_csv(csv_path)# if (total is not None):#     self.df = self.df[:total]self.transform = transformself.imagepath = imagepathdef __len__(self):return len(self.df)def __getitem__(self, index):img_path = os.path.join(self.imagepath, self.df.iloc[index].image +".png")img = Image.open(img_path)if(self.transform):img = self.transform(img)return img, torch.tensor(self.df.iloc[index].Retinopathy_grade)
train_dataset = retinaDataset(csv_path=r"D:\course\junior_2\deep_learning\mythird\train.csv")train_dataloader = DataLoader(dataset=train_dataset,
batch_size=batch_size, shuffle=True)
#这里调用Dataloder函数对数据进行分组并打乱顺序。

2.4开始训练

for epoch in range(num_epochs):for data, target in tqdm(train_dataloader):data = data.to(device=device)target = target.to(device=device)score = model(data)optimizer.zero_grad()loss = loss_criterion(score, target)loss.backward()optimizer.step()print(f"for epoch {epoch}, loss : {loss}")

2.5测试模型准确率

def f_check_accuracy(model_i,model_r, loader):model_i.eval() #模型inception_v3model_r.eval() #模型resnet50num0=0num1=0num2=0num3=0total0=0total1=0total2=0total3=0correct_output = 0total_output = 0with torch.no_grad(): #反向传播时不再自动求导,节省显存。for x, y in tqdm(loader):x = x.to(device=device)y = y.to(device=device)score_i = model_i(x)score_r = model_r(x)_,predictions_i = score_i.max(1)_,predictions_r = score_r.max(1)for i in range (len(y)):if(y[i]==0):total0=total0+1if(predictions_i[i]==0):num0=num0+1elif(y[i]==1):total1=total1+1if(predictions_r[i]==1):num1=num1+1elif(y[i]==2):total2=total2+1if(predictions_r[i]==2):num2=num2+1   elif(y[i]==3):total3=total3+1if(predictions_i[i]==3):num3=num3+1 correct_output =num0+num1+num2+num3total_output =total0+total1+total2+total3# model.train()print("0类准确率",num0/total0,"correct:",num0,"total:",total0)print("1类准确率",num1/total1,"correct:",num1,"total:",total1)print("2类准确率",num2/total2,"correct:",num2,"total:",total2)print("3类准确率",num3/total3,"correct:",num3,"total:",total3)print(f"out of {total_output} , total correct: {correct_output} with an accuracy of {float(correct_output/total_output)*100}")

解释一下我为什么要把inception_v3和resnet50结合到一起。

我们可以看到v3和resnet50再不同种类的准确率不同。所以可以把这两个模型结合在一起来提高准确率。

可以看到准确率提升的效果很好,大约20个点左右。
我们还可以从kaggle中下载一些数据来补充训练集,因为1000图片在划分之后对于四分类的任务来说是不够的。

三、未完成的问题

这行代码的作用。

model.aux_logits = False

糖尿病视网膜病变检测 (Diabetic Retinopathy Detection)相关推荐

  1. 基于逻辑回归(Logistic Regression)的糖尿病视网膜病变(Diabetic Retinopathy)检测

    基于逻辑回归的糖尿病视网膜病变检测 说明 数据集 探索性数据分析 方法 结果 代码 说明 这是我学机器学习的一个项目, 基于逻辑回归(Logistic Regression)的糖尿病视网膜病变(Dia ...

  2. Python基于逻辑回归的糖尿病视网膜病变检测(数据集messidor_features.arff)

    一. 引言 本项目基于逻辑回归理论,运用Python语言对数据集messidor_features.arff进行分析,实现对糖尿病视网膜病变的检测.糖尿病视网膜病变(DR)是糖尿病最常见的微血管并发症 ...

  3. Idx推出AI系统检测糖尿病视网膜病变

    文章来源:ATYUN AI平台 AI正在成为几乎所有行业的关键工具,但AI的一个特别强大的应用是医疗保健,人们已经看到了其潜力. 位于爱荷华州的Idx是一家使用AI来检测特定医疗状况的早期迹象的初创公 ...

  4. 糖尿病视网膜病变预测模型-机器学习-人工智能-企业科研

    糖尿病性视网膜病变是糖尿病的一种并发症,由高血糖水平损害眼睛后部(视网膜)引起.如果不加以诊断和治疗,可能会导致失明.任何患有 1 型糖尿病或 2 型糖尿病的人都有可能患上糖尿病性视网膜病变. 然而, ...

  5. 糖尿病视网膜病变预测模型-机器学习-人工智能

    糖尿病性视网膜病变是糖尿病的一种并发症,由高血糖水平损害眼睛后部(视网膜)引起.如果不加以诊断和治疗,可能会导致失明.任何患有 1 型糖尿病或 2 型糖尿病的人都有可能患上糖尿病性视网膜病变. 然而, ...

  6. 利用EfficientNet-B5从眼底摄影检测糖尿病视网膜病变的严重程度

    Abstract 糖尿病性视网膜病变(DR)在许多糖尿病患者中普遍存在.它是糖尿病患者视网膜组织损伤的一种非常重要的疾病.甚至,在极端情况下,它可能会导致长期患DR的患者永久失明.因此,有必要尽快诊断 ...

  7. kaggle糖尿病视网膜病变失明检测top5解决方案

    比赛背景 想象一下,在失明发生之前就能够发现病变.数以百万计的人患有糖尿病性视网膜病变,这是导致老年人失明的主要原因.印度的Aravind眼科医院希望在农村地区的人们中发现并预防这种疾病,而那里的医疗 ...

  8. 糖尿病视网膜病变的深度学习系统笔记

    糖尿病视网膜病变的深度学习系统笔记 论文地址:A deep learning system for detecting diabetic retinopathy across the disease ...

  9. Eyenuk宣布FDA核准EyeArt自主AI系统用于糖尿病视网膜病变筛查

    EyeArt是FDA首次核准用于自主检测轻度以上和威胁视力的糖尿病视网膜病变的AI技术 洛杉矶--(美国商业资讯)--Eyenuk, Inc.是一家全球性人工智能(AI)医疗技术和服务公司,是AI眼病 ...

最新文章

  1. JVM虚拟机参数配置官方文档
  2. 从oracle中读取图片,从oracle数据库中读取图片,在jsp?
  3. 八个最好的开源机器学习框架和库
  4. window的文件能在linux,在Linux和Window系统中生成任意大小文件
  5. C# ASP.NET MVC 图片上传的多种方式(存储至服务器文件夹,阿里云oss)
  6. C语言必须写main函数?最简单的 Hello world 你其实一点都不懂!
  7. franz ubuntu_重新审视Unix理念,持续测试,Franz,Gitbase,Python,Linux等
  8. html 中用canvas加载图片,【实例】使用canvas缓缓加载一个图片到web页面中
  9. CentOS 与 Ubuntu:哪个更适合做服务器?
  10. 你可能不知道的10条SQL技巧,涨知识了!
  11. d3js fill与class优先级
  12. 《HBase权威指南》读书笔记(二)
  13. TRANSCAD基础技巧——质心连杆生成不了?
  14. ”记录集为只读“怎么解决?请高手帮忙看看。感激不尽……
  15. Unity中Obi绳子设置
  16. Google卫星地图定位(Resources)
  17. 响铃:只做“连接器”,企业微信如何实现“人即服务”
  18. 【项目实战】C/C++轻松实现4399小游戏:围住神经猫
  19. 斐波那契数列(前30)Python
  20. 漏洞扫描的应用范围和场景

热门文章

  1. cmw测试ble_如何测试CC2640的BLE射频指标(一)
  2. 1 微信公众平台数据统计功能的作用是什么?
  3. [转]区块链代码快速学习实践
  4. MUR1060AC-ASEMI超快恢复二极管、10A快恢复二极管
  5. opencv 播放mp4
  6. Windows 11家庭版
  7. DAS\NAS\SAN\IPSAN区别
  8. 致2020年的高考:教育改变命运
  9. 微信公众号开启开发者模式
  10. CVBS/AHD 转换 BT656/BT601