糖尿病视网膜病变检测 (Diabetic Retinopathy Detection)
文章目录
- 前言
- 一、任务目标
- 二、数据处理
- 1.数据分析
- 2.模型训练
- 2.1模型准备
- 2.2参数设置
- 2.3读取数据并转换为tensor类型
- 2.4开始训练
- 2.5测试模型准确率
- 三、未完成的问题
前言
我们今年是一个实践就是 糖尿病视网膜病变检测
这个题目是kaggle的一个竞赛原题,Diabetic Retinopathy Detection。
只不过kaggle上是五分类,而我们是四分类。
一、任务目标
这次任务的数据集是1000的糖网的4个等级的眼底图像,我们需要利用深度学习框架pytorch 来根据眼底图像预测其分类。
二、数据处理
1.数据分析
通过对数据统计可以得到(已经划分的训练集):
图片种类的分布是有一点不均匀的,同时图片数量也有一点少,所以我们先简单的对图片数据进行一下扩充,这里我们使用最简单的图片反转作为数据增强的方式。我们对一类的图片进行的左右翻转和上下翻转,扩充为原来的三倍。对二类和三类的图片我们做了上下翻转,扩充到原来的二倍。对零类图片不做任何处理。
这是处理后的训练集分布。
我是7:3分割的训练集和验证集
2.模型训练
2.1模型准备
使用的模型是torchvision.model里的经典模型和预训练好的参数。
from torchvision import models as models
# inception_v3,ResNet50
model = models.resnet50(pretrained=True)
#将pretrained置为true,意思是使用已经预训练好的参数。
model.fc#打印模型全连接层的输入和输出参数
#Linear(in_features=2048, out_features=1000, bias=True)
因为我们是四分类所以调整模型输出为:
model.fc = torch.nn.Linear(in_features=2048, out_features=4, bias=True)
model.aux_logits = False #这个设置是InceptionV3这个模型需要设置的,
#不知道什么意思,但不设置会报错。
2.2参数设置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device=device)
learning_rate = 1e-4
num_epochs = 10
batch_size = 32
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
loss_criterion = torch.nn.CrossEntropyLoss()
2.3读取数据并转换为tensor类型
这一部分我是定义了一个类mydataset继承父类Dataset来获取可迭代的数据对象,同时对图片的处理和transform转换也在这里面实现。这里就不多说,对dataset不懂得可以看我之前写的dataset类。直接贴代码。
my_transform = transforms.Compose([transforms.Resize((299,299)),transforms.ToTensor(),
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
#这里进行transform是因为inception_v3模型的输入是(299*299)
#resnet就不需要是(299*299)了
class retinaDataset(Dataset):def __init__(self, imagepath=r"D:\course\junior_2\deep_learning\third\train", csv_path=" ",transform=my_transform):self.df = pd.read_csv(csv_path)# if (total is not None):# self.df = self.df[:total]self.transform = transformself.imagepath = imagepathdef __len__(self):return len(self.df)def __getitem__(self, index):img_path = os.path.join(self.imagepath, self.df.iloc[index].image +".png")img = Image.open(img_path)if(self.transform):img = self.transform(img)return img, torch.tensor(self.df.iloc[index].Retinopathy_grade)
train_dataset = retinaDataset(csv_path=r"D:\course\junior_2\deep_learning\mythird\train.csv")train_dataloader = DataLoader(dataset=train_dataset,
batch_size=batch_size, shuffle=True)
#这里调用Dataloder函数对数据进行分组并打乱顺序。
2.4开始训练
for epoch in range(num_epochs):for data, target in tqdm(train_dataloader):data = data.to(device=device)target = target.to(device=device)score = model(data)optimizer.zero_grad()loss = loss_criterion(score, target)loss.backward()optimizer.step()print(f"for epoch {epoch}, loss : {loss}")
2.5测试模型准确率
def f_check_accuracy(model_i,model_r, loader):model_i.eval() #模型inception_v3model_r.eval() #模型resnet50num0=0num1=0num2=0num3=0total0=0total1=0total2=0total3=0correct_output = 0total_output = 0with torch.no_grad(): #反向传播时不再自动求导,节省显存。for x, y in tqdm(loader):x = x.to(device=device)y = y.to(device=device)score_i = model_i(x)score_r = model_r(x)_,predictions_i = score_i.max(1)_,predictions_r = score_r.max(1)for i in range (len(y)):if(y[i]==0):total0=total0+1if(predictions_i[i]==0):num0=num0+1elif(y[i]==1):total1=total1+1if(predictions_r[i]==1):num1=num1+1elif(y[i]==2):total2=total2+1if(predictions_r[i]==2):num2=num2+1 elif(y[i]==3):total3=total3+1if(predictions_i[i]==3):num3=num3+1 correct_output =num0+num1+num2+num3total_output =total0+total1+total2+total3# model.train()print("0类准确率",num0/total0,"correct:",num0,"total:",total0)print("1类准确率",num1/total1,"correct:",num1,"total:",total1)print("2类准确率",num2/total2,"correct:",num2,"total:",total2)print("3类准确率",num3/total3,"correct:",num3,"total:",total3)print(f"out of {total_output} , total correct: {correct_output} with an accuracy of {float(correct_output/total_output)*100}")
解释一下我为什么要把inception_v3和resnet50结合到一起。
我们可以看到v3和resnet50再不同种类的准确率不同。所以可以把这两个模型结合在一起来提高准确率。
可以看到准确率提升的效果很好,大约20个点左右。
我们还可以从kaggle中下载一些数据来补充训练集,因为1000图片在划分之后对于四分类的任务来说是不够的。
三、未完成的问题
这行代码的作用。
model.aux_logits = False
糖尿病视网膜病变检测 (Diabetic Retinopathy Detection)相关推荐
- 基于逻辑回归(Logistic Regression)的糖尿病视网膜病变(Diabetic Retinopathy)检测
基于逻辑回归的糖尿病视网膜病变检测 说明 数据集 探索性数据分析 方法 结果 代码 说明 这是我学机器学习的一个项目, 基于逻辑回归(Logistic Regression)的糖尿病视网膜病变(Dia ...
- Python基于逻辑回归的糖尿病视网膜病变检测(数据集messidor_features.arff)
一. 引言 本项目基于逻辑回归理论,运用Python语言对数据集messidor_features.arff进行分析,实现对糖尿病视网膜病变的检测.糖尿病视网膜病变(DR)是糖尿病最常见的微血管并发症 ...
- Idx推出AI系统检测糖尿病视网膜病变
文章来源:ATYUN AI平台 AI正在成为几乎所有行业的关键工具,但AI的一个特别强大的应用是医疗保健,人们已经看到了其潜力. 位于爱荷华州的Idx是一家使用AI来检测特定医疗状况的早期迹象的初创公 ...
- 糖尿病视网膜病变预测模型-机器学习-人工智能-企业科研
糖尿病性视网膜病变是糖尿病的一种并发症,由高血糖水平损害眼睛后部(视网膜)引起.如果不加以诊断和治疗,可能会导致失明.任何患有 1 型糖尿病或 2 型糖尿病的人都有可能患上糖尿病性视网膜病变. 然而, ...
- 糖尿病视网膜病变预测模型-机器学习-人工智能
糖尿病性视网膜病变是糖尿病的一种并发症,由高血糖水平损害眼睛后部(视网膜)引起.如果不加以诊断和治疗,可能会导致失明.任何患有 1 型糖尿病或 2 型糖尿病的人都有可能患上糖尿病性视网膜病变. 然而, ...
- 利用EfficientNet-B5从眼底摄影检测糖尿病视网膜病变的严重程度
Abstract 糖尿病性视网膜病变(DR)在许多糖尿病患者中普遍存在.它是糖尿病患者视网膜组织损伤的一种非常重要的疾病.甚至,在极端情况下,它可能会导致长期患DR的患者永久失明.因此,有必要尽快诊断 ...
- kaggle糖尿病视网膜病变失明检测top5解决方案
比赛背景 想象一下,在失明发生之前就能够发现病变.数以百万计的人患有糖尿病性视网膜病变,这是导致老年人失明的主要原因.印度的Aravind眼科医院希望在农村地区的人们中发现并预防这种疾病,而那里的医疗 ...
- 糖尿病视网膜病变的深度学习系统笔记
糖尿病视网膜病变的深度学习系统笔记 论文地址:A deep learning system for detecting diabetic retinopathy across the disease ...
- Eyenuk宣布FDA核准EyeArt自主AI系统用于糖尿病视网膜病变筛查
EyeArt是FDA首次核准用于自主检测轻度以上和威胁视力的糖尿病视网膜病变的AI技术 洛杉矶--(美国商业资讯)--Eyenuk, Inc.是一家全球性人工智能(AI)医疗技术和服务公司,是AI眼病 ...
最新文章
- JVM虚拟机参数配置官方文档
- 从oracle中读取图片,从oracle数据库中读取图片,在jsp?
- 八个最好的开源机器学习框架和库
- window的文件能在linux,在Linux和Window系统中生成任意大小文件
- C# ASP.NET MVC 图片上传的多种方式(存储至服务器文件夹,阿里云oss)
- C语言必须写main函数?最简单的 Hello world 你其实一点都不懂!
- franz ubuntu_重新审视Unix理念,持续测试,Franz,Gitbase,Python,Linux等
- html 中用canvas加载图片,【实例】使用canvas缓缓加载一个图片到web页面中
- CentOS 与 Ubuntu:哪个更适合做服务器?
- 你可能不知道的10条SQL技巧,涨知识了!
- d3js fill与class优先级
- 《HBase权威指南》读书笔记(二)
- TRANSCAD基础技巧——质心连杆生成不了?
- ”记录集为只读“怎么解决?请高手帮忙看看。感激不尽……
- Unity中Obi绳子设置
- Google卫星地图定位(Resources)
- 响铃:只做“连接器”,企业微信如何实现“人即服务”
- 【项目实战】C/C++轻松实现4399小游戏:围住神经猫
- 斐波那契数列(前30)Python
- 漏洞扫描的应用范围和场景