使用深度学习进行生存分析
本文转自:使用深度学习进行生存分析
相关资源
原论文地址:here
论文中使用的深度生存分析库:DeepSurv,是基于Theano 和 Lasagne库实现的,支持训练网络模型,预测等功能。
考虑到DeepSurv库中存在着一些错误以及未实现的功能,博主使用目前主流的深度学习框架Tensorflow实现了深度生存分析库:TFDeepSurv。欢迎有兴趣的同学Star和Fork,指出错误,相互交流!
TFDeepSurv简介:基于tensorflow的深度生存分析框架,经过模拟数据和真实数据的测试。支持生存分析数据事件时间出现ties的建模,自定义神经网络结构及参数,可视化训练过程,输入训练数据特征重要性分析,病人生存函数的估计。还有支持使用科学的贝叶斯超参数优化方法来调整网络参数。
博主有空会给出TFDeepSurv各个功能实现参考的源论文!
前言
本文主要的目的为了介绍深度学习是如何运用到生存分析中的,包括其基本原理。然后介绍目前实现了利用深度学习进行生存分析的开源软件包 DeepSurv
,它实现了生存分析模型,使用Deep Neural Networks来训练学习参数,并且还实现了风险人群的划分。
还是一样的,强烈建议你去读一下原论文DeepSurv: Personalized Treatment Recommender System Using A Cox Proportional Hazards Deep Neural Network.,相信你会收获很大,至少比看我的好一万倍。写这篇文章的动力有几点,一是不想自己学过的知识什么很快就忘了,感觉记录一下比较重要(博主比较蠢),当作是看论文的笔记吧;二是看完文章之后,觉得我们平时还是要多思考,论文里的思想其实也不是完全原创的,神经网络不是,生存分析cox比例风险模型1972年就有了,但是别人就能洞察到使用深度学习的思想去学习COX模型中需要估计的参数,个人觉得这是一个有科学素养的人才能做到的吧;三是博主在为了PR收集生存分析资料的时候,深感不易,这方面的中文的介绍很少,所以为了方便大家的交流讨论,还是写一下吧。
博主知识水平有限,不吝赐教!欢迎提出错误!
问题来源
假设你已经知道了生存分析主要是在做哪些工作。我们都知道在进行生存分析的时候,有这么几种方法:
- 参数法:当生存时间符合某一个已知分布时,知道了分布函数,那么剩下的就是求解该分布的参数了。
- 非参数法:用KM估计去求生存函数,作生存曲线,这里面不涉及任何参数,主要思想就是频率代替概率。
- 半参数法:也就是使用COX比例风险模型来求生存函数,这个也是本文的重点。
关于COX比例风险模型是怎么提出的,这个是1972年前辈的智慧,本文不打算介绍,这里给出一个链接:hazard-curve,可以帮助你快速了解生存分析和生存函数以及风险曲线的数学定义,然后你就可以去看COX比例风险模型是怎么提出来的了。确保自己懂了COX比例风险模型的原理,可以问自己几个问题:比例两个字是体现在那个地方?为什么风险函数会是h(t)=h0(t)⋅eθ⋅xh(t) = h_0(t)\cdot e^{\theta \cdot x}h(t)=h0(t)⋅eθ⋅x 这种形式?
COX比例风险模型,直接给出了风险函数的数学表达式(假设你已经学会了懂了其背后的数学原理):
h(t)=h0(t)⋅eθ⋅xh(t) = h_0(t)\cdot e^{\theta \cdot x}h(t)=h0(t)⋅eθ⋅x
其中,θ=(θ1,…,θm)θ=(θ_1,…,θ_m)θ=(θ1,…,θm)是线性模型的系数或未知参数,h0(t)h_0 (t)h0(t)是基准风险函数。eθxe^{θx}eθx描述了患者观察到回归变量xxx时的死亡风险比例。对∀i∈N,θ_i>0,表示该协变量是危险因素,越大使生存时间越短。 ∀i∈N,θ_i < 0表示该协变量是保护因素,越大使得生存时间越长。
现在需要去求取参数θ\thetaθ,其思想就是偏似然估计法。假定在某死亡时间没有重复事件发生,设t1<t2<⋯<tkt_1<t_2<⋯<t_kt1<t2<⋯<tk 表示在观察数据中有kkk个不同的死亡事件。设xix_ixi的观察协变量。设R(ti)R(t_i)R(ti)时间仍然处于观察研究的个体集合。则风险函数h(t)h(t)h(t)的参数估计可以用以下偏似然概率估计方法:
pl(θ)=∏i=1keθxi∑j∈R(ti)eθxjpl(\theta) = \prod_{i=1}^{k}\frac{e^{\theta x_i}}{\sum_{j \in R(t_i)}e^{\theta x_j}}pl(θ)=i=1∏k∑j∈R(ti)eθxjeθxi
其中qi=eθxi∑j∈R(ti)eθxjq_i = \frac{e^{\theta x_i}}{\sum_{j \in R(t_i)}e^{\theta x_j}}qi=∑j∈R(ti)eθxjeθxi个死亡个体,其死亡条件概率。其实通俗一点的解释就是:我已经观察到时间tit_iti了,现在有一群人,我可以利用风险公式h(ti)h(t_i)h(ti) 求出这群人每一个个体的死亡风险,其中有一个人恰好在tit_iti时刻发生了死亡事件,那么这个人的死亡条件概率就写为:
qi=hi(ti)∑j∈R(ti)hj(ti)=eθxi∑j∈R(ti)eθxjq_i = \frac{h_i(t_i)}{\sum_{j \in R(t_i)}h_j(t_i)} = \frac{e^{\theta x_i}}{\sum_{j \in R(t_i)}e^{\theta x_j}}qi=∑j∈R(ti)hj(ti)hi(ti)=∑j∈R(ti)eθxjeθxi
现在就是利用偏似然估计的思想,将所有死亡时刻t1,t2,...,tkt_1,t_2,...,t_kt1,t2,...,tk的死亡条件概率相乘,求取是这个乘积最大的参数值θ\thetaθ,把它作为估计量。
注意COX模型给出的前提:假设协变量的总影响可以表示为它们的线性组合。例如,我评价一个人的颜值vvv,你告诉我颜值可以这么计算v=2x1+9x2+1.3x3v=2x_1 + 9x_2+1.3x_3v=2x1+9x2+1.3x3,x1,x2,x3x_1,x_2,x_3x1,x2,x3表示眼睛大小,脸型,鼻子高度(当然,这里是打个比方QAQ)。事实上,很多情况下,协变量的线性组合不能准确衡量它们对某个目标值的影响! 关于这点例子很多(例XOR问题),就不一一介绍了。
问题根源就是在θ⋅x\theta \cdot xθ⋅x,我们把它记为rrr。那么我们可不可以把它表示为非线性组合呢?但是好像它的数学表达式公式不太好给出,无论我们怎么表示rrr,其目标都是使pl(θ)pl(\theta)pl(θ)最小。这个时候,神经网络的作用就显现出来了,它对于表示一组协变量的非线性组合简直太擅长了!假设网络的输入为一组协变量x=(x1,x2,...,xn)x=(x_1,x_2,...,x_n)x=(x1,x2,...,xn),那么网络的输出表示为r^w,b\hat r_{w,b}r^w,b为神经网络的参数。然后,损失函数就很显而易见了:
L=−log(pl(r^w,b))=−log(∏i=1ker^w,bi∑j∈R(ti)er^w,bj)L = -log(pl(\hat r_{w,b})) = -log(\prod_{i=1}^{k}\frac{e^{\hat r_{w,b}^i}}{\sum_{j \in R(t_i)}e^{\hat r_{w,b}^j}})L=−log(pl(r^w,b))=−log(i=1∏k∑j∈R(ti)er^w,bjer^w,bi)
将h(x)h(x)h(x)表示为网络的输出,那么上式可以化简为:
L=−∑i=1k[hi(x)−log(∑j∈R(ti)ehj(x))]L = -\sum_{i=1}^{k}[h_i(x)-log(\sum_{j\in R(t_i)}e^{h_j(x)})]L=−i=1∑k[hi(x)−log(j∈R(ti)∑ehj(x))]
剩下的工作就交给神经网络去训练样本学习到这样的非线性组合<r^w,b=f(x,θ)<\hat r_{w,b}=f(x,\theta)<r^w,b=f(x,θ)。其实思路还是很好懂的嘛。
DeepSurv网络框架实现
DeepSurv 的工作就是实现了上面介绍的所有内容(最重要的是损失函数),还实现了一些其他的功能(比如划分风险人群)。下面介绍一下这个框架的实现。
这里是DeepSurv类下面定义的方法:
class DeepSurv:def __init__()# 计算Loss function值def _negative_log_likelihood()# 得到当前网络的loss值同时更新网络参数def _get_loss_updates()# 得到可调用的函数:训练集上,网络进行一次正向和反向传播# 验证集上,一遍正向传播,计算Loss function值def _get_train_valid_fn()# 计算评估指标:C Indexdef get_concordance_index()def _standardize_x()def prepare_data()def train()def to_json()def save_model()def save_weights()def load_weights()# 得到网络的输出值def risk()def predict_risk()# 划分风险人群def recommend_treatment()def plot_risk_surface()
初始化函数:初始化网络结构,并且记录一些参数
def __init__(self, n_in,learning_rate, hidden_layers_sizes = None,lr_decay = 0.0, momentum = 0.9,L2_reg = 0.0, L1_reg = 0.0,activation = "rectify",dropout = None,batch_norm = False,standardize = False,):
按照给定hidden_layers_sizes
的搭建指定的网络结构:
输入层:network = lasagne.layers.InputLayer(shape=(None,n_in),input_var = self.X)
隐藏层:network = lasagne.layers.DenseLayer(network, num_units = n_layer, nonlinearity activation_fn, W = W_init)
(参数决定该层是否dropout或者BatchNorm)
输出层:network = lasagne.layers.DenseLayer(network, num_units = 1, nonlinearity = lasagne.nonlinearities.linear, W = lasagne.init.GlorotUniform())
训练函数:在给定的训练数据上进行训练,并且在验证集上进行评估
def train(self,train_data, valid_data= None,n_epochs = 500,validation_frequency = 250,patience = 2000, improvement_threshold = 0.99999, patience_increase = 2,logger = None,update_fn = lasagne.updates.nesterov_momentum,verbose = True,**kwargs):
训练函数里的内容就是通用的一套了:
- 准备好训练数据
- 每个epoch迭代训练网络
- 计算Loss,反向传播更新网络参数
具体地,源代码里还有很多细节的地方,自己亲身学习一下还不错啊!
原博客作者另外还写了一篇实战的总结:【论文笔记】Deep Survival: A Deep Cox Proportional Hazards Network ,值得借鉴。
使用深度学习进行生存分析相关推荐
- 利用深度学习进行生存分析——DeepSuv模型小结
生存分析是一种典型的医疗时间事件(time-event)分析场景,其主要分析序列研究中事件(如复发.死亡.治愈等)随着时间变化的统计规律,从而发现其中的敏感/危险因子. 其经典的统计学手段主要有: ( ...
- 深度学习Dropout技术分析
深度学习Dropout技术分析 什么是Dropout? dropout是指在深度学习网络的训练过程中,对于神经网络单元,按照一定的概率将其暂时从网络中丢弃.注意是暂时,对于随机梯度下降来说,由于是随机 ...
- 【深度学习】EfficientNetV2分析总结和flops的开源库
[深度学习]EfficientNetV2分析总结和flops的开源库 1 EfficientNetV1中存在的问题 2 EfficientNetV2中做出的贡献 3 NAS 搜索 4 Efficien ...
- 【深度学习】生动分析半监督学习与负相关学习算法
[深度学习]生动分析半监督学习与负相关学习算法 文章目录 1 半监督学习1.1 定义1.2 半监督深度学习1.3 GAN1.4 应用 2 深度负相关学习算法2.1 负相关2.2 通俗解释 1 半监督学 ...
- 深度学习在情感分析中的应用
然语言情感分析简介 情感分析无处不在,它是一种基于自然语言处理的分类技术.其主要解决的问题是给定一段话,判断这段话是正面的还是负面的.例如在亚马逊网站或者推特网站中,人们会发表评论,谈论某个商品.事件 ...
- Densenet论文解读 深度学习领域论文分析博主
深度学习领域论文分析博主 博客链接: https://my.csdn.net/u014380165 其中一篇文章: DenseNet算法详解: https://blog.csdn.net/u01438 ...
- 深度学习笔记——情感分析
很早之前就想写一篇关于用深度学习做情感分析的文章,一直拖到现在,拖延症啊.... 什么是情感分析? 情感分析(Sentiment analysis)是自然语言处理(NLP)领域的一个任务,又称倾向性分 ...
- 【案例实践】Python多元线性回归、机器学习、深度学习在近红外光谱分析中的实践应用
查看原文>>>基于Python多元线性回归.机器学习.深度学习在近红外光谱分析中的实践应用 [专家]: 郁磊副教授 主要从事MATLAB 编程.机器学习与数据挖掘.数据可视化和软件开 ...
- 深度学习 情感分析_使用深度学习进行情感分析
深度学习 情感分析 介绍 (Introduction) The growth of the internet due to social networks such as Facebook, Twit ...
- 深度学习与视频分析简介
文章大纲 视频分析简介 视频理解 方法与技术 技术优势 重要技术罗列 目标检测 帧差法 使用深度学习进行视频分析 视频分析中的目标检测 架构 系统逻辑架构 典型应用场景 5G 工序检测 工业质量控制 ...
最新文章
- 漫画 | 大数据风控从业者的一天
- QPS、TPS、RT、并发数、吞吐量理解和性能优化深入思考
- minist读取一张图片
- Android开发之旅:组件生命周期(二)
- C++中函数参数的默认值
- 从华为“鸿蒙”备胎看IT项目建设
- python commands模块在python3.x被subprocess取代
- Swift5 利用元祖 返回多个 类型的函数,取出
- 在.Net项目中使用Redis作为缓存服务
- 深入理解Java泛型
- 解决Scrapy使用pipline保存到数据库后返回None
- 人工神经网络_人工神经网络实践
- protobuf android ndk,直接在Android NDK端使用tensorflow(不使用JAVA api)
- python 读取excel太慢_实用技巧——Python实现从Excel读取数据并绘制成图像
- [外挂4] 用CE查找棋盘基址
- Allegro导出STP文件
- 海湾标准汉字码表查询_JBQGGST5000标准汉字码表
- DevExpress WinForm控件入门指南——数据管理控件
- 龙芯2k1000-pmon(5)- pmon无法修改环境变量的问题
- Olympic Class Ships【奥林匹克级邮轮】
热门文章
- GIS数据网站分享(长期更新)
- 两台计算机直接相连教程,两台电脑怎么连接一起_如何让两台电脑相连-win7之家...
- 一行代码视频下载,so easy!
- txt文件内容导入mysql数据库中_将txt文件导入mysql数据库
- 计算机无法找到输出设备,电脑没声音找不到输出设备怎么办
- java面试(1)如何防止恶意攻击短信验证码接口
- hdoj-2567 寻梦
- html基本标记练习钱塘湖春行,《钱塘湖春行》练习题
- 辽宁计算机专业大学排名及分数线,辽宁一本大学排名及分数线2021
- ad域下发策略_AD域修改组策略