(CNS复现)CLAM——Chapter_03

CLAM: A Deep-Learning-based Pipeline for Data Efficient and Weakly Supervised Whole-Slide-level Analysis

文章目录

  • (CNS复现)CLAM——Chapter_03
  • 前言
    • Step-01: imports
    • Step-02:初始化配置信息
    • Step-03:调试代码
      • Step-03.01初始化函数
      • Step-03.02 初始化模型
      • Step-03.03 主函数调试
  • 总结

前言

(CNS复现)CLAM——Chapter_00

(CNS复现)CLAM——Chapter_01

(CNS复现)CLAM——Chapter_02

(CNS复现)CLAM——Chapter_03

在上一个章节中讲到一个很重要的点就是:

由于每一个WSI的大小是不一样的,因此patch(也就是特征/通道)的个数也不一样,这就给模型构建提升了很大的难度

这种情况下,解决方法一般有两种:

  1. 构建一个自适应模型,根据不同的输入,生成不同的模型

  2. 构建两个model,第一个model用于进一步的提取特提取,第二个model用于分类

那么显然,第二种方法执行起来更简单

对应到官方的手册,则是使用:extract_features_fp.py 进行特征提取和处理


Step-01: imports

# imports
import torch
import torch.nn as nn
from math import floor
import os
import random
import numpy as np
import pandas as pd
import pdb
import time
from torch.utils.data import DataLoader,Dataset
from torchvision import transforms, utils, modelsfrom PIL import Image
import h5py
import openslide
import warnings
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')# my imports
from models.resnet_custom import resnet50_baseline # 读取imageNet的resnet50网络
from utils.file_utils import save_hdf5 # 用于保存h5文件,上一次使用过# other options
warnings.filterwarnings("ignore")

Step-02:初始化配置信息

主要是根据 args 中给出的信息进行配置

# Feature Extraction
data_h5_dir = '/media/yuansh/14THHD/CLAM/toy_test' # h5 存放地址
data_slide_dir = '/media/yuansh/14THHD/CLAM/DataSet/toy_example' # 元数据地址
slide_ext = '.svs' # 元数据后缀类型
csv_path = '/media/yuansh/14THHD/CLAM/Step_2.csv' # Step2.csv 的地址,这个的生成方法已经在 第0章节中讲解过
feat_dir = '/media/yuansh/14THHD/CLAM/FEATURES_DIRECTORY' # 输出地址
batch_size = 512 # 训练时候的batch
no_auto_skip = False # 自动条过已经处理过的文件
custom_downsample = 1 # 下采样因子(没用)
target_patch_size = -1 # 缩放因子(没用)

Step-03:调试代码

这一部分只调试 compute_w_loader,函数定义如下:

不过这里涉及到了很多的内嵌函数,得一个个讲

Step-03.01初始化函数

# 主函数
def compute_w_loader(file_path, output_path, wsi, model,batch_size=8, verbose=0, print_every=20, pretrained=True,custom_downsample=1, target_patch_size=-1):"""args:file_path: directory of bag (.h5 file)output_path: directory to save computed features (.h5 file)model: pytorch modelbatch_size: batch_size for computing features in batchesverbose: level of feedbackpretrained: use weights pretrained on imagenetcustom_downsample: custom defined downscale factor of image patchestarget_patch_size: custom defined, rescaled image size before embedding"""dataset = Whole_Slide_Bag_FP(file_path=file_path, wsi=wsi, pretrained=pretrained,custom_downsample=custom_downsample, target_patch_size=target_patch_size)x, y = dataset[0]kwargs = {'num_workers': 4,'pin_memory': True} if device.type == "cuda" else {}loader = DataLoader(dataset=dataset, batch_size=batch_size,**kwargs, collate_fn=collate_features)if verbose > 0:print('processing {}: total of {} batches'.format(file_path, len(loader)))mode = 'w'for count, (batch, coords) in enumerate(loader):with torch.no_grad():if count % print_every == 0:print('batch {}/{}, {} files processed'.format(count,len(loader), count * batch_size))batch = batch.to(device, non_blocking=True)mini_bs = coords.shape[0]features = model(batch)features = features.cpu().numpy()asset_dict = {'features': features, 'coords': coords}save_hdf5(output_path, asset_dict, attr_dict=None, mode=mode)mode = 'a'return output_path# 用于读取csv文件中的样本id
class Dataset_All_Bags(Dataset):def __init__(self, csv_path):self.df = pd.read_csv(csv_path)def __len__(self):return len(self.df)def __getitem__(self, idx):return self.df['slide_id'][idx]class Whole_Slide_Bag_FP(Dataset):def __init__(self,file_path,wsi,pretrained=False,custom_transforms=None,custom_downsample=1,target_patch_size=-1):"""Args:file_path (string): Path to the .h5 file containing patched data.pretrained (bool): Use ImageNet transformscustom_transforms (callable, optional): Optional transform to be applied on a samplecustom_downsample (int): Custom defined downscale factor (overruled by target_patch_size)target_patch_size (int): Custom defined image size before embedding"""self.pretrained = pretrainedself.wsi = wsiif not custom_transforms:self.roi_transforms = eval_transforms(pretrained=pretrained)else:self.roi_transforms = custom_transformsself.file_path = file_pathwith h5py.File(self.file_path, "r") as f:print('\n')dset = f['coords']print(dset)print('\n')self.patch_level = f['coords'].attrs['patch_level']self.patch_size = f['coords'].attrs['patch_size']self.length = len(dset)if target_patch_size > 0:self.target_patch_size = (target_patch_size, ) * 2elif custom_downsample > 1:self.target_patch_size = (self.patch_size // custom_downsample, ) * 2else:self.target_patch_size = Noneself.summary()def __len__(self):return self.lengthdef summary(self):hdf5_file = h5py.File(self.file_path, "r")dset = hdf5_file['coords']for name, value in dset.attrs.items():print(name, value)print('\nfeature extraction settings')print('target patch size: ', self.target_patch_size)print('pretrained: ', self.pretrained)print('transformations: ', self.roi_transforms)def __getitem__(self, idx):with h5py.File(self.file_path, 'r') as hdf5_file:coord = hdf5_file['coords'][idx]img = self.wsi.read_region(coord, self.patch_level, (self.patch_size, self.patch_size)).convert('RGB')if self.target_patch_size is not None:img = img.resize(self.target_patch_size)img = self.roi_transforms(img).unsqueeze(0)return img, coord# 用于获取候选特征图
def collate_features(batch):img = torch.cat([item[0] for item in batch], dim = 0)coords = np.vstack([item[1] for item in batch])return [img, coords]# 输出模型结构
def print_network(net):num_params = 0num_params_train = 0print(net)for param in net.parameters():n = param.numel()num_params += nif param.requires_grad:num_params_train += nprint('Total number of parameters: %d' % num_params)print('Total number of trainable parameters: %d' % num_params_train)# image 标准化
def eval_transforms(pretrained=False):if pretrained:mean = (0.485, 0.456, 0.406)std = (0.229, 0.224, 0.225)else:mean = (0.5, 0.5, 0.5)std = (0.5, 0.5, 0.5)trnsfrms_val = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=mean, std=std)])return trnsfrms_val

Step-03.02 初始化模型

文章使用的是 ImageNet 中的标准的resnet50架构

因为对模型架构可视化后的图片太大了,因此没有展示出来

%%capture
print('initializing dataset')
csv_path = csv_path
if csv_path is None:raise NotImplementedError# 这个是继承了dataset的类方法
# 读取csv文件中的数据
bags_dataset = Dataset_All_Bags(csv_path)os.makedirs(feat_dir, exist_ok=True)
os.makedirs(os.path.join(feat_dir, 'pt_files'), exist_ok=True)
os.makedirs(os.path.join(feat_dir, 'h5_files'), exist_ok=True)
dest_files = os.listdir(os.path.join(feat_dir, 'pt_files'))print('loading model checkpoint')
# 调用ImageNet的 resnet50架构
model = resnet50_baseline(pretrained=True)# 可视化模型结构
import hiddenlayer as h
vis_graph = h.build_graph(model, torch.zeros([1 ,3, 256, 256]))   # 获取绘制图像的对象
vis_graph.theme = h.graph.THEMES["blue"].copy()     # 指定主题颜色
vis_graph.save("/home/yuansh/Desktop/demo1.png")   # 保存图像的路径model = model.to(device)
if torch.cuda.device_count() > 1:model = nn.DataParallel(model)model.eval()
total = len(bags_dataset)

Step-03.03 主函数调试

这一部分涉及到一个 for 循环,这一部分的作用是迭代所有的样本,因此,只需要调试其中一个样本即可

在这个步骤中,需要读取两个数据:

  1. 原WSI文件,后缀为 .svs

  2. WSI文件对应的patch文件,后缀为 .h5

bag_candidate_idx = 1
slide_id = bags_dataset[bag_candidate_idx].split(slide_ext)[0]
bag_name = slide_id+'.h5'
h5_file_path = os.path.join(data_h5_dir, 'patches', bag_name)
slide_file_path = os.path.join(data_slide_dir, slide_id+slide_ext)
output_path = os.path.join(feat_dir, 'h5_files', bag_name)

初始化wsi对象

time_start = time.time()
wsi = openslide.open_slide(slide_file_path)

接下来的话需要拆解两个嵌套函数

  1. compute_w_loader

  2. Whole_Slide_Bag_FP

初始化配置参数

file_path: patch 的路径 .h5

output_path: 筛选后的 patch 的路径 .h5

model: 定制化模型(下面流程使用的是resnet)

Custom_downsample:自定义图像补丁的降尺度因子

Target_patch_size:自定义,在嵌入前重新缩放图像大小

file_path = h5_file_path
output_path
wsi
batch_size = 256
verbose = 1
print_every = 20
custom_downsample=1
target_patch_size=-1

构建特在图迭代对象,这个基本上和平时做模型的创建的dataset类一模一样。只是wsi对象要进行特殊的处理而已

# Whole_Slide_Bag_FP
class Whole_Slide_Bag_FP(Dataset):def __init__(self,file_path,wsi,pretrained=False,custom_transforms=None,custom_downsample=1,target_patch_size=-1):# 读取与训练模型self.pretrained = pretrained# 导入wsi 对爱嗯self.wsi = wsi# 使用默认处理模式,就是resnet50使用的标准化方式和将其转化为torch对象if not custom_transforms:self.roi_transforms = eval_transforms(pretrained=pretrained)else:self.roi_transforms = custom_transforms# 读取文件路径self.file_path = file_path# 读 .h5 文件with h5py.File(self.file_path, "r") as f:# 获取patch坐标dset = f['coords']# 一些patch属性self.patch_level = f['coords'].attrs['patch_level']self.patch_size = f['coords'].attrs['patch_size']self.length = len(dset)# 这一部分都是false就不用管了# 不过这里的意思是对每一个patch进行调整下采样和缩放if target_patch_size > 0:self.target_patch_size = (target_patch_size, ) * 2elif custom_downsample > 1:self.target_patch_size = (self.patch_size // custom_downsample, ) * 2else:self.target_patch_size = Noneself.summary()# 这一部分是记录数据大小的 Dataset 类必写def __len__(self):return self.length# summarydef summary(self):hdf5_file = h5py.File(self.file_path, "r")dset = hdf5_file['coords']for name, value in dset.attrs.items():print(name, value)print('\nfeature extraction settings')print('target patch size: ', self.target_patch_size)print('pretrained: ', self.pretrained)print('transformations: ', self.roi_transforms)# 生成迭代器def __getitem__(self, idx):with h5py.File(self.file_path, 'r') as hdf5_file:coord = hdf5_file['coords'][idx]# 读取patch所对应的wsi区域# 第一个参数是patch对应的坐标,下采样水平,patch大小,然后将其转为RGBimg = self.wsi.read_region(coord, self.patch_level, (self.patch_size, self.patch_size)).convert('RGB')if self.target_patch_size is not None:img = img.resize(self.target_patch_size)img = self.roi_transforms(img).unsqueeze(0)return img, coord
hdf5_file = h5py.File(file_path, "r")
coord = hdf5_file['coords'][1]
patch_level = hdf5_file['coords'].attrs['patch_level']
patch_size = hdf5_file['coords'].attrs['patch_size']wsi.read_region(coord, patch_level, (patch_size, patch_size)).convert('RGB')

从上面的结果,我们可以知道,Whole_Slide_Bag_FP 这个函数的作用就是使用之前存下来的patch的坐标信息,在整个wsi图片上进行裁减,最后得到若干张固定大小和固定通道的patch

接着,继续看后续的步骤

# dataset 就是上一步返回的patch的image 以及对应的坐标
x, y = dataset[0]# 读取数据
kwargs = {'num_workers': 12, 'pin_memory': True} if device.type == "cuda" else {}
loader = DataLoader(dataset=dataset, batch_size=batch_size,**kwargs, collate_fn=collate_features)
# 输出进度
if verbose > 0:print('processing {}: total of {} batches'.format(file_path, len(loader)))mode = 'w'
for count, (batch, coords) in enumerate(loader):with torch.no_grad():if count % print_every == 0:print('batch {}/{}, {} files processed'.format(count,len(loader), count * batch_size))batch = batch.to(device, non_blocking=True)mini_bs = coords.shape[0]# 保存模型预测特在features = model(batch)features = features.cpu().numpy()asset_dict = {'features': features, 'coords': coords}save_hdf5(output_path, asset_dict, attr_dict=None, mode=mode)mode = 'a'

‘/media/yuansh/14THHD/CLAM/FEATURES_DIRECTORY/h5_files/C3L-00503-21.h5’
‘/media/yuansh/14THHD/CLAM/FEATURES_DIRECTORY/h5_files/C3L-00503-21.h5’ ‘/media/yuansh/14THHD/CLAM/FEATURES_DIRECTORY/h5_files/C3L-00503-21.h5’ ‘/media/yuansh/14THHD/CLAM/FEATURES_DIRECTORY/h5_files/C3L-00503-21.h5’ ‘/media/yuansh/14THHD/CLAM/FEATURES_DIRECTORY/h5_files/C3L-00503-21.h5’
‘/media/yuansh/14THHD/CLAM/FEATURES_DIRECTORY/h5_files/C3L-00503-21.h5’
‘/media/yuansh/14THHD/CLAM/FEATURES_DIRECTORY/h5_files/C3L-00503-21.h5’

截至到这里,数据预处理也都结束了

不过,其实单纯的看代码也不知道是什么意思。

比如写到这里,我说已经吧每个image的特在图取出来了一共1024

但是这样就很奇怪,他到底如何筛选的特征图的呢?

于是我进一步的尝试以下看一下每一张图的特在图结构如何

for bag_candidate_idx in range(5):slide_id = bags_dataset[bag_candidate_idx].split(slide_ext)[0]bag_name = slide_id+'.h5'h5_file_path = os.path.join(data_h5_dir, 'patches', bag_name)slide_file_path = os.path.join(data_slide_dir, slide_id+slide_ext)print(slide_id)if not no_auto_skip and slide_id+'.pt' in dest_files:continueoutput_path = os.path.join(feat_dir, 'h5_files', bag_name)time_start = time.time()wsi = openslide.open_slide(slide_file_path)output_file_path = compute_w_loader(h5_file_path, output_path, wsi,model=model, batch_size=batch_size, verbose=1, print_every=20,custom_downsample=custom_downsample, target_patch_size=target_patch_size)time_elapsed = time.time() - time_startfile = h5py.File(output_file_path, "r")features = file['features'][:]print('features size: ', features.shape)print('coordinates size: ', file['coords'].shape)features = torch.from_numpy(features)bag_base, _ = os.path.splitext(bag_name)torch.save(features, os.path.join(feat_dir, 'pt_files', bag_base+'.pt'))

C3L-00081-26
features size: (2681, 1024)
coordinates size: (2681, 2)

C3L-00503-21
features size: (1703, 1024)
coordinates size: (1703, 2)

C3L-00503-22
features size: (1755, 1024)
coordinates size: (1755, 2)

C3L-00568-21
features size: (1924, 1024)
coordinates size: (1924, 2)

C3L-00568-22
features size: (1525, 1024)
coordinates size: (1525, 2)

总结

根据上面输出的特征图的大小,可以知道他是将256256的特征图压缩成11024的样子。

意思就是,这一步仅仅只是将图片弄到宽度相等的样子而已

这就很有意思了,它后面到底是如何利用这些压缩后的特征图进行分析的呢?

给大家5秒钟的时间思考:
5
4
3
2
1


文章最后要做的任务是提取出注意力模块,也就是说其实是一种变相语义分割模型,因此文章是对后续的每张图片的每一行单独的训练,不同模型 的训练输出不同。这一点可以根据文章最后所使用的简单的前馈神经网络得到作证:

CLAM_MB((attention_net): Sequential((0): Linear(in_features=1024, out_features=512, bias=True)(1): ReLU()(2): Attn_Net_Gated((attention_a): Sequential((0): Linear(in_features=512, out_features=256, bias=True)(1): Tanh())(attention_b): Sequential((0): Linear(in_features=512, out_features=256, bias=True)(1): Sigmoid())(attention_c): Linear(in_features=256, out_features=2, bias=True)))(classifiers): ModuleList((0): Linear(in_features=512, out_features=1, bias=True)(1): Linear(in_features=512, out_features=1, bias=True))(instance_classifiers): ModuleList((0): Linear(in_features=512, out_features=2, bias=True)(1): Linear(in_features=512, out_features=2, bias=True))(instance_loss_fn): CrossEntropyLoss()
)

因此,这篇文章总体来说还是比较简单的,大头的难点都在数据处理。
不过他的这个idea确实很有意思,值得参考。

那么,这次的 Natrue文章复现就到此为止!

如果我的博客您经过深度思考后仍然看不懂,可以根据以下方式联系我:Best Regards,
Yuan.SH
---------------------------------------
School of Basic Medical Sciences,
Fujian Medical University,
Fuzhou, Fujian, China.
please contact with me via the following ways:
(a) e-mail :yuansh3354@163.com

(CNS复现)CLAM——Chapter_03相关推荐

  1. (CNS复现)CLAM——Chapter_00

    (CNS复现)CLAM--Chapter_00 CLAM: A Deep-Learning-based Pipeline for Data Efficient and Weakly Supervise ...

  2. (CNS复现)CLAM——Chapter_02

    (CNS复现)CLAM--Chapter_02 CLAM: A Deep-Learning-based Pipeline for Data Efficient and Weakly Supervise ...

  3. (CNS复现)CLAM——Chapter_01

    (CNS复现)CLAM--Chapter_01 CLAM: A Deep-Learning-based Pipeline for Data Efficient and Weakly Supervise ...

  4. 记录一下复现CLAM

    说是复现,其实人家代码都有,我只要先跑通,然后看懂代码就行. CLAM是一个帮助切割who slide image的patch和提取特征的工具包,我之后的工作可能会用到这个工具包.并且CLAM的代码比 ...

  5. 纯小白为了实现Camelyon16 数据集的分割和特征提取(基于CLAM的代码和AutoDL服务器)所走的弯路

    此贴纯纯记录一下小半个月来的时间里为了复现 CLAM 走的弯路和心路历程. CLAM 源代码地址:GitHub - mahmoodlab/CLAM: Data-efficient and weakly ...

  6. 国庆特惠 !| CNS图表复现|生信分析|R绘图 资源分享讨论群!

    cover ❝ Q:群里有哪些资源? A:2022.12.31前木舟笔记公众号更新的所有资源.(具体目录详见下方) Q:2022年都快结束了,现在加群不是亏了? A:无论什么时候加群,拿到的资源都是一 ...

  7. FigDraw 14. SCI 文章绘图之和弦图及文章复现(Chord Diagram)

    点击关注,桓峰基因 桓峰基因 生物信息分析,SCI文章撰写及生物信息基础知识学习:R语言学习,perl基础编程,linux系统命令,Python遇见更好的你 128篇原创内容 公众号 桓峰基因公众号推 ...

  8. PHP-JAVA-Python-JavaScript框架介绍CVE-2018-1002015/CNVD-2018-24942/2x-rce/Spring命令执行/CVE-2021_21234漏洞复现

    框架 假如我们要买一台电脑.框架为我们提供了已经装好的电脑,我们只要买回来就能用, 但你必须把整个电脑买回来.这样用户用起来自然轻松许多,但是我们这电脑没有软件其他,会导致很多人用一样的电脑,太死板了 ...

  9. Facebook 发布深度学习工具包 PyTorch Hub,让论文复现变得更容易

    近日,PyTorch 社区发布了一个深度学习工具包 PyTorchHub, 帮助机器学习工作者更快实现重要论文的复现工作.PyTorchHub 由一个预训练模型仓库组成,专门用于提高研究工作的复现性以 ...

最新文章

  1. Tablayout 多个界面使用一个fragment 的实例
  2. ACM入门之【DP】
  3. ps4服务器现正维修中,赶快回家试试!国行PS4终解除锁区附详解教程
  4. 第01课:中文自然语言处理的完整机器处理流程
  5. Oracle入门(十三A2)之单行函数
  6. javabean_企业JavaBean,基础架构预测以及更多行业趋势
  7. Hive 之 常用函数
  8. Boss说:你要是能搞懂这六个分布式技术栈,我给你薪资翻倍
  9. C语言实现舒尔特表格生成器
  10. matlab调用函数画图,matlab画图之pcolor函数
  11. android mmkv使用_[Android]高性能MMKV数据交互分析-MMKV初始化
  12. 扬州工业机器人外壳设计排名_世界十大工业机器人制造商公布,排名第一的竟是……...
  13. Keil MDK5解决error: L6002U: Could not open file …\obj\main.o: No such file or directory
  14. python+django+mysql电影院选座订票系统毕业设计毕设开题报告
  15. 一个IT从业者的课外读物___养生锻炼篇
  16. 跟着弦哥学人工智能2—HAND-CRAFTED RULES实现的人工智能及其缺陷
  17. python爱意满满_抖音ohbaby你就是我最想要的是什么歌 歌曲分享
  18. MXNet对DenseNet(稠密连接网络)的实现
  19. Java随机生成验证码
  20. 一个开源经典的MCU菜单框架设计

热门文章

  1. 2022-2028全球及中国电子商务欺诈预防行业研究及十四五规划分析报告
  2. C#毕业设计——基于C#+asp.net+sqlserver药店进销存管理系统设计与实现(毕业论文+程序源码)——进销存管理系统
  3. 某校2019专硕编程题-逆序输出奇数
  4. 计算机图形学-扫描转换直线段-直线方程法-DDA算法-中点算法-OPENGL实现-详解
  5. java例题:判断所输入的年月日是这一年的第多少天
  6. Nature子刊重磅综述:人脑功能的因果映射
  7. Docker 生产环境之使用可信镜像 - 在内容信任(content trust)沙盒中演示
  8. REFERENCE MADE TO UNRESOLVED EXTERNAL.
  9. 基于Arduino 开发 MAX30102 LM35 SSD1306 观察血氧、心率和温度血氧仪
  10. 【BW16 应用篇】安信可BW16模组与开发板更新固件烧录说明