本文转自:使用深度学习进行生存分析

相关资源

原论文地址:here

论文中使用的深度生存分析库:DeepSurv,是基于Theano 和 Lasagne库实现的,支持训练网络模型,预测等功能。

考虑到DeepSurv库中存在着一些错误以及未实现的功能,博主使用目前主流的深度学习框架Tensorflow实现了深度生存分析库:TFDeepSurv。欢迎有兴趣的同学Star和Fork,指出错误,相互交流!

TFDeepSurv简介:基于tensorflow的深度生存分析框架,经过模拟数据和真实数据的测试。支持生存分析数据事件时间出现ties的建模,自定义神经网络结构及参数,可视化训练过程,输入训练数据特征重要性分析,病人生存函数的估计。还有支持使用科学的贝叶斯超参数优化方法来调整网络参数。

博主有空会给出TFDeepSurv各个功能实现参考的源论文!

前言

本文主要的目的为了介绍深度学习是如何运用到生存分析中的,包括其基本原理。然后介绍目前实现了利用深度学习进行生存分析的开源软件包 DeepSurv,它实现了生存分析模型,使用Deep Neural Networks来训练学习参数,并且还实现了风险人群的划分。

还是一样的,强烈建议你去读一下原论文DeepSurv: Personalized Treatment Recommender System Using A Cox Proportional Hazards Deep Neural Network.,相信你会收获很大,至少比看我的好一万倍。写这篇文章的动力有几点,一是不想自己学过的知识什么很快就忘了,感觉记录一下比较重要(博主比较蠢),当作是看论文的笔记吧;二是看完文章之后,觉得我们平时还是要多思考,论文里的思想其实也不是完全原创的,神经网络不是,生存分析cox比例风险模型1972年就有了,但是别人就能洞察到使用深度学习的思想去学习COX模型中需要估计的参数,个人觉得这是一个有科学素养的人才能做到的吧;三是博主在为了PR收集生存分析资料的时候,深感不易,这方面的中文的介绍很少,所以为了方便大家的交流讨论,还是写一下吧。

博主知识水平有限,不吝赐教!欢迎提出错误!

问题来源

假设你已经知道了生存分析主要是在做哪些工作。我们都知道在进行生存分析的时候,有这么几种方法:

  • 参数法:当生存时间符合某一个已知分布时,知道了分布函数,那么剩下的就是求解该分布的参数了。
  • 非参数法:用KM估计去求生存函数,作生存曲线,这里面不涉及任何参数,主要思想就是频率代替概率。
  • 半参数法:也就是使用COX比例风险模型来求生存函数,这个也是本文的重点。

关于COX比例风险模型是怎么提出的,这个是1972年前辈的智慧,本文不打算介绍,这里给出一个链接:hazard-curve,可以帮助你快速了解生存分析和生存函数以及风险曲线的数学定义,然后你就可以去看COX比例风险模型是怎么提出来的了。确保自己懂了COX比例风险模型的原理,可以问自己几个问题:比例两个字是体现在那个地方?为什么风险函数会是h(t)=h0(t)⋅eθ⋅xh(t) = h_0(t)\cdot e^{\theta \cdot x}h(t)=h0​(t)⋅eθ⋅x 这种形式?

COX比例风险模型,直接给出了风险函数的数学表达式(假设你已经学会了懂了其背后的数学原理):
h(t)=h0(t)⋅eθ⋅xh(t) = h_0(t)\cdot e^{\theta \cdot x}h(t)=h0​(t)⋅eθ⋅x
其中,θ=(θ1,…,θm)θ=(θ_1,…,θ_m)θ=(θ1​,…,θm​)是线性模型的系数或未知参数,h0(t)h_0 (t)h0​(t)是基准风险函数。eθxe^{θx}eθx描述了患者观察到回归变量xxx时的死亡风险比例。对∀i∈N,θ_i>0,表示该协变量是危险因素,越大使生存时间越短。 ∀i∈N,θ_i < 0表示该协变量是保护因素,越大使得生存时间越长。

现在需要去求取参数θ\thetaθ,其思想就是偏似然估计法。假定在某死亡时间没有重复事件发生,设t1<t2<⋯<tkt_1<t_2<⋯<t_kt1​<t2​<⋯<tk​ 表示在观察数据中有kkk个不同的死亡事件。设xix_ixi​的观察协变量。设R(ti)R(t_i)R(ti​)时间仍然处于观察研究的个体集合。则风险函数h(t)h(t)h(t)的参数估计可以用以下偏似然概率估计方法:
pl(θ)=∏i=1keθxi∑j∈R(ti)eθxjpl(\theta) = \prod_{i=1}^{k}\frac{e^{\theta x_i}}{\sum_{j \in R(t_i)}e^{\theta x_j}}pl(θ)=i=1∏k​∑j∈R(ti​)​eθxj​eθxi​​

其中qi=eθxi∑j∈R(ti)eθxjq_i = \frac{e^{\theta x_i}}{\sum_{j \in R(t_i)}e^{\theta x_j}}qi​=∑j∈R(ti​)​eθxj​eθxi​​个死亡个体,其死亡条件概率。其实通俗一点的解释就是:我已经观察到时间tit_iti​了,现在有一群人,我可以利用风险公式h(ti)h(t_i)h(ti​) 求出这群人每一个个体的死亡风险,其中有一个人恰好在tit_iti​时刻发生了死亡事件,那么这个人的死亡条件概率就写为:
qi=hi(ti)∑j∈R(ti)hj(ti)=eθxi∑j∈R(ti)eθxjq_i = \frac{h_i(t_i)}{\sum_{j \in R(t_i)}h_j(t_i)} = \frac{e^{\theta x_i}}{\sum_{j \in R(t_i)}e^{\theta x_j}}qi​=∑j∈R(ti​)​hj​(ti​)hi​(ti​)​=∑j∈R(ti​)​eθxj​eθxi​​

现在就是利用偏似然估计的思想,将所有死亡时刻t1,t2,...,tkt_1,t_2,...,t_kt1​,t2​,...,tk​的死亡条件概率相乘,求取是这个乘积最大的参数值θ\thetaθ,把它作为估计量。

注意COX模型给出的前提:假设协变量的总影响可以表示为它们的线性组合。例如,我评价一个人的颜值vvv,你告诉我颜值可以这么计算v=2x1+9x2+1.3x3v=2x_1 + 9x_2+1.3x_3v=2x1​+9x2​+1.3x3​,x1,x2,x3x_1,x_2,x_3x1​,x2​,x3​表示眼睛大小,脸型,鼻子高度(当然,这里是打个比方QAQ)。事实上,很多情况下,协变量的线性组合不能准确衡量它们对某个目标值的影响! 关于这点例子很多(例XOR问题),就不一一介绍了。

问题根源就是在θ⋅x\theta \cdot xθ⋅x,我们把它记为rrr。那么我们可不可以把它表示为非线性组合呢?但是好像它的数学表达式公式不太好给出,无论我们怎么表示rrr,其目标都是使pl(θ)pl(\theta)pl(θ)最小。这个时候,神经网络的作用就显现出来了,它对于表示一组协变量的非线性组合简直太擅长了!假设网络的输入为一组协变量x=(x1,x2,...,xn)x=(x_1,x_2,...,x_n)x=(x1​,x2​,...,xn​),那么网络的输出表示为r^w,b\hat r_{w,b}r^w,b​为神经网络的参数。然后,损失函数就很显而易见了:
L=−log(pl(r^w,b))=−log(∏i=1ker^w,bi∑j∈R(ti)er^w,bj)L = -log(pl(\hat r_{w,b})) = -log(\prod_{i=1}^{k}\frac{e^{\hat r_{w,b}^i}}{\sum_{j \in R(t_i)}e^{\hat r_{w,b}^j}})L=−log(pl(r^w,b​))=−log(i=1∏k​∑j∈R(ti​)​er^w,bj​er^w,bi​​)
将h(x)h(x)h(x)表示为网络的输出,那么上式可以化简为:
L=−∑i=1k[hi(x)−log(∑j∈R(ti)ehj(x))]L = -\sum_{i=1}^{k}[h_i(x)-log(\sum_{j\in R(t_i)}e^{h_j(x)})]L=−i=1∑k​[hi​(x)−log(j∈R(ti​)∑​ehj​(x))]
剩下的工作就交给神经网络去训练样本学习到这样的非线性组合&lt;r^w,b=f(x,θ)&lt;\hat r_{w,b}=f(x,\theta)<r^w,b​=f(x,θ)。其实思路还是很好懂的嘛。

DeepSurv网络框架实现

DeepSurv 的工作就是实现了上面介绍的所有内容(最重要的是损失函数),还实现了一些其他的功能(比如划分风险人群)。下面介绍一下这个框架的实现。

这里是DeepSurv类下面定义的方法:

class DeepSurv:def __init__()# 计算Loss function值def _negative_log_likelihood()# 得到当前网络的loss值同时更新网络参数def _get_loss_updates()# 得到可调用的函数:训练集上,网络进行一次正向和反向传播#                验证集上,一遍正向传播,计算Loss function值def _get_train_valid_fn()# 计算评估指标:C Indexdef get_concordance_index()def _standardize_x()def prepare_data()def train()def to_json()def save_model()def save_weights()def load_weights()# 得到网络的输出值def risk()def predict_risk()# 划分风险人群def recommend_treatment()def plot_risk_surface()

初始化函数:初始化网络结构,并且记录一些参数

def __init__(self, n_in,learning_rate, hidden_layers_sizes = None,lr_decay = 0.0, momentum = 0.9,L2_reg = 0.0, L1_reg = 0.0,activation = "rectify",dropout = None,batch_norm = False,standardize = False,):

按照给定hidden_layers_sizes的搭建指定的网络结构:
输入层:network = lasagne.layers.InputLayer(shape=(None,n_in),input_var = self.X)
隐藏层:network = lasagne.layers.DenseLayer(network, num_units = n_layer, nonlinearity activation_fn, W = W_init)(参数决定该层是否dropout或者BatchNorm)
输出层:network = lasagne.layers.DenseLayer(network, num_units = 1, nonlinearity = lasagne.nonlinearities.linear, W = lasagne.init.GlorotUniform())

训练函数:在给定的训练数据上进行训练,并且在验证集上进行评估

def train(self,train_data, valid_data= None,n_epochs = 500,validation_frequency = 250,patience = 2000, improvement_threshold = 0.99999, patience_increase = 2,logger = None,update_fn = lasagne.updates.nesterov_momentum,verbose = True,**kwargs):

训练函数里的内容就是通用的一套了:

  • 准备好训练数据
  • 每个epoch迭代训练网络
  • 计算Loss,反向传播更新网络参数

具体地,源代码里还有很多细节的地方,自己亲身学习一下还不错啊!

原博客作者另外还写了一篇实战的总结:【论文笔记】Deep Survival: A Deep Cox Proportional Hazards Network ,值得借鉴。

使用深度学习进行生存分析相关推荐

  1. 利用深度学习进行生存分析——DeepSuv模型小结

    生存分析是一种典型的医疗时间事件(time-event)分析场景,其主要分析序列研究中事件(如复发.死亡.治愈等)随着时间变化的统计规律,从而发现其中的敏感/危险因子. 其经典的统计学手段主要有: ( ...

  2. 深度学习Dropout技术分析

    深度学习Dropout技术分析 什么是Dropout? dropout是指在深度学习网络的训练过程中,对于神经网络单元,按照一定的概率将其暂时从网络中丢弃.注意是暂时,对于随机梯度下降来说,由于是随机 ...

  3. 【深度学习】EfficientNetV2分析总结和flops的开源库

    [深度学习]EfficientNetV2分析总结和flops的开源库 1 EfficientNetV1中存在的问题 2 EfficientNetV2中做出的贡献 3 NAS 搜索 4 Efficien ...

  4. 【深度学习】生动分析半监督学习与负相关学习算法

    [深度学习]生动分析半监督学习与负相关学习算法 文章目录 1 半监督学习1.1 定义1.2 半监督深度学习1.3 GAN1.4 应用 2 深度负相关学习算法2.1 负相关2.2 通俗解释 1 半监督学 ...

  5. 深度学习在情感分析中的应用

    然语言情感分析简介 情感分析无处不在,它是一种基于自然语言处理的分类技术.其主要解决的问题是给定一段话,判断这段话是正面的还是负面的.例如在亚马逊网站或者推特网站中,人们会发表评论,谈论某个商品.事件 ...

  6. Densenet论文解读 深度学习领域论文分析博主

    深度学习领域论文分析博主 博客链接: https://my.csdn.net/u014380165 其中一篇文章: DenseNet算法详解: https://blog.csdn.net/u01438 ...

  7. 深度学习笔记——情感分析

    很早之前就想写一篇关于用深度学习做情感分析的文章,一直拖到现在,拖延症啊.... 什么是情感分析? 情感分析(Sentiment analysis)是自然语言处理(NLP)领域的一个任务,又称倾向性分 ...

  8. 【案例实践】Python多元线性回归、机器学习、深度学习在近红外光谱分析中的实践应用

    查看原文>>>基于Python多元线性回归.机器学习.深度学习在近红外光谱分析中的实践应用 [专家]: 郁磊副教授 主要从事MATLAB 编程.机器学习与数据挖掘.数据可视化和软件开 ...

  9. 深度学习 情感分析_使用深度学习进行情感分析

    深度学习 情感分析 介绍 (Introduction) The growth of the internet due to social networks such as Facebook, Twit ...

  10. 深度学习与视频分析简介

    文章大纲 视频分析简介 视频理解 方法与技术 技术优势 重要技术罗列 目标检测 帧差法 使用深度学习进行视频分析 视频分析中的目标检测 架构 系统逻辑架构 典型应用场景 5G 工序检测 工业质量控制 ...

最新文章

  1. 漫画 | 大数据风控从业者的一天
  2. QPS、TPS、RT、并发数、吞吐量理解和性能优化深入思考
  3. minist读取一张图片
  4. Android开发之旅:组件生命周期(二)
  5. C++中函数参数的默认值
  6. 从华为“鸿蒙”备胎看IT项目建设
  7. python commands模块在python3.x被subprocess取代
  8. Swift5 利用元祖 返回多个 类型的函数,取出
  9. 在.Net项目中使用Redis作为缓存服务
  10. 深入理解Java泛型
  11. 解决Scrapy使用pipline保存到数据库后返回None
  12. 人工神经网络_人工神经网络实践
  13. protobuf android ndk,直接在Android NDK端使用tensorflow(不使用JAVA api)
  14. python 读取excel太慢_实用技巧——Python实现从Excel读取数据并绘制成图像
  15. [外挂4] 用CE查找棋盘基址
  16. Allegro导出STP文件
  17. 海湾标准汉字码表查询_JBQGGST5000标准汉字码表
  18. DevExpress WinForm控件入门指南——数据管理控件
  19. 龙芯2k1000-pmon(5)- pmon无法修改环境变量的问题
  20. Olympic Class Ships【奥林匹克级邮轮】

热门文章

  1. GIS数据网站分享(长期更新)
  2. 两台计算机直接相连教程,两台电脑怎么连接一起_如何让两台电脑相连-win7之家...
  3. 一行代码视频下载,so easy!
  4. txt文件内容导入mysql数据库中_将txt文件导入mysql数据库
  5. 计算机无法找到输出设备,电脑没声音找不到输出设备怎么办
  6. java面试(1)如何防止恶意攻击短信验证码接口
  7. hdoj-2567 寻梦
  8. html基本标记练习钱塘湖春行,《钱塘湖春行》练习题
  9. 辽宁计算机专业大学排名及分数线,辽宁一本大学排名及分数线2021
  10. ad域下发策略_AD域修改组策略